diff --git a/secure.go b/secure.go index 07c8be1..e09962e 100644 --- a/secure.go +++ b/secure.go @@ -6,9 +6,104 @@ import ( "fmt" "bufio" "strings" +"bytes" "strconv" +"math/bits" ) +type Op struct { + op byte + arg byte +} + +func parseSpec(c []byte) ([]Op, error) { + r := make([]Op, 0, 5) + i := 0 + for i < len(c) { + last := i == len(c) - 1 + opid := c[i] + arg := byte(0) + if opid == 0 { return r, nil} + if opid == 2 || opid == 4 { + if last { return r, fmt.Errorf("missing byte in cipherspec") } + i += 1 + arg = c[i] + } + op := Op{opid, arg} + r = append(r, op) + i += 1 + } + return r, nil +} + +func crypt(c []Op, b byte, o int, enc bool) (byte, error) { + i := 0 + for i < len(c) { + x := i + if ! enc { + x = len(c) - i -1 + } + oper := c[x] + + switch oper.op { + case 0: return b, nil + case 1: + b = bits.Reverse8(b) + case 2: + b = b ^ oper.arg + case 3: + b = b ^ byte(o) + case 4: + if enc { + b = b + oper.arg + } else { + b = b - oper.arg + } + case 5: + if enc { + b = b + byte(o) + } else { + b = b - byte(o) + } + default: + return 0, fmt.Errorf("invalid cipherspec %d", oper.op) + } + i += 1 + } + return b, nil +} + +func readEncryptedLine(r *bufio.Reader, c []Op, offset int) (string, error) { + off := 0 + buf := make([]byte, 0, 32) + for { + b, err := r.ReadByte() + if err != nil { + return "", fmt.Errorf("read error") + } + b, err = crypt(c, b, off+offset, false) + if err != nil { + return "", err + } + if b == 10 { break } + buf = append(buf, b) + off += 1 + } + return string(buf), nil +} + +func cryptBytes(c []Op, dat []byte, offset int, encrypt bool) ([]byte, error) { + buf := make([]byte, 0, 32) + for off := range dat { + b, err := crypt(c, dat[off], off+offset, encrypt) + if err != nil { + return []byte(""), err + } + buf = append(buf, b) + } + return buf, nil +} + func largestOrder(msg string) string { orders := strings.Split(msg, ",") largest := 0 @@ -37,11 +132,17 @@ func NewSecureServer(port uint16) *SecureServer { type SecureSession struct{ con net.Conn + cipspec []Op + sent int + recv int } func NewSecureSession(con net.Conn) *SecureSession { return &SecureSession{ con, + make([]Op, 0), + 0, + 0, } } @@ -74,21 +175,42 @@ func (s *SecureServer) processClient(con net.Conn) { } -func (s *SecureSession) setEncryption(r *bufio.Reader) error { - secdef, err := r.ReadBytes('\x00') +func (s *SecureSession) setCipher(r *bufio.Reader) error { + cipspecStr, err := r.ReadBytes('\x00') if err != nil { return err } - fmt.Printf("secdef: %s\n", secdef) + fmt.Printf("cipspec: %x\n", cipspecStr) + cipspec, err := parseSpec(cipspecStr) + if err != nil { return err } + s.cipspec = cipspec + dummy := []byte("abcd") + result, err := cryptBytes(cipspec, dummy, 0, true) + if err != nil { return err } + if bytes.Equal(result, dummy) { + return fmt.Errorf("transparent cipspec") + } + decrypt, err := cryptBytes(cipspec, result, 0, false) + if err != nil { return err } + if ! bytes.Equal(dummy, decrypt) { + return fmt.Errorf("Sanity decrypt failed: %s", decrypt) + } + fmt.Println("Spec approved") return nil } func (s *SecureSession) query(r *bufio.Reader) error { - msg, err := r.ReadString('\n') + fmt.Printf("Recv %d sent %d \n", s.recv, s.sent) + msg, err := readEncryptedLine(r, s.cipspec, s.recv) + s.recv += len(msg) + 1 if err != nil { return err } msg = strings.TrimSpace(msg) - fmt.Printf("Q: %s", msg) + fmt.Printf("Q: %s\n", msg) a := largestOrder(msg) fmt.Printf("A: %s\n", a) - s.con.Write(append([]byte(a), []byte("\n")...)) + resp := append([]byte(a), []byte("\n")...) + resp, err = cryptBytes(s.cipspec, resp, s.sent, true) + s.sent += len(resp) + if err != nil { return err } + s.con.Write(resp) return nil } @@ -100,9 +222,10 @@ func (s *SecureSession) close() { func (s *SecureSession) SecureReceiver() { defer s.close() r := bufio.NewReaderSize(s.con, 5000) - err := s.setEncryption(r) + err := s.setCipher(r) for err == nil { err = s.query(r) } + fmt.Printf("%e\n", err) } diff --git a/tests/secure.py b/tests/secure.py index 62ac546..d9d075b 100644 --- a/tests/secure.py +++ b/tests/secure.py @@ -1,6 +1,7 @@ import socket from struct import pack, unpack from time import time, sleep +from binascii import unhexlify def sock(): addr = ("localhost", 13370) @@ -9,8 +10,10 @@ def sock(): return s def test(s): - s.sendall(b"\x0010x snoeperfloep,123x sapperflap\n") + s.sendall(b"\x04\x01\x0010x snoeperfloep,123x sapperflap\x11") print(s.recv(1024)) s.close() -test(sock()) \ No newline at end of file +def insane(s): s.sendall(unhexlify("03050105050500")) + +insane(sock()) \ No newline at end of file