Skip to content

Commit

Permalink
feat: introducing the binary column reader and improved copy reader i…
Browse files Browse the repository at this point in the history
…mplementation
  • Loading branch information
jeroenrinzema committed Nov 11, 2024
1 parent 76c5433 commit 62bc453
Show file tree
Hide file tree
Showing 8 changed files with 456 additions and 267 deletions.
9 changes: 2 additions & 7 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package wire
import (
"context"
"fmt"
"io"
"sync"

"github.com/jeroenrinzema/psql-wire/pkg/buffer"
Expand Down Expand Up @@ -102,11 +101,7 @@ func (cache *DefaultPortalCache) Get(ctx context.Context, name string) (*Portal,
return portal, nil
}

func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, writer *buffer.Writer) (err error) {
return cache.ExecuteCopyIn(ctx, name, writer, nil)
}

func (cache *DefaultPortalCache) ExecuteCopyIn(ctx context.Context, name string, writer *buffer.Writer, copyData io.Reader) (err error) {
func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, reader *buffer.Reader, writer *buffer.Writer) (err error) {
defer func() {
r := recover()
if r != nil {
Expand All @@ -126,5 +121,5 @@ func (cache *DefaultPortalCache) ExecuteCopyIn(ctx context.Context, name string,
return nil
}

return portal.statement.fn(ctx, NewDataWriter(ctx, portal.statement.columns, portal.formats, writer, copyData), portal.parameters)
return portal.statement.fn(ctx, NewDataWriter(ctx, portal.statement.columns, portal.formats, reader, writer), portal.parameters)
}
255 changes: 93 additions & 162 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,13 @@ func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *b
}

for {
if err = srv.consumeSingleCommand(ctx, reader, writer, srv.handleCommand(conn)); err != nil {
if err = srv.consumeSingleCommand(ctx, reader, writer, conn); err != nil {
return err
}
}
}

type commandHandler func(context.Context, types.ClientMessage, *buffer.Reader, *buffer.Writer) error

func (srv *Server) consumeSingleCommand(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer, handleCommand commandHandler) error {
func (srv *Server) consumeSingleCommand(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer, conn net.Conn) error {
t, length, err := reader.ReadTypedMsg()
if err == io.EOF {
return nil
Expand Down Expand Up @@ -107,7 +105,7 @@ func (srv *Server) consumeSingleCommand(ctx context.Context, reader *buffer.Read
// connections are not blocking a close.
srv.wg.Add(1)
srv.logger.Debug("<- incoming command", slog.Int("length", length), slog.String("type", t.String()))
err = handleCommand(ctx, t, reader, writer)
err = srv.handleCommand(ctx, conn, t, reader, writer)
srv.wg.Done()
if errors.Is(err, io.EOF) {
return nil
Expand Down Expand Up @@ -143,164 +141,101 @@ func handleMessageSizeExceeded(reader *buffer.Reader, writer *buffer.Writer, exc
// message type and reader buffer containing the actual message. The type
// indecates a action executed by the client.
// https://www.postgresql.org/docs/14/protocol-message-formats.html
func (srv *Server) handleCommand(conn net.Conn) commandHandler {
return func(ctx context.Context, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

switch t {
case types.ClientSimpleQuery:
return srv.handleSimpleQuery(ctx, reader, writer)
case types.ClientExecute:
return srv.handleExecute(ctx, reader, writer)
case types.ClientParse:
return srv.handleParse(ctx, reader, writer)
case types.ClientDescribe:
// The Describe message (portal variant) specifies the name of an
// existing portal (or an empty string for the unnamed portal). The
// response is a RowDescription message describing the rows that will be
// returned by executing the portal; or a NoData message if the portal
// does not contain a query that will return rows; or ErrorResponse if
// there is no such portal.
//
// The Describe message (statement variant) specifies the name of an
// existing prepared statement (or an empty string for the unnamed
// prepared statement). The response is a ParameterDescription message
// describing the parameters needed by the statement, followed by a
// RowDescription message describing the rows that will be returned when
// the statement is eventually executed (or a NoData message if the
// statement will not return rows). ErrorResponse is issued if there is
// no such prepared statement. Note that since Bind has not yet been
// issued, the formats to be used for returned columns are not yet known
// to the backend; the format code fields in the RowDescription message
// will be zeroes in this case.
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
return srv.handleDescribe(ctx, reader, writer)
case types.ClientSync:
// TODO: Include the ability to catch sync messages in order to
// close the current transaction.
//
// At completion of each series of extended-query messages, the frontend
// should issue a Sync message. This parameterless message causes the
// backend to close the current transaction if it's not inside a
// BEGIN/COMMIT transaction block (“close” meaning to commit if no
// error, or roll back if error). Then a ReadyForQuery response is
// issued. The purpose of Sync is to provide a resynchronization point
// for error recovery. When an error is detected while processing any
// extended-query message, the backend issues ErrorResponse, then reads
// and discards messages until a Sync is reached, then issues
// ReadyForQuery and returns to normal message processing. (But note
// that no skipping occurs if an error is detected while processing Sync
// — this ensures that there is one and only one ReadyForQuery sent for
// each Sync.)
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
return readyForQuery(writer, types.ServerIdle)
case types.ClientBind:
return srv.handleBind(ctx, reader, writer)
case types.ClientFlush:
// TODO: Flush all remaining rows inside connection buffer if
// any are remaining.
//
// The Flush message does not cause any specific
// output to be generated, but forces the backend to deliver any data
// pending in its output buffers. A Flush must be sent after any
// extended-query command except Sync, if the frontend wishes to examine
// the results of that command before issuing more commands. Without
// Flush, messages returned by the backend will be combined into the
// minimum possible number of packets to minimize network overhead.
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
return nil
case types.ClientCopyData, types.ClientCopyDone, types.ClientCopyFail:
// We're supposed to ignore these messages, per the protocol spec. This
// state will happen when an error occurs on the server-side during a copy
// operation: the server will send an error and a ready message back to
// the client, and must then ignore further copy messages. See:
// https://github.com/postgres/postgres/blob/6e1dd2773eb60a6ab87b27b8d9391b756e904ac3/src/backend/tcop/postgres.c#L4295
return nil
case types.ClientClose:
// TODO: close the statement or portal
writer.Start(types.ServerCloseComplete) //nolint:errcheck
writer.End() //nolint:errcheck
return nil
case types.ClientTerminate:
err := srv.handleConnTerminate(ctx)
if err != nil {
return err
}

err = conn.Close()
if err != nil {
return err
}

return io.EOF
default:
return ErrorCode(writer, NewErrUnimplementedMessageType(t))
func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

switch t {
case types.ClientSimpleQuery:
return srv.handleSimpleQuery(ctx, reader, writer)
case types.ClientExecute:
return srv.handleExecute(ctx, reader, writer)
case types.ClientParse:
return srv.handleParse(ctx, reader, writer)
case types.ClientDescribe:
// The Describe message (portal variant) specifies the name of an
// existing portal (or an empty string for the unnamed portal). The
// response is a RowDescription message describing the rows that will be
// returned by executing the portal; or a NoData message if the portal
// does not contain a query that will return rows; or ErrorResponse if
// there is no such portal.
//
// The Describe message (statement variant) specifies the name of an
// existing prepared statement (or an empty string for the unnamed
// prepared statement). The response is a ParameterDescription message
// describing the parameters needed by the statement, followed by a
// RowDescription message describing the rows that will be returned when
// the statement is eventually executed (or a NoData message if the
// statement will not return rows). ErrorResponse is issued if there is
// no such prepared statement. Note that since Bind has not yet been
// issued, the formats to be used for returned columns are not yet known
// to the backend; the format code fields in the RowDescription message
// will be zeroes in this case.
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
return srv.handleDescribe(ctx, reader, writer)
case types.ClientSync:
// TODO: Include the ability to catch sync messages in order to
// close the current transaction.
//
// At completion of each series of extended-query messages, the frontend
// should issue a Sync message. This parameterless message causes the
// backend to close the current transaction if it's not inside a
// BEGIN/COMMIT transaction block (“close” meaning to commit if no
// error, or roll back if error). Then a ReadyForQuery response is
// issued. The purpose of Sync is to provide a resynchronization point
// for error recovery. When an error is detected while processing any
// extended-query message, the backend issues ErrorResponse, then reads
// and discards messages until a Sync is reached, then issues
// ReadyForQuery and returns to normal message processing. (But note
// that no skipping occurs if an error is detected while processing Sync
// — this ensures that there is one and only one ReadyForQuery sent for
// each Sync.)
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
return readyForQuery(writer, types.ServerIdle)
case types.ClientBind:
return srv.handleBind(ctx, reader, writer)
case types.ClientFlush:
// TODO: Flush all remaining rows inside connection buffer if
// any are remaining.
//
// The Flush message does not cause any specific
// output to be generated, but forces the backend to deliver any data
// pending in its output buffers. A Flush must be sent after any
// extended-query command except Sync, if the frontend wishes to examine
// the results of that command before issuing more commands. Without
// Flush, messages returned by the backend will be combined into the
// minimum possible number of packets to minimize network overhead.
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
return nil
case types.ClientCopyData, types.ClientCopyDone, types.ClientCopyFail:
// We're supposed to ignore these messages, per the protocol spec. This
// state will happen when an error occurs on the server-side during a copy
// operation: the server will send an error and a ready message back to
// the client, and must then ignore further copy messages. See:
// https://github.com/postgres/postgres/blob/6e1dd2773eb60a6ab87b27b8d9391b756e904ac3/src/backend/tcop/postgres.c#L4295
return nil
case types.ClientClose:
// TODO: close the statement or portal
writer.Start(types.ServerCloseComplete) //nolint:errcheck
writer.End() //nolint:errcheck
return nil
case types.ClientTerminate:
err := srv.handleConnTerminate(ctx)
if err != nil {
return err
}
}
}

func (srv *Server) copyData(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) io.Reader {
r := &copyDataReader{}
r.more = func() error {
err := srv.consumeSingleCommand(ctx, reader, writer, srv.handleCopyInCommand(r))
if err == errClientCopyDone {
return io.EOF
}
return err
}
return r
}

type copyDataReader struct {
buf []byte
more func() error
}

func (r *copyDataReader) Read(p []byte) (n int, err error) {
if len(r.buf) == 0 {
if err := r.more(); err != nil {
return 0, err
err = conn.Close()
if err != nil {
return err
}
}

n = copy(p, r.buf)
r.buf = r.buf[n:]
return n, nil
}

// handleCopyInCommand handles the given client message, while in CopyIn mode.
func (srv *Server) handleCopyInCommand(r *copyDataReader) commandHandler {
return func(ctx context.Context, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) error {
switch t {
case types.ClientFlush, types.ClientSync:
// The backend will ignore Flush and Sync messages received during copy-in mode.
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-COPY
return nil
case types.ClientCopyData:
r.buf = reader.Msg
return nil
case types.ClientCopyDone:
return errClientCopyDone
case types.ClientCopyFail:
desc, err := reader.GetString()
if err != nil {
return err
}
return ErrorCode(writer, newErrClientCopyFailed(desc))
default:
// Receipt of any other non-copy message type constitutes an error that
// will abort the copy-in state as described above.
// https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-COPY
return ErrorCode(writer, NewErrUnimplementedMessageType(t))
}
return io.EOF
default:
return ErrorCode(writer, NewErrUnimplementedMessageType(t))
}
}

// errClientCopyDone internal sentinel error value distinct from [io.EOF], since
// that has special meaning in [commandLoop].
var errClientCopyDone = errors.New("client sent CopyDone")

func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
if srv.parse == nil {
return ErrorCode(writer, NewErrUnimplementedMessageType(types.ClientSimpleQuery))
Expand Down Expand Up @@ -343,7 +278,7 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader,
return ErrorCode(writer, err)
}

err = statements[index].fn(ctx, NewDataWriter(ctx, statements[index].columns, nil, writer, srv.copyData(ctx, reader, writer)), nil)
err = statements[index].fn(ctx, NewDataWriter(ctx, statements[index].columns, nil, reader, writer), nil)
if err != nil {
return ErrorCode(writer, err)
}
Expand Down Expand Up @@ -622,11 +557,7 @@ func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, wri
}

srv.logger.Debug("executing", slog.String("name", name), slog.Uint64("limit", uint64(limit)))
if pcCopyIn, ok := srv.Portals.(PortalCacheCopyIn); ok {
err = pcCopyIn.ExecuteCopyIn(ctx, name, writer, srv.copyData(ctx, reader, writer))
} else {
err = srv.Portals.Execute(ctx, name, writer)
}
err = srv.Portals.Execute(ctx, name, reader, writer)
if err != nil {
return ErrorCode(writer, err)
}
Expand Down
Loading

0 comments on commit 62bc453

Please sign in to comment.