pest first tests passing
This commit is contained in:
parent
fc6585ee30
commit
00fa90c6db
348
pest.go
348
pest.go
@ -9,7 +9,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const PACKET_HELLO = 0x50
|
const PACKET_HELLO = 0x50
|
||||||
|
const PACKET_ERROR = 0x51
|
||||||
const PACKET_SITE_VISIT = 0x58
|
const PACKET_SITE_VISIT = 0x58
|
||||||
|
const PACKET_DIAL_AUTHORITY = 0x53
|
||||||
|
const PACKET_TARGET_POPULATIONS = 0x54
|
||||||
|
const PACKET_CREATE_POLICY = 0x55
|
||||||
|
const PACKET_DELETE_POLICY = 0x56
|
||||||
|
const PACKET_DELETE_POLICY_OK = 0x52
|
||||||
|
const PACKET_CREATE_POLICY_OK = 0x57
|
||||||
|
|
||||||
type PacketReader struct {
|
type PacketReader struct {
|
||||||
count uint
|
count uint
|
||||||
@ -47,17 +54,34 @@ func (r *PacketReader) Read(b []byte) (int, error) {
|
|||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *PacketReader) readError() error {
|
||||||
|
plen, err := readU32(r) //plen
|
||||||
|
r.plen = uint(plen)
|
||||||
|
if err != nil { return err }
|
||||||
|
d, err := readLString(r)
|
||||||
|
_, err = readU8(r) // checksum
|
||||||
|
r.started = false
|
||||||
|
if err != nil { return err }
|
||||||
|
return fmt.Errorf("Server error: %s", d)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *PacketReader) start(ptype byte) error {
|
func (r *PacketReader) start(ptype byte) error {
|
||||||
if r.started {
|
if r.started {
|
||||||
return fmt.Errorf("already started")
|
return fmt.Errorf("already started")
|
||||||
}
|
}
|
||||||
|
var err error
|
||||||
r.plen = 5
|
r.plen = 5
|
||||||
r.sum = 0
|
r.sum = 0
|
||||||
r.count = 0
|
r.count = 0
|
||||||
r.started = true
|
r.started = true
|
||||||
r.ptype, _ = readU8(r)
|
r.ptype, _ = readU8(r)
|
||||||
|
if r.ptype == PACKET_ERROR {
|
||||||
|
err = r.readError()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if r.ptype != ptype {
|
if r.ptype != ptype {
|
||||||
return fmt.Errorf("unexpexted packet type")
|
return fmt.Errorf("expected packet type %x but found %x", ptype, r.ptype)
|
||||||
}
|
}
|
||||||
plen, _ := readU32(r)
|
plen, _ := readU32(r)
|
||||||
r.plen = uint(plen)
|
r.plen = uint(plen)
|
||||||
@ -129,6 +153,7 @@ func (w *PacketWriter) finish() error {
|
|||||||
if err != nil { return err }
|
if err != nil { return err }
|
||||||
err = writeU8(w.base, c)
|
err = writeU8(w.base, c)
|
||||||
if err != nil { return err }
|
if err != nil { return err }
|
||||||
|
//fmt.Printf("Sending packet %x length %d, data %d\n", w.ptype, plen, w.buf.Len())
|
||||||
err = w.base.Flush()
|
err = w.base.Flush()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -176,7 +201,25 @@ func writeHelloPacket(w *PacketWriter, p HelloPacket) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readHello(r *PacketReader) error {
|
||||||
|
hello, err := readHelloPacket(r)
|
||||||
|
if err != nil { return err }
|
||||||
|
fmt.Printf("%+v\n", hello)
|
||||||
|
if hello.protocol != "pestcontrol" { return fmt.Errorf("unknown protocol")}
|
||||||
|
if hello.version != 1 {
|
||||||
|
return fmt.Errorf("unknown version")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendHello(w *PacketWriter) error {
|
||||||
|
p := HelloPacket{"pestcontrol", 1}
|
||||||
|
err := writeHelloPacket(w, p)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
type Tally struct {
|
type Tally struct {
|
||||||
|
site uint
|
||||||
species string
|
species string
|
||||||
count uint
|
count uint
|
||||||
}
|
}
|
||||||
@ -186,6 +229,18 @@ type SiteVisitPacket struct {
|
|||||||
populations []Tally
|
populations []Tally
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateSiteVisit(v SiteVisitPacket) error {
|
||||||
|
s := make(map[string]uint)
|
||||||
|
for _, t := range v.populations {
|
||||||
|
t2, ok := s[t.species]
|
||||||
|
if ok && t.count != t2 {
|
||||||
|
return fmt.Errorf("conflicting counts for %s", t.species)
|
||||||
|
}
|
||||||
|
s[t.species] = t.count
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func readSiteVisitPacket(r *PacketReader) (SiteVisitPacket, error) {
|
func readSiteVisitPacket(r *PacketReader) (SiteVisitPacket, error) {
|
||||||
var p SiteVisitPacket
|
var p SiteVisitPacket
|
||||||
var err error
|
var err error
|
||||||
@ -199,6 +254,7 @@ func readSiteVisitPacket(r *PacketReader) (SiteVisitPacket, error) {
|
|||||||
for i := 0 ; i < int(talCount) ; i++ {
|
for i := 0 ; i < int(talCount) ; i++ {
|
||||||
var t Tally
|
var t Tally
|
||||||
var count uint32
|
var count uint32
|
||||||
|
t.site = uint(siteId)
|
||||||
t.species, err = readLString(r)
|
t.species, err = readLString(r)
|
||||||
if err != nil { return p, err }
|
if err != nil { return p, err }
|
||||||
count, err = readU32(r)
|
count, err = readU32(r)
|
||||||
@ -210,31 +266,283 @@ func readSiteVisitPacket(r *PacketReader) (SiteVisitPacket, error) {
|
|||||||
return p, err
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DialAuthorityPacket struct {
|
||||||
|
siteId uint
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeDialAuthorityPacket(w *PacketWriter, p DialAuthorityPacket) error {
|
||||||
|
err := w.start(PACKET_DIAL_AUTHORITY)
|
||||||
|
if err != nil { return err }
|
||||||
|
err = writeU32(w, uint32(p.siteId))
|
||||||
|
if err != nil { return err }
|
||||||
|
err = w.finish()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type MaxPop struct {
|
||||||
|
species string
|
||||||
|
max uint
|
||||||
|
min uint
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
type TargetPopulationsPacket struct {
|
||||||
|
siteId uint
|
||||||
|
populations []MaxPop
|
||||||
|
}
|
||||||
|
|
||||||
|
func readTargetPopulationsPacket(r *PacketReader) (TargetPopulationsPacket, error) {
|
||||||
|
var p TargetPopulationsPacket
|
||||||
|
var err error
|
||||||
|
err = r.start(PACKET_TARGET_POPULATIONS)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
siteId, err := readU32(r)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
p.siteId = uint(siteId)
|
||||||
|
talCount, err := readU32(r)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
for i := 0 ; i < int(talCount) ; i++ {
|
||||||
|
var t MaxPop
|
||||||
|
var max uint32
|
||||||
|
var min uint32
|
||||||
|
t.species, err = readLString(r)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
max, err = readU32(r)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
t.max = uint(max)
|
||||||
|
min, err = readU32(r)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
t.min = uint(min)
|
||||||
|
p.populations = append(p.populations, t)
|
||||||
|
}
|
||||||
|
err = r.finish()
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreatePolicyPacket struct {
|
||||||
|
species string
|
||||||
|
policy byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeCreatePolicyPacket(w *PacketWriter, p CreatePolicyPacket) error {
|
||||||
|
err := w.start(PACKET_CREATE_POLICY)
|
||||||
|
if err != nil { return err }
|
||||||
|
err = writeLString(w, p.species)
|
||||||
|
if err != nil { return err }
|
||||||
|
err = writeU8(w, p.policy)
|
||||||
|
if err != nil { return err }
|
||||||
|
err = w.finish()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeletePolicyPacket struct {
|
||||||
|
policyId uint
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeDeletePolicyPacket(w *PacketWriter, p DeletePolicyPacket) error {
|
||||||
|
err := w.start(PACKET_DELETE_POLICY)
|
||||||
|
if err != nil { return err }
|
||||||
|
err = writeU32(w, uint32(p.policyId))
|
||||||
|
if err != nil { return err }
|
||||||
|
err = w.finish()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreatePolicyOkPacket struct {
|
||||||
|
policyId uint
|
||||||
|
}
|
||||||
|
func readCreatePolicyOkPacket(r *PacketReader) (CreatePolicyOkPacket, error) {
|
||||||
|
var err error
|
||||||
|
var p CreatePolicyOkPacket
|
||||||
|
err = r.start(PACKET_CREATE_POLICY_OK)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
pid, err := readU32(r)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
p.policyId = uint(pid)
|
||||||
|
err = r.finish()
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func readDeletePolicyOkPacket(r *PacketReader) error {
|
||||||
|
var err error
|
||||||
|
err = r.start(PACKET_DELETE_POLICY_OK)
|
||||||
|
if err != nil { return err }
|
||||||
|
err = r.finish()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeError(w *PacketWriter, e error) error {
|
||||||
|
err := w.start(PACKET_ERROR)
|
||||||
|
if err != nil { return err }
|
||||||
|
err = writeLString(w, e.Error())
|
||||||
|
if err != nil { return err }
|
||||||
|
err = w.finish()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type Policy struct {
|
||||||
|
policyId uint
|
||||||
|
policy byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type Authority struct {
|
||||||
|
siteId uint
|
||||||
|
q chan SiteVisitPacket
|
||||||
|
r *PacketReader
|
||||||
|
w *PacketWriter
|
||||||
|
pops map[string]MaxPop
|
||||||
|
policies map[string]Policy
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthority(siteId uint) (*Authority) {
|
||||||
|
a := &Authority{
|
||||||
|
siteId,
|
||||||
|
make(chan SiteVisitPacket, 10),
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
make(map[string]MaxPop),
|
||||||
|
make(map[string]Policy),
|
||||||
|
}
|
||||||
|
go a.run()
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authority) connect() error {
|
||||||
|
addr := "pestcontrol.protohackers.com:20547"
|
||||||
|
con, err := net.Dial("tcp", addr)
|
||||||
|
if err != nil { return err }
|
||||||
|
a.r = newPacketReader(bufio.NewReader(con))
|
||||||
|
a.w = newPacketWriter(bufio.NewWriter(con))
|
||||||
|
sendHello(a.w)
|
||||||
|
err = readHello(a.r)
|
||||||
|
if err != nil { return err}
|
||||||
|
p := DialAuthorityPacket{a.siteId}
|
||||||
|
err = writeDialAuthorityPacket(a.w, p)
|
||||||
|
if err != nil { return err}
|
||||||
|
mp, err := readTargetPopulationsPacket(a.r)
|
||||||
|
if err != nil { return err}
|
||||||
|
for _, p := range mp.populations {
|
||||||
|
a.pops[p.species] = p
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMeasuredCount(v SiteVisitPacket, species string) uint {
|
||||||
|
for _, p := range v.populations {
|
||||||
|
if p.species == species {
|
||||||
|
return p.count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authority) deletePolicy(policyId uint) error {
|
||||||
|
p := DeletePolicyPacket{policyId}
|
||||||
|
err := writeDeletePolicyPacket(a.w, p)
|
||||||
|
if err != nil { return err }
|
||||||
|
err = readDeletePolicyOkPacket(a.r)
|
||||||
|
fmt.Printf("delpol %d\n", policyId)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authority) createPolicy(species string, policy byte) (uint, error) {
|
||||||
|
p := CreatePolicyPacket{species, policy}
|
||||||
|
err := writeCreatePolicyPacket(a.w, p)
|
||||||
|
if err != nil { return 0, err }
|
||||||
|
okp, err := readCreatePolicyOkPacket(a.r)
|
||||||
|
fmt.Printf("pol %s set to %d id %d\n", species, policy, okp.policyId)
|
||||||
|
return okp.policyId, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authority) setPolicy(species string, policy byte) error {
|
||||||
|
var err error
|
||||||
|
pol, ok := a.policies[species]
|
||||||
|
// if the policy matches, done!
|
||||||
|
if pol.policy == policy {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
err = a.deletePolicy(pol.policyId)
|
||||||
|
}
|
||||||
|
if err != nil { return err }
|
||||||
|
|
||||||
|
pol.policy = policy
|
||||||
|
pol.policyId, err = a.createPolicy(species, policy)
|
||||||
|
a.policies[species] = pol
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func(a *Authority) hdlSiteVisit(v SiteVisitPacket) error {
|
||||||
|
for _, p := range a.pops {
|
||||||
|
cnt := getMeasuredCount(v, p.species)
|
||||||
|
policy := byte(0)
|
||||||
|
if cnt > p.max { policy = 0x90 }
|
||||||
|
if cnt < p.min { policy = 0xa0 }
|
||||||
|
err := a.setPolicy(p.species, policy)
|
||||||
|
if err != nil { return err }
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authority) run() {
|
||||||
|
fmt.Printf("connecting to site %d\n", a.siteId)
|
||||||
|
err := a.connect()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("ERROR authority connect fail %s\n", err)
|
||||||
|
return // todo retry later
|
||||||
|
}
|
||||||
|
for v := range a.q {
|
||||||
|
err = a.hdlSiteVisit(v)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("ERROR %s\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type PestServer struct {
|
type PestServer struct {
|
||||||
port uint16
|
port uint16
|
||||||
|
q chan SiteVisitPacket
|
||||||
|
authorities map[uint]*Authority
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPestServer(port uint16) *PestServer {
|
func NewPestServer(port uint16) *PestServer {
|
||||||
return &PestServer{
|
return &PestServer{
|
||||||
port,
|
port,
|
||||||
|
make(chan SiteVisitPacket, 10),
|
||||||
|
make(map[uint]*Authority),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type PestSession struct{
|
type PestSession struct{
|
||||||
con net.Conn
|
con net.Conn
|
||||||
|
backend chan SiteVisitPacket
|
||||||
r *PacketReader
|
r *PacketReader
|
||||||
w *PacketWriter
|
w *PacketWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPestSession(con net.Conn) *PestSession {
|
func NewPestSession(con net.Conn, backend chan SiteVisitPacket) *PestSession {
|
||||||
return &PestSession{
|
return &PestSession{
|
||||||
con,
|
con,
|
||||||
|
backend,
|
||||||
newPacketReader(bufio.NewReader(con)),
|
newPacketReader(bufio.NewReader(con)),
|
||||||
newPacketWriter(bufio.NewWriter(con)),
|
newPacketWriter(bufio.NewWriter(con)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *PestServer) central() {
|
||||||
|
for t := range s.q {
|
||||||
|
a, ok := s.authorities[t.siteId]
|
||||||
|
if ! ok {
|
||||||
|
a = newAuthority(t.siteId)
|
||||||
|
s.authorities[t.siteId] = a
|
||||||
|
}
|
||||||
|
a.q <- t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *PestServer) Run() {
|
func (s *PestServer) Run() {
|
||||||
|
go s.central()
|
||||||
addr := fmt.Sprintf("0.0.0.0:%d", s.port)
|
addr := fmt.Sprintf("0.0.0.0:%d", s.port)
|
||||||
server, err := net.Listen("tcp", addr)
|
server, err := net.Listen("tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -255,41 +563,27 @@ func (s *PestServer) Run() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *PestServer) processClient(con net.Conn) {
|
func (s *PestServer) processClient(con net.Conn) {
|
||||||
session := NewPestSession(con)
|
session := NewPestSession(con, s.q)
|
||||||
go session.pestHandler()
|
go session.pestHandler()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PestSession) readHello() error {
|
|
||||||
hello, err := readHelloPacket(s.r)
|
|
||||||
if err != nil { return err }
|
|
||||||
fmt.Printf("%+v\n", hello)
|
|
||||||
if hello.protocol != "pestcontrol" { return fmt.Errorf("unknown protocol")}
|
|
||||||
if hello.version != 1 {
|
|
||||||
return fmt.Errorf("unknown version")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *PestSession) sendHello() error {
|
|
||||||
p := HelloPacket{"pestcontrol", 1}
|
|
||||||
err := writeHelloPacket(s.w, p)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *PestSession) sendError(err error) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *PestSession) pestHandler() error {
|
func (s *PestSession) pestHandler() error {
|
||||||
err := s.readHello()
|
err := readHello(s.r)
|
||||||
if err != nil { return err }
|
if err != nil { return err }
|
||||||
fmt.Println("got hello, sending back")
|
fmt.Println("got hello, sending back")
|
||||||
err = s.sendHello()
|
err = sendHello(s.w)
|
||||||
var v SiteVisitPacket
|
var v SiteVisitPacket
|
||||||
for err == nil {
|
for err == nil {
|
||||||
v, err = readSiteVisitPacket(s.r)
|
v, err = readSiteVisitPacket(s.r)
|
||||||
if err != nil { return err }
|
if err != nil { return err }
|
||||||
fmt.Printf("%+v\n", v)
|
fmt.Printf("I %+v\n", v)
|
||||||
|
err := validateSiteVisit(v)
|
||||||
|
if err != nil {
|
||||||
|
writeError(s.w, err)
|
||||||
|
} else {
|
||||||
|
s.backend <- v
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -298,6 +592,6 @@ func (s *PestSession) run() {
|
|||||||
err := s.pestHandler()
|
err := s.pestHandler()
|
||||||
fmt.Printf("%e\n", err)
|
fmt.Printf("%e\n", err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.sendError(err)
|
writeError(s.w, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user