From 59c4c8b2c03a0c470aca578e86b37f8f65cfae3d Mon Sep 17 00:00:00 2001 From: Jeroen Rinzema Date: Mon, 11 Nov 2024 23:04:11 +0100 Subject: [PATCH] fix: isolate statements and portals --- cache.go | 8 ++++++++ command.go | 32 +++++++++++++++++++------------- examples/session/main.go | 2 +- options.go | 12 ++++++------ options_test.go | 6 +++--- wire.go | 16 +++++++++++----- 6 files changed, 48 insertions(+), 28 deletions(-) diff --git a/cache.go b/cache.go index 5934fcd..5cf219b 100644 --- a/cache.go +++ b/cache.go @@ -15,6 +15,10 @@ type Statement struct { columns Columns } +func DefaultStatementCacheFn() StatementCache { + return &DefaultStatementCache{} +} + type DefaultStatementCache struct { statements map[string]*Statement mu sync.RWMutex @@ -63,6 +67,10 @@ type Portal struct { formats []FormatCode } +func DefaultPortalCacheFn() PortalCache { + return &DefaultPortalCache{} +} + type DefaultPortalCache struct { portals map[string]*Portal mu sync.RWMutex diff --git a/command.go b/command.go index b34a28e..f789731 100644 --- a/command.go +++ b/command.go @@ -52,12 +52,18 @@ func newErrClientCopyFailed(desc string) error { return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.Uncategorized), psqlerr.LevelError) } +type Session struct { + *Server + Statements StatementCache + Portals PortalCache +} + // consumeCommands consumes incoming commands sent over the Postgres wire connection. // Commands consumed from the connection are returned through a go channel. // Responses for the given message type are written back to the client. // This method keeps consuming messages until the client issues a close message // or the connection is terminated. -func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) error { +func (srv *Session) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) error { srv.logger.Debug("ready for query... starting to consume commands") // TODO: Include a value to identify unique connections @@ -77,7 +83,7 @@ func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *b } } -func (srv *Server) consumeSingleCommand(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer, conn net.Conn) error { +func (srv *Session) consumeSingleCommand(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer, conn net.Conn) error { t, length, err := reader.ReadTypedMsg() if err == io.EOF { return nil @@ -141,7 +147,7 @@ func handleMessageSizeExceeded(reader *buffer.Reader, writer *buffer.Writer, exc // message type and reader buffer containing the actual message. The type // indecates a action executed by the client. // https://www.postgresql.org/docs/14/protocol-message-formats.html -func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) error { +func (srv *Session) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) error { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -236,7 +242,7 @@ func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.Cli } } -func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { +func (srv *Session) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { if srv.parse == nil { return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientSimpleQuery)) } @@ -287,7 +293,7 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, return readyForQuery(writer, types.ServerIdle) } -func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { +func (srv *Session) handleParse(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { if srv.parse == nil || srv.Statements == nil { return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientParse)) } @@ -337,7 +343,7 @@ func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, write return writer.End() } -func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { +func (srv *Session) handleDescribe(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { d, err := reader.GetBytes(1) if err != nil { return err @@ -385,7 +391,7 @@ func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, wr } // https://www.postgresql.org/docs/15/protocol-message-formats.html -func (srv *Server) writeParameterDescription(writer *buffer.Writer, parameters []oid.Oid) error { +func (srv *Session) writeParameterDescription(writer *buffer.Writer, parameters []oid.Oid) error { writer.Start(types.ServerParameterDescription) writer.AddInt16(int16(len(parameters))) @@ -400,7 +406,7 @@ func (srv *Server) writeParameterDescription(writer *buffer.Writer, parameters [ // back to the writer buffer. Information about the returned columns is written // to the client. // https://www.postgresql.org/docs/15/protocol-message-formats.html -func (srv *Server) writeColumnDescription(ctx context.Context, writer *buffer.Writer, formats []FormatCode, columns Columns) error { +func (srv *Session) writeColumnDescription(ctx context.Context, writer *buffer.Writer, formats []FormatCode, columns Columns) error { if len(columns) == 0 { writer.Start(types.ServerNoData) return writer.End() @@ -409,7 +415,7 @@ func (srv *Server) writeColumnDescription(ctx context.Context, writer *buffer.Wr return columns.Define(ctx, writer, formats) } -func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { +func (srv *Session) handleBind(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { name, err := reader.GetString() if err != nil { return err @@ -451,7 +457,7 @@ func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer // readParameters attempts to read all incoming parameters from the given // reader. The parameters are parsed and returned. // https://www.postgresql.org/docs/14/protocol-message-formats.html -func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([]Parameter, error) { +func (srv *Session) readParameters(ctx context.Context, reader *buffer.Reader) ([]Parameter, error) { // NOTE: read the total amount of parameter format length that will be send // by the client. This can be zero to indicate that there are no parameters // or that the parameters all use the default format (text); or one, in @@ -516,7 +522,7 @@ func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([ return parameters, nil } -func (srv *Server) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error) { +func (srv *Session) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error) { length, err := reader.GetUint16() if err != nil { return nil, err @@ -537,7 +543,7 @@ func (srv *Server) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error) return columns, nil } -func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { +func (srv *Session) handleExecute(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { if srv.Statements == nil { return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientExecute)) } @@ -565,7 +571,7 @@ func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, wri return nil } -func (srv *Server) handleConnTerminate(ctx context.Context) error { +func (srv *Session) handleConnTerminate(ctx context.Context) error { if srv.TerminateConn == nil { return nil } diff --git a/examples/session/main.go b/examples/session/main.go index 6228dad..be55ba7 100644 --- a/examples/session/main.go +++ b/examples/session/main.go @@ -10,7 +10,7 @@ import ( ) func main() { - srv, err := wire.NewServer(handler, wire.Session(session)) + srv, err := wire.NewServer(handler, wire.SessionMiddleware(session)) if err != nil { panic(err) } diff --git a/options.go b/options.go index 15a9545..997f074 100644 --- a/options.go +++ b/options.go @@ -103,18 +103,18 @@ type OptionFn func(*Server) error // Statements sets the statement cache used to cache statements for later use. By // default [DefaultStatementCache] is used. -func Statements(cache StatementCache) OptionFn { +func Statements(handler func() StatementCache) OptionFn { return func(srv *Server) error { - srv.Statements = cache + srv.Statements = handler return nil } } // Portals sets the portals cache used to cache statements for later use. By // default [DefaultPortalCache] is used. -func Portals(cache PortalCache) OptionFn { +func Portals(handler func() PortalCache) OptionFn { return func(srv *Server) error { - srv.Portals = cache + srv.Portals = handler return nil } } @@ -199,10 +199,10 @@ func ExtendTypes(fn func(*pgtype.Map)) OptionFn { } } -// Session sets the given session handler within the underlying server. The +// SessionMiddleware sets the given session handler within the underlying server. The // session handler is called when a new connection is opened and authenticated // allowing for additional metadata to be wrapped around the connection context. -func Session(fn SessionHandler) OptionFn { +func SessionMiddleware(fn SessionHandler) OptionFn { return func(srv *Server) error { if srv.Session == nil { srv.Session = fn diff --git a/options_test.go b/options_test.go index 81d7fc1..d60c5c2 100644 --- a/options_test.go +++ b/options_test.go @@ -56,15 +56,15 @@ func TestSessionHandler(t *testing.T) { tests := map[string]test{ "single": { - Session(func(ctx context.Context) (context.Context, error) { + SessionMiddleware(func(ctx context.Context) (context.Context, error) { return context.WithValue(ctx, mock, value), nil }), }, "nested": { - Session(func(ctx context.Context) (context.Context, error) { + SessionMiddleware(func(ctx context.Context) (context.Context, error) { return ctx, nil }), - Session(func(ctx context.Context) (context.Context, error) { + SessionMiddleware(func(ctx context.Context) (context.Context, error) { return context.WithValue(ctx, mock, value), nil }), }, diff --git a/wire.go b/wire.go index 7249500..ccd3d08 100644 --- a/wire.go +++ b/wire.go @@ -35,8 +35,8 @@ func NewServer(parse ParseFn, options ...OptionFn) (*Server, error) { logger: slog.Default(), closer: make(chan struct{}), types: pgtype.NewMap(), - Statements: &DefaultStatementCache{}, - Portals: &DefaultPortalCache{}, + Statements: DefaultStatementCacheFn, + Portals: DefaultPortalCacheFn, Session: func(ctx context.Context) (context.Context, error) { return ctx, nil }, } @@ -62,8 +62,8 @@ type Server struct { TLSConfig *tls.Config parse ParseFn Session SessionHandler - Statements StatementCache - Portals PortalCache + Statements func() StatementCache + Portals func() PortalCache CloseConn CloseFn TerminateConn CloseFn Version string @@ -162,7 +162,13 @@ func (srv *Server) serve(ctx context.Context, conn net.Conn) error { return err } - return srv.consumeCommands(ctx, conn, reader, writer) + session := &Session{ + Server: srv, + Statements: srv.Statements(), + Portals: srv.Portals(), + } + + return session.consumeCommands(ctx, conn, reader, writer) } // Close gracefully closes the underlaying Postgres server.