Skip to content

Commit

Permalink
fix: isolate statements and portals
Browse files Browse the repository at this point in the history
  • Loading branch information
jeroenrinzema committed Nov 11, 2024
1 parent bb11a1b commit 59c4c8b
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 28 deletions.
8 changes: 8 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ type Statement struct {
columns Columns
}

func DefaultStatementCacheFn() StatementCache {
return &DefaultStatementCache{}
}

type DefaultStatementCache struct {
statements map[string]*Statement
mu sync.RWMutex
Expand Down Expand Up @@ -63,6 +67,10 @@ type Portal struct {
formats []FormatCode
}

func DefaultPortalCacheFn() PortalCache {
return &DefaultPortalCache{}
}

type DefaultPortalCache struct {
portals map[string]*Portal
mu sync.RWMutex
Expand Down
32 changes: 19 additions & 13 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,18 @@ func newErrClientCopyFailed(desc string) error {
return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.Uncategorized), psqlerr.LevelError)
}

type Session struct {
*Server
Statements StatementCache
Portals PortalCache
}

// consumeCommands consumes incoming commands sent over the Postgres wire connection.
// Commands consumed from the connection are returned through a go channel.
// Responses for the given message type are written back to the client.
// This method keeps consuming messages until the client issues a close message
// or the connection is terminated.
func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) error {
srv.logger.Debug("ready for query... starting to consume commands")

// TODO: Include a value to identify unique connections
Expand All @@ -77,7 +83,7 @@ func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *b
}
}

func (srv *Server) consumeSingleCommand(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer, conn net.Conn) error {
func (srv *Session) consumeSingleCommand(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer, conn net.Conn) error {
t, length, err := reader.ReadTypedMsg()
if err == io.EOF {
return nil
Expand Down Expand Up @@ -141,7 +147,7 @@ func handleMessageSizeExceeded(reader *buffer.Reader, writer *buffer.Writer, exc
// message type and reader buffer containing the actual message. The type
// indecates a action executed by the client.
// https://www.postgresql.org/docs/14/protocol-message-formats.html
func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

Expand Down Expand Up @@ -236,7 +242,7 @@ func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.Cli
}
}

func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
if srv.parse == nil {
return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientSimpleQuery))
}
Expand Down Expand Up @@ -287,7 +293,7 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader,
return readyForQuery(writer, types.ServerIdle)
}

func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) handleParse(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
if srv.parse == nil || srv.Statements == nil {
return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientParse))
}
Expand Down Expand Up @@ -337,7 +343,7 @@ func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, write
return writer.End()
}

func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) handleDescribe(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
d, err := reader.GetBytes(1)
if err != nil {
return err
Expand Down Expand Up @@ -385,7 +391,7 @@ func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, wr
}

// https://www.postgresql.org/docs/15/protocol-message-formats.html
func (srv *Server) writeParameterDescription(writer *buffer.Writer, parameters []oid.Oid) error {
func (srv *Session) writeParameterDescription(writer *buffer.Writer, parameters []oid.Oid) error {
writer.Start(types.ServerParameterDescription)
writer.AddInt16(int16(len(parameters)))

Expand All @@ -400,7 +406,7 @@ func (srv *Server) writeParameterDescription(writer *buffer.Writer, parameters [
// back to the writer buffer. Information about the returned columns is written
// to the client.
// https://www.postgresql.org/docs/15/protocol-message-formats.html
func (srv *Server) writeColumnDescription(ctx context.Context, writer *buffer.Writer, formats []FormatCode, columns Columns) error {
func (srv *Session) writeColumnDescription(ctx context.Context, writer *buffer.Writer, formats []FormatCode, columns Columns) error {
if len(columns) == 0 {
writer.Start(types.ServerNoData)
return writer.End()
Expand All @@ -409,7 +415,7 @@ func (srv *Server) writeColumnDescription(ctx context.Context, writer *buffer.Wr
return columns.Define(ctx, writer, formats)
}

func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) handleBind(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
name, err := reader.GetString()
if err != nil {
return err
Expand Down Expand Up @@ -451,7 +457,7 @@ func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer
// readParameters attempts to read all incoming parameters from the given
// reader. The parameters are parsed and returned.
// https://www.postgresql.org/docs/14/protocol-message-formats.html
func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([]Parameter, error) {
func (srv *Session) readParameters(ctx context.Context, reader *buffer.Reader) ([]Parameter, error) {
// NOTE: read the total amount of parameter format length that will be send
// by the client. This can be zero to indicate that there are no parameters
// or that the parameters all use the default format (text); or one, in
Expand Down Expand Up @@ -516,7 +522,7 @@ func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([
return parameters, nil
}

func (srv *Server) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error) {
func (srv *Session) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error) {
length, err := reader.GetUint16()
if err != nil {
return nil, err
Expand All @@ -537,7 +543,7 @@ func (srv *Server) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error)
return columns, nil
}

func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Session) handleExecute(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
if srv.Statements == nil {
return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientExecute))
}
Expand Down Expand Up @@ -565,7 +571,7 @@ func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, wri
return nil
}

func (srv *Server) handleConnTerminate(ctx context.Context) error {
func (srv *Session) handleConnTerminate(ctx context.Context) error {
if srv.TerminateConn == nil {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion examples/session/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

func main() {
srv, err := wire.NewServer(handler, wire.Session(session))
srv, err := wire.NewServer(handler, wire.SessionMiddleware(session))
if err != nil {
panic(err)
}
Expand Down
12 changes: 6 additions & 6 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,18 @@ type OptionFn func(*Server) error

// Statements sets the statement cache used to cache statements for later use. By
// default [DefaultStatementCache] is used.
func Statements(cache StatementCache) OptionFn {
func Statements(handler func() StatementCache) OptionFn {
return func(srv *Server) error {
srv.Statements = cache
srv.Statements = handler
return nil
}
}

// Portals sets the portals cache used to cache statements for later use. By
// default [DefaultPortalCache] is used.
func Portals(cache PortalCache) OptionFn {
func Portals(handler func() PortalCache) OptionFn {
return func(srv *Server) error {
srv.Portals = cache
srv.Portals = handler
return nil
}
}
Expand Down Expand Up @@ -199,10 +199,10 @@ func ExtendTypes(fn func(*pgtype.Map)) OptionFn {
}
}

// Session sets the given session handler within the underlying server. The
// SessionMiddleware sets the given session handler within the underlying server. The
// session handler is called when a new connection is opened and authenticated
// allowing for additional metadata to be wrapped around the connection context.
func Session(fn SessionHandler) OptionFn {
func SessionMiddleware(fn SessionHandler) OptionFn {
return func(srv *Server) error {
if srv.Session == nil {
srv.Session = fn
Expand Down
6 changes: 3 additions & 3 deletions options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ func TestSessionHandler(t *testing.T) {

tests := map[string]test{
"single": {
Session(func(ctx context.Context) (context.Context, error) {
SessionMiddleware(func(ctx context.Context) (context.Context, error) {
return context.WithValue(ctx, mock, value), nil
}),
},
"nested": {
Session(func(ctx context.Context) (context.Context, error) {
SessionMiddleware(func(ctx context.Context) (context.Context, error) {
return ctx, nil
}),
Session(func(ctx context.Context) (context.Context, error) {
SessionMiddleware(func(ctx context.Context) (context.Context, error) {
return context.WithValue(ctx, mock, value), nil
}),
},
Expand Down
16 changes: 11 additions & 5 deletions wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func NewServer(parse ParseFn, options ...OptionFn) (*Server, error) {
logger: slog.Default(),
closer: make(chan struct{}),
types: pgtype.NewMap(),
Statements: &DefaultStatementCache{},
Portals: &DefaultPortalCache{},
Statements: DefaultStatementCacheFn,
Portals: DefaultPortalCacheFn,
Session: func(ctx context.Context) (context.Context, error) { return ctx, nil },
}

Expand All @@ -62,8 +62,8 @@ type Server struct {
TLSConfig *tls.Config
parse ParseFn
Session SessionHandler
Statements StatementCache
Portals PortalCache
Statements func() StatementCache
Portals func() PortalCache
CloseConn CloseFn
TerminateConn CloseFn
Version string
Expand Down Expand Up @@ -162,7 +162,13 @@ func (srv *Server) serve(ctx context.Context, conn net.Conn) error {
return err
}

return srv.consumeCommands(ctx, conn, reader, writer)
session := &Session{
Server: srv,
Statements: srv.Statements(),
Portals: srv.Portals(),
}

return session.consumeCommands(ctx, conn, reader, writer)
}

// Close gracefully closes the underlaying Postgres server.
Expand Down

0 comments on commit 59c4c8b

Please sign in to comment.