diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cd3d225 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +logs \ No newline at end of file diff --git a/analyze/reverse.py b/analyze/reverse.py new file mode 100644 index 0000000..5f14c47 --- /dev/null +++ b/analyze/reverse.py @@ -0,0 +1,24 @@ +import sys + +ses = {} + +fil = open(sys.argv[1]) +for l in fil: + l = l.strip() + p = l.split("]") + tst = p[1][2:] + txt = p[2].strip() + tp = txt.split(" ") + print(txt) + if txt.startswith("NOTE:successfully connected with session"): + sid = int(tp[-1]) + ses[sid] = [sid, tst, False] + if txt.startswith("NOTE:closed session"): + sid = int(tp[2]) + ses[sid][2] = True +print("unclosed sessions:") +for v in ses.values(): + if v[2] == False: + print(v[1], v[0]) + + \ No newline at end of file diff --git a/main.go b/main.go index 3ab7e95..b57f269 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,7 @@ type Server interface { func main() { var challenge int - flag.IntVar(&challenge, "challenge",6, "Challenge number") + flag.IntVar(&challenge, "challenge",7, "Challenge number") flag.Parse() var port uint16 @@ -33,6 +33,8 @@ func main() { server = NewProxyServer(port); case 6: server = NewSpeedServer(port); + case 7: + server = NewReverseServer(port); default: fmt.Printf("Unknown challenge\n") os.Exit(1) diff --git a/reverse.go b/reverse.go new file mode 100644 index 0000000..8db2167 --- /dev/null +++ b/reverse.go @@ -0,0 +1,435 @@ +package main + +import ( +"net" +"fmt" +"os" +"bufio" +"bytes" +"time" +) + +const MAXPACKET = 999 + +func b(s string) []byte { + return []byte(s) +} + +func Min(a uint32, b uint32) uint32 { + if a < b { return a} + return b +} + +func Max(a uint32, b uint32) uint32 { + if a > b { return a} + return b +} + +func b2i(s []byte) (uint32, error) { + var num uint32 + _, err := fmt.Sscanf(string(s), "%d", &num) + return num, err +} + +func reverse(s []byte) []byte{ + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } + return s +} + +func splitSlash(s []byte) ([]byte, []byte) { + parts := bytes.SplitAfterN(s, b("/"), 2) + if len(parts) != 2 { + return b(""),b("") + } + p := parts[0] + return p[:len(p)-1], parts[1] +} + +func calcMaxEscaped(b []byte, limit uint32) uint32 { + bLen := len(b) + cnt := 0 + pos := 0 + for cnt + pos < int(limit) && pos < bLen { + if b[pos] == '/' || b[pos] =='\\' { + cnt += 1 + } + pos += 1 + + } + return uint32(pos) +} + +func escape(s []byte) []byte { + s = bytes.Replace(s, b("\\"), b("\\\\"), -1) + return bytes.Replace(s, b("/"), b("\\/"), -1) +} + +func unescape(s []byte) []byte { + s = bytes.Replace(s, b("\\/"), b("/"), -1) + return bytes.Replace(s, b("\\\\"), b("\\"), -1) +} + +func checkscape(s []byte) bool { + i := 0 + for i < len(s) - 1 { + if s[i] == '/' { + return false + } + if s[i] == '\\' { + i += 1 + } + i += 1 + } + return true +} + +type ReverseServer struct { + port uint16 + +} + +func NewReverseServer(port uint16) *ReverseServer { + return &ReverseServer{ + port, + } +} + +func (s *ReverseServer) Run() { + sock := NewLrcpServer(s.port) + for { + ses := sock.Accept() + go s.handleSession(ses) + } +} + +func (s *ReverseServer) handleSession(ses *LrcpSession) { + r := bufio.NewReader(ses) + for { + data, err := r.ReadBytes('\n') + if err != nil { return } + data = reverse(data[:len(data)-1]) + ses.Write(append(data, '\n')) + } +} + +type ReceivedData struct { + data []byte + pos uint32 +} + +type LrcpSession struct { + sessionId uint32 + ack uint32 + seq uint32 + acked uint32 + readPtr uint32 + readBuf uint32 + pc net.PacketConn + addr net.Addr + recv map[uint32]*ReceivedData + dataQ chan uint32 + ackQ chan uint32 + closed bool + last int64 +} + +func NewLrcpSession(sessionId uint32, pc net.PacketConn, addr net.Addr) *LrcpSession { + return &LrcpSession{ + sessionId, + 0, + 0, + 0, + 0, + 0, + pc, + addr, + make(map[uint32]*ReceivedData), + make(chan uint32, 100), + make(chan uint32, 100), + false, + 0, + } +} + +func (s *LrcpSession) snd(buf []byte) error { + var err error + fmt.Printf("> %s\n", bytes.Replace(buf, b("\n"), b("\\n"), -1)) + _, err = s.pc.WriteTo(buf, s.addr) + return err + +} +func (s *LrcpSession) Read(buf []byte) (int, error) { + //fmt.Printf("reading len %d\n", len(buf)) + for s.readPtr >= s.ack { + fmt.Printf("ack %d\n", s.ack) + select { + case <-s.dataQ: + case <-time.After(60 * time.Second): + s.sendClose() + return 0, fmt.Errorf("timeout") + } + } + //fmt.Println("read has data") + r, ok := s.recv[s.readBuf] + if ! ok { + return 0, fmt.Errorf("recv buf %d not found", s.readBuf) + } + b := r.data + posInBuf := s.readPtr - s.readBuf + bytesLeft := uint32(len(b)) - posInBuf + l := Min(uint32(len(buf)), bytesLeft) + //fmt.Printf("Read returning %d bytes\n", l) + copy(buf, b[posInBuf:posInBuf+l]) + bytesLeft -= l + s.readPtr += l + if bytesLeft == 0 { + s.readBuf += uint32(len(b)) + } + return int(l), nil +} + +func (s *LrcpSession) waitAck() bool { + acked := false + + for ! acked { + select { + case <- s.ackQ: acked = s.seq <= s.acked + case <-time.After(3 * time.Second): return false + } + } + + return acked +} + +func (s *LrcpSession) WriteString(b string) (int, error) { + return s.Write([]byte(b)) +} + +func (s *LrcpSession) Write(b []byte) (int, error) { + bytesLeft := uint32(len(b)) + pos := uint32(0) + for bytesLeft > 0 { + // max packed len is 1k + // 5 slashes + "data" + 2 decimal ints =~ 23, call it 30 + if s.closed { + return int(pos), fmt.Errorf("closed") + } + l := calcMaxEscaped(b[pos:pos+bytesLeft], MAXPACKET-30) + err := s.writeBlock(b[pos:pos+l]) + if err != nil { return int(pos), err} + pos += l + bytesLeft -= l + } + return int(pos), nil +} + +func (s *LrcpSession) writeBlock(b []byte) error { + bLen := uint32(len(b)) + b = escape(b) + hdr := fmt.Sprintf("/data/%d/%d/", s.sessionId, s.seq) + + packet := append([]byte(hdr), b...) + packet = append(packet, '/') + s.seq += bLen + acked := false + var err error + retries := 30 + for ! acked && retries > 0{ + if s.closed { + return fmt.Errorf("closed") + } + //fmt.Printf("waiting for seq %d acked %d\n", s.seq, s.acked) + err = s.snd(packet) + if err != nil { return err } + acked = s.waitAck() + retries -= 1 + } + return err +} + +func (s *LrcpSession) sendAck() { + m := fmt.Sprintf("/ack/%d/%d/", s.sessionId, s.ack) + s.snd([]byte(m)) +} + +func (s *LrcpSession) receiveData( pos uint32, data []byte ){ + // todo if pos in recv but len bigger, make a separate entry with higher pos and smalle data + //fmt.Printf("receiving %d at %d\n", len(data), pos) + //fmt.Printf("%+v\n", s.recv) + d, ok := s.recv[pos] + for ok { + dlen := len(d.data) + if dlen >= len(data) { + //fmt.Printf("found %d bytes at %d which is more than %d so bailing out\n", dlen, pos, len(data)) + s.sendAck() + return + } + pos += uint32(dlen) + data = data[dlen:] + //fmt.Printf("found %d bytes at pos, remaining %d at %d\n", dlen, len(data), pos) + d, ok = s.recv[pos] + } + r := &ReceivedData{data, pos} + s.recv[pos] = r + s.processNewData() +} + +func (s *LrcpSession) processNewData() { + r, ok := s.recv[s.ack] + newData := ok + + for ok { + s.ack = r.pos + uint32(len(r.data)) + r, ok = s.recv[s.ack] + } + s.sendAck() + if newData { + s.dataQ <- s.ack + } +} + +func (s *LrcpSession) close() { + s.closed = true +} + +func (s *LrcpSession) sendClose() { + m := fmt.Sprintf("/close/%d/", s.sessionId) + s.snd([]byte(m)) +} + + +type LrcpServer struct { + port uint16 + pc net.PacketConn + sessions map[uint32]*LrcpSession + acceptQ chan *LrcpSession +} + +func NewLrcpServer(port uint16) *LrcpServer { + s := &LrcpServer{ + port, + nil, + make(map[uint32]*LrcpSession), + make(chan *LrcpSession, 100), + } + go s.run() + return s +} + +func (s *LrcpServer) run() { + addr := fmt.Sprintf("0.0.0.0:%d", s.port) + var err error + s.pc, err = net.ListenPacket("udp", addr) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + defer s.pc.Close() + buf := make([]byte, 1000) + fmt.Println("Reverse server listening for packets") + for { + n, addr, err := s.pc.ReadFrom(buf) + if err != nil { + continue + } + s.handle(addr, buf[:n]) + } +} + +func (s *LrcpServer) Accept() *LrcpSession { + ses := <- s.acceptQ + //fmt.Printf("Started session %d\n", ses.sessionId) + ses.sendAck() + return ses +} + +func (s *LrcpServer) handle(addr net.Addr, req []byte) { + fmt.Printf("< %s\n", bytes.Replace(req, b("\n"), []byte("\\n"), -1)) + if ! bytes.HasPrefix(req, b("/")) || ! bytes.HasSuffix(req, b("/")) { + return + } + _, req = splitSlash(req) + cmd, args := splitSlash(req) + var err error + strSessionId, args := splitSlash(args) + sessionId, err := b2i(strSessionId) + if err != nil { return } + switch string(cmd) { + case "connect": err = s.hdlConnect(sessionId, addr) + case "data": err = s.hdlData(sessionId, args, addr) + case "ack": err = s.hdlAck(sessionId, args) + case "close": err = s.hdlClose(sessionId) + default: err = fmt.Errorf("unimplemented command %s\n", cmd) + } + if err != nil { + fmt.Println(err) + + } +} + +func (s *LrcpServer) hdlConnect(sessionId uint32, addr net.Addr) error { + + ses, ok := s.sessions[sessionId] + if ok && ! ses.closed { + ses.sendAck() + return nil + } + ses = NewLrcpSession(sessionId, s.pc, addr) + ses.last = time.Now().Unix() + s.sessions[sessionId] = ses + s.acceptQ <- ses + return nil +} + + +func (s *LrcpServer) hdlData(sessionId uint32, req []byte, addr net.Addr) error { + ses, ok := s.sessions[sessionId] + if !ok { + s.sendClose(sessionId, addr) + return fmt.Errorf("unknown session %d", sessionId)} + ses.last = time.Now().Unix() + strPos, req := splitSlash(req) + pos, err:= b2i(strPos) + if err != nil { return err } + + data := req[:len(req)-1] + if !checkscape(data) { + return fmt.Errorf("invalid data") + } + data = unescape(data) + //fmt.Printf("Data len %d for session %d pos %d\n", len(data), sessionId, pos) + if len(data) == 0 { + ses.sendAck() + } else { + ses.receiveData(pos, data) + } + return nil +} + +func (s *LrcpServer) hdlAck(sessionId uint32, req []byte) error { + ses, ok := s.sessions[sessionId] + if !ok { return fmt.Errorf("unknown session %d", sessionId)} + strPos, req := splitSlash(req) + pos, err:= b2i(strPos) + if err != nil { return err } + ses.acked = Max(ses.acked, pos) + ses.ackQ <- pos + return nil +} + +func (s *LrcpServer) hdlClose(sessionId uint32) error { + ses, ok := s.sessions[sessionId] + if !ok { return fmt.Errorf("unknown session %d", sessionId)} + s.sendClose(ses.sessionId, ses.addr) + ses.close() + return nil +} + +func (s *LrcpServer) sendClose(sessionId uint32, addr net.Addr) { + m := fmt.Sprintf("/close/%d/", sessionId) + fmt.Printf("> %s\n", m) + + s.pc.WriteTo([]byte(m), addr) +} diff --git a/tests/reverse.py b/tests/reverse.py new file mode 100644 index 0000000..8325f85 --- /dev/null +++ b/tests/reverse.py @@ -0,0 +1,69 @@ +import socket +from time import time + +addr = ("localhost", 13370) +ses = int(time()) +def sock(): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + return s + +def connect(s): + s.sendto(f"/connect/{ses}/".encode(), addr) + +def data(s): + s.sendto(f"/data/{ses}/0/12/".encode(), addr) +def data2(s): + s.sendto(f"/data/{ses}/2/Hello world this is a long long test!\n/".encode(), addr) + +def escdata(s): + s.sendto(f"/data/{ses}/0/foo\/bar\/baz\nfoo\\bar\\baz\n/".encode(), addr) + +def chunk1(s): + s.sendto(f"/data/{ses}/0/snack/".encode(), addr) + +def chunk2(s): + s.sendto(f"/data/{ses}/0/snack snoep\nsnap/".encode(), addr) + +def chunk3(s): + s.sendto(f"/data/{ses}/0/snack snoep\nsnap snep\nslap\n/".encode(), addr) + +def simple(s): + connect(s) + print(s.recv(1024).decode()) + data2(s) + print(s.recv(1024).decode()) + data2(s) + print(s.recv(1024).decode()) + data(s) + print(s.recv(1024).decode()) + print(s.recv(1024).decode()) + print(s.recv(1024).decode()) + +def escaper(s): + connect(s) + print(s.recv(1024).decode()) + escdata(s) + + print(s.recv(1024).decode()) + +def waiter(s): + connect(s) + print(s.recv(1024).decode()) + + +def chunker(s): + connect(s) + print(s.recv(1024).decode()) + chunk1(s) + print(s.recv(1024).decode()) + chunk2(s) + print(s.recv(1024).decode()) + chunk3(s) + print(s.recv(1024).decode()) + print(s.recv(1024).decode()) + + +#simple(sock()) +#escaper(sock()) +#waiter(sock()) +chunker(sock()) \ No newline at end of file