diff --git a/main.go b/main.go index 978c5b2..3ab7e95 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,7 @@ type Server interface { func main() { var challenge int - flag.IntVar(&challenge, "challenge",4, "Challenge number") + flag.IntVar(&challenge, "challenge",6, "Challenge number") flag.Parse() var port uint16 @@ -29,6 +29,10 @@ func main() { server = NewChatServer(port); case 4: server = NewDatabaseServer(port); + case 5: + server = NewProxyServer(port); + case 6: + server = NewSpeedServer(port); default: fmt.Printf("Unknown challenge\n") os.Exit(1) diff --git a/proxy.go b/proxy.go new file mode 100644 index 0000000..db17845 --- /dev/null +++ b/proxy.go @@ -0,0 +1,70 @@ +package main + +import ( +"net" +"os" +"fmt" +"bufio" +"regexp" +) + +type ProxyServer struct { + port uint16 +} + +func NewProxyServer(port uint16) *ProxyServer { + return &ProxyServer{port} +} + +func (s *ProxyServer) Run() { + addr := fmt.Sprintf("0.0.0.0:%d", s.port) + server, err := net.Listen("tcp", addr) + if err != nil { + fmt.Println("Error listening:", err.Error()) + os.Exit(1) + } + defer server.Close() + fmt.Println("ProxyServer waiting for client...") + for { + connection, err := server.Accept() + if err != nil { + fmt.Println("Error accepting: ", err.Error()) + os.Exit(1) + } + fmt.Println("client connected") + s.start(connection) + } +} + +func (s *ProxyServer) start(down net.Conn) { + addr := "chat.protohackers.com:16963" + up, err := net.Dial("tcp", addr) + if err != nil { + fmt.Println(err) + return + } + go s.stream(up, down) + go s.stream(down, up) +} + +func (s *ProxyServer) stream(input net.Conn, output net.Conn) { + r := bufio.NewReader(input) + var err error + for err == nil { + msg, err := r.ReadString('\n') + if err != nil { break } + msg = s.transform(msg) + _, err = output.Write([]byte(msg)) + } + input.Close() + output.Close() +} + +func (s *ProxyServer) transform(input string) string { + fmt.Printf(">> %s", input) + tony := "${1}7YWHMfk9JZe0LM0g1ZauHuiSxhI${3}" + r := regexp.MustCompile(`(^|\s)(7[a-zA-Z0-9]{25,34})($|\s)`) + output := r.ReplaceAllString(input, tony) + fmt.Printf("<< %s", output) + return output +} \ No newline at end of file diff --git a/speed.go b/speed.go new file mode 100644 index 0000000..9b6e2f1 --- /dev/null +++ b/speed.go @@ -0,0 +1,171 @@ +package main + +import ( +"net" +"os" +"fmt" +"time" +"encoding/binary" +) + +var BE = binary.BigEndian + +type SpeedMessage interface { + serialize() []byte +} + +type ErrorMessage struct { + err string +} + +func (m ErrorMessage) serialize() []byte { + return []byte{0x10} +} + +type HeartbeatMessage struct {} + +func (m HeartbeatMessage) serialize() []byte { + return []byte{0x41} +} + +type SpeedServer struct { + port uint16 + tickQ chan bool + clients map[uint]*SpeedClient + numClients uint +} + +func NewSpeedServer(port uint16) *SpeedServer { + s := &SpeedServer{ + port, + make(chan bool), + make(map[uint]*SpeedClient), + 0, + } + + return s +} + +type SpeedClient struct { + clientId uint + con net.Conn + q chan SpeedMessage + ctype uint8 + lastHeartbeat int64 + heartbeat int64 +} + +func NewSpeedClient(clientId uint, con net.Conn) *SpeedClient { + return &SpeedClient{ + clientId, + con, + make(chan SpeedMessage), + 0, + 0, + 0, + } +} + +func (s *SpeedServer) tick() { + cur := time.Now().UnixMilli() + m := HeartbeatMessage{} + for _, c := range s.clients { + if c.heartbeat == 0 { continue } + + if c.lastHeartbeat + c.heartbeat < cur { + c.q <- m + c.lastHeartbeat = cur + } + } +} + +func (s *SpeedServer) ticker() { + for { + s.tickQ <- true + time.Sleep(100 * time.Millisecond) + } +} + +func (s *SpeedServer) main() { + for { + select { + case _ = <- s.tickQ: s.tick() + } + } +} + +func (s *SpeedServer) Run() { + go s.listen() + go s.ticker() + s.main() +} + +func (s *SpeedServer) listen() { + addr := fmt.Sprintf("0.0.0.0:%d", s.port) + server, err := net.Listen("tcp", addr) + if err != nil { + fmt.Println("Error listening:", err.Error()) + os.Exit(1) + } + defer server.Close() + fmt.Println("SpeedServer waiting for client...") + for { + connection, err := server.Accept() + if err != nil { + fmt.Println("Error accepting: ", err.Error()) + os.Exit(1) + } + + s.start(connection) + } +} + +func (s *SpeedServer) start(con net.Conn) { + c := NewSpeedClient(s.numClients, con) + s.clients[s.numClients] = c + s.numClients += 1 + go c.sender() + go c.receiver() +} + +func (c *SpeedClient) sender() { + for m := range c.q { + c.con.Write(m.serialize()) + } + + c.con.Close() + fmt.Printf("Client %d closed\n", c.clientId) +} + +func (c *SpeedClient) receiver() { + var err error + fmt.Printf("client %d connected\n", c.clientId) + for { + var mType uint8 + err = binary.Read(c.con, BE, &mType); + if err != nil { break } + switch mType { + case 0x40: err = c.hdlWantHeartbeat() + + default: err = fmt.Errorf("Unknown message type 0x%x", mType) + } + if err != nil { break } + } + if err != nil { + fmt.Printf("Client %d error %s\n", c.clientId, err) + } else { + fmt.Printf("Client %d closing\n", c.clientId) + } + close(c.q) +} + +func (c *SpeedClient) hdlWantHeartbeat() error { + var wantedHb uint32 + err := binary.Read(c.con, BE, &wantedHb) + if err != nil { return err } + c.heartbeat = int64(wantedHb) * 100 + fmt.Printf("heartbeat for %d set to %d\n", c.clientId, wantedHb) + return nil +} + +