diff --git a/dialect/pgdialect/sqltype.go b/dialect/pgdialect/sqltype.go index 4c2d8075d..dab0446ed 100644 --- a/dialect/pgdialect/sqltype.go +++ b/dialect/pgdialect/sqltype.go @@ -68,6 +68,10 @@ func fieldSQLType(field *schema.Field) string { } } + if field.DiscoveredSQLType == sqltype.Blob { + return pgTypeBytea + } + return sqlType(field.IndirectType) } diff --git a/dialect/sqlitedialect/dialect.go b/dialect/sqlitedialect/dialect.go index 01cff2e7a..74b278a5d 100644 --- a/dialect/sqlitedialect/dialect.go +++ b/dialect/sqlitedialect/dialect.go @@ -2,6 +2,7 @@ package sqlitedialect import ( "database/sql" + "encoding/hex" "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect/feature" @@ -47,14 +48,36 @@ func (d *Dialect) OnTable(table *schema.Table) { } func (d *Dialect) onField(field *schema.Field) { - // INTEGER PRIMARY KEY is an alias for the ROWID. - // It is safe to convert all ints to INTEGER, because SQLite types don't have size. - switch field.DiscoveredSQLType { - case sqltype.SmallInt, sqltype.BigInt: - field.DiscoveredSQLType = sqltype.Integer - } + field.DiscoveredSQLType = fieldSQLType(field) } func (d *Dialect) IdentQuote() byte { return '"' } + +func (d *Dialect) AppendBytes(b []byte, bs []byte) []byte { + if bs == nil { + return dialect.AppendNull(b) + } + + b = append(b, `X'`...) + + s := len(b) + b = append(b, make([]byte, hex.EncodedLen(len(bs)))...) + hex.Encode(b[s:], bs) + + b = append(b, '\'') + + return b +} + +func fieldSQLType(field *schema.Field) string { + switch field.DiscoveredSQLType { + case sqltype.SmallInt, sqltype.BigInt: + // INTEGER PRIMARY KEY is an alias for the ROWID. + // It is safe to convert all ints to INTEGER, because SQLite types don't have size. + return sqltype.Integer + default: + return field.DiscoveredSQLType + } +} diff --git a/dialect/sqltype/sqltype.go b/dialect/sqltype/sqltype.go index 84a51d26d..f58b2f1d1 100644 --- a/dialect/sqltype/sqltype.go +++ b/dialect/sqltype/sqltype.go @@ -8,6 +8,7 @@ const ( Real = "REAL" DoublePrecision = "DOUBLE PRECISION" VarChar = "VARCHAR" + Blob = "BLOB" Timestamp = "TIMESTAMP" JSON = "JSON" JSONB = "JSONB" diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 04620dd11..ae6b86596 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -227,11 +227,12 @@ func TestDB(t *testing.T) { {testFKViolation}, {testInterfaceAny}, {testInterfaceJSON}, - {testScanBytes}, + {testScanRawMessage}, {testPointers}, {testExists}, {testScanTimeIntoString}, {testModelNonPointer}, + {testBinaryData}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -828,7 +829,7 @@ func testInterfaceJSON(t *testing.T, db *bun.DB) { require.Equal(t, "hello", model.Value) } -func testScanBytes(t *testing.T, db *bun.DB) { +func testScanRawMessage(t *testing.T, db *bun.DB) { type Model struct { ID int64 Value json.RawMessage @@ -914,3 +915,23 @@ func testModelNonPointer(t *testing.T, db *bun.DB) { require.Error(t, err) require.Equal(t, "bun: Model(non-pointer dbtest_test.Model)", err.Error()) } + +func testBinaryData(t *testing.T, db *bun.DB) { + type Model struct { + ID int64 + Data []byte + } + + ctx := context.Background() + + err := db.ResetModel(ctx, (*Model)(nil)) + require.NoError(t, err) + + _, err = db.NewInsert().Model(&Model{Data: []byte("hello")}).Exec(ctx) + require.NoError(t, err) + + var model Model + err = db.NewSelect().Model(&model).Scan(ctx) + require.NoError(t, err) + require.Equal(t, []byte("hello"), model.Data) +} diff --git a/internal/dbtest/testdata/snapshots/TestQuery-sqlite-60 b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-60 index 6442b3174..b8cc6d201 100644 --- a/internal/dbtest/testdata/snapshots/TestQuery-sqlite-60 +++ b/internal/dbtest/testdata/snapshots/TestQuery-sqlite-60 @@ -1 +1 @@ -INSERT INTO "models" ("bytes") VALUES ('\x00000000000000000000') +INSERT INTO "models" ("bytes") VALUES (X'00000000000000000000') diff --git a/schema/sqltype.go b/schema/sqltype.go index 76259a67b..90551d6aa 100644 --- a/schema/sqltype.go +++ b/schema/sqltype.go @@ -61,6 +61,14 @@ func DiscoverSQLType(typ reflect.Type) string { case nullStringType: return sqltype.VarChar } + + switch typ.Kind() { + case reflect.Slice: + if typ.Elem().Kind() == reflect.Uint8 { + return sqltype.Blob + } + } + return sqlTypes[typ.Kind()] }