diff --git a/otgorm/otgorm.go b/otgorm/otgorm.go index 75cc0d95..4fea5f45 100644 --- a/otgorm/otgorm.go +++ b/otgorm/otgorm.go @@ -18,7 +18,8 @@ func AddGormCallbacks(db *gorm.DB, tracer opentracing.Tracer) { registerCallbacks(db, "query", callbacks) registerCallbacks(db, "update", callbacks) registerCallbacks(db, "delete", callbacks) - registerCallbacks(db, "row_query", callbacks) + registerCallbacks(db, "row", callbacks) + registerCallbacks(db, "raw", callbacks) } type callbacks struct { @@ -29,16 +30,18 @@ func newCallbacks(tracer opentracing.Tracer) *callbacks { return &callbacks{tracer} } -func (c *callbacks) beforeCreate(scope *gorm.DB) { c.before(scope) } -func (c *callbacks) afterCreate(scope *gorm.DB) { c.after(scope, "INSERT") } -func (c *callbacks) beforeQuery(scope *gorm.DB) { c.before(scope) } -func (c *callbacks) afterQuery(scope *gorm.DB) { c.after(scope, "SELECT") } -func (c *callbacks) beforeUpdate(scope *gorm.DB) { c.before(scope) } -func (c *callbacks) afterUpdate(scope *gorm.DB) { c.after(scope, "UPDATE") } -func (c *callbacks) beforeDelete(scope *gorm.DB) { c.before(scope) } -func (c *callbacks) afterDelete(scope *gorm.DB) { c.after(scope, "DELETE") } -func (c *callbacks) beforeRowQuery(scope *gorm.DB) { c.before(scope) } -func (c *callbacks) afterRowQuery(scope *gorm.DB) { c.after(scope, "") } +func (c *callbacks) beforeCreate(scope *gorm.DB) { c.before(scope) } +func (c *callbacks) afterCreate(scope *gorm.DB) { c.after(scope, "INSERT") } +func (c *callbacks) beforeQuery(scope *gorm.DB) { c.before(scope) } +func (c *callbacks) afterQuery(scope *gorm.DB) { c.after(scope, "SELECT") } +func (c *callbacks) beforeUpdate(scope *gorm.DB) { c.before(scope) } +func (c *callbacks) afterUpdate(scope *gorm.DB) { c.after(scope, "UPDATE") } +func (c *callbacks) beforeDelete(scope *gorm.DB) { c.before(scope) } +func (c *callbacks) afterDelete(scope *gorm.DB) { c.after(scope, "DELETE") } +func (c *callbacks) beforeRow(scope *gorm.DB) { c.before(scope) } +func (c *callbacks) afterRow(scope *gorm.DB) { c.after(scope, "") } +func (c *callbacks) beforeRaw(scope *gorm.DB) { c.before(scope) } +func (c *callbacks) afterRaw(scope *gorm.DB) { c.after(scope, "") } func (c *callbacks) before(db *gorm.DB) { span, newCtx := opentracing.StartSpanFromContextWithTracer(db.Statement.Context, c.tracer, "sql") @@ -48,6 +51,7 @@ func (c *callbacks) before(db *gorm.DB) { } func (c *callbacks) after(db *gorm.DB, operation string) { + spanInterface, ok := db.Get("span") if !ok { return @@ -85,8 +89,11 @@ func registerCallbacks(db *gorm.DB, name string, c *callbacks) { case "delete": db.Callback().Delete().Before(gormCallbackName).Register(beforeName, c.beforeDelete) db.Callback().Delete().After(gormCallbackName).Register(afterName, c.afterDelete) - case "row_query": - db.Callback().Row().Before(gormCallbackName).Register(beforeName, c.beforeRowQuery) - db.Callback().Row().After(gormCallbackName).Register(afterName, c.afterRowQuery) + case "row": + db.Callback().Row().Before(gormCallbackName).Register(beforeName, c.beforeRow) + db.Callback().Row().After(gormCallbackName).Register(afterName, c.afterRow) + case "raw": + db.Callback().Raw().Before(gormCallbackName).Register(beforeName, c.beforeRaw) + db.Callback().Raw().After(gormCallbackName).Register(afterName, c.afterRaw) } } diff --git a/otgorm/otgorm_test.go b/otgorm/otgorm_test.go index b5e9c370..e4dbdf27 100644 --- a/otgorm/otgorm_test.go +++ b/otgorm/otgorm_test.go @@ -50,3 +50,55 @@ func TestHook(t *testing.T) { assert.True(t, interceptorCalled) } + +func TestHook_raw(t *testing.T) { + tracer := mocktracer.New() + out, cleanup, _ := provideDBFactory(&providersOption{ + drivers: map[string]func(dsn string) gorm.Dialector{"sqlite": sqlite.Open}, + })(factoryIn{ + Conf: config.MapAdapter{ + "gorm": map[string]interface{}{ + "default": map[string]interface{}{ + "database": "sqlite", + "dsn": ":memory:", + }, + }, + }, + Logger: log.NewNopLogger(), + Tracer: tracer, + }) + defer cleanup() + + factory := out.Factory + + db, err := factory.Make("default") + assert.NoError(t, err) + + _, ctx := opentracing.StartSpanFromContextWithTracer(context.Background(), tracer, "test") + + err = db.WithContext(ctx).Exec("CREATE TABLE test (id uint)").Error + assert.NoError(t, err) + + err = db.WithContext(ctx).Exec("INSERT INTO test (id) VALUES (1)").Error + assert.NoError(t, err) + + err = db.WithContext(ctx).Exec("INSERT INTO test (id) VALUES (2)").Error + assert.NoError(t, err) + + rows, err := db.WithContext(ctx).Raw("SELECT * FROM test").Rows() + assert.NoError(t, err) + + var models []mockModel + for rows.Next() { + var m mockModel + err = db.WithContext(ctx).ScanRows(rows, &m) + assert.NoError(t, err) + models = append(models, m) + } + t.Log(models) + + db.WithContext(ctx).Raw("SELECT * FROM test").Scan(&models) + t.Log(models) + + assert.Len(t, tracer.FinishedSpans(), 5) +}