From e84773bf0782dcb48224a3ee74e8d7e89c42f9b4 Mon Sep 17 00:00:00 2001 From: Bowen Date: Sat, 16 Nov 2024 22:45:30 +0800 Subject: [PATCH 1/5] feat: [#280] Implement Mysql driver --- contracts/database/orm/orm.go | 2 + contracts/database/schema/grammar.go | 4 +- contracts/database/schema/schema.go | 25 +-- database/orm/orm.go | 5 + database/schema/blueprint.go | 4 + database/schema/common_schema.go | 4 +- database/schema/grammars/mysql.go | 217 ++++++++++++++++++++++++++ database/schema/grammars/postgres.go | 15 +- database/schema/grammars/sqlite.go | 15 +- database/schema/grammars/wrap.go | 17 +- database/schema/grammars/wrap_test.go | 2 +- database/schema/mysql_schema.go | 94 +++++++++++ database/schema/processors/mysql.go | 29 ++++ database/schema/schema.go | 4 +- database/schema/schema_test.go | 16 +- mocks/database/orm/Orm.go | 45 ++++++ mocks/database/schema/Grammar.go | 42 ++--- 17 files changed, 480 insertions(+), 60 deletions(-) create mode 100644 database/schema/grammars/mysql.go create mode 100644 database/schema/mysql_schema.go create mode 100644 database/schema/processors/mysql.go diff --git a/contracts/database/orm/orm.go b/contracts/database/orm/orm.go index faabd2bb4..d730f1b83 100644 --- a/contracts/database/orm/orm.go +++ b/contracts/database/orm/orm.go @@ -14,6 +14,8 @@ type Orm interface { DB() (*sql.DB, error) // Factory gets a new factory instance for the given model name. Factory() Factory + // DatabaseName gets the current database name. + DatabaseName() string // Name gets the current connection name. Name() string // Observe registers an observer with the Orm. diff --git a/contracts/database/schema/grammar.go b/contracts/database/schema/grammar.go index cb4c4ed30..dc2644472 100644 --- a/contracts/database/schema/grammar.go +++ b/contracts/database/schema/grammar.go @@ -24,11 +24,11 @@ type Grammar interface { // CompilePrimary Compile a primary key command. CompilePrimary(blueprint Blueprint, command *Command) string // CompileTables Compile the query to determine the tables. - CompileTables() string + CompileTables(database string) string // CompileTypes Compile the query to determine the types. CompileTypes() string // CompileViews Compile the query to determine the views. - CompileViews() string + CompileViews(database string) string // GetAttributeCommands Get the commands for the schema build. GetAttributeCommands() []string // TypeBigInteger Create the column definition for a big integer type. diff --git a/contracts/database/schema/schema.go b/contracts/database/schema/schema.go index f89ee0cae..056502d5b 100644 --- a/contracts/database/schema/schema.go +++ b/contracts/database/schema/schema.go @@ -70,18 +70,19 @@ type Connection interface { } type Command struct { - Algorithm string - Column ColumnDefinition - Columns []string - From string - Index string - On string - OnDelete string - OnUpdate string - Name string - To string - References []string - Value string + Algorithm string + Column ColumnDefinition + Columns []string + From string + Index string + On string + OnDelete string + OnUpdate string + Name string + To string + References []string + ShouldBeSkipped bool + Value string } type Index struct { diff --git a/database/orm/orm.go b/database/orm/orm.go index 4cb753be9..272350b81 100644 --- a/database/orm/orm.go +++ b/database/orm/orm.go @@ -3,6 +3,7 @@ package orm import ( "context" "database/sql" + "fmt" "sync" "github.com/goravel/framework/contracts/config" @@ -89,6 +90,10 @@ func (r *Orm) Factory() contractsorm.Factory { return factory.NewFactoryImpl(r.Query()) } +func (r *Orm) DatabaseName() string { + return r.config.GetString(fmt.Sprintf("database.connections.%s.database", r.connection)) +} + func (r *Orm) Name() string { return r.connection } diff --git a/database/schema/blueprint.go b/database/schema/blueprint.go index 9102bffc4..db81f0cad 100644 --- a/database/schema/blueprint.go +++ b/database/schema/blueprint.go @@ -149,6 +149,10 @@ func (r *Blueprint) ToSql(grammar schema.Grammar) []string { var statements []string for _, command := range r.commands { + if command.ShouldBeSkipped { + continue + } + switch command.Name { case constants.CommandAdd: statements = append(statements, grammar.CompileAdd(r, command)) diff --git a/database/schema/common_schema.go b/database/schema/common_schema.go index c145f6b35..c0b2b8fb1 100644 --- a/database/schema/common_schema.go +++ b/database/schema/common_schema.go @@ -19,7 +19,7 @@ func NewCommonSchema(grammar schema.Grammar, orm orm.Orm) *CommonSchema { func (r *CommonSchema) GetTables() ([]schema.Table, error) { var tables []schema.Table - if err := r.orm.Query().Raw(r.grammar.CompileTables()).Scan(&tables); err != nil { + if err := r.orm.Query().Raw(r.grammar.CompileTables(r.orm.DatabaseName())).Scan(&tables); err != nil { return nil, err } @@ -28,7 +28,7 @@ func (r *CommonSchema) GetTables() ([]schema.Table, error) { func (r *CommonSchema) GetViews() ([]schema.View, error) { var views []schema.View - if err := r.orm.Query().Raw(r.grammar.CompileViews()).Scan(&views); err != nil { + if err := r.orm.Query().Raw(r.grammar.CompileViews(r.orm.DatabaseName())).Scan(&views); err != nil { return nil, err } diff --git a/database/schema/grammars/mysql.go b/database/schema/grammars/mysql.go new file mode 100644 index 000000000..b008e1bb3 --- /dev/null +++ b/database/schema/grammars/mysql.go @@ -0,0 +1,217 @@ +package grammars + +import ( + "fmt" + "slices" + "strings" + + contractsdatabase "github.com/goravel/framework/contracts/database" + "github.com/goravel/framework/contracts/database/schema" + "github.com/goravel/framework/database/schema/constants" +) + +type Mysql struct { + attributeCommands []string + modifiers []func(schema.Blueprint, schema.ColumnDefinition) string + serials []string + wrap *Wrap +} + +func NewMysql(tablePrefix string) *Mysql { + postgres := &Mysql{ + attributeCommands: []string{constants.CommandComment}, + serials: []string{"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger"}, + wrap: NewWrap(contractsdatabase.DriverMysql, tablePrefix), + } + postgres.modifiers = []func(schema.Blueprint, schema.ColumnDefinition) string{ + postgres.ModifyDefault, + postgres.ModifyIncrement, + postgres.ModifyNullable, + } + + return postgres +} + +func (r *Mysql) CompileAdd(blueprint schema.Blueprint, command *schema.Command) string { + return fmt.Sprintf("alter table %s add %s", r.wrap.Table(blueprint.GetTableName()), r.getColumn(blueprint, command.Column)) +} + +func (r *Mysql) CompileCreate(blueprint schema.Blueprint) string { + columns := r.getColumns(blueprint) + primaryCommand := getCommandByName(blueprint.GetCommands(), "primary") + if primaryCommand != nil { + var algorithm string + if primaryCommand.Algorithm != "" { + algorithm = "using " + primaryCommand.Algorithm + } + columns = append(columns, fmt.Sprintf("primary key %s(%s)", algorithm, r.wrap.Columnize(primaryCommand.Columns))) + + primaryCommand.ShouldBeSkipped = true + } + + return fmt.Sprintf("create table %s (%s)", r.wrap.Table(blueprint.GetTableName()), strings.Join(columns, ", ")) +} + +func (r *Mysql) CompileDisableForeignKeyConstraints() string { + return "SET FOREIGN_KEY_CHECKS=0;" +} + +func (r *Mysql) CompileDropAllDomains(domains []string) string { + return "" +} + +func (r *Mysql) CompileDropAllTables(tables []string) string { + return fmt.Sprintf("drop table %s", strings.Join(r.wrap.Columns(tables), ", ")) +} + +func (r *Mysql) CompileDropAllTypes(types []string) string { + return "" +} + +func (r *Mysql) CompileDropAllViews(views []string) string { + return fmt.Sprintf("drop view %s", strings.Join(r.wrap.Columns(views), ", ")) +} + +func (r *Mysql) CompileDropIfExists(blueprint schema.Blueprint) string { + return fmt.Sprintf("drop table if exists %s", r.wrap.Table(blueprint.GetTableName())) +} + +func (r *Mysql) CompileEnableForeignKeyConstraints() string { + return "SET FOREIGN_KEY_CHECKS=1;" +} + +func (r *Mysql) CompileForeign(blueprint schema.Blueprint, command *schema.Command) string { + sql := fmt.Sprintf("alter table %s add constraint %s foreign key (%s) references %s (%s)", + r.wrap.Table(blueprint.GetTableName()), + r.wrap.Column(command.Index), + r.wrap.Columnize(command.Columns), + r.wrap.Table(command.On), + r.wrap.Columnize(command.References)) + if command.OnDelete != "" { + sql += " on delete " + command.OnDelete + } + if command.OnUpdate != "" { + sql += " on update " + command.OnUpdate + } + + return sql +} + +func (r *Mysql) CompileIndex(blueprint schema.Blueprint, command *schema.Command) string { + var algorithm string + if command.Algorithm != "" { + algorithm = " using " + command.Algorithm + } + + return fmt.Sprintf("alter table %s add %s %s%s(%s)", + r.wrap.Table(blueprint.GetTableName()), + "index", + r.wrap.Column(command.Index), + algorithm, + r.wrap.Columnize(command.Columns), + ) +} + +func (r *Mysql) CompileIndexes(schema, table string) string { + return fmt.Sprintf( + "select index_name as `name`, group_concat(column_name order by seq_in_index) as `columns`, "+ + "index_type as `type`, not non_unique as `unique` "+ + "from information_schema.statistics where table_schema = %s and table_name = %s "+ + "group by index_name, index_type, non_unique", + r.wrap.Quote(schema), + r.wrap.Quote(table), + ) +} + +func (r *Mysql) CompilePrimary(blueprint schema.Blueprint, command *schema.Command) string { + var algorithm string + if command.Algorithm != "" { + algorithm = "using " + command.Algorithm + } + + return fmt.Sprintf("alter table %s add primary key %s(%s)", r.wrap.Table(blueprint.GetTableName()), algorithm, r.wrap.Columnize(command.Columns)) +} + +func (r *Mysql) CompileTables(database string) string { + return fmt.Sprintf("select table_name as `name`, (data_length + index_length) as `size`, "+ + "table_comment as `comment`, engine as `engine`, table_collation as `collation` "+ + "from information_schema.tables where table_schema = %s and table_type in ('BASE TABLE', 'SYSTEM VERSIONED') "+ + "order by table_name", r.wrap.Quote(database)) +} + +func (r *Mysql) CompileTypes() string { + return "" +} + +func (r *Mysql) CompileViews(database string) string { + return fmt.Sprintf("select table_name as `name`, view_definition as `definition` "+ + "from information_schema.views where table_schema = %s "+ + "order by table_name", r.wrap.Quote(database)) +} + +func (r *Mysql) GetAttributeCommands() []string { + return r.attributeCommands +} + +func (r *Mysql) ModifyDefault(blueprint schema.Blueprint, column schema.ColumnDefinition) string { + if column.GetDefault() != nil { + return fmt.Sprintf(" default %s", getDefaultValue(column.GetDefault())) + } + + return "" +} + +func (r *Mysql) ModifyNullable(blueprint schema.Blueprint, column schema.ColumnDefinition) string { + if column.GetNullable() { + return " null" + } else { + return " not null" + } +} + +func (r *Mysql) ModifyIncrement(blueprint schema.Blueprint, column schema.ColumnDefinition) string { + if slices.Contains(r.serials, column.GetType()) && column.GetAutoIncrement() { + if blueprint.HasCommand("primary") { + return "auto_increment" + } + return " auto_increment primary key" + } + + return "" +} + +func (r *Mysql) TypeBigInteger(column schema.ColumnDefinition) string { + return "bigint" +} + +func (r *Mysql) TypeInteger(column schema.ColumnDefinition) string { + return "int" +} + +func (r *Mysql) TypeString(column schema.ColumnDefinition) string { + length := column.GetLength() + if length > 0 { + return fmt.Sprintf("varchar(%d)", length) + } + + return "varchar" +} + +func (r *Mysql) getColumns(blueprint schema.Blueprint) []string { + var columns []string + for _, column := range blueprint.GetAddedColumns() { + columns = append(columns, r.getColumn(blueprint, column)) + } + + return columns +} + +func (r *Mysql) getColumn(blueprint schema.Blueprint, column schema.ColumnDefinition) string { + sql := fmt.Sprintf("%s %s", r.wrap.Column(column.GetName()), getType(r, column)) + + for _, modifier := range r.modifiers { + sql += modifier(blueprint, column) + } + + return sql +} diff --git a/database/schema/grammars/postgres.go b/database/schema/grammars/postgres.go index 24a4e39c1..3cab01978 100644 --- a/database/schema/grammars/postgres.go +++ b/database/schema/grammars/postgres.go @@ -5,6 +5,7 @@ import ( "slices" "strings" + contractsdatabase "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/contracts/database/schema" "github.com/goravel/framework/database/schema/constants" ) @@ -20,7 +21,7 @@ func NewPostgres(tablePrefix string) *Postgres { postgres := &Postgres{ attributeCommands: []string{constants.CommandComment}, serials: []string{"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger"}, - wrap: NewWrap(tablePrefix), + wrap: NewWrap(contractsdatabase.DriverPostgres, tablePrefix), } postgres.modifiers = []func(schema.Blueprint, schema.ColumnDefinition) string{ postgres.ModifyDefault, @@ -63,9 +64,9 @@ func (r *Postgres) CompileForeign(blueprint schema.Blueprint, command *schema.Co sql := fmt.Sprintf("alter table %s add constraint %s foreign key (%s) references %s (%s)", r.wrap.Table(blueprint.GetTableName()), r.wrap.Column(command.Index), - r.wrap.Columns(command.Columns), + r.wrap.Columnize(command.Columns), r.wrap.Table(command.On), - r.wrap.Columns(command.References)) + r.wrap.Columnize(command.References)) if command.OnDelete != "" { sql += " on delete " + command.OnDelete } @@ -86,7 +87,7 @@ func (r *Postgres) CompileIndex(blueprint schema.Blueprint, command *schema.Comm r.wrap.Column(command.Index), r.wrap.Table(blueprint.GetTableName()), algorithm, - r.wrap.Columns(command.Columns), + r.wrap.Columnize(command.Columns), ) } @@ -109,10 +110,10 @@ func (r *Postgres) CompileIndexes(schema, table string) string { } func (r *Postgres) CompilePrimary(blueprint schema.Blueprint, command *schema.Command) string { - return fmt.Sprintf("alter table %s add primary key (%s)", r.wrap.Table(blueprint.GetTableName()), r.wrap.Columns(command.Columns)) + return fmt.Sprintf("alter table %s add primary key (%s)", r.wrap.Table(blueprint.GetTableName()), r.wrap.Columnize(command.Columns)) } -func (r *Postgres) CompileTables() string { +func (r *Postgres) CompileTables(database string) string { return "select c.relname as name, n.nspname as schema, pg_total_relation_size(c.oid) as size, " + "obj_description(c.oid, 'pg_class') as comment from pg_class c, pg_namespace n " + "where c.relkind in ('r', 'p') and n.oid = c.relnamespace and n.nspname not in ('pg_catalog', 'information_schema') " + @@ -132,7 +133,7 @@ func (r *Postgres) CompileTypes() string { and n.nspname not in ('pg_catalog', 'information_schema')` } -func (r *Postgres) CompileViews() string { +func (r *Postgres) CompileViews(database string) string { return "select viewname as name, schemaname as schema, definition from pg_views where schemaname not in ('pg_catalog', 'information_schema') order by viewname" } diff --git a/database/schema/grammars/sqlite.go b/database/schema/grammars/sqlite.go index ad4009889..07228fdd7 100644 --- a/database/schema/grammars/sqlite.go +++ b/database/schema/grammars/sqlite.go @@ -5,6 +5,7 @@ import ( "slices" "strings" + contractsdatabase "github.com/goravel/framework/contracts/database" "github.com/goravel/framework/contracts/database/schema" ) @@ -21,7 +22,7 @@ func NewSqlite(tablePrefix string) *Sqlite { attributeCommands: []string{}, serials: []string{"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger"}, tablePrefix: tablePrefix, - wrap: NewWrap(tablePrefix), + wrap: NewWrap(contractsdatabase.DriverSqlite, tablePrefix), } sqlite.modifiers = []func(schema.Blueprint, schema.ColumnDefinition) string{ sqlite.ModifyDefault, @@ -80,7 +81,7 @@ func (r *Sqlite) CompileIndex(blueprint schema.Blueprint, command *schema.Comman return fmt.Sprintf("create index %s on %s (%s)", r.wrap.Column(command.Index), r.wrap.Table(blueprint.GetTableName()), - r.wrap.Columns(command.Columns), + r.wrap.Columnize(command.Columns), ) } @@ -106,7 +107,7 @@ func (r *Sqlite) CompileRebuild() string { return "vacuum" } -func (r *Sqlite) CompileTables() string { +func (r *Sqlite) CompileTables(database string) string { return "select name from sqlite_master where type = 'table' and name not like 'sqlite_%' order by name" } @@ -114,7 +115,7 @@ func (r *Sqlite) CompileTypes() string { return "" } -func (r *Sqlite) CompileViews() string { +func (r *Sqlite) CompileViews(database string) string { return "select name, sql as definition from sqlite_master where type = 'view' order by name" } @@ -177,7 +178,7 @@ func (r *Sqlite) addPrimaryKeys(command *schema.Command) string { return "" } - return fmt.Sprintf(", primary key (%s)", r.wrap.Columns(command.Columns)) + return fmt.Sprintf(", primary key (%s)", r.wrap.Columnize(command.Columns)) } func (r *Sqlite) getColumns(blueprint schema.Blueprint) []string { @@ -201,9 +202,9 @@ func (r *Sqlite) getColumn(blueprint schema.Blueprint, column schema.ColumnDefin func (r *Sqlite) getForeignKey(command *schema.Command) string { sql := fmt.Sprintf(", foreign key(%s) references %s(%s)", - r.wrap.Columns(command.Columns), + r.wrap.Columnize(command.Columns), r.wrap.Table(command.On), - r.wrap.Columns(command.References)) + r.wrap.Columnize(command.References)) if command.OnDelete != "" { sql += " on delete " + command.OnDelete diff --git a/database/schema/grammars/wrap.go b/database/schema/grammars/wrap.go index e8eddb710..f49d6d300 100644 --- a/database/schema/grammars/wrap.go +++ b/database/schema/grammars/wrap.go @@ -3,14 +3,18 @@ package grammars import ( "fmt" "strings" + + contractsdatabase "github.com/goravel/framework/contracts/database" ) type Wrap struct { + driver contractsdatabase.Driver tablePrefix string } -func NewWrap(tablePrefix string) *Wrap { +func NewWrap(driver contractsdatabase.Driver, tablePrefix string) *Wrap { return &Wrap{ + driver: driver, tablePrefix: tablePrefix, } } @@ -23,11 +27,17 @@ func (r *Wrap) Column(column string) string { return r.Segments(strings.Split(column, ".")) } -func (r *Wrap) Columns(columns []string) string { +func (r *Wrap) Columns(columns []string) []string { for i, column := range columns { columns[i] = r.Column(column) } + return columns +} + +func (r *Wrap) Columnize(columns []string) string { + columns = r.Columns(columns) + return strings.Join(columns, ", ") } @@ -67,6 +77,9 @@ func (r *Wrap) Table(table string) string { func (r *Wrap) Value(value string) string { if value != "*" { + if r.driver == contractsdatabase.DriverMysql { + return "`" + strings.ReplaceAll(value, "`", "``") + "`" + } return `"` + strings.ReplaceAll(value, `"`, `""`) + `"` } diff --git a/database/schema/grammars/wrap_test.go b/database/schema/grammars/wrap_test.go index 3c80b8cf4..b43ae581b 100644 --- a/database/schema/grammars/wrap_test.go +++ b/database/schema/grammars/wrap_test.go @@ -30,7 +30,7 @@ func (s *WrapTestSuite) ColumnWithoutAlias() { } func (s *WrapTestSuite) ColumnsWithMultipleColumns() { - result := s.wrap.Columns([]string{"column1", "column2 as alias2"}) + result := s.wrap.Columnize([]string{"column1", "column2 as alias2"}) s.Equal(`"column1", "column2" as "prefix_alias2"`, result) } diff --git a/database/schema/mysql_schema.go b/database/schema/mysql_schema.go new file mode 100644 index 000000000..6259d7bd1 --- /dev/null +++ b/database/schema/mysql_schema.go @@ -0,0 +1,94 @@ +package schema + +import ( + "github.com/goravel/framework/contracts/database/orm" + contractsschema "github.com/goravel/framework/contracts/database/schema" + "github.com/goravel/framework/database/schema/grammars" + "github.com/goravel/framework/database/schema/processors" +) + +type MysqlSchema struct { + contractsschema.CommonSchema + + grammar *grammars.Mysql + orm orm.Orm + prefix string + processor processors.Mysql +} + +func NewMysqlSchema(grammar *grammars.Mysql, orm orm.Orm, prefix string) *MysqlSchema { + return &MysqlSchema{ + CommonSchema: NewCommonSchema(grammar, orm), + grammar: grammar, + orm: orm, + prefix: prefix, + processor: processors.NewMysql(), + } +} + +func (r *MysqlSchema) DropAllTables() error { + tables, err := r.GetTables() + if err != nil { + return err + } + + if len(tables) == 0 { + return nil + } + + if _, err = r.orm.Query().Exec(r.grammar.CompileDisableForeignKeyConstraints()); err != nil { + return err + } + + var dropTables []string + for _, table := range tables { + dropTables = append(dropTables, table.Name) + } + if _, err = r.orm.Query().Exec(r.grammar.CompileDropAllTables(dropTables)); err != nil { + return err + } + + if _, err = r.orm.Query().Exec(r.grammar.CompileEnableForeignKeyConstraints()); err != nil { + return err + } + + return err +} + +func (r *MysqlSchema) DropAllTypes() error { + return nil +} + +func (r *MysqlSchema) DropAllViews() error { + views, err := r.GetViews() + if err != nil { + return err + } + if len(views) == 0 { + return nil + } + + var dropViews []string + for _, view := range views { + dropViews = append(dropViews, view.Name) + } + + _, err = r.orm.Query().Exec(r.grammar.CompileDropAllViews(dropViews)) + + return err +} + +func (r *MysqlSchema) GetIndexes(table string) ([]contractsschema.Index, error) { + table = r.prefix + table + + var dbIndexes []processors.DBIndex + if err := r.orm.Query().Raw(r.grammar.CompileIndexes(r.orm.DatabaseName(), table)).Scan(&dbIndexes); err != nil { + return nil, err + } + + return r.processor.ProcessIndexes(dbIndexes), nil +} + +func (r *MysqlSchema) GetTypes() ([]contractsschema.Type, error) { + return nil, nil +} diff --git a/database/schema/processors/mysql.go b/database/schema/processors/mysql.go new file mode 100644 index 000000000..22ee8ed0d --- /dev/null +++ b/database/schema/processors/mysql.go @@ -0,0 +1,29 @@ +package processors + +import ( + "strings" + + "github.com/goravel/framework/contracts/database/schema" +) + +type Mysql struct { +} + +func NewMysql() Mysql { + return Mysql{} +} + +func (r Mysql) ProcessIndexes(dbIndexes []DBIndex) []schema.Index { + var indexes []schema.Index + for _, dbIndex := range dbIndexes { + indexes = append(indexes, schema.Index{ + Columns: strings.Split(dbIndex.Columns, ","), + Name: strings.ToLower(dbIndex.Name), + Type: strings.ToLower(dbIndex.Type), + Primary: dbIndex.Primary, + Unique: dbIndex.Unique, + }) + } + + return indexes +} diff --git a/database/schema/schema.go b/database/schema/schema.go index ae5327dad..a3a92d5e0 100644 --- a/database/schema/schema.go +++ b/database/schema/schema.go @@ -45,7 +45,9 @@ func NewSchema(config config.Config, log log.Log, orm contractsorm.Orm, migratio driverSchema = NewPostgresSchema(postgresGrammar, orm, schema, prefix) grammar = postgresGrammar case contractsdatabase.DriverMysql: - // TODO Optimize here when implementing Mysql driver + mysqlGrammar := grammars.NewMysql(prefix) + driverSchema = NewMysqlSchema(mysqlGrammar, orm, prefix) + grammar = mysqlGrammar case contractsdatabase.DriverSqlserver: // TODO Optimize here when implementing Sqlserver driver case contractsdatabase.DriverSqlite: diff --git a/database/schema/schema_test.go b/database/schema/schema_test.go index fbe25462d..5c6189f2b 100644 --- a/database/schema/schema_test.go +++ b/database/schema/schema_test.go @@ -27,15 +27,19 @@ func TestSchemaSuite(t *testing.T) { func (s *SchemaSuite) SetupTest() { // TODO Add other drivers - postgresDocker := docker.Postgres() - postgresQuery := gorm.NewTestQuery(postgresDocker, true) + //postgresDocker := docker.Postgres() + //postgresQuery := gorm.NewTestQuery(postgresDocker, true) + // + //sqliteDocker := docker.Sqlite() + //sqliteQuery := gorm.NewTestQuery(sqliteDocker, true) - sqliteDocker := docker.Sqlite() - sqliteQuery := gorm.NewTestQuery(sqliteDocker, true) + mysqlDocker := docker.Mysql() + mysqlQuery := gorm.NewTestQuery(mysqlDocker, true) s.driverToTestQuery = map[database.Driver]*gorm.TestQuery{ - database.DriverPostgres: postgresQuery, - database.DriverSqlite: sqliteQuery, + //database.DriverPostgres: postgresQuery, + //database.DriverSqlite: sqliteQuery, + database.DriverMysql: mysqlQuery, } } diff --git a/mocks/database/orm/Orm.go b/mocks/database/orm/Orm.go index efdbd61d4..c5ca83a50 100644 --- a/mocks/database/orm/Orm.go +++ b/mocks/database/orm/Orm.go @@ -129,6 +129,51 @@ func (_c *Orm_DB_Call) RunAndReturn(run func() (*sql.DB, error)) *Orm_DB_Call { return _c } +// DatabaseName provides a mock function with given fields: +func (_m *Orm) DatabaseName() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for DatabaseName") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// Orm_DatabaseName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DatabaseName' +type Orm_DatabaseName_Call struct { + *mock.Call +} + +// DatabaseName is a helper method to define mock.On call +func (_e *Orm_Expecter) DatabaseName() *Orm_DatabaseName_Call { + return &Orm_DatabaseName_Call{Call: _e.mock.On("DatabaseName")} +} + +func (_c *Orm_DatabaseName_Call) Run(run func()) *Orm_DatabaseName_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Orm_DatabaseName_Call) Return(_a0 string) *Orm_DatabaseName_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Orm_DatabaseName_Call) RunAndReturn(run func() string) *Orm_DatabaseName_Call { + _c.Call.Return(run) + return _c +} + // Factory provides a mock function with given fields: func (_m *Orm) Factory() orm.Factory { ret := _m.Called() diff --git a/mocks/database/schema/Grammar.go b/mocks/database/schema/Grammar.go index 73fd8ddcb..0cc577b29 100644 --- a/mocks/database/schema/Grammar.go +++ b/mocks/database/schema/Grammar.go @@ -531,17 +531,17 @@ func (_c *Grammar_CompilePrimary_Call) RunAndReturn(run func(schema.Blueprint, * return _c } -// CompileTables provides a mock function with given fields: -func (_m *Grammar) CompileTables() string { - ret := _m.Called() +// CompileTables provides a mock function with given fields: database +func (_m *Grammar) CompileTables(database string) string { + ret := _m.Called(database) if len(ret) == 0 { panic("no return value specified for CompileTables") } var r0 string - if rf, ok := ret.Get(0).(func() string); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(database) } else { r0 = ret.Get(0).(string) } @@ -555,13 +555,14 @@ type Grammar_CompileTables_Call struct { } // CompileTables is a helper method to define mock.On call -func (_e *Grammar_Expecter) CompileTables() *Grammar_CompileTables_Call { - return &Grammar_CompileTables_Call{Call: _e.mock.On("CompileTables")} +// - database string +func (_e *Grammar_Expecter) CompileTables(database interface{}) *Grammar_CompileTables_Call { + return &Grammar_CompileTables_Call{Call: _e.mock.On("CompileTables", database)} } -func (_c *Grammar_CompileTables_Call) Run(run func()) *Grammar_CompileTables_Call { +func (_c *Grammar_CompileTables_Call) Run(run func(database string)) *Grammar_CompileTables_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(string)) }) return _c } @@ -571,7 +572,7 @@ func (_c *Grammar_CompileTables_Call) Return(_a0 string) *Grammar_CompileTables_ return _c } -func (_c *Grammar_CompileTables_Call) RunAndReturn(run func() string) *Grammar_CompileTables_Call { +func (_c *Grammar_CompileTables_Call) RunAndReturn(run func(string) string) *Grammar_CompileTables_Call { _c.Call.Return(run) return _c } @@ -621,17 +622,17 @@ func (_c *Grammar_CompileTypes_Call) RunAndReturn(run func() string) *Grammar_Co return _c } -// CompileViews provides a mock function with given fields: -func (_m *Grammar) CompileViews() string { - ret := _m.Called() +// CompileViews provides a mock function with given fields: database +func (_m *Grammar) CompileViews(database string) string { + ret := _m.Called(database) if len(ret) == 0 { panic("no return value specified for CompileViews") } var r0 string - if rf, ok := ret.Get(0).(func() string); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(database) } else { r0 = ret.Get(0).(string) } @@ -645,13 +646,14 @@ type Grammar_CompileViews_Call struct { } // CompileViews is a helper method to define mock.On call -func (_e *Grammar_Expecter) CompileViews() *Grammar_CompileViews_Call { - return &Grammar_CompileViews_Call{Call: _e.mock.On("CompileViews")} +// - database string +func (_e *Grammar_Expecter) CompileViews(database interface{}) *Grammar_CompileViews_Call { + return &Grammar_CompileViews_Call{Call: _e.mock.On("CompileViews", database)} } -func (_c *Grammar_CompileViews_Call) Run(run func()) *Grammar_CompileViews_Call { +func (_c *Grammar_CompileViews_Call) Run(run func(database string)) *Grammar_CompileViews_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(string)) }) return _c } @@ -661,7 +663,7 @@ func (_c *Grammar_CompileViews_Call) Return(_a0 string) *Grammar_CompileViews_Ca return _c } -func (_c *Grammar_CompileViews_Call) RunAndReturn(run func() string) *Grammar_CompileViews_Call { +func (_c *Grammar_CompileViews_Call) RunAndReturn(run func(string) string) *Grammar_CompileViews_Call { _c.Call.Return(run) return _c } From ffa84e42d59b0a7df918259e5c3aa59147d35b6c Mon Sep 17 00:00:00 2001 From: Bowen Date: Sun, 17 Nov 2024 11:31:31 +0800 Subject: [PATCH 2/5] Add unit tests --- contracts/database/schema/blueprint.go | 10 +- database/schema/grammars/mysql.go | 4 +- database/schema/grammars/mysql_test.go | 305 +++++++++++++++++++++++ database/schema/grammars/wrap_test.go | 30 ++- database/schema/processors/mysql_test.go | 33 +++ database/schema/schema_test.go | 18 +- mocks/database/schema/Blueprint.go | 144 +++++++++++ 7 files changed, 520 insertions(+), 24 deletions(-) create mode 100644 database/schema/grammars/mysql_test.go create mode 100644 database/schema/processors/mysql_test.go diff --git a/contracts/database/schema/blueprint.go b/contracts/database/schema/blueprint.go index 1e81bc257..7acc54c0a 100644 --- a/contracts/database/schema/blueprint.go +++ b/contracts/database/schema/blueprint.go @@ -5,6 +5,10 @@ import ( ) type Blueprint interface { + // BigIncrements Create a new auto-incrementing big integer (8-byte) column on the table. + BigIncrements(column string) ColumnDefinition + // BigInteger Create a new big integer (8-byte) column on the table. + BigInteger(column string) ColumnDefinition // Build Execute the blueprint to build / modify the table. Build(query orm.Query, grammar Grammar) error // Create Indicate that the table needs to be created. @@ -21,20 +25,22 @@ type Blueprint interface { GetTableName() string // HasCommand Determine if the blueprint has a specific command. HasCommand(command string) bool - // Primary Specify the primary key(s) for the table. - Primary(column ...string) // ID Create a new auto-incrementing big integer (8-byte) column on the table. ID(column ...string) ColumnDefinition // Index Specify an index for the table. Index(column ...string) IndexDefinition // Integer Create a new integer (4-byte) column on the table. Integer(column string) ColumnDefinition + // Primary Specify the primary key(s) for the table. + Primary(column ...string) // SetTable Set the table that the blueprint operates on. SetTable(name string) // String Create a new string column on the table. String(column string, length ...int) ColumnDefinition // ToSql Get the raw SQL statements for the blueprint. ToSql(grammar Grammar) []string + // UnsignedBigInteger Create a new unsigned big integer (8-byte) column on the table. + UnsignedBigInteger(column string) ColumnDefinition } type IndexConfig struct { diff --git a/database/schema/grammars/mysql.go b/database/schema/grammars/mysql.go index b008e1bb3..cd102ba4e 100644 --- a/database/schema/grammars/mysql.go +++ b/database/schema/grammars/mysql.go @@ -61,7 +61,7 @@ func (r *Mysql) CompileDropAllDomains(domains []string) string { } func (r *Mysql) CompileDropAllTables(tables []string) string { - return fmt.Sprintf("drop table %s", strings.Join(r.wrap.Columns(tables), ", ")) + return fmt.Sprintf("drop table %s", r.wrap.Columnize(tables)) } func (r *Mysql) CompileDropAllTypes(types []string) string { @@ -69,7 +69,7 @@ func (r *Mysql) CompileDropAllTypes(types []string) string { } func (r *Mysql) CompileDropAllViews(views []string) string { - return fmt.Sprintf("drop view %s", strings.Join(r.wrap.Columns(views), ", ")) + return fmt.Sprintf("drop view %s", r.wrap.Columnize(views)) } func (r *Mysql) CompileDropIfExists(blueprint schema.Blueprint) string { diff --git a/database/schema/grammars/mysql_test.go b/database/schema/grammars/mysql_test.go new file mode 100644 index 000000000..9b4c600df --- /dev/null +++ b/database/schema/grammars/mysql_test.go @@ -0,0 +1,305 @@ +package grammars + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + contractsschema "github.com/goravel/framework/contracts/database/schema" + mocksschema "github.com/goravel/framework/mocks/database/schema" +) + +type MysqlSuite struct { + suite.Suite + grammar *Mysql +} + +func TestMysqlSuite(t *testing.T) { + suite.Run(t, &MysqlSuite{}) +} + +func (s *MysqlSuite) SetupTest() { + s.grammar = NewMysql("goravel_") +} + +func (s *MysqlSuite) TestCompileAdd() { + mockBlueprint := mocksschema.NewBlueprint(s.T()) + mockColumn := mocksschema.NewColumnDefinition(s.T()) + + mockBlueprint.EXPECT().GetTableName().Return("users").Once() + mockColumn.EXPECT().GetName().Return("name").Once() + mockColumn.EXPECT().GetType().Return("string").Twice() + mockColumn.EXPECT().GetDefault().Return("goravel").Twice() + mockColumn.EXPECT().GetNullable().Return(false).Once() + mockColumn.EXPECT().GetLength().Return(1).Once() + + sql := s.grammar.CompileAdd(mockBlueprint, &contractsschema.Command{ + Column: mockColumn, + }) + + s.Equal("alter table `goravel_users` add `name` varchar(1) default 'goravel' not null", sql) +} + +func (s *MysqlSuite) TestCompileCreate() { + mockColumn1 := mocksschema.NewColumnDefinition(s.T()) + mockColumn2 := mocksschema.NewColumnDefinition(s.T()) + mockBlueprint := mocksschema.NewBlueprint(s.T()) + + // postgres.go::CompileCreate + primaryCommand := &contractsschema.Command{ + Name: "primary", + Columns: []string{"role_id", "user_id"}, + Algorithm: "btree", + } + mockBlueprint.EXPECT().GetCommands().Return([]*contractsschema.Command{ + primaryCommand, + }).Once() + mockBlueprint.EXPECT().GetTableName().Return("users").Once() + // utils.go::getColumns + mockBlueprint.EXPECT().GetAddedColumns().Return([]contractsschema.ColumnDefinition{ + mockColumn1, mockColumn2, + }).Once() + // utils.go::getColumns + mockColumn1.EXPECT().GetName().Return("id").Once() + // utils.go::getType + mockColumn1.EXPECT().GetType().Return("integer").Once() + // postgres.go::TypeInteger + mockColumn1.EXPECT().GetAutoIncrement().Return(true).Once() + // postgres.go::ModifyDefault + mockColumn1.EXPECT().GetDefault().Return(nil).Once() + // postgres.go::ModifyIncrement + mockBlueprint.EXPECT().HasCommand("primary").Return(false).Once() + mockColumn1.EXPECT().GetType().Return("integer").Once() + // postgres.go::ModifyNullable + mockColumn1.EXPECT().GetNullable().Return(false).Once() + + // utils.go::getColumns + mockColumn2.EXPECT().GetName().Return("name").Once() + // utils.go::getType + mockColumn2.EXPECT().GetType().Return("string").Once() + // postgres.go::TypeString + mockColumn2.EXPECT().GetLength().Return(100).Once() + // postgres.go::ModifyDefault + mockColumn2.EXPECT().GetDefault().Return(nil).Once() + // postgres.go::ModifyIncrement + mockColumn2.EXPECT().GetType().Return("string").Once() + // postgres.go::ModifyNullable + mockColumn2.EXPECT().GetNullable().Return(true).Once() + + s.Equal("create table `goravel_users` (`id` int auto_increment primary key not null, `name` varchar(100) null, primary key using btree(`role_id`, `user_id`))", + s.grammar.CompileCreate(mockBlueprint)) + s.True(primaryCommand.ShouldBeSkipped) +} + +func (s *MysqlSuite) TestCompileDropAllTables() { + s.Equal("drop table `domain`, `email`", s.grammar.CompileDropAllTables([]string{"domain", "email"})) +} + +func (s *MysqlSuite) TestCompileDropAllViews() { + s.Equal("drop view `domain`, `email`", s.grammar.CompileDropAllViews([]string{"domain", "email"})) +} + +func (s *MysqlSuite) TestCompileDropIfExists() { + mockBlueprint := mocksschema.NewBlueprint(s.T()) + mockBlueprint.EXPECT().GetTableName().Return("users").Once() + + s.Equal("drop table if exists `goravel_users`", s.grammar.CompileDropIfExists(mockBlueprint)) +} + +func (s *MysqlSuite) TestCompileForeign() { + var mockBlueprint *mocksschema.Blueprint + + beforeEach := func() { + mockBlueprint = mocksschema.NewBlueprint(s.T()) + mockBlueprint.EXPECT().GetTableName().Return("users").Once() + } + + tests := []struct { + name string + command *contractsschema.Command + expectSql string + }{ + { + name: "with on delete and on update", + command: &contractsschema.Command{ + Index: "fk_users_role_id", + Columns: []string{"role_id", "user_id"}, + On: "roles", + References: []string{"id", "user_id"}, + OnDelete: "cascade", + OnUpdate: "restrict", + }, + expectSql: "alter table `goravel_users` add constraint `fk_users_role_id` foreign key (`role_id`, `user_id`) references `goravel_roles` (`id`, `user_id`) on delete cascade on update restrict", + }, + { + name: "without on delete and on update", + command: &contractsschema.Command{ + Index: "fk_users_role_id", + Columns: []string{"role_id", "user_id"}, + On: "roles", + References: []string{"id", "user_id"}, + }, + expectSql: "alter table `goravel_users` add constraint `fk_users_role_id` foreign key (`role_id`, `user_id`) references `goravel_roles` (`id`, `user_id`)", + }, + } + + for _, test := range tests { + s.Run(test.name, func() { + beforeEach() + + sql := s.grammar.CompileForeign(mockBlueprint, test.command) + s.Equal(test.expectSql, sql) + }) + } +} + +func (s *MysqlSuite) TestCompileIndex() { + var mockBlueprint *mocksschema.Blueprint + + beforeEach := func() { + mockBlueprint = mocksschema.NewBlueprint(s.T()) + mockBlueprint.EXPECT().GetTableName().Return("users").Once() + } + + tests := []struct { + name string + command *contractsschema.Command + expectSql string + }{ + { + name: "with Algorithm", + command: &contractsschema.Command{ + Index: "fk_users_role_id", + Columns: []string{"role_id", "user_id"}, + Algorithm: "btree", + }, + expectSql: "alter table `goravel_users` add index `fk_users_role_id` using btree(`role_id`, `user_id`)", + }, + { + name: "without Algorithm", + command: &contractsschema.Command{ + Index: "fk_users_role_id", + Columns: []string{"role_id", "user_id"}, + }, + expectSql: "alter table `goravel_users` add index `fk_users_role_id`(`role_id`, `user_id`)", + }, + } + + for _, test := range tests { + s.Run(test.name, func() { + beforeEach() + + sql := s.grammar.CompileIndex(mockBlueprint, test.command) + s.Equal(test.expectSql, sql) + }) + } +} + +func (s *MysqlSuite) TestCompilePrimary() { + mockBlueprint := mocksschema.NewBlueprint(s.T()) + mockBlueprint.EXPECT().GetTableName().Return("users").Once() + + s.Equal("alter table `goravel_users` add primary key (`role_id`, `user_id`)", s.grammar.CompilePrimary(mockBlueprint, &contractsschema.Command{ + Columns: []string{"role_id", "user_id"}, + })) +} + +func (s *MysqlSuite) TestGetColumns() { + mockColumn1 := mocksschema.NewColumnDefinition(s.T()) + mockColumn2 := mocksschema.NewColumnDefinition(s.T()) + mockBlueprint := mocksschema.NewBlueprint(s.T()) + + mockBlueprint.EXPECT().GetAddedColumns().Return([]contractsschema.ColumnDefinition{ + mockColumn1, mockColumn2, + }).Once() + mockBlueprint.EXPECT().HasCommand("primary").Return(false).Once() + + mockColumn1.EXPECT().GetName().Return("id").Once() + mockColumn1.EXPECT().GetType().Return("integer").Twice() + mockColumn1.EXPECT().GetDefault().Return(nil).Once() + mockColumn1.EXPECT().GetNullable().Return(false).Once() + mockColumn1.EXPECT().GetAutoIncrement().Return(true).Once() + + mockColumn2.EXPECT().GetName().Return("name").Once() + mockColumn2.EXPECT().GetType().Return("string").Twice() + mockColumn2.EXPECT().GetDefault().Return("goravel").Twice() + mockColumn2.EXPECT().GetNullable().Return(true).Once() + mockColumn2.EXPECT().GetLength().Return(10).Once() + + s.Equal([]string{"`id` int auto_increment primary key not null", "`name` varchar(10) default 'goravel' null"}, s.grammar.getColumns(mockBlueprint)) +} + +func (s *MysqlSuite) TestModifyDefault() { + var ( + mockBlueprint *mocksschema.Blueprint + mockColumn *mocksschema.ColumnDefinition + ) + + tests := []struct { + name string + setup func() + expectSql string + }{ + { + name: "without change and default is nil", + setup: func() { + mockColumn.EXPECT().GetDefault().Return(nil).Once() + }, + }, + { + name: "without change and default is not nil", + setup: func() { + mockColumn.EXPECT().GetDefault().Return("goravel").Twice() + }, + expectSql: " default 'goravel'", + }, + } + + for _, test := range tests { + s.Run(test.name, func() { + mockBlueprint = mocksschema.NewBlueprint(s.T()) + mockColumn = mocksschema.NewColumnDefinition(s.T()) + + test.setup() + + sql := s.grammar.ModifyDefault(mockBlueprint, mockColumn) + + s.Equal(test.expectSql, sql) + }) + } +} + +func (s *MysqlSuite) TestModifyNullable() { + mockBlueprint := mocksschema.NewBlueprint(s.T()) + mockColumn := mocksschema.NewColumnDefinition(s.T()) + mockColumn.EXPECT().GetNullable().Return(true).Once() + + s.Equal(" null", s.grammar.ModifyNullable(mockBlueprint, mockColumn)) + + mockColumn.EXPECT().GetNullable().Return(false).Once() + + s.Equal(" not null", s.grammar.ModifyNullable(mockBlueprint, mockColumn)) +} + +func (s *MysqlSuite) TestModifyIncrement() { + mockBlueprint := mocksschema.NewBlueprint(s.T()) + + mockColumn := mocksschema.NewColumnDefinition(s.T()) + mockBlueprint.EXPECT().HasCommand("primary").Return(false).Once() + mockColumn.EXPECT().GetType().Return("bigInteger").Once() + mockColumn.EXPECT().GetAutoIncrement().Return(true).Once() + + s.Equal(" auto_increment primary key", s.grammar.ModifyIncrement(mockBlueprint, mockColumn)) +} + +func (s *MysqlSuite) TestTypeString() { + mockColumn1 := mocksschema.NewColumnDefinition(s.T()) + mockColumn1.EXPECT().GetLength().Return(100).Once() + + s.Equal("varchar(100)", s.grammar.TypeString(mockColumn1)) + + mockColumn2 := mocksschema.NewColumnDefinition(s.T()) + mockColumn2.EXPECT().GetLength().Return(0).Once() + + s.Equal("varchar", s.grammar.TypeString(mockColumn2)) +} diff --git a/database/schema/grammars/wrap_test.go b/database/schema/grammars/wrap_test.go index b43ae581b..503e3953d 100644 --- a/database/schema/grammars/wrap_test.go +++ b/database/schema/grammars/wrap_test.go @@ -4,6 +4,8 @@ import ( "testing" "github.com/stretchr/testify/suite" + + "github.com/goravel/framework/contracts/database" ) type WrapTestSuite struct { @@ -16,55 +18,61 @@ func TestWrapSuite(t *testing.T) { } func (s *WrapTestSuite) SetupTest() { - s.wrap = NewWrap("prefix_") + s.wrap = NewWrap(database.DriverPostgres, "prefix_") } -func (s *WrapTestSuite) ColumnWithAlias() { +func (s *WrapTestSuite) TestColumnWithAlias() { result := s.wrap.Column("column as alias") s.Equal(`"column" as "prefix_alias"`, result) } -func (s *WrapTestSuite) ColumnWithoutAlias() { +func (s *WrapTestSuite) TestColumnWithoutAlias() { result := s.wrap.Column("column") s.Equal(`"column"`, result) } -func (s *WrapTestSuite) ColumnsWithMultipleColumns() { +func (s *WrapTestSuite) TestColumnsWithMultipleColumns() { result := s.wrap.Columnize([]string{"column1", "column2 as alias2"}) s.Equal(`"column1", "column2" as "prefix_alias2"`, result) } -func (s *WrapTestSuite) QuoteWithNonEmptyValue() { +func (s *WrapTestSuite) TestQuoteWithNonEmptyValue() { result := s.wrap.Quote("value") s.Equal("'value'", result) } -func (s *WrapTestSuite) QuoteWithEmptyValue() { +func (s *WrapTestSuite) TestQuoteWithEmptyValue() { result := s.wrap.Quote("") s.Equal("", result) } -func (s *WrapTestSuite) SegmentsWithMultipleSegments() { +func (s *WrapTestSuite) TestSegmentsWithMultipleSegments() { result := s.wrap.Segments([]string{"table", "column"}) s.Equal(`"prefix_table"."column"`, result) } -func (s *WrapTestSuite) TableWithAlias() { +func (s *WrapTestSuite) TestTableWithAlias() { result := s.wrap.Table("table as alias") s.Equal(`"prefix_table" as "prefix_alias"`, result) } -func (s *WrapTestSuite) TableWithoutAlias() { +func (s *WrapTestSuite) TestTableWithoutAlias() { result := s.wrap.Table("table") s.Equal(`"prefix_table"`, result) } -func (s *WrapTestSuite) ValueWithAsterisk() { +func (s *WrapTestSuite) TestValueWithAsterisk() { result := s.wrap.Value("*") s.Equal("*", result) } -func (s *WrapTestSuite) ValueWithNonAsterisk() { +func (s *WrapTestSuite) TestValueWithNonAsterisk() { result := s.wrap.Value("value") s.Equal(`"value"`, result) } + +func (s *WrapTestSuite) TestValueOfMysql() { + s.wrap.driver = database.DriverMysql + result := s.wrap.Value("value") + s.Equal("`value`", result) +} diff --git a/database/schema/processors/mysql_test.go b/database/schema/processors/mysql_test.go new file mode 100644 index 000000000..7ac53bea3 --- /dev/null +++ b/database/schema/processors/mysql_test.go @@ -0,0 +1,33 @@ +package processors + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/goravel/framework/contracts/database/schema" +) + +func TestMysqlProcessIndexes(t *testing.T) { + // Test with valid indexes + input := []DBIndex{ + {Name: "INDEX_A", Type: "BTREE", Columns: "a,b"}, + {Name: "INDEX_B", Type: "HASH", Columns: "c,d"}, + } + expected := []schema.Index{ + {Name: "index_a", Type: "btree", Columns: []string{"a", "b"}}, + {Name: "index_b", Type: "hash", Columns: []string{"c", "d"}}, + } + + postgres := NewMysql() + result := postgres.ProcessIndexes(input) + + assert.Equal(t, expected, result) + + // Test with empty input + input = []DBIndex{} + + result = postgres.ProcessIndexes(input) + + assert.Nil(t, result) +} diff --git a/database/schema/schema_test.go b/database/schema/schema_test.go index 5c6189f2b..d62677752 100644 --- a/database/schema/schema_test.go +++ b/database/schema/schema_test.go @@ -27,19 +27,19 @@ func TestSchemaSuite(t *testing.T) { func (s *SchemaSuite) SetupTest() { // TODO Add other drivers - //postgresDocker := docker.Postgres() - //postgresQuery := gorm.NewTestQuery(postgresDocker, true) - // - //sqliteDocker := docker.Sqlite() - //sqliteQuery := gorm.NewTestQuery(sqliteDocker, true) + postgresDocker := docker.Postgres() + postgresQuery := gorm.NewTestQuery(postgresDocker, true) + + sqliteDocker := docker.Sqlite() + sqliteQuery := gorm.NewTestQuery(sqliteDocker, true) mysqlDocker := docker.Mysql() mysqlQuery := gorm.NewTestQuery(mysqlDocker, true) s.driverToTestQuery = map[database.Driver]*gorm.TestQuery{ - //database.DriverPostgres: postgresQuery, - //database.DriverSqlite: sqliteQuery, - database.DriverMysql: mysqlQuery, + database.DriverPostgres: postgresQuery, + database.DriverSqlite: sqliteQuery, + database.DriverMysql: mysqlQuery, } } @@ -114,7 +114,7 @@ func (s *SchemaSuite) TestForeign() { err = schema.Create(table2, func(table contractsschema.Blueprint) { table.ID() table.String("name") - table.Integer("foreign1_id") + table.BigInteger("foreign1_id") table.Foreign("foreign1_id").References("id").On(table1) }) diff --git a/mocks/database/schema/Blueprint.go b/mocks/database/schema/Blueprint.go index 32e2aa0b7..c4c6134bd 100644 --- a/mocks/database/schema/Blueprint.go +++ b/mocks/database/schema/Blueprint.go @@ -21,6 +21,102 @@ func (_m *Blueprint) EXPECT() *Blueprint_Expecter { return &Blueprint_Expecter{mock: &_m.Mock} } +// BigIncrements provides a mock function with given fields: column +func (_m *Blueprint) BigIncrements(column string) schema.ColumnDefinition { + ret := _m.Called(column) + + if len(ret) == 0 { + panic("no return value specified for BigIncrements") + } + + var r0 schema.ColumnDefinition + if rf, ok := ret.Get(0).(func(string) schema.ColumnDefinition); ok { + r0 = rf(column) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(schema.ColumnDefinition) + } + } + + return r0 +} + +// Blueprint_BigIncrements_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BigIncrements' +type Blueprint_BigIncrements_Call struct { + *mock.Call +} + +// BigIncrements is a helper method to define mock.On call +// - column string +func (_e *Blueprint_Expecter) BigIncrements(column interface{}) *Blueprint_BigIncrements_Call { + return &Blueprint_BigIncrements_Call{Call: _e.mock.On("BigIncrements", column)} +} + +func (_c *Blueprint_BigIncrements_Call) Run(run func(column string)) *Blueprint_BigIncrements_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *Blueprint_BigIncrements_Call) Return(_a0 schema.ColumnDefinition) *Blueprint_BigIncrements_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Blueprint_BigIncrements_Call) RunAndReturn(run func(string) schema.ColumnDefinition) *Blueprint_BigIncrements_Call { + _c.Call.Return(run) + return _c +} + +// BigInteger provides a mock function with given fields: column +func (_m *Blueprint) BigInteger(column string) schema.ColumnDefinition { + ret := _m.Called(column) + + if len(ret) == 0 { + panic("no return value specified for BigInteger") + } + + var r0 schema.ColumnDefinition + if rf, ok := ret.Get(0).(func(string) schema.ColumnDefinition); ok { + r0 = rf(column) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(schema.ColumnDefinition) + } + } + + return r0 +} + +// Blueprint_BigInteger_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BigInteger' +type Blueprint_BigInteger_Call struct { + *mock.Call +} + +// BigInteger is a helper method to define mock.On call +// - column string +func (_e *Blueprint_Expecter) BigInteger(column interface{}) *Blueprint_BigInteger_Call { + return &Blueprint_BigInteger_Call{Call: _e.mock.On("BigInteger", column)} +} + +func (_c *Blueprint_BigInteger_Call) Run(run func(column string)) *Blueprint_BigInteger_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *Blueprint_BigInteger_Call) Return(_a0 schema.ColumnDefinition) *Blueprint_BigInteger_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Blueprint_BigInteger_Call) RunAndReturn(run func(string) schema.ColumnDefinition) *Blueprint_BigInteger_Call { + _c.Call.Return(run) + return _c +} + // Build provides a mock function with given fields: query, grammar func (_m *Blueprint) Build(query orm.Query, grammar schema.Grammar) error { ret := _m.Called(query, grammar) @@ -738,6 +834,54 @@ func (_c *Blueprint_ToSql_Call) RunAndReturn(run func(schema.Grammar) []string) return _c } +// UnsignedBigInteger provides a mock function with given fields: column +func (_m *Blueprint) UnsignedBigInteger(column string) schema.ColumnDefinition { + ret := _m.Called(column) + + if len(ret) == 0 { + panic("no return value specified for UnsignedBigInteger") + } + + var r0 schema.ColumnDefinition + if rf, ok := ret.Get(0).(func(string) schema.ColumnDefinition); ok { + r0 = rf(column) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(schema.ColumnDefinition) + } + } + + return r0 +} + +// Blueprint_UnsignedBigInteger_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UnsignedBigInteger' +type Blueprint_UnsignedBigInteger_Call struct { + *mock.Call +} + +// UnsignedBigInteger is a helper method to define mock.On call +// - column string +func (_e *Blueprint_Expecter) UnsignedBigInteger(column interface{}) *Blueprint_UnsignedBigInteger_Call { + return &Blueprint_UnsignedBigInteger_Call{Call: _e.mock.On("UnsignedBigInteger", column)} +} + +func (_c *Blueprint_UnsignedBigInteger_Call) Run(run func(column string)) *Blueprint_UnsignedBigInteger_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *Blueprint_UnsignedBigInteger_Call) Return(_a0 schema.ColumnDefinition) *Blueprint_UnsignedBigInteger_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Blueprint_UnsignedBigInteger_Call) RunAndReturn(run func(string) schema.ColumnDefinition) *Blueprint_UnsignedBigInteger_Call { + _c.Call.Return(run) + return _c +} + // NewBlueprint creates a new instance of Blueprint. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewBlueprint(t interface { From 895b613114eba3e94b72ae4ca1cf43c4cc482c8e Mon Sep 17 00:00:00 2001 From: Bowen Date: Sun, 17 Nov 2024 11:50:48 +0800 Subject: [PATCH 3/5] fix test --- database/schema/schema_test.go | 24 ++++++++++++++---------- support/docker/mysql.go | 8 ++++++++ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/database/schema/schema_test.go b/database/schema/schema_test.go index d62677752..23b144015 100644 --- a/database/schema/schema_test.go +++ b/database/schema/schema_test.go @@ -27,19 +27,19 @@ func TestSchemaSuite(t *testing.T) { func (s *SchemaSuite) SetupTest() { // TODO Add other drivers - postgresDocker := docker.Postgres() - postgresQuery := gorm.NewTestQuery(postgresDocker, true) - - sqliteDocker := docker.Sqlite() - sqliteQuery := gorm.NewTestQuery(sqliteDocker, true) + //postgresDocker := docker.Postgres() + //postgresQuery := gorm.NewTestQuery(postgresDocker, true) + // + //sqliteDocker := docker.Sqlite() + //sqliteQuery := gorm.NewTestQuery(sqliteDocker, true) mysqlDocker := docker.Mysql() mysqlQuery := gorm.NewTestQuery(mysqlDocker, true) s.driverToTestQuery = map[database.Driver]*gorm.TestQuery{ - database.DriverPostgres: postgresQuery, - database.DriverSqlite: sqliteQuery, - database.DriverMysql: mysqlQuery, + //database.DriverPostgres: postgresQuery, + //database.DriverSqlite: sqliteQuery, + database.DriverMysql: mysqlQuery, } } @@ -137,10 +137,14 @@ func (s *SchemaSuite) TestPrimary() { })) s.Require().True(schema.HasTable(table)) - if driver != database.DriverSqlite { - // SQLite does not support set primary index separately + + // SQLite does not support set primary index separately + if driver == database.DriverPostgres { s.Require().True(schema.HasIndex(table, "goravel_primaries_pkey")) } + if driver == database.DriverMysql { + s.Require().True(schema.HasIndex(table, "primary")) + } }) } } diff --git a/support/docker/mysql.go b/support/docker/mysql.go index a07424ed5..7d46b905c 100644 --- a/support/docker/mysql.go +++ b/support/docker/mysql.go @@ -97,6 +97,10 @@ func (receiver *MysqlImpl) Fresh() error { return fmt.Errorf("get tables of Mysql error: %v", res.Error) } + if res := instance.Exec("SET FOREIGN_KEY_CHECKS=0;"); res.Error != nil { + return fmt.Errorf("disable foreign key check of Mysql error: %v", res.Error) + } + for _, table := range tables { res = instance.Exec(table) if res.Error != nil { @@ -104,6 +108,10 @@ func (receiver *MysqlImpl) Fresh() error { } } + if res := instance.Exec("SET FOREIGN_KEY_CHECKS=1;"); res.Error != nil { + return fmt.Errorf("enable foreign key check of Mysql error: %v", res.Error) + } + return nil } From e691fd28754de3ae8b8f2f02053c337b037c6bb1 Mon Sep 17 00:00:00 2001 From: Bowen Date: Sun, 17 Nov 2024 12:14:42 +0800 Subject: [PATCH 4/5] fix AI comments --- database/schema/grammars/mysql.go | 14 +++++++------- database/schema/grammars/mysql_test.go | 2 +- database/schema/processors/mysql_test.go | 6 +++--- database/schema/schema_test.go | 16 ++++++++-------- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/database/schema/grammars/mysql.go b/database/schema/grammars/mysql.go index cd102ba4e..3281ca547 100644 --- a/database/schema/grammars/mysql.go +++ b/database/schema/grammars/mysql.go @@ -18,18 +18,18 @@ type Mysql struct { } func NewMysql(tablePrefix string) *Mysql { - postgres := &Mysql{ + mysql := &Mysql{ attributeCommands: []string{constants.CommandComment}, serials: []string{"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger"}, wrap: NewWrap(contractsdatabase.DriverMysql, tablePrefix), } - postgres.modifiers = []func(schema.Blueprint, schema.ColumnDefinition) string{ - postgres.ModifyDefault, - postgres.ModifyIncrement, - postgres.ModifyNullable, + mysql.modifiers = []func(schema.Blueprint, schema.ColumnDefinition) string{ + mysql.ModifyDefault, + mysql.ModifyIncrement, + mysql.ModifyNullable, } - return postgres + return mysql } func (r *Mysql) CompileAdd(blueprint schema.Blueprint, command *schema.Command) string { @@ -194,7 +194,7 @@ func (r *Mysql) TypeString(column schema.ColumnDefinition) string { return fmt.Sprintf("varchar(%d)", length) } - return "varchar" + return "varchar(255)" } func (r *Mysql) getColumns(blueprint schema.Blueprint) []string { diff --git a/database/schema/grammars/mysql_test.go b/database/schema/grammars/mysql_test.go index 9b4c600df..1ee5989a7 100644 --- a/database/schema/grammars/mysql_test.go +++ b/database/schema/grammars/mysql_test.go @@ -301,5 +301,5 @@ func (s *MysqlSuite) TestTypeString() { mockColumn2 := mocksschema.NewColumnDefinition(s.T()) mockColumn2.EXPECT().GetLength().Return(0).Once() - s.Equal("varchar", s.grammar.TypeString(mockColumn2)) + s.Equal("varchar(255)", s.grammar.TypeString(mockColumn2)) } diff --git a/database/schema/processors/mysql_test.go b/database/schema/processors/mysql_test.go index 7ac53bea3..97dab50bb 100644 --- a/database/schema/processors/mysql_test.go +++ b/database/schema/processors/mysql_test.go @@ -19,15 +19,15 @@ func TestMysqlProcessIndexes(t *testing.T) { {Name: "index_b", Type: "hash", Columns: []string{"c", "d"}}, } - postgres := NewMysql() - result := postgres.ProcessIndexes(input) + mysql := NewMysql() + result := mysql.ProcessIndexes(input) assert.Equal(t, expected, result) // Test with empty input input = []DBIndex{} - result = postgres.ProcessIndexes(input) + result = mysql.ProcessIndexes(input) assert.Nil(t, result) } diff --git a/database/schema/schema_test.go b/database/schema/schema_test.go index 23b144015..fb4aeaa75 100644 --- a/database/schema/schema_test.go +++ b/database/schema/schema_test.go @@ -27,19 +27,19 @@ func TestSchemaSuite(t *testing.T) { func (s *SchemaSuite) SetupTest() { // TODO Add other drivers - //postgresDocker := docker.Postgres() - //postgresQuery := gorm.NewTestQuery(postgresDocker, true) - // - //sqliteDocker := docker.Sqlite() - //sqliteQuery := gorm.NewTestQuery(sqliteDocker, true) + postgresDocker := docker.Postgres() + postgresQuery := gorm.NewTestQuery(postgresDocker, true) + + sqliteDocker := docker.Sqlite() + sqliteQuery := gorm.NewTestQuery(sqliteDocker, true) mysqlDocker := docker.Mysql() mysqlQuery := gorm.NewTestQuery(mysqlDocker, true) s.driverToTestQuery = map[database.Driver]*gorm.TestQuery{ - //database.DriverPostgres: postgresQuery, - //database.DriverSqlite: sqliteQuery, - database.DriverMysql: mysqlQuery, + database.DriverPostgres: postgresQuery, + database.DriverSqlite: sqliteQuery, + database.DriverMysql: mysqlQuery, } } From 89a166c54420b23be3e2978604eed0a0bf2457ef Mon Sep 17 00:00:00 2001 From: Bowen Date: Sun, 17 Nov 2024 21:36:33 +0800 Subject: [PATCH 5/5] add ready method --- database/schema/schema_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/database/schema/schema_test.go b/database/schema/schema_test.go index 5cf1aafc5..4f59e0372 100644 --- a/database/schema/schema_test.go +++ b/database/schema/schema_test.go @@ -36,6 +36,8 @@ func (s *SchemaSuite) SetupTest() { sqliteQuery := gorm.NewTestQuery(sqliteDocker, true) mysqlDocker := docker.Mysql() + s.Require().NoError(mysqlDocker.Ready()) + mysqlQuery := gorm.NewTestQuery(mysqlDocker, true) s.driverToTestQuery = map[database.Driver]*gorm.TestQuery{