diff --git a/config/application_test.go b/config/application_test.go index 421bb091f..f3d29173f 100644 --- a/config/application_test.go +++ b/config/application_test.go @@ -24,8 +24,8 @@ DB_PORT=3306 `)) temp, err := os.CreateTemp("", "goravel.env") assert.Nil(t, err) - defer os.Remove(temp.Name()) defer temp.Close() + defer os.Remove(temp.Name()) _, err = temp.Write([]byte(` APP_KEY=12345678901234567890123456789012 diff --git a/database/gorm/conditions.go b/database/gorm/conditions.go new file mode 100644 index 000000000..163678c6d --- /dev/null +++ b/database/gorm/conditions.go @@ -0,0 +1,57 @@ +package gorm + +import ( + ormcontract "github.com/goravel/framework/contracts/database/orm" +) + +type Conditions struct { + distinct []any + group string + having *Having + join []Join + limit *int + lockForUpdate bool + model any + offset *int + omit []string + order []any + scopes []func(ormcontract.Query) ormcontract.Query + selectColumns *Select + sharedLock bool + table *Table + where []Where + with []With + withoutEvents bool + withTrashed bool +} + +type Having struct { + query any + args []any +} + +type Join struct { + query string + args []any +} + +type Select struct { + query any + args []any +} + +type Table struct { + name string + args []any +} + +type Where struct { + query any + args []any + or bool +} + +type With struct { + query string + args []any +} diff --git a/database/gorm/cursor.go b/database/gorm/cursor.go index 01ea48db5..dd8c6fd02 100644 --- a/database/gorm/cursor.go +++ b/database/gorm/cursor.go @@ -38,8 +38,12 @@ func (c *CursorImpl) Scan(value any) error { return err } - for relation, args := range c.query.with { - if err := c.query.origin.Load(value, relation, args...); err != nil { + for _, item := range c.query.conditions.with { + // Need to new a query, avoid to clear the conditions + query := c.query.new(c.query.instance) + // The new query must be cleared + query.clearConditions() + if err := query.Load(value, item.query, item.args...); err != nil { return err } } diff --git a/database/gorm/event.go b/database/gorm/event.go index bfeee1cf6..ebcaf5932 100644 --- a/database/gorm/event.go +++ b/database/gorm/event.go @@ -28,6 +28,102 @@ func NewEvent(query *QueryImpl, model, dest any) *Event { } } +func (e *Event) ColumnNamesWithDbColumnNames() map[string]string { + if e.columnNamesWithDbColumnNames != nil { + return e.columnNamesWithDbColumnNames + } + + res := make(map[string]string) + var modelType reflect.Type + var modelValue reflect.Value + + if e.model != nil { + modelType = reflect.TypeOf(e.model) + modelValue = reflect.ValueOf(e.model) + } else { + modelType = reflect.TypeOf(e.dest) + modelValue = reflect.ValueOf(e.dest) + } + if modelType.Kind() == reflect.Pointer { + modelType = modelType.Elem() + modelValue = modelValue.Elem() + } + + for i := 0; i < modelType.NumField(); i++ { + if !modelType.Field(i).IsExported() { + continue + } + if modelType.Field(i).Name == "Model" && modelValue.Field(i).Type().Kind() == reflect.Struct { + structField := modelValue.Field(i).Type() + for j := 0; j < structField.NumField(); j++ { + if !structField.Field(i).IsExported() { + continue + } + dbColumn := structNameToDbColumnName(structField.Field(j).Name, structField.Field(j).Tag.Get("gorm")) + res[structField.Field(j).Name] = dbColumn + res[dbColumn] = dbColumn + } + } + + dbColumn := structNameToDbColumnName(modelType.Field(i).Name, modelType.Field(i).Tag.Get("gorm")) + res[modelType.Field(i).Name] = dbColumn + res[dbColumn] = dbColumn + } + + return res +} + +func (e *Event) Context() context.Context { + return e.query.instance.Statement.Context +} + +func (e *Event) DestOfMap() map[string]any { + if e.destOfMap != nil { + return e.destOfMap + } + + var destOfMap map[string]any + if destMap, ok := e.dest.(map[string]any); ok { + destOfMap = destMap + } else { + destType := reflect.TypeOf(e.dest) + if destType.Kind() == reflect.Pointer { + destType = destType.Elem() + } + if destType.Kind() == reflect.Struct { + destOfMap = structToMap(e.dest) + } + } + + e.destOfMap = destOfMap + + return e.destOfMap +} + +func (e *Event) GetAttribute(key string) any { + destOfMap := e.DestOfMap() + value, exist := destOfMap[e.toDBColumnName(key)] + if exist && e.validColumn(key) && e.validValue(key, value) { + return value + } + + return e.GetOriginal(key) +} + +func (e *Event) GetOriginal(key string, def ...any) any { + modelOfMap := e.ModelOfMap() + value, exist := modelOfMap[e.toDBColumnName(key)] + if exist { + return value + } + + if len(def) > 0 { + return def[0] + } + + return nil +} + func (e *Event) IsDirty(columns ...string) bool { destOfMap := e.DestOfMap() @@ -63,15 +159,22 @@ func (e *Event) IsClean(fields ...string) bool { return !e.IsDirty(fields...) } -func (e *Event) Query() orm.Query { - return NewQueryImplByInstance(e.query.instance.Session(&gorm.Session{NewDB: true}), &QueryImpl{ - config: e.query.config, - withoutEvents: false, - }) +func (e *Event) ModelOfMap() map[string]any { + if e.modelOfMap != nil { + return e.modelOfMap + } + + if e.model == nil { + return map[string]any{} + } + + e.modelOfMap = structToMap(e.model) + + return e.modelOfMap } -func (e *Event) Context() context.Context { - return e.query.instance.Statement.Context +func (e *Event) Query() orm.Query { + return NewQueryImpl(e.query.ctx, e.query.config, e.query.connection, e.query.instance.Session(&gorm.Session{NewDB: true}), nil) } func (e *Event) SetAttribute(key string, value any) { @@ -113,28 +216,35 @@ func (e *Event) SetAttribute(key string, value any) { } } -func (e *Event) GetAttribute(key string) any { - destOfMap := e.DestOfMap() - value, exist := destOfMap[e.toDBColumnName(key)] - if exist && e.validColumn(key) && e.validValue(key, value) { - return value +func (e *Event) dirty(destColumn string, destValue any) bool { + modelOfMap := e.ModelOfMap() + dbDestColumn := e.toDBColumnName(destColumn) + + if modelValue, exist := modelOfMap[dbDestColumn]; exist { + return !reflect.DeepEqual(modelValue, destValue) } - return e.GetOriginal(key) + return true } -func (e *Event) GetOriginal(key string, def ...any) any { - modelOfMap := e.ModelOfMap() - value, exist := modelOfMap[e.toDBColumnName(key)] - if exist { - return value +func (e *Event) equalColumnName(origin, source string) bool { + originDbColumnName := e.toDBColumnName(origin) + sourceDbColumnName := e.toDBColumnName(source) + + if originDbColumnName == "" || sourceDbColumnName == "" { + return false } - if len(def) > 0 { - return def[0] + return originDbColumnName == sourceDbColumnName +} + +func (e *Event) toDBColumnName(name string) string { + dbColumnName, exist := e.ColumnNamesWithDbColumnNames()[name] + if exist { + return dbColumnName } - return nil + return "" } func (e *Event) validColumn(name string) bool { @@ -197,119 +307,6 @@ func (e *Event) validValue(name string, value any) bool { return !valueValue.IsZero() } -func (e *Event) dirty(destColumn string, destValue any) bool { - modelOfMap := e.ModelOfMap() - dbDestColumn := e.toDBColumnName(destColumn) - - if modelValue, exist := modelOfMap[dbDestColumn]; exist { - return !reflect.DeepEqual(modelValue, destValue) - } - - return true -} - -func (e *Event) equalColumnName(origin, source string) bool { - originDbColumnName := e.toDBColumnName(origin) - sourceDbColumnName := e.toDBColumnName(source) - - if originDbColumnName == "" || sourceDbColumnName == "" { - return false - } - - return originDbColumnName == sourceDbColumnName -} - -func (e *Event) toDBColumnName(name string) string { - dbColumnName, exist := e.ColumnNamesWithDbColumnNames()[name] - if exist { - return dbColumnName - } - - return "" -} - -func (e *Event) ModelOfMap() map[string]any { - if e.modelOfMap != nil { - return e.modelOfMap - } - - if e.model == nil { - return map[string]any{} - } - - e.modelOfMap = structToMap(e.model) - - return e.modelOfMap -} - -func (e *Event) DestOfMap() map[string]any { - if e.destOfMap != nil { - return e.destOfMap - } - - var destOfMap map[string]any - if destMap, ok := e.dest.(map[string]any); ok { - destOfMap = destMap - } else { - destType := reflect.TypeOf(e.dest) - if destType.Kind() == reflect.Pointer { - destType = destType.Elem() - } - if destType.Kind() == reflect.Struct { - destOfMap = structToMap(e.dest) - } - } - - e.destOfMap = destOfMap - - return e.destOfMap -} - -func (e *Event) ColumnNamesWithDbColumnNames() map[string]string { - if e.columnNamesWithDbColumnNames != nil { - return e.columnNamesWithDbColumnNames - } - - res := make(map[string]string) - var modelType reflect.Type - var modelValue reflect.Value - - if e.model != nil { - modelType = reflect.TypeOf(e.model) - modelValue = reflect.ValueOf(e.model) - } else { - modelType = reflect.TypeOf(e.dest) - modelValue = reflect.ValueOf(e.dest) - } - if modelType.Kind() == reflect.Pointer { - modelType = modelType.Elem() - modelValue = modelValue.Elem() - } - - for i := 0; i < modelType.NumField(); i++ { - if !modelType.Field(i).IsExported() { - continue - } - if modelType.Field(i).Name == "Model" && modelValue.Field(i).Type().Kind() == reflect.Struct { - structField := modelValue.Field(i).Type() - for j := 0; j < structField.NumField(); j++ { - if !structField.Field(i).IsExported() { - continue - } - dbColumn := structNameToDbColumnName(structField.Field(j).Name, structField.Field(j).Tag.Get("gorm")) - res[structField.Field(j).Name] = dbColumn - res[dbColumn] = dbColumn - } - } - - dbColumn := structNameToDbColumnName(modelType.Field(i).Name, modelType.Field(i).Tag.Get("gorm")) - res[modelType.Field(i).Name] = dbColumn - res[dbColumn] = dbColumn - } - - return res -} - func structToMap(data any) map[string]any { res := make(map[string]any) modelType := reflect.TypeOf(data) diff --git a/database/gorm/event_test.go b/database/gorm/event_test.go index 1ec5dc76e..c8febb446 100644 --- a/database/gorm/event_test.go +++ b/database/gorm/event_test.go @@ -21,15 +21,14 @@ type TestEventModel struct { var testNow = time.Now().Add(-1 * time.Second) var testEventModel = TestEventModel{Name: "name", Avatar: "avatar", IsAdmin: true, IsManage: 0, AdminAt: testNow, ManageAt: testNow, high: 1} -var testQuery = NewQueryImplByInstance(&gorm.DB{ - Statement: &gorm.Statement{ - Selects: []string{}, - Omits: []string{}, +var testQuery = &QueryImpl{ + instance: &gorm.DB{ + Statement: &gorm.Statement{ + Selects: []string{}, + Omits: []string{}, + }, }, -}, &QueryImpl{ - config: nil, - withoutEvents: false, -}) +} type EventTestSuite struct { suite.Suite @@ -51,16 +50,15 @@ func (s *EventTestSuite) SetupTest() { func (s *EventTestSuite) TestSetAttribute() { dest := map[string]any{"avatar": "avatar1"} - query := NewQueryImplByInstance(&gorm.DB{ - Statement: &gorm.Statement{ - Selects: []string{}, - Omits: []string{}, - Dest: dest, + query := &QueryImpl{ + instance: &gorm.DB{ + Statement: &gorm.Statement{ + Selects: []string{}, + Omits: []string{}, + Dest: dest, + }, }, - }, &QueryImpl{ - config: nil, - withoutEvents: false, - }) + } event := NewEvent(query, &testEventModel, dest) @@ -154,29 +152,27 @@ func (s *EventTestSuite) TestValidColumn() { s.True(event.validColumn("manage")) s.False(event.validColumn("age")) - event.query = NewQueryImplByInstance(&gorm.DB{ - Statement: &gorm.Statement{ - Selects: []string{"name"}, - Omits: []string{}, + event.query = &QueryImpl{ + instance: &gorm.DB{ + Statement: &gorm.Statement{ + Selects: []string{"name"}, + Omits: []string{}, + }, }, - }, &QueryImpl{ - config: nil, - withoutEvents: false, - }) + } s.True(event.validColumn("Name")) s.True(event.validColumn("name")) s.False(event.validColumn("avatar")) s.False(event.validColumn("Avatar")) - event.query = NewQueryImplByInstance(&gorm.DB{ - Statement: &gorm.Statement{ - Selects: []string{}, - Omits: []string{"name"}, + event.query = &QueryImpl{ + instance: &gorm.DB{ + Statement: &gorm.Statement{ + Selects: []string{}, + Omits: []string{"name"}, + }, }, - }, &QueryImpl{ - config: nil, - withoutEvents: false, - }) + } s.False(event.validColumn("Name")) s.False(event.validColumn("name")) s.True(event.validColumn("avatar")) diff --git a/database/gorm/query.go b/database/gorm/query.go index 78f108fa2..9894c3cb7 100644 --- a/database/gorm/query.go +++ b/database/gorm/query.go @@ -2,6 +2,7 @@ package gorm import ( "context" + "database/sql" "errors" "fmt" "reflect" @@ -15,26 +16,42 @@ import ( "gorm.io/gorm/clause" "github.com/goravel/framework/contracts/config" - gormcontract "github.com/goravel/framework/contracts/database/gorm" + "github.com/goravel/framework/contracts/database/gorm" ormcontract "github.com/goravel/framework/contracts/database/orm" "github.com/goravel/framework/database/gorm/hints" "github.com/goravel/framework/database/orm" "github.com/goravel/framework/support/database" ) -var QuerySet = wire.NewSet(NewQueryImpl, wire.Bind(new(ormcontract.Query), new(*QueryImpl))) +var QuerySet = wire.NewSet(BuildQueryImpl, wire.Bind(new(ormcontract.Query), new(*QueryImpl))) var _ ormcontract.Query = &QueryImpl{} type QueryImpl struct { - config config.Config - ctx context.Context - instance *gormio.DB - origin *QueryImpl - with map[string][]any - withoutEvents bool + conditions Conditions + config config.Config + connection string + ctx context.Context + instance *gormio.DB + queries map[string]*QueryImpl } -func NewQueryImpl(ctx context.Context, config config.Config, gorm gormcontract.Gorm) (*QueryImpl, error) { +func NewQueryImpl(ctx context.Context, config config.Config, connection string, db *gormio.DB, conditions *Conditions) *QueryImpl { + queryImpl := &QueryImpl{ + config: config, + connection: connection, + ctx: ctx, + instance: db, + queries: make(map[string]*QueryImpl), + } + + if conditions != nil { + queryImpl.conditions = *conditions + } + + return queryImpl +} + +func BuildQueryImpl(ctx context.Context, config config.Config, connection string, gorm gorm.Gorm) (*QueryImpl, error) { db, err := gorm.Make() if err != nil { return nil, err @@ -43,32 +60,19 @@ func NewQueryImpl(ctx context.Context, config config.Config, gorm gormcontract.G db = db.WithContext(ctx) } - return &QueryImpl{ - instance: db, - config: config, - ctx: ctx, - }, nil -} - -func NewQueryImplByInstance(db *gormio.DB, instance *QueryImpl) *QueryImpl { - queryImpl := &QueryImpl{config: instance.config, ctx: db.Statement.Context, instance: db, origin: instance.origin, with: instance.with, withoutEvents: instance.withoutEvents} - - // The origin is used by the With method to load the relationship. - if instance.origin == nil && instance.instance != nil { - queryImpl.origin = instance - } - - return queryImpl + return NewQueryImpl(ctx, config, connection, db, nil), nil } func (r *QueryImpl) Association(association string) ormcontract.Association { - return r.instance.Association(association) + query := r.buildConditions() + + return query.instance.Association(association) } func (r *QueryImpl) Begin() (ormcontract.Transaction, error) { tx := r.instance.Begin() - return NewTransaction(tx, r.config), tx.Error + return NewTransaction(tx, r.config, r.connection), tx.Error } func (r *QueryImpl) Driver() ormcontract.Driver { @@ -76,33 +80,43 @@ func (r *QueryImpl) Driver() ormcontract.Driver { } func (r *QueryImpl) Count(count *int64) error { - return r.instance.Count(count).Error + query := r.buildConditions() + + return query.instance.Count(count).Error } func (r *QueryImpl) Create(value any) error { - if err := r.refreshConnection(value); err != nil { + query, err := r.refreshConnection(value) + if err != nil { return err } - if len(r.instance.Statement.Selects) > 0 && len(r.instance.Statement.Omits) > 0 { + query = query.buildConditions() + + if len(query.instance.Statement.Selects) > 0 && len(query.instance.Statement.Omits) > 0 { return errors.New("cannot set Select and Omits at the same time") } - if len(r.instance.Statement.Selects) > 0 { - return r.selectCreate(value) + if len(query.instance.Statement.Selects) > 0 { + return query.selectCreate(value) } - if len(r.instance.Statement.Omits) > 0 { - return r.omitCreate(value) + if len(query.instance.Statement.Omits) > 0 { + return query.omitCreate(value) } - return r.create(value) + return query.create(value) } func (r *QueryImpl) Cursor() (chan ormcontract.Cursor, error) { + with := r.conditions.with + query := r.buildConditions() + r.conditions.with = with + var err error cursorChan := make(chan ormcontract.Cursor) go func() { - rows, err := r.instance.Rows() + var rows *sql.Rows + rows, err = query.instance.Rows() if err != nil { return } @@ -110,11 +124,11 @@ func (r *QueryImpl) Cursor() (chan ormcontract.Cursor, error) { for rows.Next() { val := make(map[string]any) - err := r.instance.ScanRows(rows, val) + err = query.instance.ScanRows(rows, val) if err != nil { return } - cursorChan <- &CursorImpl{row: val, query: r} + cursorChan <- &CursorImpl{query: r, row: val} } close(cursorChan) }() @@ -122,19 +136,22 @@ func (r *QueryImpl) Cursor() (chan ormcontract.Cursor, error) { } func (r *QueryImpl) Delete(dest any, conds ...any) (*ormcontract.Result, error) { - if err := r.refreshConnection(dest); err != nil { + query, err := r.refreshConnection(dest) + if err != nil { return nil, err } - if err := r.deleting(dest); err != nil { + query = query.buildConditions() + + if err := query.deleting(dest); err != nil { return nil, err } - res := r.instance.Delete(dest, conds...) + res := query.instance.Delete(dest, conds...) if res.Error != nil { return nil, res.Error } - if err := r.deleted(dest); err != nil { + if err := query.deleted(dest); err != nil { return nil, err } @@ -144,13 +161,15 @@ func (r *QueryImpl) Delete(dest any, conds ...any) (*ormcontract.Result, error) } func (r *QueryImpl) Distinct(args ...any) ormcontract.Query { - tx := r.instance.Distinct(args...) + conditions := r.conditions + conditions.distinct = append(conditions.distinct, args...) - return NewQueryImplByInstance(tx, r) + return r.setConditions(conditions) } func (r *QueryImpl) Exec(sql string, values ...any) (*ormcontract.Result, error) { - result := r.instance.Exec(sql, values...) + query := r.buildConditions() + result := query.instance.Exec(sql, values...) return &ormcontract.Result{ RowsAffected: result.RowsAffected, @@ -158,32 +177,42 @@ func (r *QueryImpl) Exec(sql string, values ...any) (*ormcontract.Result, error) } func (r *QueryImpl) Exists(exists *bool) error { - return r.instance.Select("1").Limit(1).Find(exists).Error + query := r.buildConditions() + + return query.instance.Select("1").Limit(1).Find(exists).Error } func (r *QueryImpl) Find(dest any, conds ...any) error { - if err := r.refreshConnection(dest); err != nil { + query, err := r.refreshConnection(dest) + if err != nil { return err } + + query = query.buildConditions() + if err := filterFindConditions(conds...); err != nil { return err } - if err := r.instance.Find(dest, conds...).Error; err != nil { + if err := query.instance.Find(dest, conds...).Error; err != nil { return err } - return r.retrieved(dest) + return query.retrieved(dest) } func (r *QueryImpl) FindOrFail(dest any, conds ...any) error { - if err := r.refreshConnection(dest); err != nil { + query, err := r.refreshConnection(dest) + if err != nil { return err } + + query = query.buildConditions() + if err := filterFindConditions(conds...); err != nil { return err } - res := r.instance.Find(dest, conds...) + res := query.instance.Find(dest, conds...) if err := res.Error; err != nil { return err } @@ -192,14 +221,18 @@ func (r *QueryImpl) FindOrFail(dest any, conds ...any) error { return orm.ErrRecordNotFound } - return r.retrieved(dest) + return query.retrieved(dest) } func (r *QueryImpl) First(dest any) error { - if err := r.refreshConnection(dest); err != nil { + query, err := r.refreshConnection(dest) + if err != nil { return err } - res := r.instance.First(dest) + + query = query.buildConditions() + + res := query.instance.First(dest) if res.Error != nil { if errors.Is(res.Error, gormio.ErrRecordNotFound) { return nil @@ -208,15 +241,18 @@ func (r *QueryImpl) First(dest any) error { return res.Error } - return r.retrieved(dest) + return query.retrieved(dest) } func (r *QueryImpl) FirstOr(dest any, callback func() error) error { - if err := r.refreshConnection(dest); err != nil { + query, err := r.refreshConnection(dest) + if err != nil { return err } - err := r.instance.First(dest).Error - if err != nil { + + query = query.buildConditions() + + if err := query.instance.First(dest).Error; err != nil { if errors.Is(err, gormio.ErrRecordNotFound) { return callback() } @@ -224,40 +260,47 @@ func (r *QueryImpl) FirstOr(dest any, callback func() error) error { return err } - return r.retrieved(dest) + return query.retrieved(dest) } func (r *QueryImpl) FirstOrCreate(dest any, conds ...any) error { - if err := r.refreshConnection(dest); err != nil { + query, err := r.refreshConnection(dest) + if err != nil { return err } + + query = query.buildConditions() + if len(conds) == 0 { return errors.New("query condition is require") } var res *gormio.DB if len(conds) > 1 { - res = r.instance.Attrs(conds[1]).FirstOrInit(dest, conds[0]) + res = query.instance.Attrs(conds[1]).FirstOrInit(dest, conds[0]) } else { - res = r.instance.FirstOrInit(dest, conds[0]) + res = query.instance.FirstOrInit(dest, conds[0]) } if res.Error != nil { return res.Error } if res.RowsAffected > 0 { - return r.retrieved(dest) + return query.retrieved(dest) } - return r.Create(dest) + return query.Create(dest) } func (r *QueryImpl) FirstOrFail(dest any) error { - if err := r.refreshConnection(dest); err != nil { + query, err := r.refreshConnection(dest) + if err != nil { return err } - err := r.instance.First(dest).Error - if err != nil { + + query = query.buildConditions() + + if err := query.instance.First(dest).Error; err != nil { if errors.Is(err, gormio.ErrRecordNotFound) { return orm.ErrRecordNotFound } @@ -265,45 +308,53 @@ func (r *QueryImpl) FirstOrFail(dest any) error { return err } - return r.retrieved(dest) + return query.retrieved(dest) } func (r *QueryImpl) FirstOrNew(dest any, attributes any, values ...any) error { - if err := r.refreshConnection(dest); err != nil { + query, err := r.refreshConnection(dest) + if err != nil { return err } + + query = query.buildConditions() + var res *gormio.DB if len(values) > 0 { - res = r.instance.Attrs(values[0]).FirstOrInit(dest, attributes) + res = query.instance.Attrs(values[0]).FirstOrInit(dest, attributes) } else { - res = r.instance.FirstOrInit(dest, attributes) + res = query.instance.FirstOrInit(dest, attributes) } if res.Error != nil { return res.Error } if res.RowsAffected > 0 { - return r.retrieved(dest) + return query.retrieved(dest) } return nil } func (r *QueryImpl) ForceDelete(value any, conds ...any) (*ormcontract.Result, error) { - if err := r.refreshConnection(value); err != nil { + query, err := r.refreshConnection(value) + if err != nil { return nil, err } - if err := r.forceDeleting(value); err != nil { + + query = query.buildConditions() + + if err := query.forceDeleting(value); err != nil { return nil, err } - res := r.instance.Unscoped().Delete(value, conds...) + res := query.instance.Unscoped().Delete(value, conds...) if res.Error != nil { return nil, res.Error } if res.RowsAffected > 0 { - if err := r.forceDeleted(value); err != nil { + if err := query.forceDeleted(value); err != nil { return nil, err } } @@ -318,15 +369,20 @@ func (r *QueryImpl) Get(dest any) error { } func (r *QueryImpl) Group(name string) ormcontract.Query { - tx := r.instance.Group(name) + conditions := r.conditions + conditions.group = name - return NewQueryImplByInstance(tx, r) + return r.setConditions(conditions) } func (r *QueryImpl) Having(query any, args ...any) ormcontract.Query { - tx := r.instance.Having(query, args...) + conditions := r.conditions + conditions.having = &Having{ + query: query, + args: args, + } - return NewQueryImplByInstance(tx, r) + return r.setConditions(conditions) } func (r *QueryImpl) Instance() *gormio.DB { @@ -334,14 +390,20 @@ func (r *QueryImpl) Instance() *gormio.DB { } func (r *QueryImpl) Join(query string, args ...any) ormcontract.Query { - tx := r.instance.Joins(query, args...) - return NewQueryImplByInstance(tx, r) + conditions := r.conditions + conditions.join = append(conditions.join, Join{ + query: query, + args: args, + }) + + return r.setConditions(conditions) } func (r *QueryImpl) Limit(limit int) ormcontract.Query { - tx := r.instance.Limit(limit) + conditions := r.conditions + conditions.limit = &limit - return NewQueryImplByInstance(tx, r) + return r.setConditions(conditions) } func (r *QueryImpl) Load(model any, relation string, args ...any) error { @@ -410,49 +472,38 @@ func (r *QueryImpl) LoadMissing(model any, relation string, args ...any) error { } func (r *QueryImpl) LockForUpdate() ormcontract.Query { - driver := r.instance.Name() - mysqlDialector := mysql.Dialector{} - postgresqlDialector := postgres.Dialector{} - sqlserverDialector := sqlserver.Dialector{} - - if driver == mysqlDialector.Name() || driver == postgresqlDialector.Name() { - tx := r.instance.Clauses(clause.Locking{Strength: "UPDATE"}) - - return NewQueryImplByInstance(tx, r) - } else if driver == sqlserverDialector.Name() { - tx := r.instance.Clauses(hints.With("rowlock", "updlock", "holdlock")) + conditions := r.conditions + conditions.lockForUpdate = true - return NewQueryImplByInstance(tx, r) - } - - return r + return r.setConditions(conditions) } func (r *QueryImpl) Model(value any) ormcontract.Query { - if err := r.refreshConnection(value); err != nil { - return nil - } - tx := r.instance.Model(value) + conditions := r.conditions + conditions.model = value - return NewQueryImplByInstance(tx, r) + return r.setConditions(conditions) } func (r *QueryImpl) Offset(offset int) ormcontract.Query { - tx := r.instance.Offset(offset) + conditions := r.conditions + conditions.offset = &offset - return NewQueryImplByInstance(tx, r) + return r.setConditions(conditions) } func (r *QueryImpl) Omit(columns ...string) ormcontract.Query { - tx := r.instance.Omit(columns...) + conditions := r.conditions + conditions.omit = columns - return NewQueryImplByInstance(tx, r) + return r.setConditions(conditions) } func (r *QueryImpl) Order(value any) ormcontract.Query { - tx := r.instance.Order(value) + conditions := r.conditions + conditions.order = append(r.conditions.order, value) - return NewQueryImplByInstance(tx, r) + return r.setConditions(conditions) } func (r *QueryImpl) OrderBy(column string, direction ...string) ormcontract.Query { @@ -485,87 +536,103 @@ func (r *QueryImpl) InRandomOrder() ormcontract.Query { } func (r *QueryImpl) OrWhere(query any, args ...any) ormcontract.Query { - tx := r.instance.Or(query, args...) + conditions := r.conditions + conditions.where = append(r.conditions.where, Where{ + query: query, + args: args, + or: true, + }) - return NewQueryImplByInstance(tx, r) + return r.setConditions(conditions) } func (r *QueryImpl) Paginate(page, limit int, dest any, total *int64) error { + query, err := r.refreshConnection(dest) + if err != nil { + return err + } + + query = query.buildConditions() + offset := (page - 1) * limit if total != nil { - if r.instance.Statement.Table == "" && r.instance.Statement.Model == nil { - if err := r.Model(dest).Count(total); err != nil { + if query.conditions.table == nil && query.conditions.model == nil { + if err := query.Model(dest).Count(total); err != nil { return err } } else { - if err := r.Count(total); err != nil { + if err := query.Count(total); err != nil { return err } } } - return r.Offset(offset).Limit(limit).Find(dest) + return query.Offset(offset).Limit(limit).Find(dest) } func (r *QueryImpl) Pluck(column string, dest any) error { - return r.instance.Pluck(column, dest).Error + query := r.buildConditions() + + return query.instance.Pluck(column, dest).Error } func (r *QueryImpl) Raw(sql string, values ...any) ormcontract.Query { - tx := r.instance.Raw(sql, values...) - - return NewQueryImplByInstance(tx, r) + return r.new(r.instance.Raw(sql, values...)) } func (r *QueryImpl) Save(value any) error { - if err := r.refreshConnection(value); err != nil { + query, err := r.refreshConnection(value) + if err != nil { return err } - if len(r.instance.Statement.Selects) > 0 && len(r.instance.Statement.Omits) > 0 { + + query = query.buildConditions() + + if len(query.instance.Statement.Selects) > 0 && len(query.instance.Statement.Omits) > 0 { return errors.New("cannot set Select and Omits at the same time") } - model := r.instance.Statement.Model + model := query.instance.Statement.Model id := database.GetID(value) update := id != nil - if err := r.saving(model, value); err != nil { + if err := query.saving(model, value); err != nil { return err } if update { - if err := r.updating(model, value); err != nil { + if err := query.updating(model, value); err != nil { return err } } else { - if err := r.creating(value); err != nil { + if err := query.creating(value); err != nil { return err } } - if len(r.instance.Statement.Selects) > 0 { - if err := r.selectSave(value); err != nil { + if len(query.instance.Statement.Selects) > 0 { + if err := query.selectSave(value); err != nil { return err } - } else if len(r.instance.Statement.Omits) > 0 { - if err := r.omitSave(value); err != nil { + } else if len(query.instance.Statement.Omits) > 0 { + if err := query.omitSave(value); err != nil { return err } } else { - if err := r.save(value); err != nil { + if err := query.save(value); err != nil { return err } } if update { - if err := r.updated(model, value); err != nil { + if err := query.updated(model, value); err != nil { return err } } else { - if err := r.created(value); err != nil { + if err := query.created(value); err != nil { return err } } - if err := r.saved(model, value); err != nil { + if err := query.saved(model, value); err != nil { return err } @@ -577,98 +644,93 @@ func (r *QueryImpl) SaveQuietly(value any) error { } func (r *QueryImpl) Scan(dest any) error { - if err := r.refreshConnection(dest); err != nil { + query, err := r.refreshConnection(dest) + if err != nil { return err } - return r.instance.Scan(dest).Error -} - -func (r *QueryImpl) Select(query any, args ...any) ormcontract.Query { - tx := r.instance.Select(query, args...) + query = query.buildConditions() - return NewQueryImplByInstance(tx, r) + return query.instance.Scan(dest).Error } func (r *QueryImpl) Scopes(funcs ...func(ormcontract.Query) ormcontract.Query) ormcontract.Query { - var gormFuncs []func(*gormio.DB) *gormio.DB - for _, item := range funcs { - gormFuncs = append(gormFuncs, func(tx *gormio.DB) *gormio.DB { - item(NewQueryImplByInstance(tx, r)) + conditions := r.conditions + conditions.scopes = append(r.conditions.scopes, funcs...) - return tx - }) - } + return r.setConditions(conditions) +} - tx := r.instance.Scopes(gormFuncs...) +func (r *QueryImpl) Select(query any, args ...any) ormcontract.Query { + conditions := r.conditions + conditions.selectColumns = &Select{ + query: query, + args: args, + } - return NewQueryImplByInstance(tx, r) + return r.setConditions(conditions) } func (r *QueryImpl) SharedLock() ormcontract.Query { - driver := r.instance.Name() - mysqlDialector := mysql.Dialector{} - postgresqlDialector := postgres.Dialector{} - sqlserverDialector := sqlserver.Dialector{} - - if driver == mysqlDialector.Name() || driver == postgresqlDialector.Name() { - tx := r.instance.Clauses(clause.Locking{Strength: "SHARE"}) - - return NewQueryImplByInstance(tx, r) - } else if driver == sqlserverDialector.Name() { - tx := r.instance.Clauses(hints.With("rowlock", "holdlock")) - - return NewQueryImplByInstance(tx, r) - } + conditions := r.conditions + conditions.sharedLock = true - return r + return r.setConditions(conditions) } func (r *QueryImpl) Sum(column string, dest any) error { - return r.instance.Select("SUM(" + column + ")").Row().Scan(dest) + query := r.buildConditions() + + return query.instance.Select("SUM(" + column + ")").Row().Scan(dest) } func (r *QueryImpl) Table(name string, args ...any) ormcontract.Query { - tx := r.instance.Table(name, args...) + conditions := r.conditions + conditions.table = &Table{ + name: name, + args: args, + } - return NewQueryImplByInstance(tx, r) + return r.setConditions(conditions) } func (r *QueryImpl) Update(column any, value ...any) (*ormcontract.Result, error) { + query := r.buildConditions() + if _, ok := column.(string); !ok && len(value) > 0 { return nil, errors.New("parameter error, please check the document") } var singleUpdate bool - model := r.instance.Statement.Model + model := query.instance.Statement.Model if model != nil { id := database.GetID(model) singleUpdate = id != nil } if c, ok := column.(string); ok && len(value) > 0 { - r.instance.Statement.Dest = map[string]any{c: value[0]} + query.instance.Statement.Dest = map[string]any{c: value[0]} } if len(value) == 0 { - r.instance.Statement.Dest = column + query.instance.Statement.Dest = column } if singleUpdate { - if err := r.saving(model, r.instance.Statement.Dest); err != nil { + if err := query.saving(model, query.instance.Statement.Dest); err != nil { return nil, err } - if err := r.updating(model, r.instance.Statement.Dest); err != nil { + if err := query.updating(model, query.instance.Statement.Dest); err != nil { return nil, err } } - res, err := r.updates(r.instance.Statement.Dest) + res, err := query.updates(query.instance.Statement.Dest) if singleUpdate && err == nil { - if err := r.updated(model, r.instance.Statement.Dest); err != nil { + if err := query.updated(model, query.instance.Statement.Dest); err != nil { return nil, err } - if err := r.saved(model, r.instance.Statement.Dest); err != nil { + if err := query.saved(model, query.instance.Statement.Dest); err != nil { return nil, err } } @@ -677,24 +739,32 @@ func (r *QueryImpl) Update(column any, value ...any) (*ormcontract.Result, error } func (r *QueryImpl) UpdateOrCreate(dest any, attributes any, values any) error { - if err := r.refreshConnection(dest); err != nil { + query, err := r.refreshConnection(dest) + if err != nil { return err } - res := r.instance.Assign(values).FirstOrInit(dest, attributes) + + query = query.buildConditions() + + res := query.instance.Assign(values).FirstOrInit(dest, attributes) if res.Error != nil { return res.Error } if res.RowsAffected > 0 { - return r.Save(dest) + return query.Save(dest) } - return r.Create(dest) + return query.Create(dest) } func (r *QueryImpl) Where(query any, args ...any) ormcontract.Query { - tx := r.instance.Where(query, args...) + conditions := r.conditions + conditions.where = append(r.conditions.where, Where{ + query: query, + args: args, + }) - return NewQueryImplByInstance(tx, r) + return r.setConditions(conditions) } func (r *QueryImpl) WhereIn(column string, values []any) ormcontract.Query { @@ -737,94 +807,316 @@ func (r *QueryImpl) WhereNotNull(column string) ormcontract.Query { return r.Where(fmt.Sprintf("%s IS NOT NULL", column)) } -func (r *QueryImpl) WithoutEvents() ormcontract.Query { - return NewQueryImplByInstance(r.instance, &QueryImpl{ - config: r.config, - withoutEvents: true, +func (r *QueryImpl) With(query string, args ...any) ormcontract.Query { + conditions := r.conditions + conditions.with = append(r.conditions.with, With{ + query: query, + args: args, }) + + return r.setConditions(conditions) +} + +func (r *QueryImpl) WithoutEvents() ormcontract.Query { + conditions := r.conditions + conditions.withoutEvents = true + + return r.setConditions(conditions) } func (r *QueryImpl) WithTrashed() ormcontract.Query { - tx := r.instance.Unscoped() + conditions := r.conditions + conditions.withTrashed = true - return NewQueryImplByInstance(tx, r) + return r.setConditions(conditions) } -func (r *QueryImpl) With(query string, args ...any) ormcontract.Query { - if len(args) == 1 { - switch arg := args[0].(type) { - case func(ormcontract.Query) ormcontract.Query: - newArgs := []any{ - func(tx *gormio.DB) *gormio.DB { - query := arg(NewQueryImplByInstance(tx, r)) - - return query.(*QueryImpl).instance - }, - } +func (r *QueryImpl) buildConditions() *QueryImpl { + query := r.buildModel() + db := query.instance + db = query.buildDistinct(db) + db = query.buildGroup(db) + db = query.buildHaving(db) + db = query.buildJoin(db) + db = query.buildLockForUpdate(db) + db = query.buildLimit(db) + db = query.buildOrder(db) + db = query.buildOffset(db) + db = query.buildOmit(db) + db = query.buildScopes(db) + db = query.buildSelectColumns(db) + db = query.buildSharedLock(db) + db = query.buildTable(db) + db = query.buildWith(db) + db = query.buildWithTrashed(db) + db = query.buildWhere(db) - tx := r.instance.Preload(query, newArgs...) + return query.new(db) +} - return NewQueryImplByInstance(tx, r) - } +func (r *QueryImpl) buildDistinct(db *gormio.DB) *gormio.DB { + if len(r.conditions.distinct) == 0 { + return db } - tx := r.instance.Preload(query, args...) + db = db.Distinct(r.conditions.distinct...) + r.conditions.distinct = nil + + return db +} - queryImpl := NewQueryImplByInstance(tx, r) - if queryImpl.with == nil { - queryImpl.with = make(map[string][]any) +func (r *QueryImpl) buildGroup(db *gormio.DB) *gormio.DB { + if r.conditions.group == "" { + return db } - queryImpl.with[query] = args + db = db.Group(r.conditions.group) + r.conditions.group = "" - return queryImpl + return db } -func (r *QueryImpl) refreshConnection(value any) error { - model, ok := value.(ormcontract.ConnectionModel) - if !ok { +func (r *QueryImpl) buildHaving(db *gormio.DB) *gormio.DB { + if r.conditions.having == nil { + return db + } + + db = db.Having(r.conditions.having.query, r.conditions.having.args...) + r.conditions.having = nil + + return db +} + +func (r *QueryImpl) buildJoin(db *gormio.DB) *gormio.DB { + if r.conditions.join == nil { + return db + } + + for _, item := range r.conditions.join { + db = db.Joins(item.query, item.args...) + } + + r.conditions.join = nil + + return db +} + +func (r *QueryImpl) buildLimit(db *gormio.DB) *gormio.DB { + if r.conditions.limit == nil { + return db + } + + db = db.Limit(*r.conditions.limit) + r.conditions.limit = nil + + return db +} + +func (r *QueryImpl) buildLockForUpdate(db *gormio.DB) *gormio.DB { + if !r.conditions.lockForUpdate { + return db + } + + driver := r.instance.Name() + mysqlDialector := mysql.Dialector{} + postgresqlDialector := postgres.Dialector{} + sqlserverDialector := sqlserver.Dialector{} + + if driver == mysqlDialector.Name() || driver == postgresqlDialector.Name() { + return db.Clauses(clause.Locking{Strength: "UPDATE"}) + } else if driver == sqlserverDialector.Name() { + return db.Clauses(hints.With("rowlock", "updlock", "holdlock")) + } + + r.conditions.lockForUpdate = false + + return db +} + +func (r *QueryImpl) buildModel() *QueryImpl { + if r.conditions.model == nil { + return r + } + + query, err := r.refreshConnection(r.conditions.model) + if err != nil { return nil } - conn := model.Connection() - if conn == "" { - conn = r.config.GetString("database.default") + + return query.new(query.instance.Model(r.conditions.model)) +} + +func (r *QueryImpl) buildOffset(db *gormio.DB) *gormio.DB { + if r.conditions.offset == nil { + return db } - driver := driver2gorm(r.config.GetString(fmt.Sprintf("database.connections.%s.driver", conn))) - if driver == "" { - return fmt.Errorf("connection %s driver is not supported", conn) + + db = db.Offset(*r.conditions.offset) + r.conditions.offset = nil + + return db +} + +func (r *QueryImpl) buildOmit(db *gormio.DB) *gormio.DB { + if len(r.conditions.omit) == 0 { + return db } - // if a driver is not the same, we need to refresh the connection - if driver != r.instance.Name() { - query, err := InitializeQuery(r.ctx, r.config, conn) - if err != nil { - return err - } - dbInstance := query.instance - stmt := r.instance.Statement - stmt.DB = dbInstance.Statement.DB - stmt.ConnPool = dbInstance.ConnPool - if r.ctx != nil { - dbInstance = dbInstance.WithContext(r.ctx) + + db = db.Omit(r.conditions.omit...) + r.conditions.omit = nil + + return db +} + +func (r *QueryImpl) buildOrder(db *gormio.DB) *gormio.DB { + if len(r.conditions.order) == 0 { + return db + } + + for _, order := range r.conditions.order { + db = db.Order(order) + } + + r.conditions.order = nil + + return db +} + +func (r *QueryImpl) buildSelectColumns(db *gormio.DB) *gormio.DB { + if r.conditions.selectColumns == nil { + return db + } + + db = db.Select(r.conditions.selectColumns.query, r.conditions.selectColumns.args...) + r.conditions.selectColumns = nil + + return db +} + +func (r *QueryImpl) buildScopes(db *gormio.DB) *gormio.DB { + if len(r.conditions.scopes) == 0 { + return db + } + + var gormFuncs []func(*gormio.DB) *gormio.DB + for _, scope := range r.conditions.scopes { + gormFuncs = append(gormFuncs, func(tx *gormio.DB) *gormio.DB { + queryImpl := r.new(tx) + query := scope(queryImpl) + queryImpl = query.(*QueryImpl) + queryImpl = queryImpl.buildConditions() + + return queryImpl.instance + }) + } + + db = db.Scopes(gormFuncs...) + r.conditions.scopes = nil + + return db +} + +func (r *QueryImpl) buildSharedLock(db *gormio.DB) *gormio.DB { + if !r.conditions.sharedLock { + return db + } + + driver := r.instance.Name() + mysqlDialector := mysql.Dialector{} + postgresqlDialector := postgres.Dialector{} + sqlserverDialector := sqlserver.Dialector{} + + if driver == mysqlDialector.Name() || driver == postgresqlDialector.Name() { + return db.Clauses(clause.Locking{Strength: "SHARE"}) + } else if driver == sqlserverDialector.Name() { + return db.Clauses(hints.With("rowlock", "holdlock")) + } + + r.conditions.sharedLock = false + + return db +} + +func (r *QueryImpl) buildTable(db *gormio.DB) *gormio.DB { + if r.conditions.table == nil { + return db + } + + db = db.Table(r.conditions.table.name, r.conditions.table.args...) + r.conditions.table = nil + + return db +} + +func (r *QueryImpl) buildWhere(db *gormio.DB) *gormio.DB { + if len(r.conditions.where) == 0 { + return db + } + + for _, item := range r.conditions.where { + if item.or { + db = db.Or(item.query, item.args...) + } else { + db = db.Where(item.query, item.args...) } - dbInstance.Statement = stmt - r.instance = dbInstance } - return nil + + r.conditions.where = nil + + return db } -func (r *QueryImpl) selectCreate(value any) error { - if len(r.instance.Statement.Selects) > 1 { - for _, val := range r.instance.Statement.Selects { - if val == orm.Associations { - return errors.New("cannot set orm.Associations and other fields at the same time") +func (r *QueryImpl) buildWith(db *gormio.DB) *gormio.DB { + if len(r.conditions.with) == 0 { + return db + } + + for _, item := range r.conditions.with { + isSet := false + if len(item.args) == 1 { + if arg, ok := item.args[0].(func(ormcontract.Query) ormcontract.Query); ok { + newArgs := []any{ + func(tx *gormio.DB) *gormio.DB { + queryImpl := NewQueryImpl(r.ctx, r.config, r.connection, tx, nil) + query := arg(queryImpl) + queryImpl = query.(*QueryImpl) + queryImpl = queryImpl.buildConditions() + + return queryImpl.instance + }, + } + + db = db.Preload(item.query, newArgs...) + isSet = true } } + + if !isSet { + db = db.Preload(item.query, item.args...) + } } - if len(r.instance.Statement.Selects) == 1 && r.instance.Statement.Selects[0] == orm.Associations { - r.instance.Statement.Selects = []string{} + r.conditions.with = nil + + return db +} + +func (r *QueryImpl) buildWithTrashed(db *gormio.DB) *gormio.DB { + if !r.conditions.withTrashed { + return db } + db = db.Unscoped() + r.conditions.withTrashed = false + + return db +} + +func (r *QueryImpl) clearConditions() { + r.conditions = Conditions{} +} + +func (r *QueryImpl) create(value any) error { if err := r.saving(nil, value); err != nil { return err } @@ -832,7 +1124,7 @@ func (r *QueryImpl) selectCreate(value any) error { return err } - if err := r.instance.Create(value).Error; err != nil { + if err := r.instance.Omit(orm.Associations).Create(value).Error; err != nil { return err } @@ -846,6 +1138,81 @@ func (r *QueryImpl) selectCreate(value any) error { return nil } +func (r *QueryImpl) created(dest any) error { + return r.event(ormcontract.EventCreated, nil, dest) +} + +func (r *QueryImpl) creating(dest any) error { + return r.event(ormcontract.EventCreating, nil, dest) +} + +func (r *QueryImpl) deleting(dest any) error { + return r.event(ormcontract.EventDeleting, nil, dest) +} + +func (r *QueryImpl) deleted(dest any) error { + return r.event(ormcontract.EventDeleted, nil, dest) +} + +func (r *QueryImpl) forceDeleting(dest any) error { + return r.event(ormcontract.EventForceDeleting, nil, dest) +} + +func (r *QueryImpl) forceDeleted(dest any) error { + return r.event(ormcontract.EventForceDeleted, nil, dest) +} + +func (r *QueryImpl) event(event ormcontract.EventType, model, dest any) error { + if r.conditions.withoutEvents { + return nil + } + + instance := NewEvent(r, model, dest) + + if dispatchesEvents, exist := dest.(ormcontract.DispatchesEvents); exist { + if event, exist := dispatchesEvents.DispatchesEvents()[event]; exist { + return event(instance) + } + + return nil + } + if model != nil { + if dispatchesEvents, exist := model.(ormcontract.DispatchesEvents); exist { + if event, exist := dispatchesEvents.DispatchesEvents()[event]; exist { + return event(instance) + } + } + + return nil + } + + if observer := observer(dest); observer != nil { + if observerEvent := observerEvent(event, observer); observerEvent != nil { + return observerEvent(instance) + } + + return nil + } + + if model != nil { + if observer := observer(model); observer != nil { + if observerEvent := observerEvent(event, observer); observerEvent != nil { + return observerEvent(instance) + } + + return nil + } + } + + return nil +} + +func (r *QueryImpl) new(db *gormio.DB) *QueryImpl { + query := NewQueryImpl(r.ctx, r.config, r.connection, db, &r.conditions) + + return query +} + func (r *QueryImpl) omitCreate(value any) error { if len(r.instance.Statement.Omits) > 1 { for _, val := range r.instance.Statement.Omits { @@ -886,7 +1253,73 @@ func (r *QueryImpl) omitCreate(value any) error { return nil } -func (r *QueryImpl) create(value any) error { +func (r *QueryImpl) omitSave(value any) error { + for _, val := range r.instance.Statement.Omits { + if val == orm.Associations { + return r.instance.Omit(orm.Associations).Save(value).Error + } + } + + return r.instance.Save(value).Error +} + +func (r *QueryImpl) refreshConnection(model any) (*QueryImpl, error) { + connection, err := getModelConnection(model) + if err != nil { + return nil, err + } + if connection == "" || connection == r.connection { + return r, nil + } + + query, ok := r.queries[connection] + if !ok { + var err error + query, err = InitializeQuery(r.ctx, r.config, connection) + if err != nil { + return nil, err + } + + if r.queries == nil { + r.queries = make(map[string]*QueryImpl) + } + r.queries[connection] = query + } + + query.conditions = r.conditions + + return query, nil +} + +func (r *QueryImpl) retrieved(dest any) error { + return r.event(ormcontract.EventRetrieved, nil, dest) +} + +func (r *QueryImpl) save(value any) error { + return r.instance.Omit(orm.Associations).Save(value).Error +} + +func (r *QueryImpl) saved(model, dest any) error { + return r.event(ormcontract.EventSaved, model, dest) +} + +func (r *QueryImpl) saving(model, dest any) error { + return r.event(ormcontract.EventSaving, model, dest) +} + +func (r *QueryImpl) selectCreate(value any) error { + if len(r.instance.Statement.Selects) > 1 { + for _, val := range r.instance.Statement.Selects { + if val == orm.Associations { + return errors.New("cannot set orm.Associations and other fields at the same time") + } + } + } + + if len(r.instance.Statement.Selects) == 1 && r.instance.Statement.Selects[0] == orm.Associations { + r.instance.Statement.Selects = []string{} + } + if err := r.saving(nil, value); err != nil { return err } @@ -894,7 +1327,7 @@ func (r *QueryImpl) create(value any) error { return err } - if err := r.instance.Omit(orm.Associations).Create(value).Error; err != nil { + if err := r.instance.Create(value).Error; err != nil { return err } @@ -922,18 +1355,19 @@ func (r *QueryImpl) selectSave(value any) error { return nil } -func (r *QueryImpl) omitSave(value any) error { - for _, val := range r.instance.Statement.Omits { - if val == orm.Associations { - return r.instance.Omit(orm.Associations).Save(value).Error - } - } +func (r *QueryImpl) setConditions(conditions Conditions) *QueryImpl { + query := r.new(r.instance) + query.conditions = conditions - return r.instance.Save(value).Error + return query } -func (r *QueryImpl) save(value any) error { - return r.instance.Omit(orm.Associations).Save(value).Error +func (r *QueryImpl) updating(model, dest any) error { + return r.event(ormcontract.EventUpdating, model, dest) +} + +func (r *QueryImpl) updated(model, dest any) error { + return r.event(ormcontract.EventUpdated, model, dest) } func (r *QueryImpl) updates(values any) (*ormcontract.Result, error) { @@ -981,95 +1415,6 @@ func (r *QueryImpl) updates(values any) (*ormcontract.Result, error) { }, result.Error } -func (r *QueryImpl) retrieved(dest any) error { - return r.event(ormcontract.EventRetrieved, nil, dest) -} - -func (r *QueryImpl) updating(model, dest any) error { - return r.event(ormcontract.EventUpdating, model, dest) -} - -func (r *QueryImpl) updated(model, dest any) error { - return r.event(ormcontract.EventUpdated, model, dest) -} - -func (r *QueryImpl) saving(model, dest any) error { - return r.event(ormcontract.EventSaving, model, dest) -} - -func (r *QueryImpl) saved(model, dest any) error { - return r.event(ormcontract.EventSaved, model, dest) -} - -func (r *QueryImpl) creating(dest any) error { - return r.event(ormcontract.EventCreating, nil, dest) -} - -func (r *QueryImpl) created(dest any) error { - return r.event(ormcontract.EventCreated, nil, dest) -} - -func (r *QueryImpl) deleting(dest any) error { - return r.event(ormcontract.EventDeleting, nil, dest) -} - -func (r *QueryImpl) deleted(dest any) error { - return r.event(ormcontract.EventDeleted, nil, dest) -} - -func (r *QueryImpl) forceDeleting(dest any) error { - return r.event(ormcontract.EventForceDeleting, nil, dest) -} - -func (r *QueryImpl) forceDeleted(dest any) error { - return r.event(ormcontract.EventForceDeleted, nil, dest) -} - -func (r *QueryImpl) event(event ormcontract.EventType, model, dest any) error { - if r.withoutEvents { - return nil - } - - instance := NewEvent(r, model, dest) - - if dispatchesEvents, exist := dest.(ormcontract.DispatchesEvents); exist { - if event, exist := dispatchesEvents.DispatchesEvents()[event]; exist { - return event(instance) - } - - return nil - } - if model != nil { - if dispatchesEvents, exist := model.(ormcontract.DispatchesEvents); exist { - if event, exist := dispatchesEvents.DispatchesEvents()[event]; exist { - return event(instance) - } - } - - return nil - } - - if observer := observer(dest); observer != nil { - if observerEvent := observerEvent(event, observer); observerEvent != nil { - return observerEvent(instance) - } - - return nil - } - - if model != nil { - if observer := observer(model); observer != nil { - if observerEvent := observerEvent(event, observer); observerEvent != nil { - return observerEvent(instance) - } - - return nil - } - } - - return nil -} - func filterFindConditions(conds ...any) error { if len(conds) > 0 { switch cond := conds[0].(type) { @@ -1091,6 +1436,37 @@ func filterFindConditions(conds ...any) error { return nil } +func getModelConnection(model any) (string, error) { + value1 := reflect.ValueOf(model) + if value1.Kind() == reflect.Ptr && value1.IsNil() { + value1 = reflect.New(value1.Type().Elem()) + } + modelType := reflect.Indirect(value1).Type() + + if modelType.Kind() == reflect.Interface { + modelType = reflect.Indirect(reflect.ValueOf(model)).Elem().Type() + } + + for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + if modelType.PkgPath() == "" { + return "", errors.New("invalid model") + } + return "", fmt.Errorf("%s: %s.%s", "invalid model", modelType.PkgPath(), modelType.Name()) + } + + modelValue := reflect.New(modelType) + connectionModel, ok := modelValue.Interface().(ormcontract.ConnectionModel) + if !ok { + return "", nil + } + + return connectionModel.Connection(), nil +} + func observer(dest any) ormcontract.Observer { destType := reflect.TypeOf(dest) if destType.Kind() == reflect.Pointer { @@ -1138,18 +1514,3 @@ func observerEvent(event ormcontract.EventType, observer ormcontract.Observer) f return nil } - -func driver2gorm(driver string) string { - switch driver { - case "mysql": - return "mysql" - case "postgresql": - return "postgres" - case "sqlite": - return "sqlite" - case "sqlserver": - return "sqlserver" - default: - return "" - } -} diff --git a/database/gorm/query_test.go b/database/gorm/query_test.go index 465473807..65cc87ab2 100644 --- a/database/gorm/query_test.go +++ b/database/gorm/query_test.go @@ -14,327 +14,23 @@ import ( _ "gorm.io/driver/postgres" ormcontract "github.com/goravel/framework/contracts/database/orm" + contractstesting "github.com/goravel/framework/contracts/testing" databasedb "github.com/goravel/framework/database/db" "github.com/goravel/framework/database/orm" + configmocks "github.com/goravel/framework/mocks/config" supportdocker "github.com/goravel/framework/support/docker" "github.com/goravel/framework/support/env" "github.com/goravel/framework/support/file" ) -type contextKey int - -const testContextKey contextKey = 0 - -type User struct { - orm.Model - orm.SoftDeletes - Name string - Bio *string - Avatar string - Address *Address - Books []*Book - House *House `gorm:"polymorphic:Houseable"` - Phones []*Phone `gorm:"polymorphic:Phoneable"` - Roles []*Role `gorm:"many2many:role_user"` - age int -} - -func (u *User) DispatchesEvents() map[ormcontract.EventType]func(ormcontract.Event) error { - return map[ormcontract.EventType]func(ormcontract.Event) error{ - ormcontract.EventCreating: func(event ormcontract.Event) error { - name := event.GetAttribute("name") - if name != nil { - if name.(string) == "event_creating_name" { - event.SetAttribute("avatar", "event_creating_avatar") - } - if name.(string) == "event_creating_FirstOrCreate_name" { - event.SetAttribute("avatar", "event_creating_FirstOrCreate_avatar") - } - if name.(string) == "event_creating_IsDirty_name" { - if event.IsDirty("name") { - event.SetAttribute("avatar", "event_creating_IsDirty_avatar") - } - } - if name.(string) == "event_context" { - val := event.Context().Value(testContextKey) - event.SetAttribute("avatar", val.(string)) - } - if name.(string) == "event_query" { - _ = event.Query().Create(&User{Name: "event_query1"}) - } - } - - return nil - }, - ormcontract.EventCreated: func(event ormcontract.Event) error { - name := event.GetAttribute("name") - if name != nil { - if name.(string) == "event_created_name" { - event.SetAttribute("avatar", "event_created_avatar") - } - if name.(string) == "event_created_FirstOrCreate_name" { - event.SetAttribute("avatar", "event_created_FirstOrCreate_avatar") - } - } - - return nil - }, - ormcontract.EventSaving: func(event ormcontract.Event) error { - name := event.GetAttribute("name") - if name != nil { - if name.(string) == "event_saving_create_name" { - event.SetAttribute("avatar", "event_saving_create_avatar") - } - if name.(string) == "event_saving_save_name" { - event.SetAttribute("avatar", "event_saving_save_avatar") - } - if name.(string) == "event_saving_FirstOrCreate_name" { - event.SetAttribute("avatar", "event_saving_FirstOrCreate_avatar") - } - if name.(string) == "event_save_without_name" { - event.SetAttribute("avatar", "event_save_without_avatar") - } - if name.(string) == "event_save_quietly_name" { - event.SetAttribute("avatar", "event_save_quietly_avatar") - } - if name.(string) == "event_saving_IsDirty_name" { - if event.IsDirty("name") { - event.SetAttribute("avatar", "event_saving_IsDirty_avatar") - } - } - } - - avatar := event.GetAttribute("avatar") - if avatar != nil && avatar.(string) == "event_saving_single_update_avatar" { - event.SetAttribute("avatar", "event_saving_single_update_avatar1") - } - - return nil - }, - ormcontract.EventSaved: func(event ormcontract.Event) error { - name := event.GetAttribute("name") - if name != nil { - if name.(string) == "event_saved_create_name" { - event.SetAttribute("avatar", "event_saved_create_avatar") - } - if name.(string) == "event_saved_save_name" { - event.SetAttribute("avatar", "event_saved_save_avatar") - } - if name.(string) == "event_saved_FirstOrCreate_name" { - event.SetAttribute("avatar", "event_saved_FirstOrCreate_avatar") - } - if name.(string) == "event_save_without_name" { - event.SetAttribute("avatar", "event_saved_without_avatar") - } - if name.(string) == "event_save_quietly_name" { - event.SetAttribute("avatar", "event_saved_quietly_avatar") - } - } - - avatar := event.GetAttribute("avatar") - if avatar != nil && avatar.(string) == "event_saved_map_update_avatar" { - event.SetAttribute("avatar", "event_saved_map_update_avatar1") - } - - return nil - }, - ormcontract.EventUpdating: func(event ormcontract.Event) error { - name := event.GetAttribute("name") - if name != nil { - if name.(string) == "event_updating_create_name" { - event.SetAttribute("avatar", "event_updating_create_avatar") - } - if name.(string) == "event_updating_save_name" { - event.SetAttribute("avatar", "event_updating_save_avatar") - } - if name.(string) == "event_updating_single_update_IsDirty_name1" { - if event.IsDirty("name") { - name := event.GetAttribute("name") - if name != "event_updating_single_update_IsDirty_name1" { - return errors.New("error") - } - - event.SetAttribute("avatar", "event_updating_single_update_IsDirty_avatar") - } - } - if name.(string) == "event_updating_map_update_IsDirty_name1" { - if event.IsDirty("name") { - name := event.GetAttribute("name") - if name != "event_updating_map_update_IsDirty_name1" { - return errors.New("error") - } - - event.SetAttribute("avatar", "event_updating_map_update_IsDirty_avatar") - } - } - if name.(string) == "event_updating_model_update_IsDirty_name1" { - if event.IsDirty("name") { - name := event.GetAttribute("name") - if name != "event_updating_model_update_IsDirty_name1" { - return errors.New("error") - } - event.SetAttribute("avatar", "event_updating_model_update_IsDirty_avatar") - } - } - } - - avatar := event.GetAttribute("avatar") - if avatar != nil { - if avatar.(string) == "event_updating_save_avatar" { - event.SetAttribute("avatar", "event_updating_save_avatar1") - } - if avatar.(string) == "event_updating_model_update_avatar" { - event.SetAttribute("avatar", "event_updating_model_update_avatar1") - } - } - - return nil - }, - ormcontract.EventUpdated: func(event ormcontract.Event) error { - name := event.GetAttribute("name") - if name != nil { - if name.(string) == "event_updated_create_name" { - event.SetAttribute("avatar", "event_updated_create_avatar") - } - if name.(string) == "event_updated_save_name" { - event.SetAttribute("avatar", "event_updated_save_avatar") - } - } - - avatar := event.GetAttribute("avatar") - if avatar != nil { - if avatar.(string) == "event_updated_save_avatar" { - event.SetAttribute("avatar", "event_updated_save_avatar1") - } - if avatar.(string) == "event_updated_model_update_avatar" { - event.SetAttribute("avatar", "event_updated_model_update_avatar1") - } - } - - return nil - }, - ormcontract.EventDeleting: func(event ormcontract.Event) error { - name := event.GetAttribute("name") - if name != nil && name.(string) == "event_deleting_name" { - return errors.New("deleting error") - } - - return nil - }, - ormcontract.EventDeleted: func(event ormcontract.Event) error { - name := event.GetAttribute("name") - if name != nil && name.(string) == "event_deleted_name" { - return errors.New("deleted error") - } - - return nil - }, - ormcontract.EventForceDeleting: func(event ormcontract.Event) error { - name := event.GetAttribute("name") - if name != nil && name.(string) == "event_force_deleting_name" { - return errors.New("force deleting error") - } - - return nil - }, - ormcontract.EventForceDeleted: func(event ormcontract.Event) error { - name := event.GetAttribute("name") - if name != nil && name.(string) == "event_force_deleted_name" { - return errors.New("force deleted error") - } - - return nil - }, - ormcontract.EventRetrieved: func(event ormcontract.Event) error { - name := event.GetAttribute("name") - if name != nil && name.(string) == "event_retrieved_name" { - event.SetAttribute("name", "event_retrieved_name1") - } - - return nil - }, - } -} - -type Role struct { - orm.Model - Name string - Users []*User `gorm:"many2many:role_user"` -} - -type Address struct { - orm.Model - UserID uint - Name string - Province string - User *User -} - -type Book struct { - orm.Model - UserID uint - Name string - User *User - Author *Author -} - -type Author struct { - orm.Model - BookID uint - Name string -} - -type House struct { - orm.Model - Name string - HouseableID uint - HouseableType string -} - -func (h *House) Factory() string { - return "house" -} - -type Phone struct { - orm.Model - Name string - PhoneableID uint - PhoneableType string -} - -type Product struct { - orm.Model - orm.SoftDeletes - Name string -} - -func (p *Product) Connection() string { - return "postgresql" -} - -type Review struct { - orm.Model - orm.SoftDeletes - Body string -} - -func (r *Review) Connection() string { - return "" -} - -type Person struct { - orm.Model - orm.SoftDeletes - Name string -} - -func (p *Person) Connection() string { - return "dummy" -} - type QueryTestSuite struct { suite.Suite - queries map[ormcontract.Driver]ormcontract.Query + queries map[ormcontract.Driver]ormcontract.Query + mysqlDocker *MysqlDocker + mysqlDocker1 *MysqlDocker + postgresqlDocker *PostgresqlDocker + sqliteDocker *SqliteDocker + sqlserverDocker *SqlserverDocker } func TestQueryTestSuite(t *testing.T) { @@ -355,6 +51,12 @@ func TestQueryTestSuite(t *testing.T) { log.Fatalf("Init mysql error: %s", err) } + mysqlDocker1 := NewMysql1Docker(testDatabaseDocker) + _, err = mysqlDocker1.New() + if err != nil { + log.Fatalf("Init mysql1 error: %s", err) + } + postgresqlDocker := NewPostgresqlDocker(testDatabaseDocker) postgresqlQuery, err := postgresqlDocker.New() if err != nil { @@ -380,6 +82,11 @@ func TestQueryTestSuite(t *testing.T) { ormcontract.DriverSqlite: sqliteQuery, ormcontract.DriverSqlserver: sqlserverQuery, }, + mysqlDocker: mysqlDocker, + mysqlDocker1: mysqlDocker1, + postgresqlDocker: postgresqlDocker, + sqliteDocker: sqliteDocker, + sqlserverDocker: sqlserverDocker, }) } @@ -394,7 +101,7 @@ func (s *QueryTestSuite) TestAssociation() { { name: "Find", setup: func() { - user := &User{ + user := User{ Name: "association_find_name", Address: &Address{ Name: "association_find_address", @@ -419,7 +126,7 @@ func (s *QueryTestSuite) TestAssociation() { { name: "hasOne Append", setup: func() { - user := &User{ + user := User{ Name: "association_has_one_append_name", Address: &Address{ Name: "association_has_one_append_address", @@ -443,7 +150,7 @@ func (s *QueryTestSuite) TestAssociation() { { name: "hasMany Append", setup: func() { - user := &User{ + user := User{ Name: "association_has_many_append_name", Books: []*Book{ {Name: "association_has_many_append_address1"}, @@ -469,7 +176,7 @@ func (s *QueryTestSuite) TestAssociation() { { name: "hasOne Replace", setup: func() { - user := &User{ + user := User{ Name: "association_has_one_append_name", Address: &Address{ Name: "association_has_one_append_address", @@ -493,7 +200,7 @@ func (s *QueryTestSuite) TestAssociation() { { name: "hasMany Replace", setup: func() { - user := &User{ + user := User{ Name: "association_has_many_replace_name", Books: []*Book{ {Name: "association_has_many_replace_address1"}, @@ -519,7 +226,7 @@ func (s *QueryTestSuite) TestAssociation() { { name: "Delete", setup: func() { - user := &User{ + user := User{ Name: "association_delete_name", Address: &Address{ Name: "association_delete_address", @@ -555,7 +262,7 @@ func (s *QueryTestSuite) TestAssociation() { { name: "Clear", setup: func() { - user := &User{ + user := User{ Name: "association_clear_name", Address: &Address{ Name: "association_clear_address", @@ -579,7 +286,7 @@ func (s *QueryTestSuite) TestAssociation() { { name: "Count", setup: func() { - user := &User{ + user := User{ Name: "association_count_name", Books: []*Book{ {Name: "association_count_address1"}, @@ -653,11 +360,25 @@ func (s *QueryTestSuite) TestCount() { } func (s *QueryTestSuite) TestCreate() { - for _, query := range s.queries { + for driver, query := range s.queries { tests := []struct { name string setup func() }{ + { + name: "success when refresh connection", + setup: func() { + s.mockDummyConnection(driver) + + people := People{Body: "create_people"} + s.Nil(query.Create(&people)) + s.True(people.ID > 0) + + people1 := People{Body: "create_people1"} + s.Nil(query.Model(&People{}).Create(&people1)) + s.True(people1.ID > 0) + }, + }, { name: "success when create with no relationships", setup: func() { @@ -814,8 +535,35 @@ func (s *QueryTestSuite) TestCursor() { } } +func (s *QueryTestSuite) TestDBRaw() { + userName := "db_raw" + for driver, query := range s.queries { + s.Run(driver.String(), func() { + user := User{Name: userName} + + s.Nil(query.Create(&user)) + s.True(user.ID > 0) + switch driver { + case ormcontract.DriverSqlserver, ormcontract.DriverMysql: + res, err := query.Model(&user).Update("Name", databasedb.Raw("concat(name, ?)", driver.String())) + s.Nil(err) + s.Equal(int64(1), res.RowsAffected) + default: + res, err := query.Model(&user).Update("Name", databasedb.Raw("name || ?", driver.String())) + s.Nil(err) + s.Equal(int64(1), res.RowsAffected) + } + + var user1 User + s.Nil(query.Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.True(user1.Name == userName+driver.String()) + }) + } +} + func (s *QueryTestSuite) TestDelete() { - for _, query := range s.queries { + for driver, query := range s.queries { tests := []struct { name string setup func() @@ -836,6 +584,37 @@ func (s *QueryTestSuite) TestDelete() { s.Equal(uint(0), user1.ID) }, }, + { + name: "success when refresh connection", + setup: func() { + user := User{Name: "delete_user", Avatar: "delete_avatar"} + s.Nil(query.Create(&user)) + s.True(user.ID > 0) + + res, err := query.Delete(&user) + s.Equal(int64(1), res.RowsAffected) + s.Nil(err) + + var user1 User + s.Nil(query.Find(&user1, user.ID)) + s.Equal(uint(0), user1.ID) + + // refresh connection + s.mockDummyConnection(driver) + + people := People{Body: "delete_people"} + s.Nil(query.Create(&people)) + s.True(people.ID > 0) + + res, err = query.Delete(&people) + s.Equal(int64(1), res.RowsAffected) + s.Nil(err) + + var people1 People + s.Nil(query.Find(&people1, people.ID)) + s.Equal(uint(0), people1.ID) + }, + }, { name: "success by id", setup: func() { @@ -1522,6 +1301,7 @@ func (s *QueryTestSuite) TestEvent_Query() { user := User{Name: "event_query"} s.Nil(query.Create(&user)) s.True(user.ID > 0) + s.Equal("event_query", user.Name) var user1 User s.Nil(query.Where("name", "event_query1").Find(&user1)) @@ -1576,36 +1356,21 @@ func (s *QueryTestSuite) TestExists() { func (s *QueryTestSuite) TestFind() { for _, query := range s.queries { - tests := []struct { - name string - setup func() - }{ - { - name: "success", - setup: func() { - user := User{Name: "find_user"} - s.Nil(query.Create(&user)) - s.True(user.ID > 0) + user := User{Name: "find_user"} + s.Nil(query.Create(&user)) + s.True(user.ID > 0) - var user2 User - s.Nil(query.Find(&user2, user.ID)) - s.True(user2.ID > 0) + var user2 User + s.Nil(query.Find(&user2, user.ID)) + s.True(user2.ID > 0) - var user3 []User - s.Nil(query.Find(&user3, []uint{user.ID})) - s.Equal(1, len(user3)) + var user3 []User + s.Nil(query.Find(&user3, []uint{user.ID})) + s.Equal(1, len(user3)) - var user4 []User - s.Nil(query.Where("id in ?", []uint{user.ID}).Find(&user4)) - s.Equal(1, len(user4)) - }, - }, - } - for _, test := range tests { - s.Run(test.name, func() { - test.setup() - }) - } + var user4 []User + s.Nil(query.Where("id in ?", []uint{user.ID}).Find(&user4)) + s.Equal(1, len(user4)) } } @@ -1644,29 +1409,25 @@ func (s *QueryTestSuite) TestFindOrFail() { } func (s *QueryTestSuite) TestFirst() { - for _, query := range s.queries { - tests := []struct { - name string - setup func() - }{ - { - name: "success", - setup: func() { - user := User{Name: "first_user"} - s.Nil(query.Create(&user)) - s.True(user.ID > 0) + for driver, query := range s.queries { + user := User{Name: "first_user"} + s.Nil(query.Create(&user)) + s.True(user.ID > 0) - var user1 User - s.Nil(query.Where("name", "first_user").First(&user1)) - s.True(user1.ID > 0) - }, - }, - } - for _, test := range tests { - s.Run(test.name, func() { - test.setup() - }) - } + var user1 User + s.Nil(query.Where("name", "first_user").First(&user1)) + s.True(user1.ID > 0) + + // refresh connection + s.mockDummyConnection(driver) + + people := People{Body: "first_people"} + s.Nil(query.Create(&people)) + s.True(people.ID > 0) + + var people1 People + s.Nil(query.Where("id in ?", []uint{people.ID}).First(&people1)) + s.True(people1.ID > 0) } } @@ -1876,7 +1637,23 @@ func (s *QueryTestSuite) TestGet() { var user1 []User s.Nil(query.Where("id in ?", []uint{user.ID}).Get(&user1)) s.Equal(1, len(user1)) + + // refresh connection + s.mockDummyConnection(driver) + + people := People{Body: "get_people"} + s.Nil(query.Create(&people)) + s.True(people.ID > 0) + + var people1 []People + s.Nil(query.Where("id in ?", []uint{people.ID}).Get(&people1)) + s.Equal(1, len(people1)) + + var user2 []User + s.Nil(query.Where("id in ?", []uint{user.ID}).Get(&user2)) + s.Equal(1, len(user2)) }) + break } } @@ -2062,21 +1839,26 @@ func (s *QueryTestSuite) TestPaginate() { s.True(user3.ID > 0) var users []User - var total int64 s.Nil(query.Where("name = ?", "paginate_user").Paginate(1, 3, &users, nil)) s.Equal(3, len(users)) - s.Nil(query.Where("name = ?", "paginate_user").Paginate(2, 3, &users, &total)) - s.Equal(1, len(users)) - s.Equal(int64(4), total) - - s.Nil(query.Model(User{}).Where("name = ?", "paginate_user").Paginate(1, 3, &users, &total)) - s.Equal(3, len(users)) - s.Equal(int64(4), total) + var users1 []User + var total1 int64 + s.Nil(query.Where("name = ?", "paginate_user").Paginate(2, 3, &users1, &total1)) + s.Equal(1, len(users1)) + s.Equal(int64(4), total1) - s.Nil(query.Table("users").Where("name = ?", "paginate_user").Paginate(1, 3, &users, &total)) - s.Equal(3, len(users)) - s.Equal(int64(4), total) + var users2 []User + var total2 int64 + s.Nil(query.Model(User{}).Where("name = ?", "paginate_user").Paginate(1, 3, &users2, &total2)) + s.Equal(3, len(users2)) + s.Equal(int64(4), total2) + + var users3 []User + var total3 int64 + s.Nil(query.Table("users").Where("name = ?", "paginate_user").Paginate(1, 3, &users3, &total3)) + s.Equal(3, len(users3)) + s.Equal(int64(4), total3) }) } } @@ -2254,97 +2036,97 @@ func (s *QueryTestSuite) TestLimit() { } func (s *QueryTestSuite) TestLoad() { - for driver, query := range s.queries { - s.Run(driver.String(), func() { - user := User{Name: "load_user", Address: &Address{}, Books: []*Book{&Book{}, &Book{}}} - user.Address.Name = "load_address" - user.Books[0].Name = "load_book0" - user.Books[1].Name = "load_book1" - s.Nil(query.Select(orm.Associations).Create(&user)) - s.True(user.ID > 0) - s.True(user.Address.ID > 0) - s.True(user.Books[0].ID > 0) - s.True(user.Books[1].ID > 0) + for _, query := range s.queries { + user := User{Name: "load_user", Address: &Address{}, Books: []*Book{&Book{}, &Book{}}} + user.Address.Name = "load_address" + user.Books[0].Name = "load_book0" + user.Books[1].Name = "load_book1" + s.Nil(query.Select(orm.Associations).Create(&user)) + s.True(user.ID > 0) + s.True(user.Address.ID > 0) + s.True(user.Books[0].ID > 0) + s.True(user.Books[1].ID > 0) - tests := []struct { - description string - setup func(description string) - }{ - { - description: "simple load relationship", - setup: func(description string) { - var user1 User - s.Nil(query.Find(&user1, user.ID)) - s.True(user1.ID > 0) - s.Nil(user1.Address) - s.True(len(user1.Books) == 0) - s.Nil(query.Load(&user1, "Address")) - s.True(user1.Address.ID > 0) - s.True(len(user1.Books) == 0) - s.Nil(query.Load(&user1, "Books")) - s.True(user1.Address.ID > 0) - s.True(len(user1.Books) == 2) - }, + tests := []struct { + description string + setup func(description string) + }{ + { + description: "simple load relationship", + setup: func(description string) { + var user1 User + s.Nil(query.Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.True(len(user1.Books) == 0) + s.Nil(query.Load(&user1, "Address")) + s.True(user1.Address.ID > 0) + s.True(len(user1.Books) == 0) + s.Nil(query.Load(&user1, "Books")) + s.True(user1.Address.ID > 0) + s.True(len(user1.Books) == 2) }, - { - description: "load relationship with simple condition", - setup: func(description string) { - var user1 User - s.Nil(query.Find(&user1, user.ID)) - s.True(user1.ID > 0) - s.Nil(user1.Address) - s.Equal(0, len(user1.Books)) - s.Nil(query.Load(&user1, "Books", "name = ?", "load_book0")) - s.True(user1.ID > 0) - s.Nil(user1.Address) - s.Equal(1, len(user1.Books)) - s.Equal("load_book0", user.Books[0].Name) - }, + }, + { + description: "load relationship with simple condition", + setup: func(description string) { + var user1 User + s.Nil(query.Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.Equal(0, len(user1.Books)) + s.Nil(query.Load(&user1, "Books", "name = ?", "load_book0")) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.Equal(1, len(user1.Books)) + s.Equal("load_book0", user.Books[0].Name) }, - { - description: "load relationship with func condition", - setup: func(description string) { - var user1 User - s.Nil(query.Find(&user1, user.ID)) - s.True(user1.ID > 0) - s.Nil(user1.Address) - s.Equal(0, len(user1.Books)) - s.Nil(query.Load(&user1, "Books", func(query ormcontract.Query) ormcontract.Query { - return query.Where("name = ?", "load_book0") - })) - s.True(user1.ID > 0) - s.Nil(user1.Address) - s.Equal(1, len(user1.Books)) - s.Equal("load_book0", user.Books[0].Name) - }, + }, + { + description: "load relationship with func condition", + setup: func(description string) { + var user1 User + s.Nil(query.Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.Equal(0, len(user1.Books)) + s.Nil(query.Load(&user1, "Books", func(query ormcontract.Query) ormcontract.Query { + return query.Where("name = ?", "load_book0") + })) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.Equal(1, len(user1.Books)) + s.Equal("load_book0", user.Books[0].Name) }, - { - description: "error when relation is empty", - setup: func(description string) { - var user1 User - s.Nil(query.Find(&user1, user.ID)) - s.True(user1.ID > 0) - s.Nil(user1.Address) - s.Equal(0, len(user1.Books)) - s.EqualError(query.Load(&user1, ""), "relation cannot be empty") - }, + }, + { + description: "error when relation is empty", + setup: func(description string) { + var user1 User + s.Nil(query.Find(&user1, user.ID)) + s.True(user1.ID > 0) + s.Nil(user1.Address) + s.Equal(0, len(user1.Books)) + s.EqualError(query.Load(&user1, ""), "relation cannot be empty") }, - { - description: "error when id is nil", - setup: func(description string) { - type UserNoID struct { - Name string - Avatar string - } - var userNoID UserNoID - s.EqualError(query.Load(&userNoID, "Book"), "id cannot be empty") - }, + }, + { + description: "error when id is nil", + setup: func(description string) { + type UserNoID struct { + Name string + Avatar string + } + var userNoID UserNoID + s.EqualError(query.Load(&userNoID, "Book"), "id cannot be empty") }, - } - for _, test := range tests { + }, + } + for _, test := range tests { + s.Run(test.description, func() { test.setup(test.description) - } - }) + }) + } } } @@ -2419,6 +2201,107 @@ func (s *QueryTestSuite) TestRaw() { } } +func (s *QueryTestSuite) TestReuse() { + for _, query := range s.queries { + users := []User{{Name: "reuse_user", Avatar: "reuse_avatar"}, {Name: "reuse_user1", Avatar: "reuse_avatar1"}} + s.Nil(query.Create(&users)) + s.True(users[0].ID > 0) + s.True(users[1].ID > 0) + + q := query.Where("name", "reuse_user") + + var users1 User + s.Nil(q.Where("avatar", "reuse_avatar").Find(&users1)) + s.True(users1.ID > 0) + + var users2 User + s.Nil(q.Where("avatar", "reuse_avatar1").Find(&users2)) + s.True(users2.ID == 0) + + var users3 User + s.Nil(query.Where("avatar", "reuse_avatar1").Find(&users3)) + s.True(users3.ID > 0) + } +} + +func (s *QueryTestSuite) TestRefreshConnection() { + tests := []struct { + name string + model any + setup func() + expectConnection string + expectErr string + }{ + { + name: "invalid model", + model: func() any { + var product string + return product + }(), + setup: func() {}, + expectErr: "invalid model", + }, + { + name: "the connection of model is empty", + model: func() any { + var review Review + return review + }(), + setup: func() {}, + expectConnection: "mysql", + }, + { + name: "the connection of model is same as current connection", + model: func() any { + var box Box + return box + }(), + setup: func() {}, + expectConnection: "mysql", + }, + { + name: "connections are different, but drivers are same", + model: func() any { + var people People + return people + }(), + setup: func() { + mockDummyConnection(s.mysqlDocker.MockConfig, testDatabaseDocker.Mysql1.Config()) + }, + expectConnection: "dummy", + }, + { + name: "connections and drivers are different", + model: func() any { + var product Product + return product + }(), + setup: func() { + mockPostgresqlConnection(s.mysqlDocker.MockConfig, testDatabaseDocker.Postgresql.Config()) + }, + expectConnection: "postgresql", + }, + } + + for _, test := range tests { + s.Run(test.name, func() { + test.setup() + queryImpl := s.queries[ormcontract.DriverMysql].(*QueryImpl) + query, err := queryImpl.refreshConnection(test.model) + if test.expectErr != "" { + s.EqualError(err, test.expectErr) + } else { + s.Nil(err) + } + if test.expectConnection == "" { + s.Nil(query) + } else { + s.Equal(test.expectConnection, query.connection) + } + }) + } +} + func (s *QueryTestSuite) TestSave() { for _, query := range s.queries { tests := []struct { @@ -2463,31 +2346,16 @@ func (s *QueryTestSuite) TestSave() { func (s *QueryTestSuite) TestSaveQuietly() { for _, query := range s.queries { - tests := []struct { - name string - setup func() - }{ - { - name: "success", - setup: func() { - user := User{Name: "event_save_quietly_name", Avatar: "save_quietly_avatar"} - s.Nil(query.SaveQuietly(&user)) - s.True(user.ID > 0) - s.Equal("event_save_quietly_name", user.Name) - s.Equal("save_quietly_avatar", user.Avatar) + user := User{Name: "event_save_quietly_name", Avatar: "save_quietly_avatar"} + s.Nil(query.SaveQuietly(&user)) + s.True(user.ID > 0) + s.Equal("event_save_quietly_name", user.Name) + s.Equal("save_quietly_avatar", user.Avatar) - var user1 User - s.Nil(query.Find(&user1, user.ID)) - s.Equal("event_save_quietly_name", user1.Name) - s.Equal("save_quietly_avatar", user1.Avatar) - }, - }, - } - for _, test := range tests { - s.Run(test.name, func() { - test.setup() - }) - } + var user1 User + s.Nil(query.Find(&user1, user.ID)) + s.Equal("event_save_quietly_name", user1.Name) + s.Equal("save_quietly_avatar", user1.Avatar) } } @@ -2783,7 +2651,7 @@ func (s *QueryTestSuite) TestWhere() { var user2 []User s.Nil(query.Where("name = ?", "where_user").OrWhere("avatar = ?", "where_avatar1").Find(&user2)) - s.True(len(user2) > 0) + s.Equal(2, len(user2)) var user3 User s.Nil(query.Where("name = 'where_user'").Find(&user3)) @@ -3148,30 +3016,16 @@ func (s *QueryTestSuite) TestWithNesting() { } } -func (s *QueryTestSuite) TestDBRaw() { - userName := "db_raw" - for driver, query := range s.queries { - s.Run(driver.String(), func() { - user := User{Name: userName} - - s.Nil(query.Create(&user)) - s.True(user.ID > 0) - switch driver { - case ormcontract.DriverSqlserver, ormcontract.DriverMysql: - res, err := query.Model(&user).Update("Name", databasedb.Raw("concat(name, ?)", driver.String())) - s.Nil(err) - s.Equal(int64(1), res.RowsAffected) - default: - res, err := query.Model(&user).Update("Name", databasedb.Raw("name || ?", driver.String())) - s.Nil(err) - s.Equal(int64(1), res.RowsAffected) - } - - var user1 User - s.Nil(query.Find(&user1, user.ID)) - s.True(user1.ID > 0) - s.True(user1.Name == userName+driver.String()) - }) +func (s *QueryTestSuite) mockDummyConnection(driver ormcontract.Driver) { + switch driver { + case ormcontract.DriverMysql: + mockDummyConnection(s.mysqlDocker.MockConfig, testDatabaseDocker.Mysql1.Config()) + case ormcontract.DriverPostgresql: + mockDummyConnection(s.postgresqlDocker.MockConfig, testDatabaseDocker.Mysql1.Config()) + case ormcontract.DriverSqlite: + mockDummyConnection(s.sqliteDocker.MockConfig, testDatabaseDocker.Mysql1.Config()) + case ormcontract.DriverSqlserver: + mockDummyConnection(s.sqlserverDocker.MockConfig, testDatabaseDocker.Mysql1.Config()) } } @@ -3203,19 +3057,7 @@ func TestCustomConnection(t *testing.T) { assert.Nil(t, query.Where("body", "create_review").First(&review1)) assert.True(t, review1.ID > 0) - config := testDatabaseDocker.Postgresql.Config() - mysqlDocker.MockConfig.On("Get", "database.connections.postgresql.read").Return(nil) - mysqlDocker.MockConfig.On("Get", "database.connections.postgresql.write").Return(nil) - mysqlDocker.MockConfig.On("GetString", "database.connections.postgresql.host").Return("localhost") - mysqlDocker.MockConfig.On("GetString", "database.connections.postgresql.username").Return(config.Username) - mysqlDocker.MockConfig.On("GetString", "database.connections.postgresql.password").Return(config.Password) - mysqlDocker.MockConfig.On("GetString", "database.connections.postgresql.driver").Return(ormcontract.DriverPostgresql.String()) - mysqlDocker.MockConfig.On("GetString", "database.connections.postgresql.database").Return(config.Database) - mysqlDocker.MockConfig.On("GetString", "database.connections.postgresql.sslmode").Return("disable") - mysqlDocker.MockConfig.On("GetString", "database.connections.postgresql.timezone").Return("UTC") - mysqlDocker.MockConfig.On("GetString", "database.connections.postgresql.prefix").Return("") - mysqlDocker.MockConfig.On("GetBool", "database.connections.postgresql.singular").Return(false) - mysqlDocker.MockConfig.On("GetInt", "database.connections.postgresql.port").Return(config.Port) + mockPostgresqlConnection(mysqlDocker.MockConfig, testDatabaseDocker.Postgresql.Config()) product := Product{Name: "create_product"} assert.Nil(t, query.Create(&product)) @@ -3229,13 +3071,135 @@ func TestCustomConnection(t *testing.T) { assert.Nil(t, query.Where("name", "create_product1").First(&product2)) assert.True(t, product2.ID == 0) - mysqlDocker.MockConfig.On("GetString", "database.connections.dummy.driver").Return("") + mockDummyConnection(mysqlDocker.MockConfig, testDatabaseDocker.Mysql.Config()) person := Person{Name: "create_person"} assert.NotNil(t, query.Create(&person)) assert.True(t, person.ID == 0) } +func TestFilterFindConditions(t *testing.T) { + tests := []struct { + name string + conditions []any + expectErr error + }{ + { + name: "condition is empty", + }, + { + name: "condition is empty string", + conditions: []any{""}, + expectErr: ErrorMissingWhereClause, + }, + { + name: "condition is empty slice", + conditions: []any{[]string{}}, + expectErr: ErrorMissingWhereClause, + }, + { + name: "condition has value", + conditions: []any{"name = ?", "test"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := filterFindConditions(test.conditions...) + if test.expectErr != nil { + assert.Equal(t, err, test.expectErr) + } else { + assert.Nil(t, err) + } + }) + } +} + +func TestGetModelConnection(t *testing.T) { + tests := []struct { + name string + model any + expectErr string + expectConnection string + }{ + { + name: "invalid model", + model: func() any { + var product string + return product + }(), + expectErr: "invalid model", + }, + { + name: "not ConnectionModel", + model: func() any { + var phone Phone + return phone + }(), + }, + { + name: "the connection of model is empty", + model: func() any { + var review Review + return review + }(), + }, + { + name: "the connection of model is not empty", + model: func() any { + var product Product + return product + }(), + expectConnection: "postgresql", + }, + { + name: "the connection of model is not empty and model is slice", + model: func() any { + var products []Product + return products + }(), + expectConnection: "postgresql", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + connection, err := getModelConnection(test.model) + if test.expectErr != "" { + assert.EqualError(t, err, test.expectErr) + } else { + assert.Nil(t, err) + } + assert.Equal(t, test.expectConnection, connection) + }) + } +} + +func TestObserver(t *testing.T) { + orm.Observers = append(orm.Observers, orm.Observer{ + Model: User{}, + Observer: &UserObserver{}, + }) + + assert.Nil(t, observer(Product{})) + assert.Equal(t, &UserObserver{}, observer(User{})) +} + +func TestObserverEvent(t *testing.T) { + assert.EqualError(t, observerEvent(ormcontract.EventRetrieved, &UserObserver{})(nil), "retrieved") + assert.EqualError(t, observerEvent(ormcontract.EventCreating, &UserObserver{})(nil), "creating") + assert.EqualError(t, observerEvent(ormcontract.EventCreated, &UserObserver{})(nil), "created") + assert.EqualError(t, observerEvent(ormcontract.EventUpdating, &UserObserver{})(nil), "updating") + assert.EqualError(t, observerEvent(ormcontract.EventUpdated, &UserObserver{})(nil), "updated") + assert.EqualError(t, observerEvent(ormcontract.EventSaving, &UserObserver{})(nil), "saving") + assert.EqualError(t, observerEvent(ormcontract.EventSaved, &UserObserver{})(nil), "saved") + assert.EqualError(t, observerEvent(ormcontract.EventDeleting, &UserObserver{})(nil), "deleting") + assert.EqualError(t, observerEvent(ormcontract.EventDeleted, &UserObserver{})(nil), "deleted") + assert.EqualError(t, observerEvent(ormcontract.EventForceDeleting, &UserObserver{})(nil), "forceDeleting") + assert.EqualError(t, observerEvent(ormcontract.EventForceDeleted, &UserObserver{})(nil), "forceDeleted") + assert.Nil(t, observerEvent("error", &UserObserver{})) +} + func TestReadWriteSeparate(t *testing.T) { if env.IsWindows() { t.Skip("Skipping tests of using docker") @@ -3430,3 +3394,79 @@ func paginator(page string, limit string) func(methods ormcontract.Query) ormcon return query.Offset(offset).Limit(limit) } } + +func mockDummyConnection(mockConfig *configmocks.Config, databaseConfig contractstesting.DatabaseConfig) { + mockConfig.On("GetString", "database.connections.dummy.prefix").Return("") + mockConfig.On("GetBool", "database.connections.dummy.singular").Return(false) + mockConfig.On("Get", "database.connections.dummy.read").Return(nil) + mockConfig.On("Get", "database.connections.dummy.write").Return(nil) + mockConfig.On("GetString", "database.connections.dummy.host").Return("127.0.0.1") + mockConfig.On("GetString", "database.connections.dummy.username").Return(databaseConfig.Username) + mockConfig.On("GetString", "database.connections.dummy.password").Return(databaseConfig.Password) + mockConfig.On("GetInt", "database.connections.dummy.port").Return(databaseConfig.Port) + mockConfig.On("GetString", "database.connections.dummy.driver").Return(ormcontract.DriverMysql.String()) + mockConfig.On("GetString", "database.connections.dummy.charset").Return("utf8mb4") + mockConfig.On("GetString", "database.connections.dummy.loc").Return("Local") + mockConfig.On("GetString", "database.connections.dummy.database").Return(databaseConfig.Database) +} + +func mockPostgresqlConnection(mockConfig *configmocks.Config, databaseConfig contractstesting.DatabaseConfig) { + mockConfig.On("GetString", "database.connections.postgresql.prefix").Return("") + mockConfig.On("GetBool", "database.connections.postgresql.singular").Return(false) + mockConfig.On("Get", "database.connections.postgresql.read").Return(nil) + mockConfig.On("Get", "database.connections.postgresql.write").Return(nil) + mockConfig.On("GetString", "database.connections.postgresql.host").Return("127.0.0.1") + mockConfig.On("GetString", "database.connections.postgresql.username").Return(databaseConfig.Username) + mockConfig.On("GetString", "database.connections.postgresql.password").Return(databaseConfig.Password) + mockConfig.On("GetInt", "database.connections.postgresql.port").Return(databaseConfig.Port) + mockConfig.On("GetString", "database.connections.postgresql.driver").Return(ormcontract.DriverPostgresql.String()) + mockConfig.On("GetString", "database.connections.postgresql.sslmode").Return("disable") + mockConfig.On("GetString", "database.connections.postgresql.timezone").Return("UTC") + mockConfig.On("GetString", "database.connections.postgresql.database").Return(databaseConfig.Database) +} + +type UserObserver struct{} + +func (u *UserObserver) Retrieved(event ormcontract.Event) error { + return errors.New("retrieved") +} + +func (u *UserObserver) Creating(event ormcontract.Event) error { + return errors.New("creating") +} + +func (u *UserObserver) Created(event ormcontract.Event) error { + return errors.New("created") +} + +func (u *UserObserver) Updating(event ormcontract.Event) error { + return errors.New("updating") +} + +func (u *UserObserver) Updated(event ormcontract.Event) error { + return errors.New("updated") +} + +func (u *UserObserver) Saving(event ormcontract.Event) error { + return errors.New("saving") +} + +func (u *UserObserver) Saved(event ormcontract.Event) error { + return errors.New("saved") +} + +func (u *UserObserver) Deleting(event ormcontract.Event) error { + return errors.New("deleting") +} + +func (u *UserObserver) Deleted(event ormcontract.Event) error { + return errors.New("deleted") +} + +func (u *UserObserver) ForceDeleting(event ormcontract.Event) error { + return errors.New("forceDeleting") +} + +func (u *UserObserver) ForceDeleted(event ormcontract.Event) error { + return errors.New("forceDeleted") +} diff --git a/database/gorm/test_models.go b/database/gorm/test_models.go new file mode 100644 index 000000000..8a9bcec33 --- /dev/null +++ b/database/gorm/test_models.go @@ -0,0 +1,339 @@ +package gorm + +import ( + "errors" + + ormcontract "github.com/goravel/framework/contracts/database/orm" + "github.com/goravel/framework/database/orm" +) + +type contextKey int + +const testContextKey contextKey = 0 + +type User struct { + orm.Model + orm.SoftDeletes + Name string + Bio *string + Avatar string + Address *Address + Books []*Book + House *House `gorm:"polymorphic:Houseable"` + Phones []*Phone `gorm:"polymorphic:Phoneable"` + Roles []*Role `gorm:"many2many:role_user"` + age int +} + +func (u *User) DispatchesEvents() map[ormcontract.EventType]func(ormcontract.Event) error { + return map[ormcontract.EventType]func(ormcontract.Event) error{ + ormcontract.EventCreating: func(event ormcontract.Event) error { + name := event.GetAttribute("name") + if name != nil { + if name.(string) == "event_creating_name" { + event.SetAttribute("avatar", "event_creating_avatar") + } + if name.(string) == "event_creating_FirstOrCreate_name" { + event.SetAttribute("avatar", "event_creating_FirstOrCreate_avatar") + } + if name.(string) == "event_creating_IsDirty_name" { + if event.IsDirty("name") { + event.SetAttribute("avatar", "event_creating_IsDirty_avatar") + } + } + if name.(string) == "event_context" { + val := event.Context().Value(testContextKey) + event.SetAttribute("avatar", val.(string)) + } + if name.(string) == "event_query" { + _ = event.Query().Create(&User{Name: "event_query1"}) + } + } + + return nil + }, + ormcontract.EventCreated: func(event ormcontract.Event) error { + name := event.GetAttribute("name") + if name != nil { + if name.(string) == "event_created_name" { + event.SetAttribute("avatar", "event_created_avatar") + } + if name.(string) == "event_created_FirstOrCreate_name" { + event.SetAttribute("avatar", "event_created_FirstOrCreate_avatar") + } + } + + return nil + }, + ormcontract.EventSaving: func(event ormcontract.Event) error { + name := event.GetAttribute("name") + if name != nil { + if name.(string) == "event_saving_create_name" { + event.SetAttribute("avatar", "event_saving_create_avatar") + } + if name.(string) == "event_saving_save_name" { + event.SetAttribute("avatar", "event_saving_save_avatar") + } + if name.(string) == "event_saving_FirstOrCreate_name" { + event.SetAttribute("avatar", "event_saving_FirstOrCreate_avatar") + } + if name.(string) == "event_save_without_name" { + event.SetAttribute("avatar", "event_save_without_avatar") + } + if name.(string) == "event_save_quietly_name" { + event.SetAttribute("avatar", "event_save_quietly_avatar") + } + if name.(string) == "event_saving_IsDirty_name" { + if event.IsDirty("name") { + event.SetAttribute("avatar", "event_saving_IsDirty_avatar") + } + } + } + + avatar := event.GetAttribute("avatar") + if avatar != nil && avatar.(string) == "event_saving_single_update_avatar" { + event.SetAttribute("avatar", "event_saving_single_update_avatar1") + } + + return nil + }, + ormcontract.EventSaved: func(event ormcontract.Event) error { + name := event.GetAttribute("name") + if name != nil { + if name.(string) == "event_saved_create_name" { + event.SetAttribute("avatar", "event_saved_create_avatar") + } + if name.(string) == "event_saved_save_name" { + event.SetAttribute("avatar", "event_saved_save_avatar") + } + if name.(string) == "event_saved_FirstOrCreate_name" { + event.SetAttribute("avatar", "event_saved_FirstOrCreate_avatar") + } + if name.(string) == "event_save_without_name" { + event.SetAttribute("avatar", "event_saved_without_avatar") + } + if name.(string) == "event_save_quietly_name" { + event.SetAttribute("avatar", "event_saved_quietly_avatar") + } + } + + avatar := event.GetAttribute("avatar") + if avatar != nil && avatar.(string) == "event_saved_map_update_avatar" { + event.SetAttribute("avatar", "event_saved_map_update_avatar1") + } + + return nil + }, + ormcontract.EventUpdating: func(event ormcontract.Event) error { + name := event.GetAttribute("name") + if name != nil { + if name.(string) == "event_updating_create_name" { + event.SetAttribute("avatar", "event_updating_create_avatar") + } + if name.(string) == "event_updating_save_name" { + event.SetAttribute("avatar", "event_updating_save_avatar") + } + if name.(string) == "event_updating_single_update_IsDirty_name1" { + if event.IsDirty("name") { + name := event.GetAttribute("name") + if name != "event_updating_single_update_IsDirty_name1" { + return errors.New("error") + } + + event.SetAttribute("avatar", "event_updating_single_update_IsDirty_avatar") + } + } + if name.(string) == "event_updating_map_update_IsDirty_name1" { + if event.IsDirty("name") { + name := event.GetAttribute("name") + if name != "event_updating_map_update_IsDirty_name1" { + return errors.New("error") + } + + event.SetAttribute("avatar", "event_updating_map_update_IsDirty_avatar") + } + } + if name.(string) == "event_updating_model_update_IsDirty_name1" { + if event.IsDirty("name") { + name := event.GetAttribute("name") + if name != "event_updating_model_update_IsDirty_name1" { + return errors.New("error") + } + event.SetAttribute("avatar", "event_updating_model_update_IsDirty_avatar") + } + } + } + + avatar := event.GetAttribute("avatar") + if avatar != nil { + if avatar.(string) == "event_updating_save_avatar" { + event.SetAttribute("avatar", "event_updating_save_avatar1") + } + if avatar.(string) == "event_updating_model_update_avatar" { + event.SetAttribute("avatar", "event_updating_model_update_avatar1") + } + } + + return nil + }, + ormcontract.EventUpdated: func(event ormcontract.Event) error { + name := event.GetAttribute("name") + if name != nil { + if name.(string) == "event_updated_create_name" { + event.SetAttribute("avatar", "event_updated_create_avatar") + } + if name.(string) == "event_updated_save_name" { + event.SetAttribute("avatar", "event_updated_save_avatar") + } + } + + avatar := event.GetAttribute("avatar") + if avatar != nil { + if avatar.(string) == "event_updated_save_avatar" { + event.SetAttribute("avatar", "event_updated_save_avatar1") + } + if avatar.(string) == "event_updated_model_update_avatar" { + event.SetAttribute("avatar", "event_updated_model_update_avatar1") + } + } + + return nil + }, + ormcontract.EventDeleting: func(event ormcontract.Event) error { + name := event.GetAttribute("name") + if name != nil && name.(string) == "event_deleting_name" { + return errors.New("deleting error") + } + + return nil + }, + ormcontract.EventDeleted: func(event ormcontract.Event) error { + name := event.GetAttribute("name") + if name != nil && name.(string) == "event_deleted_name" { + return errors.New("deleted error") + } + + return nil + }, + ormcontract.EventForceDeleting: func(event ormcontract.Event) error { + name := event.GetAttribute("name") + if name != nil && name.(string) == "event_force_deleting_name" { + return errors.New("force deleting error") + } + + return nil + }, + ormcontract.EventForceDeleted: func(event ormcontract.Event) error { + name := event.GetAttribute("name") + if name != nil && name.(string) == "event_force_deleted_name" { + return errors.New("force deleted error") + } + + return nil + }, + ormcontract.EventRetrieved: func(event ormcontract.Event) error { + name := event.GetAttribute("name") + if name != nil && name.(string) == "event_retrieved_name" { + event.SetAttribute("name", "event_retrieved_name1") + } + + return nil + }, + } +} + +type Role struct { + orm.Model + Name string + Users []*User `gorm:"many2many:role_user"` +} + +type Address struct { + orm.Model + UserID uint + Name string + Province string + User *User +} + +type Book struct { + orm.Model + UserID uint + Name string + User *User + Author *Author +} + +type Author struct { + orm.Model + BookID uint + Name string +} + +type House struct { + orm.Model + Name string + HouseableID uint + HouseableType string +} + +func (h *House) Factory() string { + return "house" +} + +type Phone struct { + orm.Model + Name string + PhoneableID uint + PhoneableType string +} + +type Product struct { + orm.Model + orm.SoftDeletes + Name string +} + +func (p *Product) Connection() string { + return "postgresql" +} + +type Review struct { + orm.Model + orm.SoftDeletes + Body string +} + +func (r *Review) Connection() string { + return "" +} + +type People struct { + orm.Model + orm.SoftDeletes + Body string +} + +func (p *People) Connection() string { + return "dummy" +} + +type Person struct { + orm.Model + orm.SoftDeletes + Name string +} + +func (p *Person) Connection() string { + return "dummy" +} + +type Box struct { + orm.Model + orm.SoftDeletes + Name string +} + +func (p *Box) Connection() string { + return "mysql" +} diff --git a/database/gorm/test_utils.go b/database/gorm/test_utils.go index a22cdbd9d..a04c99d76 100644 --- a/database/gorm/test_utils.go +++ b/database/gorm/test_utils.go @@ -31,6 +31,12 @@ func NewMysqlDocker(database *supportdocker.Database) *MysqlDocker { return &MysqlDocker{MockConfig: &mocksconfig.Config{}, Port: config.Port, user: config.Username, password: config.Password, database: config.Database} } +func NewMysql1Docker(database *supportdocker.Database) *MysqlDocker { + config := database.Mysql1.Config() + + return &MysqlDocker{MockConfig: &mocksconfig.Config{}, Port: config.Port, user: config.Username, password: config.Password, database: config.Database} +} + func (r *MysqlDocker) New() (orm.Query, error) { r.mock() @@ -60,7 +66,7 @@ func (r *MysqlDocker) Query(createTable bool) (orm.Query, error) { } if createTable { - err := Table{}.Create(orm.DriverMysql, query) + err := Tables{}.Create(orm.DriverMysql, query) if err != nil { return nil, err } @@ -75,7 +81,7 @@ func (r *MysqlDocker) QueryWithPrefixAndSingular() (orm.Query, error) { return nil, errors.New("connect to mysql failed") } - err = Table{}.CreateWithPrefixAndSingular(orm.DriverMysql, query) + err = Tables{}.CreateWithPrefixAndSingular(orm.DriverMysql, query) if err != nil { return nil, err } @@ -175,7 +181,7 @@ func (r *PostgresqlDocker) Query(createTable bool) (orm.Query, error) { } if createTable { - err := Table{}.Create(orm.DriverPostgresql, query) + err := Tables{}.Create(orm.DriverPostgresql, query) if err != nil { return nil, err } @@ -190,7 +196,7 @@ func (r *PostgresqlDocker) QueryWithPrefixAndSingular() (orm.Query, error) { return nil, errors.New("connect to postgresql failed") } - err = Table{}.CreateWithPrefixAndSingular(orm.DriverPostgresql, query) + err = Tables{}.CreateWithPrefixAndSingular(orm.DriverPostgresql, query) if err != nil { return nil, err } @@ -284,7 +290,7 @@ func (r *SqliteDocker) Query(createTable bool) (orm.Query, error) { } if createTable { - err = Table{}.Create(orm.DriverSqlite, db) + err = Tables{}.Create(orm.DriverSqlite, db) if err != nil { return nil, err } @@ -299,7 +305,7 @@ func (r *SqliteDocker) QueryWithPrefixAndSingular() (orm.Query, error) { return nil, err } - err = Table{}.CreateWithPrefixAndSingular(orm.DriverSqlite, db) + err = Tables{}.CreateWithPrefixAndSingular(orm.DriverSqlite, db) if err != nil { return nil, err } @@ -391,7 +397,7 @@ func (r *SqlserverDocker) Query(createTable bool) (orm.Query, error) { } if createTable { - err := Table{}.Create(orm.DriverSqlserver, query) + err := Tables{}.Create(orm.DriverSqlserver, query) if err != nil { return nil, err } @@ -406,7 +412,7 @@ func (r *SqlserverDocker) QueryWithPrefixAndSingular() (orm.Query, error) { return nil, errors.New("connect to sqlserver failed") } - err = Table{}.CreateWithPrefixAndSingular(orm.DriverSqlserver, query) + err = Tables{}.CreateWithPrefixAndSingular(orm.DriverSqlserver, query) if err != nil { return nil, err } @@ -460,10 +466,10 @@ func (r *SqlserverDocker) mockOfCommon() { mockPool(r.MockConfig) } -type Table struct { +type Tables struct { } -func (r Table) Create(driver orm.Driver, db orm.Query) error { +func (r Tables) Create(driver orm.Driver, db orm.Query) error { _, err := db.Exec(r.createPeopleTable(driver)) if err != nil { return err @@ -512,7 +518,7 @@ func (r Table) Create(driver orm.Driver, db orm.Query) error { return nil } -func (r Table) CreateWithPrefixAndSingular(driver orm.Driver, db orm.Query) error { +func (r Tables) CreateWithPrefixAndSingular(driver orm.Driver, db orm.Query) error { _, err := db.Exec(r.createUserTableWithPrefixAndSingular(driver)) if err != nil { return err @@ -521,11 +527,11 @@ func (r Table) CreateWithPrefixAndSingular(driver orm.Driver, db orm.Query) erro return nil } -func (r Table) createPeopleTable(driver orm.Driver) string { +func (r Tables) createPeopleTable(driver orm.Driver) string { switch driver { case orm.DriverMysql: return ` -CREATE TABLE people ( +CREATE TABLE peoples ( id bigint(20) unsigned NOT NULL AUTO_INCREMENT, body varchar(255) NOT NULL, created_at datetime(3) NOT NULL, @@ -538,7 +544,7 @@ CREATE TABLE people ( ` case orm.DriverPostgresql: return ` -CREATE TABLE people ( +CREATE TABLE peoples ( id SERIAL PRIMARY KEY NOT NULL, body varchar(255) NOT NULL, created_at timestamp NOT NULL, @@ -548,7 +554,7 @@ CREATE TABLE people ( ` case orm.DriverSqlite: return ` -CREATE TABLE people ( +CREATE TABLE peoples ( id integer PRIMARY KEY AUTOINCREMENT NOT NULL, body varchar(255) NOT NULL, created_at datetime NOT NULL, @@ -558,7 +564,7 @@ CREATE TABLE people ( ` case orm.DriverSqlserver: return ` -CREATE TABLE people ( +CREATE TABLE peoples ( id bigint NOT NULL IDENTITY(1,1), body varchar(255) NOT NULL, created_at datetime NOT NULL, @@ -572,7 +578,7 @@ CREATE TABLE people ( } } -func (r Table) createReviewTable(driver orm.Driver) string { +func (r Tables) createReviewTable(driver orm.Driver) string { switch driver { case orm.DriverMysql: return ` @@ -623,7 +629,7 @@ CREATE TABLE reviews ( } } -func (r Table) createProductTable(driver orm.Driver) string { +func (r Tables) createProductTable(driver orm.Driver) string { switch driver { case orm.DriverMysql: return ` @@ -674,7 +680,7 @@ CREATE TABLE products ( } } -func (r Table) createUserTable(driver orm.Driver) string { +func (r Tables) createUserTable(driver orm.Driver) string { switch driver { case orm.DriverMysql: return ` @@ -733,7 +739,7 @@ CREATE TABLE users ( } } -func (r Table) createUserTableWithPrefixAndSingular(driver orm.Driver) string { +func (r Tables) createUserTableWithPrefixAndSingular(driver orm.Driver) string { switch driver { case orm.DriverMysql: return ` @@ -792,7 +798,7 @@ CREATE TABLE goravel_user ( } } -func (r Table) createAddressTable(driver orm.Driver) string { +func (r Tables) createAddressTable(driver orm.Driver) string { switch driver { case orm.DriverMysql: return ` @@ -847,7 +853,7 @@ CREATE TABLE addresses ( } } -func (r Table) createBookTable(driver orm.Driver) string { +func (r Tables) createBookTable(driver orm.Driver) string { switch driver { case orm.DriverMysql: return ` @@ -898,7 +904,7 @@ CREATE TABLE books ( } } -func (r Table) createAuthorTable(driver orm.Driver) string { +func (r Tables) createAuthorTable(driver orm.Driver) string { switch driver { case orm.DriverMysql: return ` @@ -949,7 +955,7 @@ CREATE TABLE authors ( } } -func (r Table) createRoleTable(driver orm.Driver) string { +func (r Tables) createRoleTable(driver orm.Driver) string { switch driver { case orm.DriverMysql: return ` @@ -996,7 +1002,7 @@ CREATE TABLE roles ( } } -func (r Table) createHouseTable(driver orm.Driver) string { +func (r Tables) createHouseTable(driver orm.Driver) string { switch driver { case orm.DriverMysql: return ` @@ -1051,7 +1057,7 @@ CREATE TABLE houses ( } } -func (r Table) createPhoneTable(driver orm.Driver) string { +func (r Tables) createPhoneTable(driver orm.Driver) string { switch driver { case orm.DriverMysql: return ` @@ -1106,7 +1112,7 @@ CREATE TABLE phones ( } } -func (r Table) createRoleUserTable(driver orm.Driver) string { +func (r Tables) createRoleUserTable(driver orm.Driver) string { switch driver { case orm.DriverMysql: return ` diff --git a/database/gorm/transaction.go b/database/gorm/transaction.go index 61bb91ac6..85ff9cf8b 100644 --- a/database/gorm/transaction.go +++ b/database/gorm/transaction.go @@ -12,11 +12,8 @@ type Transaction struct { instance *gorm.DB } -func NewTransaction(tx *gorm.DB, config config.Config) *Transaction { - return &Transaction{Query: NewQueryImplByInstance(tx, &QueryImpl{ - config: config, - withoutEvents: false, - }), instance: tx} +func NewTransaction(tx *gorm.DB, config config.Config, connection string) *Transaction { + return &Transaction{Query: NewQueryImpl(tx.Statement.Context, config, connection, tx, nil), instance: tx} } func (r *Transaction) Commit() error { diff --git a/database/gorm/wire.go b/database/gorm/wire.go index c1688e021..c43157861 100644 --- a/database/gorm/wire.go +++ b/database/gorm/wire.go @@ -23,7 +23,7 @@ func InitializeGorm(config config.Config, connection string) *GormImpl { //go:generate wire func InitializeQuery(ctx context.Context, config config.Config, connection string) (*QueryImpl, error) { - wire.Build(NewQueryImpl, GormSet, db.ConfigSet, DialectorSet) + wire.Build(BuildQueryImpl, GormSet, db.ConfigSet, DialectorSet) return nil, nil } diff --git a/database/gorm/wire_gen.go b/database/gorm/wire_gen.go index bb5259615..7147c077a 100644 --- a/database/gorm/wire_gen.go +++ b/database/gorm/wire_gen.go @@ -27,7 +27,7 @@ func InitializeQuery(ctx context.Context, config2 config.Config, connection stri configImpl := db.NewConfigImpl(config2, connection) dialectorImpl := NewDialectorImpl(config2, connection) gormImpl := NewGormImpl(config2, connection, configImpl, dialectorImpl) - queryImpl, err := NewQueryImpl(ctx, config2, gormImpl) + queryImpl, err := BuildQueryImpl(ctx, config2, connection, gormImpl) if err != nil { return nil, err } diff --git a/database/orm.go b/database/orm.go index e81b39dc2..3a23e06c6 100644 --- a/database/orm.go +++ b/database/orm.go @@ -40,9 +40,11 @@ func (r *OrmImpl) Connection(name string) ormcontract.Orm { } if instance, exist := r.queries[name]; exist { return &OrmImpl{ - ctx: r.ctx, - query: instance, - queries: r.queries, + ctx: r.ctx, + config: r.config, + connection: name, + query: instance, + queries: r.queries, } } @@ -56,9 +58,11 @@ func (r *OrmImpl) Connection(name string) ormcontract.Orm { r.queries[name] = queue return &OrmImpl{ - ctx: r.ctx, - query: queue, - queries: r.queries, + ctx: r.ctx, + config: r.config, + connection: name, + query: queue, + queries: r.queries, } } diff --git a/database/wire_gen.go b/database/wire_gen.go index 500c34a83..f78d32167 100644 --- a/database/wire_gen.go +++ b/database/wire_gen.go @@ -20,7 +20,7 @@ func InitializeOrm(ctx context.Context, config2 config.Config, connection string configImpl := db.NewConfigImpl(config2, connection) dialectorImpl := gorm.NewDialectorImpl(config2, connection) gormImpl := gorm.NewGormImpl(config2, connection, configImpl, dialectorImpl) - queryImpl, err := gorm.NewQueryImpl(ctx, config2, gormImpl) + queryImpl, err := gorm.BuildQueryImpl(ctx, config2, connection, gormImpl) if err != nil { return nil, err } diff --git a/go.mod b/go.mod index a8523a8b9..bc6c6f3b3 100644 --- a/go.mod +++ b/go.mod @@ -118,6 +118,7 @@ require ( go.opencensus.io v0.24.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.3.0 // indirect + golang.org/x/mod v0.10.0 // indirect golang.org/x/net v0.20.0 // indirect golang.org/x/oauth2 v0.15.0 // indirect golang.org/x/sync v0.6.0 // indirect diff --git a/log/formatter/general.go b/log/formatter/general.go index 74415605d..5052a2f78 100644 --- a/log/formatter/general.go +++ b/log/formatter/general.go @@ -76,7 +76,7 @@ func formatData(data logrus.Fields) (string, error) { return "", err } - builder.WriteString(fmt.Sprintf("%s: %v\n", key, string(v))) + builder.WriteString(fmt.Sprintf(`%s: %v\n"`, key, string(v))) } } diff --git a/queue/application_test.go b/queue/application_test.go index 576db3fe7..330cb6278 100644 --- a/queue/application_test.go +++ b/queue/application_test.go @@ -2,9 +2,11 @@ package queue import ( "context" + "errors" "testing" "time" + "github.com/spf13/cast" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" @@ -17,13 +19,15 @@ import ( ) var ( - testSyncJob = 0 - testAsyncJob = 0 - testDelayAsyncJob = 0 - testCustomAsyncJob = 0 - testErrorAsyncJob = 0 - testChainAsyncJob = 0 - testChainSyncJob = 0 + testSyncJob = 0 + testAsyncJob = 0 + testDelayAsyncJob = 0 + testCustomAsyncJob = 0 + testErrorAsyncJob = 0 + testChainAsyncJob = 0 + testChainSyncJob = 0 + testChainAsyncJobError = 0 + testChainSyncJobError = 0 ) type QueueTestSuite struct { @@ -255,6 +259,51 @@ func (s *QueueTestSuite) TestChainAsyncQueue() { s.mockConfig.AssertExpectations(s.T()) } +func (s *QueueTestSuite) TestChainAsyncQueue_Error() { + s.mockConfig.On("GetString", "queue.default").Return("redis").Times(2) + s.mockConfig.On("GetString", "app.name").Return("goravel").Times(4) + s.mockConfig.On("GetString", "queue.connections.redis.queue", "default").Return("default").Twice() + s.mockConfig.On("GetString", "queue.connections.redis.driver").Return("redis").Times(3) + s.mockConfig.On("GetString", "queue.connections.redis.connection").Return("default").Twice() + s.mockConfig.On("GetString", "database.redis.default.host").Return("localhost").Twice() + s.mockConfig.On("GetString", "database.redis.default.password").Return("").Twice() + s.mockConfig.On("GetInt", "database.redis.default.port").Return(s.port).Twice() + s.mockConfig.On("GetInt", "database.redis.default.database").Return(0).Twice() + s.app.jobs = []queue.Job{&TestChainAsyncJob{}, &TestChainSyncJob{}} + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + go func(ctx context.Context) { + s.Nil(s.app.Worker(&queue.Args{ + Queue: "chain", + }).Run()) + + for range ctx.Done() { + return + } + }(ctx) + + time.Sleep(2 * time.Second) + s.Nil(s.app.Chain([]queue.Jobs{ + { + Job: &TestChainAsyncJob{}, + Args: []queue.Arg{ + {Type: "bool", Value: true}, + }, + }, + { + Job: &TestChainSyncJob{}, + Args: []queue.Arg{}, + }, + }).OnQueue("chain").Dispatch()) + + time.Sleep(2 * time.Second) + s.Equal(1, testChainAsyncJobError) + s.Equal(0, testChainSyncJobError) + + s.mockConfig.AssertExpectations(s.T()) +} + type TestAsyncJob struct { } @@ -340,6 +389,12 @@ func (receiver *TestChainAsyncJob) Signature() string { // Handle Execute the job. func (receiver *TestChainAsyncJob) Handle(args ...any) error { + if len(args) > 0 && cast.ToBool(args[0]) { + testChainAsyncJobError++ + + return errors.New("error") + } + testChainAsyncJob++ return nil diff --git a/queue/task.go b/queue/task.go index d98c864b0..eadc6e020 100644 --- a/queue/task.go +++ b/queue/task.go @@ -68,13 +68,7 @@ func (receiver *Task) Dispatch() error { receiver.server = server if receiver.chain { - for _, job := range receiver.jobs { - if err := receiver.handleAsync(job.Job, job.Args); err != nil { - return err - } - } - - return nil + return receiver.handleChain(receiver.jobs) } else { job := receiver.jobs[0] @@ -110,6 +104,34 @@ func (receiver *Task) OnQueue(queue string) queue.Task { return receiver } +func (receiver *Task) handleChain(jobs []queue.Jobs) error { + var signatures []*tasks.Signature + for _, job := range jobs { + var realArgs []tasks.Arg + for _, arg := range job.Args { + realArgs = append(realArgs, tasks.Arg{ + Type: arg.Type, + Value: arg.Value, + }) + } + + signatures = append(signatures, &tasks.Signature{ + Name: job.Job.Signature(), + Args: realArgs, + ETA: receiver.delay, + }) + } + + chain, err := tasks.NewChain(signatures...) + if err != nil { + return err + } + + _, err = receiver.server.SendChain(chain) + + return err +} + func (receiver *Task) handleAsync(job queue.Job, args []queue.Arg) error { var realArgs []tasks.Arg for _, arg := range args { diff --git a/support/constant.go b/support/constant.go index 21b32e14f..063454200 100644 --- a/support/constant.go +++ b/support/constant.go @@ -1,6 +1,6 @@ package support -const Version string = "v1.13.4" +const Version string = "v1.13.7" const ( EnvRuntime = "runtime" diff --git a/support/docker/database.go b/support/docker/database.go index d355a0334..0569a81ff 100644 --- a/support/docker/database.go +++ b/support/docker/database.go @@ -8,6 +8,7 @@ const ( type Database struct { Mysql *Mysql + Mysql1 *Mysql Postgresql *Postgresql Sqlserver *Sqlserver Sqlite *Sqlite @@ -19,6 +20,11 @@ func InitDatabase() (*Database, error) { return nil, err } + mysql1Docker := NewMysql(database, username, password) + if err := mysql1Docker.Build(); err != nil { + return nil, err + } + postgresqlDocker := NewPostgresql(database, username, password) if err := postgresqlDocker.Build(); err != nil { return nil, err @@ -36,6 +42,7 @@ func InitDatabase() (*Database, error) { return &Database{ Mysql: mysqlDocker, + Mysql1: mysql1Docker, Postgresql: postgresqlDocker, Sqlserver: sqlserverDocker, Sqlite: sqliteDocker, diff --git a/testing/mock/log.go b/testing/mock/log.go index eb8d364e3..821fd8daa 100644 --- a/testing/mock/log.go +++ b/testing/mock/log.go @@ -12,116 +12,165 @@ import ( var _ log.Log = &TestLog{} type TestLog struct { + *TestLogWriter } func NewTestLog() log.Log { - return &TestLog{} + return &TestLog{ + TestLogWriter: NewTestLogWriter(), + } } func (r *TestLog) WithContext(ctx context.Context) log.Writer { - return r + return NewTestLogWriter() +} + +type TestLogWriter struct { + data map[string]any +} + +func NewTestLogWriter() *TestLogWriter { + return &TestLogWriter{ + data: make(map[string]any), + } } -func (r *TestLog) Debug(args ...any) { +func (r *TestLogWriter) Debug(args ...any) { fmt.Print(prefix("debug")) fmt.Println(args...) + r.printData() } -func (r *TestLog) Debugf(format string, args ...any) { +func (r *TestLogWriter) Debugf(format string, args ...any) { fmt.Print(prefix("debug")) fmt.Printf(format+"\n", args...) + r.printData() } -func (r *TestLog) Info(args ...any) { +func (r *TestLogWriter) Info(args ...any) { fmt.Print(prefix("info")) fmt.Println(args...) + r.printData() } -func (r *TestLog) Infof(format string, args ...any) { +func (r *TestLogWriter) Infof(format string, args ...any) { fmt.Print(prefix("info")) fmt.Printf(format+"\n", args...) + r.printData() } -func (r *TestLog) Warning(args ...any) { +func (r *TestLogWriter) Warning(args ...any) { fmt.Print(prefix("warning")) fmt.Println(args...) + r.printData() } -func (r *TestLog) Warningf(format string, args ...any) { +func (r *TestLogWriter) Warningf(format string, args ...any) { fmt.Print(prefix("warning")) fmt.Printf(format+"\n", args...) + r.printData() } -func (r *TestLog) Error(args ...any) { +func (r *TestLogWriter) Error(args ...any) { fmt.Print(prefix("error")) fmt.Println(args...) + r.printData() } -func (r *TestLog) Errorf(format string, args ...any) { +func (r *TestLogWriter) Errorf(format string, args ...any) { fmt.Print(prefix("error")) fmt.Printf(format+"\n", args...) + r.printData() } -func (r *TestLog) Fatal(args ...any) { +func (r *TestLogWriter) Fatal(args ...any) { fmt.Print(prefix("fatal")) fmt.Println(args...) + r.printData() } -func (r *TestLog) Fatalf(format string, args ...any) { +func (r *TestLogWriter) Fatalf(format string, args ...any) { fmt.Print(prefix("fatal")) fmt.Printf(format+"\n", args...) + r.printData() } -func (r *TestLog) Panic(args ...any) { +func (r *TestLogWriter) Panic(args ...any) { fmt.Print(prefix("panic")) fmt.Println(args...) + r.printData() } -func (r *TestLog) Panicf(format string, args ...any) { +func (r *TestLogWriter) Panicf(format string, args ...any) { fmt.Print(prefix("panic")) fmt.Printf(format+"\n", args...) + r.printData() } -func (r *TestLog) User(user any) log.Writer { +func (r *TestLogWriter) User(user any) log.Writer { + r.data["user"] = user + return r } -func (r *TestLog) Owner(owner any) log.Writer { +func (r *TestLogWriter) Owner(owner any) log.Writer { + r.data["owner"] = owner + return r } -func (r *TestLog) Hint(hint string) log.Writer { +func (r *TestLogWriter) Hint(hint string) log.Writer { + r.data["hint"] = hint + return r } -func (r *TestLog) Code(code string) log.Writer { +func (r *TestLogWriter) Code(code string) log.Writer { + r.data["code"] = code + return r } -func (r *TestLog) With(data map[string]any) log.Writer { +func (r *TestLogWriter) With(data map[string]any) log.Writer { + r.data["with"] = data + return r } -func (r *TestLog) Tags(tags ...string) log.Writer { +func (r *TestLogWriter) Tags(tags ...string) log.Writer { + r.data["tags"] = tags + return r } -func (r *TestLog) WithTrace() log.Writer { +func (r *TestLogWriter) WithTrace() log.Writer { return r } -func (r *TestLog) Request(req http.ContextRequest) log.Writer { +func (r *TestLogWriter) Request(req http.ContextRequest) log.Writer { + r.data["request"] = req + return r } -func (r *TestLog) Response(res http.ContextResponse) log.Writer { +func (r *TestLogWriter) Response(res http.ContextResponse) log.Writer { + r.data["response"] = res + return r } -func (r *TestLog) In(domain string) log.Writer { +func (r *TestLogWriter) In(domain string) log.Writer { + r.data["in"] = domain + return r } +func (r *TestLogWriter) printData() { + if len(r.data) > 0 { + fmt.Println(r.data) + } +} + func prefix(model string) string { timestamp := carbon.Now().ToDateTimeString()