diff --git a/CHANGELOG.md b/CHANGELOG.md index 5033928..af6d35b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [1.5.1] - !fix(orderby): use BuildOption instead of allowedColumns (#46) +- feat(string): added nullable String/Null for sql/json (#47) ## [1.5.0] - 2024-04-30 ### Changed diff --git a/bool_test.go b/bool_test.go index 5c479e6..fc35062 100644 --- a/bool_test.go +++ b/bool_test.go @@ -11,7 +11,7 @@ func TestBool(t *testing.T) { d, err := sql.Open("sqlite3", "file::memory:") require.NoError(t, err) - _, err = d.Exec("CREATE TABLE `users` (`id` id NOT NULL,`status` BIT(1), PRIMARY KEY (`id`))") + _, err = d.Exec("CREATE TABLE `users` (`id` int NOT NULL,`status` BIT(1), PRIMARY KEY (`id`))") require.NoError(t, err) result, err := d.Exec("INSERT INTO `users`(`id`, `status`) VALUES(?, ?)", 10, Bool(true)) diff --git a/null.go b/null.go new file mode 100644 index 0000000..334f9ad --- /dev/null +++ b/null.go @@ -0,0 +1,61 @@ +package sqle + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" +) + +var nullJsonBytes = []byte("null") + +const nullJson = "null" + +type Null[T any] struct { + sql.Null[T] +} + +func NewNull[T any](v T, valid bool) Null[T] { + return Null[T]{Null: sql.Null[T]{V: v, Valid: valid}} +} + +// Scan implements the [sql.Scanner] interface. +func (t *Null[T]) Scan(value any) error { // skipcq: GO-W1029 + return t.Null.Scan(value) +} + +// Value implements the [driver.Valuer] interface. +func (t Null[T]) Value() (driver.Value, error) { // skipcq: GO-W1029 + return t.Null.Value() +} + +// TValue returns the underlying value of the Null struct. +func (t *Null[T]) TValue() T { // skipcq: GO-W1029 + return t.Null.V +} + +// MarshalJSON implements the json.Marshaler interface +func (t Null[T]) MarshalJSON() ([]byte, error) { // skipcq: GO-W1029 + if t.Valid { + return json.Marshal(t.Null.V) + } + return nullJsonBytes, nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface +func (t *Null[T]) UnmarshalJSON(data []byte) error { // skipcq: GO-W1029 + if len(data) == 0 || string(data) == nullJson { + t.Null.Valid = false + return nil + } + + var v T + err := json.Unmarshal(data, &v) + if err != nil { + return err + } + + t.Null.V = v + t.Null.Valid = true + + return nil +} diff --git a/null_test.go b/null_test.go new file mode 100644 index 0000000..d751de0 --- /dev/null +++ b/null_test.go @@ -0,0 +1,112 @@ +package sqle + +import ( + "database/sql" + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNullInSQL(t *testing.T) { + + v := float64(10.5) + d, err := sql.Open("sqlite3", "file::memory:") + require.NoError(t, err) + + _, err = d.Exec("CREATE TABLE `nulls` (`id` int NOT NULL,`value` DECIMAL(10, 2), PRIMARY KEY (`id`))") + require.NoError(t, err) + + result, err := d.Exec("INSERT INTO `nulls`(`id`) VALUES(?)", 10) + require.NoError(t, err) + + rows, err := result.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), rows) + + result, err = d.Exec("INSERT INTO `nulls`(`id`, `value`) VALUES(?, ?)", 20, v) + require.NoError(t, err) + + rows, err = result.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), rows) + + var v10 Null[float64] + err = d.QueryRow("SELECT `value` FROM `nulls` WHERE id=?", 10).Scan(&v10) + require.NoError(t, err) + + require.EqualValues(t, false, v10.Valid) + + var v20 Null[float64] + err = d.QueryRow("SELECT `value` FROM `nulls` WHERE id=?", 20).Scan(&v20) + require.NoError(t, err) + + require.EqualValues(t, true, v20.Valid) + require.EqualValues(t, v, v20.TValue()) + + result, err = d.Exec("INSERT INTO `nulls`(`id`,`value`) VALUES(?, ?)", 11, v10) + require.NoError(t, err) + + rows, err = result.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), rows) + + result, err = d.Exec("INSERT INTO `nulls`(`id`, `value`) VALUES(?, ?)", 21, v20) + require.NoError(t, err) + + rows, err = result.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), rows) + + var v11 Null[float64] + err = d.QueryRow("SELECT `value` FROM `nulls` WHERE id=?", 11).Scan(&v11) + require.NoError(t, err) + + require.EqualValues(t, false, v11.Valid) + + var v21 Null[float64] + err = d.QueryRow("SELECT `value` FROM `nulls` WHERE id=?", 21).Scan(&v21) + require.NoError(t, err) + + require.EqualValues(t, true, v21.Valid) + require.EqualValues(t, v, v21.TValue()) + +} + +func TestNullInJSON(t *testing.T) { + + sysValue := 10.5 + + bufSysValue, err := json.Marshal(sysValue) + require.NoError(t, err) + + sqleNull := NewNull(sysValue, true) + + bufSqleNull, err := json.Marshal(sqleNull) + require.NoError(t, err) + + require.Equal(t, bufSysValue, bufSqleNull) + + var jsSqleValue Null[float64] + // Unmarshal sqle.Time from time.Time json bytes + err = json.Unmarshal(bufSysValue, &jsSqleValue) + require.NoError(t, err) + + require.Equal(t, sysValue, jsSqleValue.TValue()) + require.Equal(t, true, jsSqleValue.Valid) + + var jsSysValue float64 + // Unmarshal time.Time from sqle.Time json bytes + err = json.Unmarshal(bufSqleNull, &jsSysValue) + require.NoError(t, err) + require.Equal(t, sysValue, jsSysValue) + + var nullValue Null[float64] + err = json.Unmarshal([]byte("null"), &nullValue) + require.NoError(t, err) + require.Equal(t, false, nullValue.Valid) + + bufNull, err := json.Marshal(nullValue) + require.NoError(t, err) + require.Equal(t, []byte("null"), bufNull) +} diff --git a/string.go b/string.go new file mode 100644 index 0000000..6786552 --- /dev/null +++ b/string.go @@ -0,0 +1,56 @@ +package sqle + +import ( + "database/sql/driver" + "encoding/json" +) + +type String struct { + Null[string] +} + +func NewString(s string) String { + return String{Null: NewNull(s, true)} +} + +// Scan implements the [sql.Scanner] interface. +func (t *String) Scan(value any) error { // skipcq: GO-W1029 + return t.Null.Scan(value) +} + +// Value implements the [driver.Valuer] interface. +func (t String) Value() (driver.Value, error) { // skipcq: GO-W1029 + return t.Null.Value() +} + +// Time returns the underlying time.Time value of the Time struct. +func (t *String) String() string { // skipcq: GO-W1029 + return t.TValue() +} + +// MarshalJSON implements the json.Marshaler interface +func (t String) MarshalJSON() ([]byte, error) { // skipcq: GO-W1029 + if t.Valid { + return json.Marshal(t.TValue()) + } + return nullJsonBytes, nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface +func (t *String) UnmarshalJSON(data []byte) error { // skipcq: GO-W1029 + if len(data) == 0 || string(data) == nullJson { + t.Null.Valid = false + return nil + } + + var v string + err := json.Unmarshal(data, &v) + if err != nil { + return err + } + + t.Null.V = v + t.Null.Valid = true + + return nil +} diff --git a/string_test.go b/string_test.go new file mode 100644 index 0000000..4e1950d --- /dev/null +++ b/string_test.go @@ -0,0 +1,112 @@ +package sqle + +import ( + "database/sql" + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStringInSQL(t *testing.T) { + + v := "has value" + d, err := sql.Open("sqlite3", "file::memory:") + require.NoError(t, err) + + _, err = d.Exec("CREATE TABLE `strings` (`id` int NOT NULL,`name` varchar(125), PRIMARY KEY (`id`))") + require.NoError(t, err) + + result, err := d.Exec("INSERT INTO `strings`(`id`) VALUES(?)", 10) + require.NoError(t, err) + + rows, err := result.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), rows) + + result, err = d.Exec("INSERT INTO `strings`(`id`, `name`) VALUES(?, ?)", 20, v) + require.NoError(t, err) + + rows, err = result.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), rows) + + var v10 String + err = d.QueryRow("SELECT `name` FROM `strings` WHERE id=?", 10).Scan(&v10) + require.NoError(t, err) + + require.EqualValues(t, false, v10.Valid) + + var v20 String + err = d.QueryRow("SELECT `name` FROM `strings` WHERE id=?", 20).Scan(&v20) + require.NoError(t, err) + + require.EqualValues(t, true, v20.Valid) + require.EqualValues(t, v, v20.String()) + + result, err = d.Exec("INSERT INTO `strings`(`id`,`name`) VALUES(?, ?)", 11, v10) + require.NoError(t, err) + + rows, err = result.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), rows) + + result, err = d.Exec("INSERT INTO `strings`(`id`, `name`) VALUES(?, ?)", 21, v20) + require.NoError(t, err) + + rows, err = result.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(1), rows) + + var v11 String + err = d.QueryRow("SELECT `name` FROM `strings` WHERE id=?", 11).Scan(&v11) + require.NoError(t, err) + + require.EqualValues(t, false, v11.Valid) + + var v21 String + err = d.QueryRow("SELECT `name` FROM `strings` WHERE id=?", 21).Scan(&v21) + require.NoError(t, err) + + require.EqualValues(t, true, v21.Valid) + require.EqualValues(t, v, v21.String()) + +} + +func TestStringInJSON(t *testing.T) { + + sysString := "has value" + + bufSysString, err := json.Marshal(sysString) + require.NoError(t, err) + + sqleString := NewString(sysString) + + bufSqleString, err := json.Marshal(sqleString) + require.NoError(t, err) + + require.Equal(t, bufSysString, bufSqleString) + + var jsSqleString String + // Unmarshal sqle.Time from time.Time json bytes + err = json.Unmarshal(bufSysString, &jsSqleString) + require.NoError(t, err) + + require.Equal(t, sysString, jsSqleString.String()) + require.Equal(t, true, jsSqleString.Valid) + + var jsSysString string + // Unmarshal time.Time from sqle.Time json bytes + err = json.Unmarshal(bufSqleString, &jsSysString) + require.NoError(t, err) + require.Equal(t, sysString, jsSysString) + + var nullString String + err = json.Unmarshal([]byte("null"), &nullString) + require.NoError(t, err) + require.Equal(t, false, nullString.Valid) + + bufNull, err := json.Marshal(nullString) + require.NoError(t, err) + require.Equal(t, []byte("null"), bufNull) +} diff --git a/time.go b/time.go index d28b8b7..2c43489 100644 --- a/time.go +++ b/time.go @@ -7,10 +7,6 @@ import ( "time" ) -var nullTimeJsonBytes = []byte("null") - -const nullTimeJson = "null" - // Time represents a nullable time value. type Time struct { sql.NullTime @@ -41,12 +37,12 @@ func (t Time) MarshalJSON() ([]byte, error) { // skipcq: GO-W1029 if t.Valid { return json.Marshal(t.NullTime.Time) } - return nullTimeJsonBytes, nil + return nullJsonBytes, nil } // UnmarshalJSON implements the json.Unmarshaler interface func (t *Time) UnmarshalJSON(data []byte) error { // skipcq: GO-W1029 - if len(data) == 0 || string(data) == nullTimeJson { + if len(data) == 0 || string(data) == nullJson { t.NullTime.Time = time.Time{} t.NullTime.Valid = false return nil