pest: client interaction complete
This commit is contained in:
parent
f0b7401ed7
commit
fc6585ee30
61
binutil.go
Normal file
61
binutil.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var BE = binary.BigEndian
|
||||||
|
|
||||||
|
func readU8(r io.Reader) (uint8, error) {
|
||||||
|
var result uint8
|
||||||
|
err := binary.Read(r, BE, &result);
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func readU16(r io.Reader) (uint16, error) {
|
||||||
|
var result uint16
|
||||||
|
err := binary.Read(r, BE, &result);
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
func readU32(r io.Reader) (uint32, error) {
|
||||||
|
var result uint32
|
||||||
|
err := binary.Read(r, BE, &result);
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
func readString(r io.Reader) (string, error) {
|
||||||
|
l, err := readU8(r)
|
||||||
|
if err == nil && l == 0 {
|
||||||
|
err = fmt.Errorf("invalid string length 0")
|
||||||
|
}
|
||||||
|
if err != nil { return "", err}
|
||||||
|
buf := make([]byte, l)
|
||||||
|
_, err = io.ReadFull(r, buf)
|
||||||
|
return string(buf), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func readLString(r io.Reader) (string, error) {
|
||||||
|
l, err := readU32(r)
|
||||||
|
if err == nil && l == 0 {
|
||||||
|
err = fmt.Errorf("invalid string length 0")
|
||||||
|
}
|
||||||
|
if err != nil { return "", err}
|
||||||
|
buf := make([]byte, l)
|
||||||
|
_, err = io.ReadFull(r, buf)
|
||||||
|
return string(buf), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeLString(w io.Writer, s string) error {
|
||||||
|
writeU32(w,uint32(len(s)))
|
||||||
|
_, err := w.Write([]byte(s))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeU8(w io.Writer, i uint8) error {
|
||||||
|
return binary.Write(w, BE, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeU32(w io.Writer, i uint32) error {
|
||||||
|
return binary.Write(w, BE, i);
|
||||||
|
}
|
249
pest.go
249
pest.go
@ -4,8 +4,211 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const PACKET_HELLO = 0x50
|
||||||
|
const PACKET_SITE_VISIT = 0x58
|
||||||
|
|
||||||
|
type PacketReader struct {
|
||||||
|
count uint
|
||||||
|
sum byte
|
||||||
|
ptype byte
|
||||||
|
plen uint
|
||||||
|
base *bufio.Reader
|
||||||
|
started bool
|
||||||
|
}
|
||||||
|
func newPacketReader(r *bufio.Reader) *PacketReader {
|
||||||
|
return &PacketReader{
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
r,
|
||||||
|
false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *PacketReader) Read(b []byte) (int, error) {
|
||||||
|
bytesLeft := r.plen - r.count
|
||||||
|
if len(b) > int(bytesLeft) {
|
||||||
|
return 0, fmt.Errorf("not enough bytes left in package")
|
||||||
|
}
|
||||||
|
if ! r.started {
|
||||||
|
return 0, fmt.Errorf("read called but not started")
|
||||||
|
}
|
||||||
|
n, err := r.base.Read(b)
|
||||||
|
if err != nil { return 0, err }
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
r.sum += b[i]
|
||||||
|
}
|
||||||
|
r.count += uint(len(b))
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *PacketReader) start(ptype byte) error {
|
||||||
|
if r.started {
|
||||||
|
return fmt.Errorf("already started")
|
||||||
|
}
|
||||||
|
r.plen = 5
|
||||||
|
r.sum = 0
|
||||||
|
r.count = 0
|
||||||
|
r.started = true
|
||||||
|
r.ptype, _ = readU8(r)
|
||||||
|
if r.ptype != ptype {
|
||||||
|
return fmt.Errorf("unexpexted packet type")
|
||||||
|
}
|
||||||
|
plen, _ := readU32(r)
|
||||||
|
r.plen = uint(plen)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *PacketReader) finish() error {
|
||||||
|
if ! r.started {
|
||||||
|
return fmt.Errorf("not started")
|
||||||
|
}
|
||||||
|
_, err := readU8(r) //checksum
|
||||||
|
r.started = false
|
||||||
|
if err != nil { return err }
|
||||||
|
if r.count != r.plen {
|
||||||
|
return fmt.Errorf("packet len mismatch exp %d read %d", r.plen, r.count)
|
||||||
|
}
|
||||||
|
if r.sum != 0 {
|
||||||
|
return fmt.Errorf("checksum error: %d", r.sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketWriter struct {
|
||||||
|
count uint
|
||||||
|
sum byte
|
||||||
|
ptype byte
|
||||||
|
buf *bytes.Buffer
|
||||||
|
base *bufio.Writer
|
||||||
|
started bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPacketWriter(r *bufio.Writer) *PacketWriter {
|
||||||
|
return &PacketWriter{
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
nil,
|
||||||
|
r,
|
||||||
|
false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *PacketWriter) start(ptype byte) error {
|
||||||
|
if w.started {
|
||||||
|
return fmt.Errorf("already started")
|
||||||
|
}
|
||||||
|
w.buf = bytes.NewBuffer([]byte{})
|
||||||
|
w.count = 0
|
||||||
|
w.sum = ptype
|
||||||
|
w.started = true
|
||||||
|
w.ptype = ptype
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *PacketWriter) finish() error {
|
||||||
|
if ! w.started {
|
||||||
|
return fmt.Errorf("not started")
|
||||||
|
}
|
||||||
|
w.started = false
|
||||||
|
plen := uint32(w.count + 6)
|
||||||
|
c := byte(plen) + byte(plen >> 8) + byte(plen >> 16) + byte(plen >> 24) + w.sum
|
||||||
|
c = byte(- c)
|
||||||
|
err := writeU8(w.base, w.ptype)
|
||||||
|
if err != nil { return err }
|
||||||
|
err = writeU32(w.base, plen)
|
||||||
|
if err != nil { return err }
|
||||||
|
_,err = w.base.Write(w.buf.Bytes())
|
||||||
|
if err != nil { return err }
|
||||||
|
err = writeU8(w.base, c)
|
||||||
|
if err != nil { return err }
|
||||||
|
err = w.base.Flush()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *PacketWriter) Write(b []byte) (int, error) {
|
||||||
|
if ! w.started {
|
||||||
|
return 0, fmt.Errorf("write called but not started")
|
||||||
|
}
|
||||||
|
n, err := w.buf.Write(b)
|
||||||
|
w.count += uint(n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
w.sum += b[i]
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
type HelloPacket struct {
|
||||||
|
protocol string
|
||||||
|
version uint
|
||||||
|
}
|
||||||
|
|
||||||
|
func readHelloPacket(r *PacketReader) (HelloPacket, error) {
|
||||||
|
var p HelloPacket
|
||||||
|
var err error
|
||||||
|
err = r.start(PACKET_HELLO)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
p.protocol, err = readLString(r)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
version, err := readU32(r)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
p.version = uint(version)
|
||||||
|
err = r.finish()
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeHelloPacket(w *PacketWriter, p HelloPacket) error {
|
||||||
|
err := w.start(PACKET_HELLO)
|
||||||
|
if err != nil { return err }
|
||||||
|
err = writeLString(w, p.protocol)
|
||||||
|
if err != nil { return err }
|
||||||
|
err = writeU32(w, uint32(p.version))
|
||||||
|
if err != nil { return err }
|
||||||
|
err = w.finish()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tally struct {
|
||||||
|
species string
|
||||||
|
count uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type SiteVisitPacket struct {
|
||||||
|
siteId uint
|
||||||
|
populations []Tally
|
||||||
|
}
|
||||||
|
|
||||||
|
func readSiteVisitPacket(r *PacketReader) (SiteVisitPacket, error) {
|
||||||
|
var p SiteVisitPacket
|
||||||
|
var err error
|
||||||
|
err = r.start(PACKET_SITE_VISIT)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
siteId, err := readU32(r)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
p.siteId = uint(siteId)
|
||||||
|
talCount, err := readU32(r)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
for i := 0 ; i < int(talCount) ; i++ {
|
||||||
|
var t Tally
|
||||||
|
var count uint32
|
||||||
|
t.species, err = readLString(r)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
count, err = readU32(r)
|
||||||
|
if err != nil { return p, err }
|
||||||
|
t.count = uint(count)
|
||||||
|
p.populations = append(p.populations, t)
|
||||||
|
}
|
||||||
|
err = r.finish()
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
type PestServer struct {
|
type PestServer struct {
|
||||||
port uint16
|
port uint16
|
||||||
@ -19,11 +222,15 @@ func NewPestServer(port uint16) *PestServer {
|
|||||||
|
|
||||||
type PestSession struct{
|
type PestSession struct{
|
||||||
con net.Conn
|
con net.Conn
|
||||||
|
r *PacketReader
|
||||||
|
w *PacketWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPestSession(con net.Conn) *PestSession {
|
func NewPestSession(con net.Conn) *PestSession {
|
||||||
return &PestSession{
|
return &PestSession{
|
||||||
con,
|
con,
|
||||||
|
newPacketReader(bufio.NewReader(con)),
|
||||||
|
newPacketWriter(bufio.NewWriter(con)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,7 +259,45 @@ func (s *PestServer) processClient(con net.Conn) {
|
|||||||
go session.pestHandler()
|
go session.pestHandler()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PestSession) pestHandler() {
|
func (s *PestSession) readHello() error {
|
||||||
|
hello, err := readHelloPacket(s.r)
|
||||||
|
if err != nil { return err }
|
||||||
|
fmt.Printf("%+v\n", hello)
|
||||||
|
if hello.protocol != "pestcontrol" { return fmt.Errorf("unknown protocol")}
|
||||||
|
if hello.version != 1 {
|
||||||
|
return fmt.Errorf("unknown version")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *PestSession) sendHello() error {
|
||||||
|
p := HelloPacket{"pestcontrol", 1}
|
||||||
|
err := writeHelloPacket(s.w, p)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PestSession) sendError(err error) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PestSession) pestHandler() error {
|
||||||
|
err := s.readHello()
|
||||||
|
if err != nil { return err }
|
||||||
|
fmt.Println("got hello, sending back")
|
||||||
|
err = s.sendHello()
|
||||||
|
var v SiteVisitPacket
|
||||||
|
for err == nil {
|
||||||
|
v, err = readSiteVisitPacket(s.r)
|
||||||
|
if err != nil { return err }
|
||||||
|
fmt.Printf("%+v\n", v)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PestSession) run() {
|
||||||
|
err := s.pestHandler()
|
||||||
|
fmt.Printf("%e\n", err)
|
||||||
|
if err != nil {
|
||||||
|
s.sendError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
28
speed.go
28
speed.go
@ -12,34 +12,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
)
|
)
|
||||||
|
|
||||||
var BE = binary.BigEndian
|
|
||||||
|
|
||||||
func readU8(r io.Reader) (uint8, error) {
|
|
||||||
var result uint8
|
|
||||||
err := binary.Read(r, BE, &result);
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func readU16(r io.Reader) (uint16, error) {
|
|
||||||
var result uint16
|
|
||||||
err := binary.Read(r, BE, &result);
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
func readU32(r io.Reader) (uint32, error) {
|
|
||||||
var result uint32
|
|
||||||
err := binary.Read(r, BE, &result);
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
func readString(r io.Reader) (string, error) {
|
|
||||||
l, err := readU8(r)
|
|
||||||
if err == nil && l == 0 {
|
|
||||||
err = fmt.Errorf("invalid string length 0")
|
|
||||||
}
|
|
||||||
if err != nil { return "", err}
|
|
||||||
buf := make([]byte, l)
|
|
||||||
_, err = io.ReadFull(r, buf)
|
|
||||||
return string(buf), err
|
|
||||||
}
|
|
||||||
type SpeedMessage interface {
|
type SpeedMessage interface {
|
||||||
serialize() []byte
|
serialize() []byte
|
||||||
}
|
}
|
||||||
|
46
tests/pest.py
Normal file
46
tests/pest.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import socket
|
||||||
|
from struct import pack, unpack
|
||||||
|
from time import time, sleep
|
||||||
|
|
||||||
|
def sock():
|
||||||
|
addr = ("localhost", 13370)
|
||||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
s.connect(addr)
|
||||||
|
return s
|
||||||
|
|
||||||
|
def recv(s):
|
||||||
|
ptype, plen = unpack(">BI", s.recv(5))
|
||||||
|
data = s.recv(plen-6)
|
||||||
|
cs = s.recv(1)
|
||||||
|
print(f"I {ptype:X} {data}")
|
||||||
|
|
||||||
|
def pstr(s):
|
||||||
|
if type(s) == str:
|
||||||
|
s = s.encode()
|
||||||
|
return pack(">I", len(s)) + s
|
||||||
|
|
||||||
|
def snd(s, typ, dat):
|
||||||
|
p = pack(">BI", typ, len(dat)+6) + dat
|
||||||
|
csum = (-sum(p)) % 256
|
||||||
|
c = bytes([ csum ])
|
||||||
|
s.sendall(p + c)
|
||||||
|
|
||||||
|
def hello(s):
|
||||||
|
h = pstr("pestcontrol")
|
||||||
|
h += pack(">I", 1)
|
||||||
|
snd(s, 0x50, h)
|
||||||
|
|
||||||
|
def sivi(s):
|
||||||
|
h = pack(">II", 1337, 3)
|
||||||
|
h += pstr("green starred rat")
|
||||||
|
h += pack(">I", 765)
|
||||||
|
h += pstr("red footed elephant")
|
||||||
|
h += pack(">I", 6029)
|
||||||
|
h += pstr("black tailed unicorn")
|
||||||
|
h += pack(">I", 1234)
|
||||||
|
snd(s, 0x58, h)
|
||||||
|
|
||||||
|
s = sock()
|
||||||
|
hello(s)
|
||||||
|
recv(s)
|
||||||
|
sivi(s)
|
Loading…
Reference in New Issue
Block a user