Skip to content

Commit

Permalink
Refactor simple / extended query protocol executions
Browse files Browse the repository at this point in the history
  • Loading branch information
exAspArk committed Nov 20, 2024
1 parent d2a4d65 commit 2c7853b
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 75 deletions.
136 changes: 70 additions & 66 deletions src/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,77 +45,16 @@ func (postgres *Postgres) Run(proxy *Proxy) {
for {
message, err := postgres.backend.Receive()
if err != nil {
return
return // Terminate connection
}

switch message.(type) {
case *pgproto3.Query:
query := message.(*pgproto3.Query).String
LogDebug(postgres.config, "Received query:", query)
messages, err := proxy.HandleQuery(query)
if err != nil {
postgres.writeError("Internal error")
continue
}
messages = append(messages, &pgproto3.ReadyForQuery{TxStatus: PG_TX_STATUS_IDLE})
postgres.writeMessages(messages...)
case *pgproto3.Parse: // Extended query protocol
message := message.(*pgproto3.Parse)
LogDebug(postgres.config, "Parsing query", message.Query)
messages, preparedStatement, err := proxy.HandleParseQuery(message)
postgres.handleSimpleQuery(proxy, message.(*pgproto3.Query))
case *pgproto3.Parse:
err = postgres.handleExtendedQuery(proxy, message.(*pgproto3.Parse))
if err != nil {
postgres.writeError("Failed to parse query")
continue
}
postgres.writeMessages(messages...)

for {
message, err := postgres.backend.Receive()
if err != nil {
return
}
synced := false

switch message.(type) {
case *pgproto3.Bind:
message := message.(*pgproto3.Bind)
LogDebug(postgres.config, "Binding query", message.PreparedStatement)
messages, preparedStatement, err = proxy.HandleBindQuery(message, preparedStatement)
if err != nil {
postgres.writeError("Failed to bind query")
continue
}
postgres.writeMessages(messages...)
case *pgproto3.Describe:
message := message.(*pgproto3.Describe)
LogDebug(postgres.config, "Describing query", message.Name, "("+string(message.ObjectType)+")")
var messages []pgproto3.Message
messages, preparedStatement, err = proxy.HandleDescribeQuery(message, preparedStatement)
if err != nil {
postgres.writeError("Failed to describe query")
continue
}
postgres.writeMessages(messages...)
case *pgproto3.Execute:
message := message.(*pgproto3.Execute)
LogDebug(postgres.config, "Executing query", message.Portal)
messages, err := proxy.HandleExecuteQuery(message, preparedStatement)
if err != nil {
postgres.writeError("Failed to execute query")
continue
}
postgres.writeMessages(messages...)
case *pgproto3.Sync:
LogDebug(postgres.config, "Syncing query")
postgres.writeMessages(
&pgproto3.ReadyForQuery{TxStatus: PG_TX_STATUS_IDLE},
)
synced = true
}

if synced {
break
}
return // Terminate connection
}
case *pgproto3.Terminate:
LogDebug(postgres.config, "Client terminated connection")
Expand All @@ -130,6 +69,71 @@ func (postgres *Postgres) Close() error {
return (*postgres.conn).Close()
}

func (postgres *Postgres) handleSimpleQuery(proxy *Proxy, queryMessage *pgproto3.Query) {
LogDebug(postgres.config, "Received query:", queryMessage.String)
messages, err := proxy.HandleQuery(queryMessage.String)
if err != nil {
postgres.writeError("Internal error")
return
}
messages = append(messages, &pgproto3.ReadyForQuery{TxStatus: PG_TX_STATUS_IDLE})
postgres.writeMessages(messages...)
}

func (postgres *Postgres) handleExtendedQuery(proxy *Proxy, parseMessage *pgproto3.Parse) error {
LogDebug(postgres.config, "Parsing query", parseMessage.Query)
messages, preparedStatement, err := proxy.HandleParseQuery(parseMessage)
if err != nil {
postgres.writeError("Failed to parse query")
return nil
}
postgres.writeMessages(messages...)

for {
message, err := postgres.backend.Receive()
if err != nil {
return err
}

switch message.(type) {
case *pgproto3.Bind:
message := message.(*pgproto3.Bind)
LogDebug(postgres.config, "Binding query", message.PreparedStatement)
messages, preparedStatement, err = proxy.HandleBindQuery(message, preparedStatement)
if err != nil {
postgres.writeError("Failed to bind query")
continue
}
postgres.writeMessages(messages...)
case *pgproto3.Describe:
message := message.(*pgproto3.Describe)
LogDebug(postgres.config, "Describing query", message.Name, "("+string(message.ObjectType)+")")
var messages []pgproto3.Message
messages, preparedStatement, err = proxy.HandleDescribeQuery(message, preparedStatement)
if err != nil {
postgres.writeError("Failed to describe query")
continue
}
postgres.writeMessages(messages...)
case *pgproto3.Execute:
message := message.(*pgproto3.Execute)
LogDebug(postgres.config, "Executing query", message.Portal)
messages, err := proxy.HandleExecuteQuery(message, preparedStatement)
if err != nil {
postgres.writeError("Failed to execute query")
continue
}
postgres.writeMessages(messages...)
case *pgproto3.Sync:
LogDebug(postgres.config, "Syncing query")
postgres.writeMessages(
&pgproto3.ReadyForQuery{TxStatus: PG_TX_STATUS_IDLE},
)
return nil
}
}
}

func (postgres *Postgres) writeMessages(messages ...pgproto3.Message) {
var buf []byte
var err error
Expand Down
20 changes: 11 additions & 9 deletions src/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,9 @@ func (proxy *Proxy) HandleDescribeQuery(message *pgproto3.Describe, preparedStat
LogError(proxy.config, "Couldn't execute prepared statement via DuckDB:", preparedStatement.Query+"\n"+err.Error())
return nil, nil, err
}
defer rows.Close()

preparedStatement.Rows = rows

messages, err := proxy.rowsToDescriptionMessages(rows, preparedStatement.Query)
messages, err := proxy.rowsToDescriptionMessages(preparedStatement.Rows, preparedStatement.Query)
if err != nil {
return nil, nil, err
}
Expand All @@ -307,14 +305,18 @@ func (proxy *Proxy) HandleExecuteQuery(message *pgproto3.Execute, preparedStatem
return nil, errors.New("Portal mismatch")
}

rows, err := preparedStatement.Statement.QueryContext(context.Background(), preparedStatement.Variables...)
if err != nil {
LogError(proxy.config, "Couldn't execute prepared statement via DuckDB:", preparedStatement.Query+"\n"+err.Error())
return nil, err
if preparedStatement.Rows == nil {
rows, err := preparedStatement.Statement.QueryContext(context.Background(), preparedStatement.Variables...)
if err != nil {
LogError(proxy.config, "Couldn't execute prepared statement via DuckDB:", preparedStatement.Query+"\n"+err.Error())
return nil, err
}
preparedStatement.Rows = rows
}
defer rows.Close()

return proxy.rowsToDataMessages(rows, preparedStatement.Query)
defer preparedStatement.Rows.Close()

return proxy.rowsToDataMessages(preparedStatement.Rows, preparedStatement.Query)
}

func (proxy *Proxy) rowsToDescriptionMessages(rows *sql.Rows, query string) ([]pgproto3.Message, error) {
Expand Down
1 change: 1 addition & 0 deletions src/select_remapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var KNOWN_SET_STATEMENTS = NewSet([]string{
"client_min_messages", // SET client_min_messages TO 'warning'
"standard_conforming_strings", // SET standard_conforming_strings = on
"intervalstyle", // SET intervalstyle = iso_8601
"timezone", // SET SESSION timezone TO 'UTC'
})

type SelectRemapper struct {
Expand Down

0 comments on commit 2c7853b

Please sign in to comment.