Skip to content

Commit

Permalink
refactor: updated wire parameters and prepared statement api
Browse files Browse the repository at this point in the history
  • Loading branch information
jeroenrinzema committed Nov 8, 2023
1 parent bdc34d1 commit a1020c9
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 105 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ func main() {
wire.ListenAndServe("127.0.0.1:5432", handler)
}

func handler(ctx context.Context, query string) (wire.PreparedStatementFn, []oid.Oid, wire.Columns, error) {
func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) {
fmt.Println(query)

statement := func(ctx context.Context, writer wire.DataWriter, parameters []string) error {
statement := wire.NewPreparedStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
return writer.Complete("OK")
}
})

return statement, nil, nil, nil
return statement, nil
}
```

Expand Down
12 changes: 6 additions & 6 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type DefaultStatementCache struct {

// Set attempts to bind the given statement to the given name. Any
// previously defined statement is overridden.
func (cache *DefaultStatementCache) Set(ctx context.Context, name string, fn PreparedStatementFn, parameters []oid.Oid, columns Columns) error {
func (cache *DefaultStatementCache) Set(ctx context.Context, name string, stmt *PreparedStatement) error {
cache.mu.Lock()
defer cache.mu.Unlock()

Expand All @@ -30,9 +30,9 @@ func (cache *DefaultStatementCache) Set(ctx context.Context, name string, fn Pre
}

cache.statements[name] = &Statement{
fn: fn,
parameters: parameters,
columns: columns,
fn: stmt.fn,
parameters: stmt.parameters,
columns: stmt.columns,
}

return nil
Expand All @@ -58,15 +58,15 @@ func (cache *DefaultStatementCache) Get(ctx context.Context, name string) (*Stat

type portal struct {
statement *Statement
parameters []string
parameters []Parameter
}

type DefaultPortalCache struct {
portals map[string]portal
mu sync.RWMutex
}

func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *Statement, parameters []string) error {
func (cache *DefaultPortalCache) Bind(ctx context.Context, name string, stmt *Statement, parameters []Parameter) error {
cache.mu.Lock()
defer cache.mu.Unlock()

Expand Down
54 changes: 31 additions & 23 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader,
return readyForQuery(writer, types.ServerIdle)
}

statement, _, columns, err := srv.parse(ctx, query)
statement, err := srv.parse(ctx, query)
if err != nil {
return ErrorCode(writer, err)
}
Expand All @@ -250,12 +250,12 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader,
}

// NOTE: we have to define the column definitions before executing a simple query
err = columns.Define(ctx, writer)
err = statement.columns.Define(ctx, writer)
if err != nil {
return ErrorCode(writer, err)
}

err = statement(ctx, NewDataWriter(ctx, columns, writer), nil)
err = statement.fn(ctx, NewDataWriter(ctx, statement.columns, writer), nil)
if err != nil {
return ErrorCode(writer, err)
}
Expand Down Expand Up @@ -295,14 +295,14 @@ func (srv *Server) handleParse(ctx context.Context, reader *buffer.Reader, write
// `reader.GetUint32()`
}

statement, params, columns, err := srv.parse(ctx, query)
statement, err := srv.parse(ctx, query)
if err != nil {
return ErrorCode(writer, err)
}

srv.logger.Debug("incoming extended query", slog.String("query", query), slog.String("name", name), slog.Int("parameters", len(params)))
srv.logger.Debug("incoming extended query", slog.String("query", query), slog.String("name", name), slog.Int("parameters", len(statement.parameters)))

err = srv.Statements.Set(ctx, name, statement, params, columns)
err = srv.Statements.Set(ctx, name, statement)
if err != nil {
return ErrorCode(writer, err)
}
Expand Down Expand Up @@ -411,33 +411,35 @@ func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer
// readParameters attempts to read all incoming parameters from the given
// reader. The parameters are parsed and returned.
// https://www.postgresql.org/docs/14/protocol-message-formats.html
func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([]string, error) {
// NOTE: read the total amount of parameter format codes that will
// be send by the client.
func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([]Parameter, error) {
// NOTE: read the total amount of parameter format length that will be send
// by the client. This can be zero to indicate that there are no parameters
// or that the parameters all use the default format (text); or one, in
// which case the specified format code is applied to all parameters; or it
// can equal the actual number of parameters.
length, err := reader.GetUint16()
if err != nil {
return nil, err
}

srv.logger.Debug("reading parameters format codes", slog.Uint64("length", uint64(length)))

defaultFormat := TextFormat
formats := make([]FormatCode, length)
for i := uint16(0); i < length; i++ {
// NOTE: we have to set the default format code to the given format code
// if only one is given according to the protocol specs.
if length == 1 {
defaultFormat = FormatCode(i)
break
}

format, err := reader.GetUint16()
if err != nil {
return nil, err
}

// NOTE: the parameter format codes. Each must presently be zero (text) or one (binary).
// https://www.postgresql.org/docs/14/protocol-message-formats.html
if format != 0 {
return nil, errors.New("unsupported binary parameter format, only text formatted parameter types are currently supported")
}

// TODO: Handle multiple parameter format codes.
//
// We are currently only supporting string parameters. We have to
// include support for binary parameters in the future.
// https://www.postgresql.org/docs/14/protocol-message-formats.html
formats[i] = FormatCode(format)
}

// NOTE: read the total amount of parameter values that will be send
Expand All @@ -449,8 +451,8 @@ func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([

srv.logger.Debug("reading parameters values", slog.Uint64("length", uint64(length)))

parameters := make([]string, length)
for i := uint16(0); i < length; i++ {
parameters := make([]Parameter, length)
for i := 0; i < int(length); i++ {
length, err := reader.GetUint32()
if err != nil {
return nil, err
Expand All @@ -462,7 +464,13 @@ func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([
}

srv.logger.Debug("incoming parameter", slog.String("value", string(value)))
parameters[i] = string(value)

format := defaultFormat
if len(formats) > int(i) {
format = formats[i]
}

parameters[i] = NewParameter(format, value)
}

// NOTE: Read the total amount of result-column format that will be
Expand Down
20 changes: 12 additions & 8 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/lib/pq/oid"
"github.com/neilotoole/slogt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestMessageSizeExceeded(t *testing.T) {
Expand Down Expand Up @@ -67,19 +68,24 @@ func TestBindMessageParameters(t *testing.T) {
},
}

handler := func(ctx context.Context, query string) (PreparedStatementFn, []oid.Oid, Columns, error) {
statement := func(ctx context.Context, writer DataWriter, parameters []string) error {
handler := func(ctx context.Context, query string) (*PreparedStatement, error) {
statement := NewPreparedStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error {
t.Log("serving query")

if len(parameters) != 2 {
return fmt.Errorf("unexpected amount of parameters %d, expected 2", len(parameters))
}

writer.Row([]any{parameters[0], parameters[1]}) //nolint:errcheck
first := string(parameters[0].value)
second := string(parameters[1].value)

writer.Row([]any{first, second}) //nolint:errcheck
return writer.Complete("SELECT 1")
}
})

return statement, ParseParameters(query), columns, nil
statement.WithParameters(ParseParameters(query))
statement.WithColumns(columns)
return statement, nil
}

server, err := NewServer(handler, Logger(slogt.New(t)))
Expand Down Expand Up @@ -111,9 +117,7 @@ func TestBindMessageParameters(t *testing.T) {
var answer string

err = rows.Scan(&name, &answer)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)

t.Logf("scan result: %s, %s", name, answer)

Expand Down
10 changes: 5 additions & 5 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ import (
"github.com/jackc/pgx/v5"
"github.com/jeroenrinzema/psql-wire/codes"
psqlerr "github.com/jeroenrinzema/psql-wire/errors"
"github.com/lib/pq/oid"
"github.com/neilotoole/slogt"
"github.com/stretchr/testify/assert"
)

func TestErrorCode(t *testing.T) {
handler := func(ctx context.Context, query string) (PreparedStatementFn, []oid.Oid, Columns, error) {
statement := func(ctx context.Context, writer DataWriter, parameters []string) error {
handler := func(ctx context.Context, query string) (*PreparedStatement, error) {
statement := NewPreparedStatement(func(ctx context.Context, writer DataWriter, parameters []Parameter) error {
return psqlerr.WithSeverity(psqlerr.WithCode(errors.New("unimplemented feature"), codes.FeatureNotSupported), psqlerr.LevelFatal)
}
return statement, nil, nil, nil
})

return statement, nil
}

server, err := NewServer(handler, Logger(slogt.New(t)))
Expand Down
6 changes: 3 additions & 3 deletions examples/error/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ import (
wire "github.com/jeroenrinzema/psql-wire"
"github.com/jeroenrinzema/psql-wire/codes"
psqlerr "github.com/jeroenrinzema/psql-wire/errors"
"github.com/lib/pq/oid"
)

func main() {
log.Println("PostgreSQL server is up and running at [127.0.0.1:5432]")
wire.ListenAndServe("127.0.0.1:5432", handler)
}

func handler(ctx context.Context, query string) (wire.PreparedStatementFn, []oid.Oid, wire.Columns, error) {
func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) {
log.Println("incoming SQL query:", query)

err := errors.New("unimplemented feature")
err = psqlerr.WithCode(err, codes.FeatureNotSupported)
err = psqlerr.WithSeverity(err, psqlerr.LevelFatal)
return nil, nil, nil, err

return nil, err
}
9 changes: 5 additions & 4 deletions examples/numeric/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,19 @@ var table = wire.Columns{
},
}

func handler(ctx context.Context, query string) (wire.PreparedStatementFn, []oid.Oid, wire.Columns, error) {
func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) {
log.Println("incoming SQL query:", query)

statement := func(ctx context.Context, writer wire.DataWriter, parameters []string) error {
statement := wire.NewPreparedStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
balance, err := decimal.NewFromString("256.23")
if err != nil {
return err
}

writer.Row([]any{balance})
return writer.Complete("SELECT 1")
}
})

return statement, wire.ParseParameters(query), table, nil
statement.WithColumns(table)
return statement, nil
}
9 changes: 4 additions & 5 deletions examples/session/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"sync"

wire "github.com/jeroenrinzema/psql-wire"
"github.com/lib/pq/oid"
)

func main() {
Expand All @@ -33,13 +32,13 @@ func session(ctx context.Context) (context.Context, error) {
return context.WithValue(ctx, id, counter), nil
}

func handler(ctx context.Context, query string) (wire.PreparedStatementFn, []oid.Oid, wire.Columns, error) {
func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) {
log.Println("incoming SQL query:", query)

statement := func(ctx context.Context, writer wire.DataWriter, parameters []string) error {
statement := wire.NewPreparedStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
session := ctx.Value(id).(int)
return writer.Complete(fmt.Sprintf("OK, session: %d", session))
}
})

return statement, wire.ParseParameters(query), nil, nil
return statement, nil
}
9 changes: 5 additions & 4 deletions examples/simple/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@ var table = wire.Columns{
},
}

func handler(ctx context.Context, query string) (wire.PreparedStatementFn, []oid.Oid, wire.Columns, error) {
func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) {
log.Println("incoming SQL query:", query)

statement := func(ctx context.Context, writer wire.DataWriter, parameters []string) error {
statement := wire.NewPreparedStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
writer.Row([]any{"John", true, 29})
writer.Row([]any{"Marry", false, 21})
return writer.Complete("SELECT 2")
}
})

return statement, wire.ParseParameters(query), table, nil
statement.WithColumns(table)
return statement, nil
}
9 changes: 4 additions & 5 deletions examples/tls/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"log/slog"

wire "github.com/jeroenrinzema/psql-wire"
"github.com/lib/pq/oid"
)

func main() {
Expand All @@ -34,12 +33,12 @@ func run() error {
return server.ListenAndServe("127.0.0.1:5432")
}

func handler(ctx context.Context, query string) (wire.PreparedStatementFn, []oid.Oid, wire.Columns, error) {
func handler(ctx context.Context, query string) (*wire.PreparedStatement, error) {
slog.Info("incoming SQL query", slog.String("query", query))

statement := func(ctx context.Context, writer wire.DataWriter, parameters []string) error {
statement := wire.NewPreparedStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
return writer.Complete("OK")
}
})

return statement, wire.ParseParameters(query), nil, nil
return statement, nil
}
Loading

0 comments on commit a1020c9

Please sign in to comment.