From a6f81ee48da409e81520d02de35c921e29bf0467 Mon Sep 17 00:00:00 2001 From: mattn Date: Sat, 26 Oct 2024 15:55:26 +0900 Subject: [PATCH] Support Oracle dialect (#995) * feat: support Oracle dialect * feat: update README.md --- README.md | 3 +- dialect/dialect.go | 3 + dialect/oracledialect/dialect.go | 126 +++++++++++++++++++++++++++++++ dialect/oracledialect/scan.go | 11 +++ dialect/oracledialect/version.go | 6 ++ query_base.go | 7 +- query_table_create.go | 9 ++- 7 files changed, 161 insertions(+), 4 deletions(-) create mode 100644 dialect/oracledialect/dialect.go create mode 100644 dialect/oracledialect/scan.go create mode 100644 dialect/oracledialect/version.go diff --git a/README.md b/README.md index 07a01aa61..dbe5bc0b4 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# SQL-first Golang ORM for PostgreSQL, MySQL, MSSQL, and SQLite +# SQL-first Golang ORM for PostgreSQL, MySQL, MSSQL, SQLite and Oracle [![build workflow](https://github.com/uptrace/bun/actions/workflows/build.yml/badge.svg)](https://github.com/uptrace/bun/actions) [![PkgGoDev](https://pkg.go.dev/badge/github.com/uptrace/bun)](https://pkg.go.dev/github.com/uptrace/bun) @@ -19,6 +19,7 @@ [MySQL](https://bun.uptrace.dev/guide/drivers.html#mysql) (including MariaDB), [MSSQL](https://bun.uptrace.dev/guide/drivers.html#mssql), [SQLite](https://bun.uptrace.dev/guide/drivers.html#sqlite). + [Oracle](https://bun.uptrace.dev/guide/drivers.html#oracle). - [ORM-like](/example/basic/) experience using good old SQL. Bun supports structs, map, scalars, and slices of map/structs/scalars. - [Bulk inserts](https://bun.uptrace.dev/guide/query-insert.html). diff --git a/dialect/dialect.go b/dialect/dialect.go index 03b81fbbc..4dde63c92 100644 --- a/dialect/dialect.go +++ b/dialect/dialect.go @@ -12,6 +12,8 @@ func (n Name) String() string { return "mysql" case MSSQL: return "mssql" + case Oracle: + return "oracle" default: return "invalid" } @@ -23,4 +25,5 @@ const ( SQLite MySQL MSSQL + Oracle ) diff --git a/dialect/oracledialect/dialect.go b/dialect/oracledialect/dialect.go new file mode 100644 index 000000000..cc4806b3b --- /dev/null +++ b/dialect/oracledialect/dialect.go @@ -0,0 +1,126 @@ +package oracledialect + +import ( + "database/sql" + "encoding/hex" + "fmt" + "time" + + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect" + "github.com/uptrace/bun/dialect/feature" + "github.com/uptrace/bun/dialect/sqltype" + "github.com/uptrace/bun/schema" +) + +func init() { + if Version() != bun.Version() { + panic(fmt.Errorf("oracledialect and Bun must have the same version: v%s != v%s", + Version(), bun.Version())) + } +} + +type Dialect struct { + schema.BaseDialect + + tables *schema.Tables + features feature.Feature +} + +func New() *Dialect { + d := new(Dialect) + d.tables = schema.NewTables(d) + d.features = feature.CTE | + feature.WithValues | + feature.Returning | + //feature.InsertReturning | // TODO + //feature.Output | // TODO + feature.InsertOnConflict | + //feature.TableNotExists | + feature.SelectExists | + feature.AutoIncrement | + feature.CompositeIn + return d +} + +func (d *Dialect) Init(*sql.DB) {} + +func (d *Dialect) Name() dialect.Name { + return dialect.Oracle +} + +func (d *Dialect) Features() feature.Feature { + return d.features +} + +func (d *Dialect) Tables() *schema.Tables { + return d.tables +} + +func (d *Dialect) OnTable(table *schema.Table) { + for _, field := range table.FieldMap { + d.onField(field) + } +} + +func (d *Dialect) onField(field *schema.Field) { + field.DiscoveredSQLType = fieldSQLType(field) +} + +func (d *Dialect) IdentQuote() byte { + return '"' +} + +func (*Dialect) AppendBytes(b, bs []byte) []byte { + if bs == nil { + return dialect.AppendNull(b) + } + + b = append(b, "0x"...) + + s := len(b) + b = append(b, make([]byte, hex.EncodedLen(len(bs)))...) + hex.Encode(b[s:], bs) + + return b +} + +func (d *Dialect) DefaultVarcharLen() int { + return 255 +} + +func (d *Dialect) AppendSequence(b []byte, table *schema.Table, field *schema.Field) []byte { + return append(b, " GENERATED BY DEFAULT AS IDENTITY"...) +} + +func fieldSQLType(field *schema.Field) string { + switch field.DiscoveredSQLType { + case sqltype.SmallInt, sqltype.BigInt: + // INTEGER PRIMARY KEY is an alias for the ROWID. + // It is safe to convert all ints to INTEGER, because SQLite types don't have size. + return sqltype.Integer + case sqltype.Boolean: + return "number(1,0)" + default: + return field.DiscoveredSQLType + } +} + +func (*Dialect) AppendTime(b []byte, tm time.Time) []byte { + if tm.IsZero() { + b = append(b, "NULL"...) + return b + } + b = append(b, "TO_TIMESTAMP('"...) + b = tm.AppendFormat(b, "2006-01-02 15:04:05.999999") + b = append(b, "', 'YYYY-MM-DD HH24:MI:SS.FF')"...) + return b +} + +func (*Dialect) AppendBool(b []byte, v bool) []byte { + if v { + return append(b, '1') + } + + return append(b, '0') +} diff --git a/dialect/oracledialect/scan.go b/dialect/oracledialect/scan.go new file mode 100644 index 000000000..7239792b2 --- /dev/null +++ b/dialect/oracledialect/scan.go @@ -0,0 +1,11 @@ +package oracledialect + +import ( + "reflect" + + "github.com/uptrace/bun/schema" +) + +func scanner(typ reflect.Type) schema.ScannerFunc { + return schema.Scanner(typ) +} diff --git a/dialect/oracledialect/version.go b/dialect/oracledialect/version.go new file mode 100644 index 000000000..49d68c8a5 --- /dev/null +++ b/dialect/oracledialect/version.go @@ -0,0 +1,6 @@ +package oracledialect + +// Version is the current release version. +func Version() string { + return "1.2.1" +} diff --git a/query_base.go b/query_base.go index 2321a7537..8a26a4c8a 100644 --- a/query_base.go +++ b/query_base.go @@ -8,6 +8,7 @@ import ( "fmt" "time" + "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" @@ -418,7 +419,11 @@ func (q *baseQuery) _appendTables( } else { b = fmter.AppendQuery(b, string(q.table.SQLNameForSelects)) if withAlias && q.table.SQLAlias != q.table.SQLNameForSelects { - b = append(b, " AS "...) + if q.db.dialect.Name() == dialect.Oracle { + b = append(b, ' ') + } else { + b = append(b, " AS "...) + } b = append(b, q.table.SQLAlias...) } } diff --git a/query_table_create.go b/query_table_create.go index 3d98da07b..9b844e58f 100644 --- a/query_table_create.go +++ b/query_table_create.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" + "github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/dialect/sqltype" "github.com/uptrace/bun/internal" @@ -165,7 +166,7 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by b = append(b, field.SQLName...) b = append(b, " "...) b = q.appendSQLType(b, field) - if field.NotNull { + if field.NotNull && q.db.dialect.Name() != dialect.Oracle { b = append(b, " NOT NULL"...) } @@ -246,7 +247,11 @@ func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte { return append(b, field.CreateTableSQLType...) } - b = append(b, sqltype.VarChar...) + if q.db.dialect.Name() == dialect.Oracle { + b = append(b, "VARCHAR2"...) + } else { + b = append(b, sqltype.VarChar...) + } b = append(b, "("...) b = strconv.AppendInt(b, int64(q.varchar), 10) b = append(b, ")"...)