diff --git a/conn.go b/conn.go index 54d128ab0..410acad7a 100644 --- a/conn.go +++ b/conn.go @@ -98,6 +98,9 @@ var ErrNoRows = errors.New("no rows in result set") // ErrInvalidLogLevel occurs on attempt to set an invalid log level. var ErrInvalidLogLevel = errors.New("invalid log level") +var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") +var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") + // Connect establishes a connection with a PostgreSQL server with a connection string. See // pgconn.Connect for details. func Connect(ctx context.Context, connString string) (*Conn, error) { @@ -430,7 +433,7 @@ optionLoop: switch mode { case QueryExecModeCacheStatement: if c.statementCache == nil { - return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") + return pgconn.CommandTag{}, errDisabledStatementCache } sd, err := c.statementCache.Get(ctx, sql) if err != nil { @@ -440,7 +443,7 @@ optionLoop: return c.execPrepared(ctx, sd, arguments) case QueryExecModeCacheDescribe: if c.descriptionCache == nil { - return pgconn.CommandTag{}, fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") + return pgconn.CommandTag{}, errDisabledDescriptionCache } sd, err := c.descriptionCache.Get(ctx, sql) if err != nil { @@ -536,24 +539,49 @@ func (c *Conn) execSQLParams(ctx context.Context, sql string, args []interface{} c.eqb.Reset() anynil.NormalizeSlice(args) + err := c.appendParamsForQueryExecModeExec(args) + if err != nil { + return pgconn.CommandTag{}, err + } - paramOIDs := make([]uint32, len(args)) + result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats).Read() + c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. + return result.CommandTag, result.Err +} +// appendParamsForQueryExecModeExec appends the args to c.eqb. +// +// Parameters must be encoded in the text format because of differences in type conversion between timestamps and +// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the +// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both +// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL +// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date. +// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion +// before converting it to date. This means that dates can be shifted by one day. In text format without that double +// type conversion it takes the date directly and ignores time zone (i.e. it works). +// +// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is +// no way to safely use binary or to specify the parameter OIDs. +func (c *Conn) appendParamsForQueryExecModeExec(args []interface{}) error { for i := range args { - dt, ok := c.TypeMap().TypeForValue(args[i]) - if !ok { - return pgconn.CommandTag{}, &unknownArgumentTypeQueryExecModeExecError{arg: args[i]} - } - err := c.eqb.AppendParam(c.typeMap, dt.OID, args[i]) - if err != nil { - return pgconn.CommandTag{}, err + if args[i] == nil { + err := c.eqb.AppendParamFormat(c.typeMap, 0, TextFormatCode, args[i]) + if err != nil { + return err + } + } else { + dt, ok := c.TypeMap().TypeForValue(args[i]) + if !ok { + return &unknownArgumentTypeQueryExecModeExecError{arg: args[i]} + } + err := c.eqb.AppendParamFormat(c.typeMap, dt.OID, TextFormatCode, args[i]) + if err != nil { + return err + } } - paramOIDs[i] = dt.OID } - result := c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, paramOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read() - c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. - return result.CommandTag, result.Err + return nil } func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows { @@ -589,14 +617,11 @@ const ( // when the the database schema is modified concurrently. QueryExecModeDescribeExec - // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended - // protocol. Queries are executed in a single round trip. Type mappings can be registered with - // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious. - // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use - // a map[string]string directly as an argument. This mode cannot. - // - // It may be necessary to specify the desired type of an argument in the SQL string when it cannot be inferred. e.g. - // "SELECT $1::boolean". + // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol + // with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be + // registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are + // unregistered or ambigious. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know + // the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot. QueryExecModeExec // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. @@ -605,8 +630,13 @@ const ( // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use // a map[string]string directly as an argument. This mode cannot. // - // This mode uses client side parameter interpolation. All values are quoted and escaped. It may be necessary to - // specify the desired type of an argument in the SQL string when it cannot be inferred. e.g. "SELECT $1::boolean". + // QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor + // exceptions such as behavior when multiple result returning queries are erroneously sent in a single string. + // + // QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer + // QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol + // should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does + // not support the extended protocol. QueryExecModeSimpleProtocol ) @@ -640,13 +670,13 @@ type QueryResultFormatsByOID map[uint32]int16 // Err() on the returned Rows must be checked after the Rows is closed to determine if the query executed successfully // as some errors can only be detected by reading the entire response. e.g. A divide by zero error on the last row. // -// For extra control over how the query is executed, the types QuerySimpleProtocol, QueryResultFormats, and +// For extra control over how the query is executed, the types QueryExecMode, QueryResultFormats, and // QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely // needed. See the documentation for those types for details. func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { var resultFormats QueryResultFormats var resultFormatsByOID QueryResultFormatsByOID - simpleProtocol := c.config.DefaultQueryExecMode == QueryExecModeSimpleProtocol + mode := c.config.DefaultQueryExecMode optionLoop: for len(args) > 0 { @@ -658,91 +688,118 @@ optionLoop: resultFormatsByOID = arg args = args[1:] case QueryExecMode: - simpleProtocol = arg == QueryExecModeSimpleProtocol + mode = arg args = args[1:] default: break optionLoop } } + c.eqb.Reset() + anynil.NormalizeSlice(args) rows := c.getRows(ctx, sql, args) var err error - sd, ok := c.preparedStatements[sql] - - if simpleProtocol && !ok { - sql, err = c.sanitizeForSimpleQuery(sql, args...) - if err != nil { - rows.fatal(err) - return rows, err + sd := c.preparedStatements[sql] + if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec { + if sd == nil { + switch mode { + case QueryExecModeCacheStatement: + if c.statementCache == nil { + err = errDisabledStatementCache + rows.fatal(err) + return rows, err + } + sd, err = c.statementCache.Get(ctx, sql) + if err != nil { + rows.fatal(err) + return rows, err + } + case QueryExecModeCacheDescribe: + if c.descriptionCache == nil { + err = errDisabledDescriptionCache + rows.fatal(err) + return rows, err + } + sd, err = c.descriptionCache.Get(ctx, sql) + if err != nil { + rows.fatal(err) + return rows, err + } + case QueryExecModeDescribeExec: + sd, err = c.Prepare(ctx, "", sql) + if err != nil { + rows.fatal(err) + return rows, err + } + } } - mrr := c.pgConn.Exec(ctx, sql) - if mrr.NextResult() { - rows.resultReader = mrr.ResultReader() - rows.multiResultReader = mrr - } else { - err = mrr.Close() - rows.fatal(err) - return rows, err + if len(sd.ParamOIDs) != len(args) { + rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args))) + return rows, rows.err } - return rows, nil - } + rows.sql = sd.SQL - c.eqb.Reset() - - if !ok { - if c.statementCache != nil { - sd, err = c.statementCache.Get(ctx, sql) + for i := range args { + err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) if err != nil { rows.fatal(err) return rows, rows.err } - } else { - sd, err = c.pgConn.Prepare(ctx, "", sql, nil) - if err != nil { - rows.fatal(err) - return rows, rows.err + } + + if resultFormatsByOID != nil { + resultFormats = make([]int16, len(sd.Fields)) + for i := range resultFormats { + resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)] } } - } - if len(sd.ParamOIDs) != len(args) { - rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args))) - return rows, rows.err - } - rows.sql = sd.SQL + if resultFormats == nil { + for i := range sd.Fields { + c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) + } - anynil.NormalizeSlice(args) + resultFormats = c.eqb.resultFormats + } - for i := range args { - err = c.eqb.AppendParam(c.typeMap, sd.ParamOIDs[i], args[i]) + if mode == QueryExecModeCacheDescribe { + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats) + } else { + rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) + } + } else if mode == QueryExecModeExec { + err := c.appendParamsForQueryExecModeExec(args) if err != nil { rows.fatal(err) return rows, rows.err } - } - if resultFormatsByOID != nil { - resultFormats = make([]int16, len(sd.Fields)) - for i := range resultFormats { - resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)] + rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, nil, c.eqb.paramFormats, c.eqb.resultFormats) + } else if mode == QueryExecModeSimpleProtocol { + sql, err = c.sanitizeForSimpleQuery(sql, args...) + if err != nil { + rows.fatal(err) + return rows, err } - } - if resultFormats == nil { - for i := range sd.Fields { - c.eqb.AppendResultFormat(c.TypeMap().FormatCodeForOID(sd.Fields[i].DataTypeOID)) + mrr := c.pgConn.Exec(ctx, sql) + if mrr.NextResult() { + rows.resultReader = mrr.ResultReader() + rows.multiResultReader = mrr + } else { + err = mrr.Close() + rows.fatal(err) + return rows, err } - resultFormats = c.eqb.resultFormats - } - - if c.statementCache != nil && c.statementCache.Mode() == stmtcache.ModeDescribe { - rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats) + return rows, nil } else { - rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) + err = fmt.Errorf("unknown QueryExecMode: %v", mode) + rows.fatal(err) + return rows, rows.err } c.eqb.Reset() // Allow c.eqb internal memory to be GC'ed as soon as possible. diff --git a/conn_test.go b/conn_test.go index 85b0da2b7..625d96937 100644 --- a/conn_test.go +++ b/conn_test.go @@ -256,15 +256,7 @@ func TestExecFailureWithArguments(t *testing.T) { assert.False(t, pgconn.SafeToRetry(err)) _, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2") - if conn.Config().DefaultQueryExecMode == pgx.QueryExecModeExec { - // The PostgreSQL server apparently doesn't care about receiving too many arguments and the only way to detect it - // locally would be to parse the SQL. The simple protocol path has to parse the SQL so it can cheaply do a check - // for the correct number of arguments. But since exec doesn't need to it doesn't make sense to waste time parsing - // the SQL. - require.NoError(t, err) - } else { - require.Error(t, err) - } + require.Error(t, err) }) } diff --git a/extended_query_builder.go b/extended_query_builder.go index 5409c0fd2..0b6e19625 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -14,9 +14,13 @@ type extendedQueryBuilder struct { func (eqb *extendedQueryBuilder) AppendParam(m *pgtype.Map, oid uint32, arg interface{}) error { f := eqb.chooseParameterFormatCode(m, oid, arg) - eqb.paramFormats = append(eqb.paramFormats, f) + return eqb.AppendParamFormat(m, oid, f, arg) +} + +func (eqb *extendedQueryBuilder) AppendParamFormat(m *pgtype.Map, oid uint32, format int16, arg interface{}) error { + eqb.paramFormats = append(eqb.paramFormats, format) - v, err := eqb.encodeExtendedParamValue(m, oid, f, arg) + v, err := eqb.encodeExtendedParamValue(m, oid, format, arg) if err != nil { return err } diff --git a/values_test.go b/values_test.go index b7d5c572e..f036b8a63 100644 --- a/values_test.go +++ b/values_test.go @@ -891,6 +891,19 @@ func TestEncodeTypeRename(t *testing.T) { inString := _string("foo") var outString _string + // pgx.QueryExecModeExec requires all types to be registered. + conn.TypeMap().RegisterDefaultPgType(inInt, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt8, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt16, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt32, "int8") + conn.TypeMap().RegisterDefaultPgType(inInt64, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint8, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint16, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint32, "int8") + conn.TypeMap().RegisterDefaultPgType(inUint64, "int8") + conn.TypeMap().RegisterDefaultPgType(inString, "text") + err := conn.QueryRow(context.Background(), "select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text", inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString, ).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString)