From 13beb380f51e3d744e27e7d55a3078ac8753ccf9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 18 May 2024 17:17:46 -0500 Subject: [PATCH] Fix encode driver.Valuer on nil-able non-pointers https://github.com/jackc/pgx/issues/1566 https://github.com/jackc/pgx/issues/1860 https://github.com/jackc/pgx/pull/2019#discussion_r1605806751 --- extended_query_builder.go | 5 ++ internal/anynil/anynil.go | 20 +++---- pgtype/doc.go | 6 +- query_test.go | 119 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 130 insertions(+), 20 deletions(-) diff --git a/extended_query_builder.go b/extended_query_builder.go index 0056cec7c..522a70e38 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -198,6 +198,11 @@ func (eqb *ExtendedQueryBuilder) oidAndArgForQueryExecModeExec(m *pgtype.Map, ar if err != nil { return 0, nil, err } + + if v == nil { + return 0, nil, nil + } + if dt, ok := m.TypeForValue(v); ok { return dt.OID, v, nil } diff --git a/internal/anynil/anynil.go b/internal/anynil/anynil.go index 314e2bbb2..967dcf415 100644 --- a/internal/anynil/anynil.go +++ b/internal/anynil/anynil.go @@ -14,9 +14,8 @@ import ( // var valuerReflectType = reflect.TypeFor[driver.Valuer]() var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() -// Is returns true if value is any type of nil except a pointer that directly implements driver.Valuer. e.g. nil, -// []byte(nil), and a *T where T implements driver.Valuer get normalized to nil but a *T where *T implements -// driver.Valuer does not. +// Is returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement +// driver.Valuer if it is only implemented by T. func Is(value any) bool { if value == nil { return true @@ -30,14 +29,13 @@ func Is(value any) bool { return false } - if kind == reflect.Ptr { - if _, ok := value.(driver.Valuer); ok { - // The pointer will be considered to implement driver.Valuer even if it is actually implemented on the value. - // But we only want to consider it nil if it is implemented on the pointer. So check if what the pointer points - // to implements driver.Valuer. - if !refVal.Type().Elem().Implements(valuerReflectType) { - return false - } + if _, ok := value.(driver.Valuer); ok { + if kind == reflect.Ptr { + // The type assertion will succeed if driver.Valuer is implemented on T or *T. Check if it is implemented on T + // to see if it is not implemented on *T. + return refVal.Type().Elem().Implements(valuerReflectType) + } else { + return false } } diff --git a/pgtype/doc.go b/pgtype/doc.go index 8e5023038..2039fcf1d 100644 --- a/pgtype/doc.go +++ b/pgtype/doc.go @@ -144,10 +144,10 @@ Encoding Typed Nils pgtype normalizes typed nils (e.g. []byte(nil)) into nil. nil is always encoded is the SQL NULL value without going through the Codec system. This means that Codecs and other encoding logic does not have to handle nil or *T(nil). -However, database/sql compatibility requires Value to be called on a pointer that implements driver.Valuer. See +However, database/sql compatibility requires Value to be called on T(nil) when T implements driver.Valuer. Therefore, +driver.Valuer values are not normalized to nil unless it is a *T(nil) where driver.Valuer is implemented on T. See https://github.com/golang/go/issues/8415 and -https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870. Therefore, pointers that implement -driver.Valuer are not normalized to nil. +https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870. Child Records diff --git a/query_test.go b/query_test.go index 550e8cb89..a6a26ad77 100644 --- a/query_test.go +++ b/query_test.go @@ -1173,12 +1173,12 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes ensureConnValid(t, conn) } -type nilAsEmptyJSONObject struct { +type nilPointerAsEmptyJSONObject struct { ID string Name string } -func (v *nilAsEmptyJSONObject) Value() (driver.Value, error) { +func (v *nilPointerAsEmptyJSONObject) Value() (driver.Value, error) { if v == nil { return "{}", nil } @@ -1187,7 +1187,7 @@ func (v *nilAsEmptyJSONObject) Value() (driver.Value, error) { } // https://github.com/jackc/pgx/issues/1566 -func TestConnQueryDatabaseSQLDriverValuerCalledOnPointerImplementers(t *testing.T) { +func TestConnQueryDatabaseSQLDriverValuerCalledOnNilPointerImplementers(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -1195,7 +1195,7 @@ func TestConnQueryDatabaseSQLDriverValuerCalledOnPointerImplementers(t *testing. mustExec(t, conn, "create temporary table t(v json not null)") - var v *nilAsEmptyJSONObject + var v *nilPointerAsEmptyJSONObject commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v) require.NoError(t, err) require.Equal(t, "INSERT 0 1", commandTag.String()) @@ -1208,12 +1208,119 @@ func TestConnQueryDatabaseSQLDriverValuerCalledOnPointerImplementers(t *testing. _, err = conn.Exec(context.Background(), `delete from t`) require.NoError(t, err) - v = &nilAsEmptyJSONObject{ID: "1", Name: "foo"} + v = &nilPointerAsEmptyJSONObject{ID: "1", Name: "foo"} commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v) require.NoError(t, err) require.Equal(t, "INSERT 0 1", commandTag.String()) - var v2 *nilAsEmptyJSONObject + var v2 *nilPointerAsEmptyJSONObject + err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2) + require.NoError(t, err) + require.Equal(t, v, v2) + + ensureConnValid(t, conn) +} + +type nilSliceAsEmptySlice []byte + +func (j nilSliceAsEmptySlice) Value() (driver.Value, error) { + if len(j) == 0 { + return []byte("[]"), nil + } + + return []byte(j), nil +} + +func (j *nilSliceAsEmptySlice) UnmarshalJSON(data []byte) error { + *j = bytes.Clone(data) + return nil +} + +// https://github.com/jackc/pgx/issues/1860 +func TestConnQueryDatabaseSQLDriverValuerCalledOnNilSliceImplementers(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table t(v json not null)") + + var v nilSliceAsEmptySlice + commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var s string + err = conn.QueryRow(context.Background(), "select v from t").Scan(&s) + require.NoError(t, err) + require.Equal(t, "[]", s) + + _, err = conn.Exec(context.Background(), `delete from t`) + require.NoError(t, err) + + v = nilSliceAsEmptySlice(`{"name": "foo"}`) + commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var v2 nilSliceAsEmptySlice + err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2) + require.NoError(t, err) + require.Equal(t, v, v2) + + ensureConnValid(t, conn) +} + +type nilMapAsEmptyObject map[string]any + +func (j nilMapAsEmptyObject) Value() (driver.Value, error) { + if j == nil { + return []byte("{}"), nil + } + + return json.Marshal(j) +} + +func (j *nilMapAsEmptyObject) UnmarshalJSON(data []byte) error { + var m map[string]any + err := json.Unmarshal(data, &m) + if err != nil { + return err + } + + *j = m + + return nil +} + +// https://github.com/jackc/pgx/pull/2019#discussion_r1605806751 +func TestConnQueryDatabaseSQLDriverValuerCalledOnNilMapImplementers(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table t(v json not null)") + + var v nilMapAsEmptyObject + commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var s string + err = conn.QueryRow(context.Background(), "select v from t").Scan(&s) + require.NoError(t, err) + require.Equal(t, "{}", s) + + _, err = conn.Exec(context.Background(), `delete from t`) + require.NoError(t, err) + + v = nilMapAsEmptyObject{"name": "foo"} + commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var v2 nilMapAsEmptyObject err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2) require.NoError(t, err) require.Equal(t, v, v2)