Proto/pest.go
2023-11-30 21:26:30 +01:00

608 lines
13 KiB
Go

package main
import (
"net"
"os"
"fmt"
"bufio"
"bytes"
"time"
)
const PACKET_HELLO = 0x50
const PACKET_ERROR = 0x51
const PACKET_SITE_VISIT = 0x58
const PACKET_DIAL_AUTHORITY = 0x53
const PACKET_TARGET_POPULATIONS = 0x54
const PACKET_CREATE_POLICY = 0x55
const PACKET_DELETE_POLICY = 0x56
const PACKET_DELETE_POLICY_OK = 0x52
const PACKET_CREATE_POLICY_OK = 0x57
type PacketReader struct {
count uint
sum byte
ptype byte
plen uint
base *bufio.Reader
started bool
}
func newPacketReader(r *bufio.Reader) *PacketReader {
return &PacketReader{
0,
0,
0,
0,
r,
false,
}
}
func (r *PacketReader) Read(b []byte) (int, error) {
bytesLeft := r.plen - r.count
if len(b) > int(bytesLeft) {
return 0, fmt.Errorf("not enough bytes left in package")
}
if ! r.started {
return 0, fmt.Errorf("read called but not started")
}
fmt.Printf("rd %d cnt %d left %d\n", len(b), r.count, bytesLeft)
n, err := r.base.Read(b)
if err != nil { return 0, err }
for i := 0; i < n; i++ {
r.sum += b[i]
}
r.count += uint(n)
return n, err
}
func (r *PacketReader) readError() error {
plen, err := readU32(r) //plen
r.plen = uint(plen)
if err != nil { return err }
d, err := readLString(r)
_, err = readU8(r) // checksum
r.started = false
if err != nil { return err }
return fmt.Errorf("Server error: %s", d)
}
func (r *PacketReader) start(ptype byte) error {
if r.started {
return fmt.Errorf("already started")
}
var err error
r.plen = 5
r.sum = 0
r.count = 0
r.started = true
r.ptype, _ = readU8(r)
if r.ptype == PACKET_ERROR {
err = r.readError()
return err
}
if r.ptype != ptype {
return fmt.Errorf("expected packet type %x but found %x", ptype, r.ptype)
}
plen, _ := readU32(r)
r.plen = uint(plen)
return nil
}
func (r *PacketReader) finish() error {
if ! r.started {
return fmt.Errorf("not started")
}
_, err := readU8(r) //checksum
r.started = false
if err != nil { return err }
if r.count != r.plen {
return fmt.Errorf("packet len mismatch exp %d read %d", r.plen, r.count)
}
if r.sum != 0 {
return fmt.Errorf("checksum error: %d", r.sum)
}
return nil
}
type PacketWriter struct {
count uint
sum byte
ptype byte
buf *bytes.Buffer
base *bufio.Writer
started bool
}
func newPacketWriter(r *bufio.Writer) *PacketWriter {
return &PacketWriter{
0,
0,
0,
nil,
r,
false,
}
}
func (w *PacketWriter) start(ptype byte) error {
if w.started {
return fmt.Errorf("already started")
}
w.buf = bytes.NewBuffer([]byte{})
w.count = 0
w.sum = ptype
w.started = true
w.ptype = ptype
return nil
}
func (w *PacketWriter) finish() error {
if ! w.started {
return fmt.Errorf("not started")
}
w.started = false
plen := uint32(w.count + 6)
c := byte(plen) + byte(plen >> 8) + byte(plen >> 16) + byte(plen >> 24) + w.sum
c = byte(- c)
err := writeU8(w.base, w.ptype)
if err != nil { return err }
err = writeU32(w.base, plen)
if err != nil { return err }
_,err = w.base.Write(w.buf.Bytes())
if err != nil { return err }
err = writeU8(w.base, c)
if err != nil { return err }
//fmt.Printf("Sending packet %x length %d, data %d\n", w.ptype, plen, w.buf.Len())
err = w.base.Flush()
return err
}
func (w *PacketWriter) Write(b []byte) (int, error) {
if ! w.started {
return 0, fmt.Errorf("write called but not started")
}
n, err := w.buf.Write(b)
w.count += uint(n)
for i := 0; i < n; i++ {
w.sum += b[i]
}
return n, err
}
type HelloPacket struct {
protocol string
version uint
}
func readHelloPacket(r *PacketReader) (HelloPacket, error) {
var p HelloPacket
var err error
err = r.start(PACKET_HELLO)
if err != nil { return p, err }
p.protocol, err = readLString(r)
if err != nil { return p, err }
version, err := readU32(r)
if err != nil { return p, err }
p.version = uint(version)
err = r.finish()
return p, err
}
func writeHelloPacket(w *PacketWriter, p HelloPacket) error {
err := w.start(PACKET_HELLO)
if err != nil { return err }
err = writeLString(w, p.protocol)
if err != nil { return err }
err = writeU32(w, uint32(p.version))
if err != nil { return err }
err = w.finish()
return err
}
func readHello(r *PacketReader) error {
hello, err := readHelloPacket(r)
if err != nil { return err }
fmt.Printf("%+v\n", hello)
if hello.protocol != "pestcontrol" { return fmt.Errorf("unknown protocol")}
if hello.version != 1 {
return fmt.Errorf("unknown version")
}
return nil
}
func sendHello(w *PacketWriter) error {
p := HelloPacket{"pestcontrol", 1}
err := writeHelloPacket(w, p)
return err
}
type Tally struct {
site uint
species string
count uint
}
type SiteVisitPacket struct {
siteId uint
populations []Tally
}
func validateSiteVisit(v SiteVisitPacket) error {
s := make(map[string]uint)
for _, t := range v.populations {
t2, ok := s[t.species]
if ok && t.count != t2 {
return fmt.Errorf("conflicting counts for %s", t.species)
}
s[t.species] = t.count
}
return nil
}
func readSiteVisitPacket(r *PacketReader) (SiteVisitPacket, error) {
var p SiteVisitPacket
var err error
err = r.start(PACKET_SITE_VISIT)
if err != nil { return p, err }
siteId, err := readU32(r)
if err != nil { return p, err }
p.siteId = uint(siteId)
talCount, err := readU32(r)
if err != nil { return p, err }
for i := 0 ; i < int(talCount) ; i++ {
var t Tally
var count uint32
t.site = uint(siteId)
t.species, err = readLString(r)
if err != nil { return p, err }
count, err = readU32(r)
if err != nil { return p, err }
t.count = uint(count)
p.populations = append(p.populations, t)
}
err = r.finish()
return p, err
}
type DialAuthorityPacket struct {
siteId uint
}
func writeDialAuthorityPacket(w *PacketWriter, p DialAuthorityPacket) error {
err := w.start(PACKET_DIAL_AUTHORITY)
if err != nil { return err }
err = writeU32(w, uint32(p.siteId))
if err != nil { return err }
err = w.finish()
return err
}
type MaxPop struct {
species string
max uint
min uint
}
type TargetPopulationsPacket struct {
siteId uint
populations []MaxPop
}
func readTargetPopulationsPacket(r *PacketReader) (TargetPopulationsPacket, error) {
var p TargetPopulationsPacket
var err error
err = r.start(PACKET_TARGET_POPULATIONS)
if err != nil { return p, err }
siteId, err := readU32(r)
if err != nil { return p, err }
p.siteId = uint(siteId)
talCount, err := readU32(r)
if err != nil { return p, err }
for i := 0 ; i < int(talCount) ; i++ {
var t MaxPop
var max uint32
var min uint32
t.species, err = readLString(r)
if err != nil { return p, err }
max, err = readU32(r)
if err != nil { return p, err }
t.max = uint(max)
min, err = readU32(r)
if err != nil { return p, err }
t.min = uint(min)
p.populations = append(p.populations, t)
}
err = r.finish()
return p, err
}
type CreatePolicyPacket struct {
species string
policy byte
}
func writeCreatePolicyPacket(w *PacketWriter, p CreatePolicyPacket) error {
err := w.start(PACKET_CREATE_POLICY)
if err != nil { return err }
err = writeLString(w, p.species)
if err != nil { return err }
err = writeU8(w, p.policy)
if err != nil { return err }
err = w.finish()
return err
}
type DeletePolicyPacket struct {
policyId uint
}
func writeDeletePolicyPacket(w *PacketWriter, p DeletePolicyPacket) error {
err := w.start(PACKET_DELETE_POLICY)
if err != nil { return err }
err = writeU32(w, uint32(p.policyId))
if err != nil { return err }
err = w.finish()
return err
}
type CreatePolicyOkPacket struct {
policyId uint
}
func readCreatePolicyOkPacket(r *PacketReader) (CreatePolicyOkPacket, error) {
var err error
var p CreatePolicyOkPacket
err = r.start(PACKET_CREATE_POLICY_OK)
if err != nil { return p, err }
pid, err := readU32(r)
if err != nil { return p, err }
p.policyId = uint(pid)
err = r.finish()
return p, err
}
func readDeletePolicyOkPacket(r *PacketReader) error {
var err error
err = r.start(PACKET_DELETE_POLICY_OK)
if err != nil { return err }
err = r.finish()
return err
}
func writeError(w *PacketWriter, e error) error {
fmt.Printf("sending error: %s\n", e)
err := w.start(PACKET_ERROR)
if err != nil { return err }
err = writeLString(w, e.Error())
if err != nil { return err }
err = w.finish()
return err
}
type Policy struct {
policyId uint
policy byte
}
type Authority struct {
siteId uint
q chan SiteVisitPacket
r *PacketReader
w *PacketWriter
pops map[string]MaxPop
policies map[string]Policy
}
func newAuthority(siteId uint) (*Authority) {
a := &Authority{
siteId,
make(chan SiteVisitPacket, 10),
nil,
nil,
make(map[string]MaxPop),
make(map[string]Policy),
}
go a.run()
return a
}
func (a *Authority) connect() error {
addr := "pestcontrol.protohackers.com:20547"
con, err := net.Dial("tcp", addr)
if err != nil { return err }
a.r = newPacketReader(bufio.NewReader(con))
a.w = newPacketWriter(bufio.NewWriter(con))
sendHello(a.w)
err = readHello(a.r)
if err != nil { return err}
p := DialAuthorityPacket{a.siteId}
err = writeDialAuthorityPacket(a.w, p)
if err != nil { return err}
mp, err := readTargetPopulationsPacket(a.r)
if err != nil { return err}
for _, p := range mp.populations {
a.pops[p.species] = p
}
return nil
}
func getMeasuredCount(v SiteVisitPacket, species string) uint {
for _, p := range v.populations {
if p.species == species {
return p.count
}
}
return 0
}
func (a *Authority) deletePolicy(policyId uint) error {
p := DeletePolicyPacket{policyId}
err := writeDeletePolicyPacket(a.w, p)
if err != nil { return err }
err = readDeletePolicyOkPacket(a.r)
fmt.Printf("delpol %d\n", policyId)
return err
}
func (a *Authority) createPolicy(species string, policy byte) (uint, error) {
p := CreatePolicyPacket{species, policy}
err := writeCreatePolicyPacket(a.w, p)
if err != nil { return 0, err }
okp, err := readCreatePolicyOkPacket(a.r)
fmt.Printf("pol %s set to %d id %d\n", species, policy, okp.policyId)
return okp.policyId, err
}
func (a *Authority) setPolicy(species string, policy byte) error {
var err error
pol, ok := a.policies[species]
// if the policy matches, done!
if pol.policy == policy {
return nil
}
if ok {
err = a.deletePolicy(pol.policyId)
}
if err != nil { return err }
pol.policy = policy
pol.policyId, err = a.createPolicy(species, policy)
a.policies[species] = pol
return err
}
func(a *Authority) hdlSiteVisit(v SiteVisitPacket) error {
for _, p := range a.pops {
cnt := getMeasuredCount(v, p.species)
policy := byte(0)
if cnt > p.max { policy = 0x90 }
if cnt < p.min { policy = 0xa0 }
err := a.setPolicy(p.species, policy)
if err != nil { return err }
}
return nil
}
func (a *Authority) run() {
fmt.Printf("connecting to site %d\n", a.siteId)
err := a.connect()
if err != nil {
fmt.Printf("ERROR authority connect fail %s\n", err)
return // todo retry later
}
for v := range a.q {
err = a.hdlSiteVisit(v)
if err != nil {
fmt.Printf("ERROR %s\n", err)
}
}
}
type PestServer struct {
port uint16
q chan SiteVisitPacket
authorities map[uint]*Authority
}
func NewPestServer(port uint16) *PestServer {
return &PestServer{
port,
make(chan SiteVisitPacket, 10),
make(map[uint]*Authority),
}
}
type PestSession struct{
con net.Conn
backend chan SiteVisitPacket
r *PacketReader
w *PacketWriter
}
func NewPestSession(con net.Conn, backend chan SiteVisitPacket) *PestSession {
return &PestSession{
con,
backend,
newPacketReader(bufio.NewReader(con)),
newPacketWriter(bufio.NewWriter(con)),
}
}
func (s *PestServer) central() {
for t := range s.q {
a, ok := s.authorities[t.siteId]
if ! ok {
a = newAuthority(t.siteId)
s.authorities[t.siteId] = a
}
a.q <- t
}
}
func (s *PestServer) Run() {
go s.central()
addr := fmt.Sprintf("0.0.0.0:%d", s.port)
server, err := net.Listen("tcp", addr)
if err != nil {
fmt.Println("Error listening:", err.Error())
os.Exit(1)
}
defer server.Close()
fmt.Println("PestServer waiting for client...")
for {
connection, err := server.Accept()
if err != nil {
fmt.Println("Error accepting: ", err.Error())
os.Exit(1)
}
fmt.Println("client connected")
s.processClient(connection)
}
}
func (s *PestServer) processClient(con net.Conn) {
session := NewPestSession(con, s.q)
go session.pestHandler()
}
func (s *PestSession) pestHandler() error {
err := readHello(s.r)
fmt.Println("got hello, sending back")
sendHello(s.w)
if err != nil {
writeError(s.w, err)
return err
}
var v SiteVisitPacket
for err == nil {
s.con.SetReadDeadline(time.Now().Add(30*time.Second))
v, err = readSiteVisitPacket(s.r)
if err != nil { break }
fmt.Printf("I %+v\n", v)
err := validateSiteVisit(v)
if err != nil {
writeError(s.w, err)
} else {
s.backend <- v
}
}
writeError(s.w, err)
fmt.Printf("ERROR: %s\n", err)
return err
}
func (s *PestSession) run() {
err := s.pestHandler()
fmt.Printf("%e\n", err)
if err != nil {
writeError(s.w, err)
}
}