From 52170029dae93c8ef725f67411e27a09a387f34b Mon Sep 17 00:00:00 2001 From: K Date: Sun, 15 Sep 2024 00:50:20 +0330 Subject: [PATCH] fix(server): replace x/net with gorilla (#32) --- go.mod | 2 +- go.sum | 4 ++-- server/server.go | 57 ++++++++++++++++++++++++------------------------ 3 files changed, 31 insertions(+), 32 deletions(-) diff --git a/go.mod b/go.mod index 785fbf1..9ecf942 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,10 @@ go 1.22.5 require ( github.com/btcsuite/btcd/btcec/v2 v2.3.4 + github.com/gorilla/websocket v1.5.3 github.com/mailru/easyjson v0.7.7 github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.17.3 - golang.org/x/net v0.29.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 0aa1fcd..f5a0495 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/decred/dcrd/crypto/blake256 v1.0.0 h1:/8DMNYp9SGi5f0w7uCm6d6M4OU2rGFK github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 h1:YLtO71vCjJRCBcrPMtQ9nqBsqpA1m5sE92cU+pd5Mcc= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= @@ -22,8 +24,6 @@ github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= -golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/server/server.go b/server/server.go index 0aa8d3e..a7dab3a 100644 --- a/server/server.go +++ b/server/server.go @@ -1,9 +1,7 @@ package server import ( - "errors" "fmt" - "io" "net" "net/http" "strconv" @@ -11,11 +9,12 @@ import ( "github.com/dezh-tech/immortal/types/filter" "github.com/dezh-tech/immortal/types/message" - "golang.org/x/net/websocket" + "github.com/gorilla/websocket" ) -// TODO::: replace with https://github.com/coder/websocket. -// TODO::: replace `log` with main logger. +var upgrader = websocket.Upgrader{ + CheckOrigin: func(_ *http.Request) bool { return true }, +} // Server represents a websocket serer which keeps track of client connections and handle them. type Server struct { @@ -34,7 +33,7 @@ func NewServer(cfg Config) *Server { // Start strats a new server instance. func (s *Server) Start() error { - http.Handle("/", websocket.Handler(s.handleWS)) + http.Handle("/", s) err := http.ListenAndServe(net.JoinHostPort(s.config.Bind, //nolint strconv.Itoa(int(s.config.Port))), nil) @@ -42,30 +41,30 @@ func (s *Server) Start() error { } // handleWS is WebSocket handler. -func (s *Server) handleWS(ws *websocket.Conn) { +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + s.connsLock.Lock() - s.conns[ws] = make(map[string]filter.Filters) + s.conns[conn] = make(map[string]filter.Filters) s.connsLock.Unlock() - s.readLoop(ws) + s.readLoop(conn) } // readLoop reads incoming messages from a client and answer to them. func (s *Server) readLoop(ws *websocket.Conn) { - buf := make([]byte, 1024) for { - n, err := ws.Read(buf) + _, buf, err := ws.ReadMessage() if err != nil { - if errors.Is(err, io.EOF) { - break - } - - continue + break } - msg := message.ParseMessage(buf[:n]) + msg := message.ParseMessage(buf) if msg == nil { - _, _ = ws.Write(message.MakeNotice("error: can't parse message.")) + _ = ws.WriteMessage(1, message.MakeNotice("error: can't parse message.")) continue } @@ -91,7 +90,7 @@ func (s *Server) handleReq(ws *websocket.Conn, m message.Message) { msg, ok := m.(*message.Req) if !ok { - _, _ = ws.Write(message.MakeNotice("error: can't parse REQ message.")) + _ = ws.WriteMessage(1, message.MakeNotice("error: can't parse REQ message.")) return } @@ -101,7 +100,7 @@ func (s *Server) handleReq(ws *websocket.Conn, m message.Message) { subs, ok := s.conns[ws] if !ok { - _, _ = ws.Write(message.MakeNotice(fmt.Sprintf("error: can't find connection %s.", + _ = ws.WriteMessage(1, message.MakeNotice(fmt.Sprintf("error: can't find connection %s.", ws.RemoteAddr()))) return @@ -119,30 +118,30 @@ func (s *Server) handleEvent(ws *websocket.Conn, m message.Message) { "error: can't parse EVENT message.", ) - _, _ = ws.Write(okm) + _ = ws.WriteMessage(1, okm) return } if !msg.Event.IsValid() { okm := message.MakeOK(false, - msg.SubscriptionID, + msg.Event.ID, "invalid: id or sig is not correct.", ) - _, _ = ws.Write(okm) + _ = ws.WriteMessage(1, okm) return } - _, _ = ws.Write(message.MakeOK(true, msg.SubscriptionID, "")) + _ = ws.WriteMessage(1, message.MakeOK(true, msg.Event.ID, "")) for conn, subs := range s.conns { for id, filters := range subs { if !filters.Match(msg.Event) { return } - _, _ = conn.Write(message.MakeEvent(id, msg.Event)) + _ = conn.WriteMessage(1, message.MakeEvent(id, msg.Event)) } } } @@ -151,7 +150,7 @@ func (s *Server) handleEvent(ws *websocket.Conn, m message.Message) { func (s *Server) handleClose(ws *websocket.Conn, m message.Message) { msg, ok := m.(*message.Close) if !ok { - _, _ = ws.Write(message.MakeNotice("error: can't parse CLOSE message.")) + _ = ws.WriteMessage(1, message.MakeNotice("error: can't parse CLOSE message.")) return } @@ -161,14 +160,14 @@ func (s *Server) handleClose(ws *websocket.Conn, m message.Message) { conn, ok := s.conns[ws] if !ok { - _, _ = ws.Write(message.MakeNotice(fmt.Sprintf("error: can't find connection %s.", + _ = ws.WriteMessage(1, message.MakeNotice(fmt.Sprintf("error: can't find connection %s.", ws.RemoteAddr()))) return } delete(conn, msg.String()) - _, _ = ws.Write(message.MakeClosed(msg.String(), "ok: closed successfully.")) + _ = ws.WriteMessage(1, message.MakeClosed(msg.String(), "ok: closed successfully.")) } // Stop shutdowns the server gracefully. @@ -179,7 +178,7 @@ func (s *Server) Stop() error { for wsConn, subs := range s.conns { // close all subscriptions. for id := range subs { - _, _ = wsConn.Write(message.MakeClosed(id, "error: shutdowning the relay.")) + _ = wsConn.WriteMessage(1, message.MakeClosed(id, "error: shutdowning the relay.")) } // close connection.