diff --git a/binutil.go b/binutil.go new file mode 100644 index 0000000..1cff537 --- /dev/null +++ b/binutil.go @@ -0,0 +1,61 @@ +package main + +import ( +"encoding/binary" +"io" +"fmt" +) + +var BE = binary.BigEndian + +func readU8(r io.Reader) (uint8, error) { + var result uint8 + err := binary.Read(r, BE, &result); + return result, err +} + +func readU16(r io.Reader) (uint16, error) { + var result uint16 + err := binary.Read(r, BE, &result); + return result, err +} +func readU32(r io.Reader) (uint32, error) { + var result uint32 + err := binary.Read(r, BE, &result); + return result, err +} +func readString(r io.Reader) (string, error) { + l, err := readU8(r) + if err == nil && l == 0 { + err = fmt.Errorf("invalid string length 0") + } + if err != nil { return "", err} + buf := make([]byte, l) + _, err = io.ReadFull(r, buf) + return string(buf), err +} + +func readLString(r io.Reader) (string, error) { + l, err := readU32(r) + if err == nil && l == 0 { + err = fmt.Errorf("invalid string length 0") + } + if err != nil { return "", err} + buf := make([]byte, l) + _, err = io.ReadFull(r, buf) + return string(buf), err +} + +func writeLString(w io.Writer, s string) error { + writeU32(w,uint32(len(s))) + _, err := w.Write([]byte(s)) + return err +} + +func writeU8(w io.Writer, i uint8) error { + return binary.Write(w, BE, i); +} + +func writeU32(w io.Writer, i uint32) error { + return binary.Write(w, BE, i); +} diff --git a/pest.go b/pest.go index 642135e..630ce75 100644 --- a/pest.go +++ b/pest.go @@ -4,8 +4,211 @@ import ( "net" "os" "fmt" +"bufio" +"bytes" ) +const PACKET_HELLO = 0x50 +const PACKET_SITE_VISIT = 0x58 + +type PacketReader struct { + count uint + sum byte + ptype byte + plen uint + base *bufio.Reader + started bool +} +func newPacketReader(r *bufio.Reader) *PacketReader { + return &PacketReader{ + 0, + 0, + 0, + 0, + r, + false, + } +} + +func (r *PacketReader) Read(b []byte) (int, error) { + bytesLeft := r.plen - r.count + if len(b) > int(bytesLeft) { + return 0, fmt.Errorf("not enough bytes left in package") + } + if ! r.started { + return 0, fmt.Errorf("read called but not started") + } + n, err := r.base.Read(b) + if err != nil { return 0, err } + for i := 0; i < n; i++ { + r.sum += b[i] + } + r.count += uint(len(b)) + return n, err +} + +func (r *PacketReader) start(ptype byte) error { + if r.started { + return fmt.Errorf("already started") + } + r.plen = 5 + r.sum = 0 + r.count = 0 + r.started = true + r.ptype, _ = readU8(r) + if r.ptype != ptype { + return fmt.Errorf("unexpexted packet type") + } + plen, _ := readU32(r) + r.plen = uint(plen) + return nil +} + +func (r *PacketReader) finish() error { + if ! r.started { + return fmt.Errorf("not started") + } + _, err := readU8(r) //checksum + r.started = false + if err != nil { return err } + if r.count != r.plen { + return fmt.Errorf("packet len mismatch exp %d read %d", r.plen, r.count) + } + if r.sum != 0 { + return fmt.Errorf("checksum error: %d", r.sum) + } + + return nil +} + +type PacketWriter struct { + count uint + sum byte + ptype byte + buf *bytes.Buffer + base *bufio.Writer + started bool +} + +func newPacketWriter(r *bufio.Writer) *PacketWriter { + return &PacketWriter{ + 0, + 0, + 0, + nil, + r, + false, + } +} + +func (w *PacketWriter) start(ptype byte) error { + if w.started { + return fmt.Errorf("already started") + } + w.buf = bytes.NewBuffer([]byte{}) + w.count = 0 + w.sum = ptype + w.started = true + w.ptype = ptype + return nil +} + +func (w *PacketWriter) finish() error { + if ! w.started { + return fmt.Errorf("not started") + } + w.started = false + plen := uint32(w.count + 6) + c := byte(plen) + byte(plen >> 8) + byte(plen >> 16) + byte(plen >> 24) + w.sum + c = byte(- c) + err := writeU8(w.base, w.ptype) + if err != nil { return err } + err = writeU32(w.base, plen) + if err != nil { return err } + _,err = w.base.Write(w.buf.Bytes()) + if err != nil { return err } + err = writeU8(w.base, c) + if err != nil { return err } + err = w.base.Flush() + return err +} + +func (w *PacketWriter) Write(b []byte) (int, error) { + if ! w.started { + return 0, fmt.Errorf("write called but not started") + } + n, err := w.buf.Write(b) + w.count += uint(n) + for i := 0; i < n; i++ { + w.sum += b[i] + } + return n, err +} + + +type HelloPacket struct { + protocol string + version uint +} + +func readHelloPacket(r *PacketReader) (HelloPacket, error) { + var p HelloPacket + var err error + err = r.start(PACKET_HELLO) + if err != nil { return p, err } + p.protocol, err = readLString(r) + if err != nil { return p, err } + version, err := readU32(r) + if err != nil { return p, err } + p.version = uint(version) + err = r.finish() + return p, err +} + +func writeHelloPacket(w *PacketWriter, p HelloPacket) error { + err := w.start(PACKET_HELLO) + if err != nil { return err } + err = writeLString(w, p.protocol) + if err != nil { return err } + err = writeU32(w, uint32(p.version)) + if err != nil { return err } + err = w.finish() + return err +} + +type Tally struct { + species string + count uint +} + +type SiteVisitPacket struct { + siteId uint + populations []Tally +} + +func readSiteVisitPacket(r *PacketReader) (SiteVisitPacket, error) { + var p SiteVisitPacket + var err error + err = r.start(PACKET_SITE_VISIT) + if err != nil { return p, err } + siteId, err := readU32(r) + if err != nil { return p, err } + p.siteId = uint(siteId) + talCount, err := readU32(r) + if err != nil { return p, err } + for i := 0 ; i < int(talCount) ; i++ { + var t Tally + var count uint32 + t.species, err = readLString(r) + if err != nil { return p, err } + count, err = readU32(r) + if err != nil { return p, err } + t.count = uint(count) + p.populations = append(p.populations, t) + } + err = r.finish() + return p, err +} type PestServer struct { port uint16 @@ -19,11 +222,15 @@ func NewPestServer(port uint16) *PestServer { type PestSession struct{ con net.Conn + r *PacketReader + w *PacketWriter } func NewPestSession(con net.Conn) *PestSession { return &PestSession{ con, + newPacketReader(bufio.NewReader(con)), + newPacketWriter(bufio.NewWriter(con)), } } @@ -52,7 +259,45 @@ func (s *PestServer) processClient(con net.Conn) { go session.pestHandler() } -func (s *PestSession) pestHandler() { - +func (s *PestSession) readHello() error { + hello, err := readHelloPacket(s.r) + if err != nil { return err } + fmt.Printf("%+v\n", hello) + if hello.protocol != "pestcontrol" { return fmt.Errorf("unknown protocol")} + if hello.version != 1 { + return fmt.Errorf("unknown version") + } + return nil } +func (s *PestSession) sendHello() error { + p := HelloPacket{"pestcontrol", 1} + err := writeHelloPacket(s.w, p) + return err +} + +func (s *PestSession) sendError(err error) error { + return nil +} + +func (s *PestSession) pestHandler() error { + err := s.readHello() + if err != nil { return err } + fmt.Println("got hello, sending back") + err = s.sendHello() + var v SiteVisitPacket + for err == nil { + v, err = readSiteVisitPacket(s.r) + if err != nil { return err } + fmt.Printf("%+v\n", v) + } + return err +} + +func (s *PestSession) run() { + err := s.pestHandler() + fmt.Printf("%e\n", err) + if err != nil { + s.sendError(err) + } +} diff --git a/speed.go b/speed.go index 93f22fb..8cded96 100644 --- a/speed.go +++ b/speed.go @@ -12,34 +12,6 @@ import ( "bytes" ) -var BE = binary.BigEndian - -func readU8(r io.Reader) (uint8, error) { - var result uint8 - err := binary.Read(r, BE, &result); - return result, err -} - -func readU16(r io.Reader) (uint16, error) { - var result uint16 - err := binary.Read(r, BE, &result); - return result, err -} -func readU32(r io.Reader) (uint32, error) { - var result uint32 - err := binary.Read(r, BE, &result); - return result, err -} -func readString(r io.Reader) (string, error) { - l, err := readU8(r) - if err == nil && l == 0 { - err = fmt.Errorf("invalid string length 0") - } - if err != nil { return "", err} - buf := make([]byte, l) - _, err = io.ReadFull(r, buf) - return string(buf), err -} type SpeedMessage interface { serialize() []byte } diff --git a/tests/pest.py b/tests/pest.py new file mode 100644 index 0000000..1d0095e --- /dev/null +++ b/tests/pest.py @@ -0,0 +1,46 @@ +import socket +from struct import pack, unpack +from time import time, sleep + +def sock(): + addr = ("localhost", 13370) + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect(addr) + return s + +def recv(s): + ptype, plen = unpack(">BI", s.recv(5)) + data = s.recv(plen-6) + cs = s.recv(1) + print(f"I {ptype:X} {data}") + +def pstr(s): + if type(s) == str: + s = s.encode() + return pack(">I", len(s)) + s + +def snd(s, typ, dat): + p = pack(">BI", typ, len(dat)+6) + dat + csum = (-sum(p)) % 256 + c = bytes([ csum ]) + s.sendall(p + c) + +def hello(s): + h = pstr("pestcontrol") + h += pack(">I", 1) + snd(s, 0x50, h) + +def sivi(s): + h = pack(">II", 1337, 3) + h += pstr("green starred rat") + h += pack(">I", 765) + h += pstr("red footed elephant") + h += pack(">I", 6029) + h += pstr("black tailed unicorn") + h += pack(">I", 1234) + snd(s, 0x58, h) + +s = sock() +hello(s) +recv(s) +sivi(s) \ No newline at end of file