From 00fa90c6db0abd24849743fea5530e78d3654642 Mon Sep 17 00:00:00 2001 From: Richard Date: Thu, 30 Nov 2023 06:13:57 +0100 Subject: [PATCH] pest first tests passing --- pest.go | 350 +++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 322 insertions(+), 28 deletions(-) diff --git a/pest.go b/pest.go index 630ce75..e0c66e8 100644 --- a/pest.go +++ b/pest.go @@ -9,7 +9,14 @@ import ( ) const PACKET_HELLO = 0x50 +const PACKET_ERROR = 0x51 const PACKET_SITE_VISIT = 0x58 +const PACKET_DIAL_AUTHORITY = 0x53 +const PACKET_TARGET_POPULATIONS = 0x54 +const PACKET_CREATE_POLICY = 0x55 +const PACKET_DELETE_POLICY = 0x56 +const PACKET_DELETE_POLICY_OK = 0x52 +const PACKET_CREATE_POLICY_OK = 0x57 type PacketReader struct { count uint @@ -47,17 +54,34 @@ func (r *PacketReader) Read(b []byte) (int, error) { return n, err } +func (r *PacketReader) readError() error { + plen, err := readU32(r) //plen + r.plen = uint(plen) + if err != nil { return err } + d, err := readLString(r) + _, err = readU8(r) // checksum + r.started = false + if err != nil { return err } + return fmt.Errorf("Server error: %s", d) +} + func (r *PacketReader) start(ptype byte) error { if r.started { return fmt.Errorf("already started") } + var err error r.plen = 5 r.sum = 0 r.count = 0 r.started = true r.ptype, _ = readU8(r) + if r.ptype == PACKET_ERROR { + err = r.readError() + return err + } + if r.ptype != ptype { - return fmt.Errorf("unexpexted packet type") + return fmt.Errorf("expected packet type %x but found %x", ptype, r.ptype) } plen, _ := readU32(r) r.plen = uint(plen) @@ -129,6 +153,7 @@ func (w *PacketWriter) finish() error { if err != nil { return err } err = writeU8(w.base, c) if err != nil { return err } + //fmt.Printf("Sending packet %x length %d, data %d\n", w.ptype, plen, w.buf.Len()) err = w.base.Flush() return err } @@ -176,7 +201,25 @@ func writeHelloPacket(w *PacketWriter, p HelloPacket) error { return err } +func readHello(r *PacketReader) error { + hello, err := readHelloPacket(r) + if err != nil { return err } + fmt.Printf("%+v\n", hello) + if hello.protocol != "pestcontrol" { return fmt.Errorf("unknown protocol")} + if hello.version != 1 { + return fmt.Errorf("unknown version") + } + return nil +} + +func sendHello(w *PacketWriter) error { + p := HelloPacket{"pestcontrol", 1} + err := writeHelloPacket(w, p) + return err +} + type Tally struct { + site uint species string count uint } @@ -186,6 +229,18 @@ type SiteVisitPacket struct { populations []Tally } +func validateSiteVisit(v SiteVisitPacket) error { + s := make(map[string]uint) + for _, t := range v.populations { + t2, ok := s[t.species] + if ok && t.count != t2 { + return fmt.Errorf("conflicting counts for %s", t.species) + } + s[t.species] = t.count + } + return nil +} + func readSiteVisitPacket(r *PacketReader) (SiteVisitPacket, error) { var p SiteVisitPacket var err error @@ -199,6 +254,7 @@ func readSiteVisitPacket(r *PacketReader) (SiteVisitPacket, error) { for i := 0 ; i < int(talCount) ; i++ { var t Tally var count uint32 + t.site = uint(siteId) t.species, err = readLString(r) if err != nil { return p, err } count, err = readU32(r) @@ -210,31 +266,283 @@ func readSiteVisitPacket(r *PacketReader) (SiteVisitPacket, error) { return p, err } +type DialAuthorityPacket struct { + siteId uint +} + +func writeDialAuthorityPacket(w *PacketWriter, p DialAuthorityPacket) error { + err := w.start(PACKET_DIAL_AUTHORITY) + if err != nil { return err } + err = writeU32(w, uint32(p.siteId)) + if err != nil { return err } + err = w.finish() + return err +} + +type MaxPop struct { + species string + max uint + min uint + +} + +type TargetPopulationsPacket struct { + siteId uint + populations []MaxPop +} + +func readTargetPopulationsPacket(r *PacketReader) (TargetPopulationsPacket, error) { + var p TargetPopulationsPacket + var err error + err = r.start(PACKET_TARGET_POPULATIONS) + if err != nil { return p, err } + siteId, err := readU32(r) + if err != nil { return p, err } + p.siteId = uint(siteId) + talCount, err := readU32(r) + if err != nil { return p, err } + for i := 0 ; i < int(talCount) ; i++ { + var t MaxPop + var max uint32 + var min uint32 + t.species, err = readLString(r) + if err != nil { return p, err } + max, err = readU32(r) + if err != nil { return p, err } + t.max = uint(max) + min, err = readU32(r) + if err != nil { return p, err } + t.min = uint(min) + p.populations = append(p.populations, t) + } + err = r.finish() + return p, err +} + +type CreatePolicyPacket struct { + species string + policy byte +} + +func writeCreatePolicyPacket(w *PacketWriter, p CreatePolicyPacket) error { + err := w.start(PACKET_CREATE_POLICY) + if err != nil { return err } + err = writeLString(w, p.species) + if err != nil { return err } + err = writeU8(w, p.policy) + if err != nil { return err } + err = w.finish() + return err +} + +type DeletePolicyPacket struct { + policyId uint +} + +func writeDeletePolicyPacket(w *PacketWriter, p DeletePolicyPacket) error { + err := w.start(PACKET_DELETE_POLICY) + if err != nil { return err } + err = writeU32(w, uint32(p.policyId)) + if err != nil { return err } + err = w.finish() + return err +} + +type CreatePolicyOkPacket struct { + policyId uint +} +func readCreatePolicyOkPacket(r *PacketReader) (CreatePolicyOkPacket, error) { + var err error + var p CreatePolicyOkPacket + err = r.start(PACKET_CREATE_POLICY_OK) + if err != nil { return p, err } + pid, err := readU32(r) + if err != nil { return p, err } + p.policyId = uint(pid) + err = r.finish() + return p, err +} + +func readDeletePolicyOkPacket(r *PacketReader) error { + var err error + err = r.start(PACKET_DELETE_POLICY_OK) + if err != nil { return err } + err = r.finish() + return err +} + +func writeError(w *PacketWriter, e error) error { + err := w.start(PACKET_ERROR) + if err != nil { return err } + err = writeLString(w, e.Error()) + if err != nil { return err } + err = w.finish() + return err +} + +type Policy struct { + policyId uint + policy byte +} + +type Authority struct { + siteId uint + q chan SiteVisitPacket + r *PacketReader + w *PacketWriter + pops map[string]MaxPop + policies map[string]Policy + +} + +func newAuthority(siteId uint) (*Authority) { + a := &Authority{ + siteId, + make(chan SiteVisitPacket, 10), + nil, + nil, + make(map[string]MaxPop), + make(map[string]Policy), + } + go a.run() + return a +} + +func (a *Authority) connect() error { + addr := "pestcontrol.protohackers.com:20547" + con, err := net.Dial("tcp", addr) + if err != nil { return err } + a.r = newPacketReader(bufio.NewReader(con)) + a.w = newPacketWriter(bufio.NewWriter(con)) + sendHello(a.w) + err = readHello(a.r) + if err != nil { return err} + p := DialAuthorityPacket{a.siteId} + err = writeDialAuthorityPacket(a.w, p) + if err != nil { return err} + mp, err := readTargetPopulationsPacket(a.r) + if err != nil { return err} + for _, p := range mp.populations { + a.pops[p.species] = p + } + return nil +} + +func getMeasuredCount(v SiteVisitPacket, species string) uint { + for _, p := range v.populations { + if p.species == species { + return p.count + } + } + return 0 +} + +func (a *Authority) deletePolicy(policyId uint) error { + p := DeletePolicyPacket{policyId} + err := writeDeletePolicyPacket(a.w, p) + if err != nil { return err } + err = readDeletePolicyOkPacket(a.r) + fmt.Printf("delpol %d\n", policyId) + return err +} + +func (a *Authority) createPolicy(species string, policy byte) (uint, error) { + p := CreatePolicyPacket{species, policy} + err := writeCreatePolicyPacket(a.w, p) + if err != nil { return 0, err } + okp, err := readCreatePolicyOkPacket(a.r) + fmt.Printf("pol %s set to %d id %d\n", species, policy, okp.policyId) + return okp.policyId, err +} + +func (a *Authority) setPolicy(species string, policy byte) error { + var err error + pol, ok := a.policies[species] + // if the policy matches, done! + if pol.policy == policy { + return nil + } + if ok { + err = a.deletePolicy(pol.policyId) + } + if err != nil { return err } + + pol.policy = policy + pol.policyId, err = a.createPolicy(species, policy) + a.policies[species] = pol + return err +} + +func(a *Authority) hdlSiteVisit(v SiteVisitPacket) error { + for _, p := range a.pops { + cnt := getMeasuredCount(v, p.species) + policy := byte(0) + if cnt > p.max { policy = 0x90 } + if cnt < p.min { policy = 0xa0 } + err := a.setPolicy(p.species, policy) + if err != nil { return err } + } + return nil +} + +func (a *Authority) run() { + fmt.Printf("connecting to site %d\n", a.siteId) + err := a.connect() + if err != nil { + fmt.Printf("ERROR authority connect fail %s\n", err) + return // todo retry later + } + for v := range a.q { + err = a.hdlSiteVisit(v) + if err != nil { + fmt.Printf("ERROR %s\n", err) + } + } +} + type PestServer struct { port uint16 + q chan SiteVisitPacket + authorities map[uint]*Authority } func NewPestServer(port uint16) *PestServer { return &PestServer{ port, + make(chan SiteVisitPacket, 10), + make(map[uint]*Authority), } } type PestSession struct{ con net.Conn + backend chan SiteVisitPacket r *PacketReader w *PacketWriter } -func NewPestSession(con net.Conn) *PestSession { +func NewPestSession(con net.Conn, backend chan SiteVisitPacket) *PestSession { return &PestSession{ con, - newPacketReader(bufio.NewReader(con)), + backend, + newPacketReader(bufio.NewReader(con)), newPacketWriter(bufio.NewWriter(con)), } } +func (s *PestServer) central() { + for t := range s.q { + a, ok := s.authorities[t.siteId] + if ! ok { + a = newAuthority(t.siteId) + s.authorities[t.siteId] = a + } + a.q <- t + } +} + func (s *PestServer) Run() { + go s.central() addr := fmt.Sprintf("0.0.0.0:%d", s.port) server, err := net.Listen("tcp", addr) if err != nil { @@ -255,41 +563,27 @@ func (s *PestServer) Run() { } func (s *PestServer) processClient(con net.Conn) { - session := NewPestSession(con) + session := NewPestSession(con, s.q) go session.pestHandler() } -func (s *PestSession) readHello() error { - hello, err := readHelloPacket(s.r) - if err != nil { return err } - fmt.Printf("%+v\n", hello) - if hello.protocol != "pestcontrol" { return fmt.Errorf("unknown protocol")} - if hello.version != 1 { - return fmt.Errorf("unknown version") - } - return nil -} - -func (s *PestSession) sendHello() error { - p := HelloPacket{"pestcontrol", 1} - err := writeHelloPacket(s.w, p) - return err -} - -func (s *PestSession) sendError(err error) error { - return nil -} func (s *PestSession) pestHandler() error { - err := s.readHello() + err := readHello(s.r) if err != nil { return err } fmt.Println("got hello, sending back") - err = s.sendHello() + err = sendHello(s.w) var v SiteVisitPacket for err == nil { v, err = readSiteVisitPacket(s.r) if err != nil { return err } - fmt.Printf("%+v\n", v) + fmt.Printf("I %+v\n", v) + err := validateSiteVisit(v) + if err != nil { + writeError(s.w, err) + } else { + s.backend <- v + } } return err } @@ -298,6 +592,6 @@ func (s *PestSession) run() { err := s.pestHandler() fmt.Printf("%e\n", err) if err != nil { - s.sendError(err) + writeError(s.w, err) } }