diff --git a/pkg/instrumentation/database.go b/pkg/instrumentation/database.go index 19922c2f..931f217b 100644 --- a/pkg/instrumentation/database.go +++ b/pkg/instrumentation/database.go @@ -10,11 +10,13 @@ import ( "gorm.io/gorm" ) -const ( +type spanContextKey string + +var ( // ParentSpanGormKey is the name of the parent span key - ParentSpanGormKey = "tracingParentSpan" + parentSpanGormKey = spanContextKey("trancingParentSpan") // SpanGormKey is the name of the span key - SpanGormKey = "tracingSpan" + spanGormKey = spanContextKey("tracingSpan") ) // TraceDatabase sets span to gorm settings, returns cloned DB @@ -22,8 +24,12 @@ func TraceDatabase(ctx context.Context, db *gorm.DB) *gorm.DB { if ctx == nil { return db } + parentSpan, _ := ddtrace.SpanFromContext(ctx) - return db.Set(ParentSpanGormKey, parentSpan) + + return db.Session(&gorm.Session{ + Context: context.WithValue(db.Statement.Context, parentSpanGormKey, parentSpan), + }) } // InstrumentDatabase adds callbacks for tracing, call TraceDatabase to make it work @@ -58,12 +64,15 @@ func (c *callbacks) afterDelete(db *gorm.DB) { c.after(db) } func (c *callbacks) beforeRow(db *gorm.DB) { c.before(db, "", c.serviceName) } func (c *callbacks) afterRow(db *gorm.DB) { c.after(db) } func (c *callbacks) before(db *gorm.DB, operationName string, serviceName string) { - val, ok := db.Get(ParentSpanGormKey) + if db.Statement == nil || db.Statement.Context == nil { + return + } + + parentSpan, ok := db.Statement.Context.Value(parentSpanGormKey).(ddtrace.Span) if !ok { return } - parentSpan := val.(ddtrace.Span) spanOpts := []ddtrace.StartSpanOption{ ddtrace.ChildOf(parentSpan.Context()), ddtrace.SpanType(ext.SpanTypeSQL), @@ -73,16 +82,19 @@ func (c *callbacks) before(db *gorm.DB, operationName string, serviceName string operationName = strings.Split(db.Statement.SQL.String(), " ")[0] } sp := ddtrace.StartSpan(operationName, spanOpts...) - db.Set(SpanGormKey, sp) + db.Statement.Context = context.WithValue(db.Statement.Context, spanGormKey, sp) } func (c *callbacks) after(db *gorm.DB) { - val, ok := db.Get(SpanGormKey) + if db.Statement == nil || db.Statement.Context == nil { + return + } + + sp, ok := db.Statement.Context.Value(spanGormKey).(ddtrace.Span) if !ok { return } - sp := val.(ddtrace.Span) sp.SetTag(ext.ResourceName, strings.ToUpper(db.Statement.SQL.String())) sp.SetTag("db.table", db.Statement.Table) sp.SetTag("db.query", strings.ToUpper(db.Statement.SQL.String())) diff --git a/pkg/instrumentation/database_test.go b/pkg/instrumentation/database_test.go index 8d7ce542..de7d2e31 100644 --- a/pkg/instrumentation/database_test.go +++ b/pkg/instrumentation/database_test.go @@ -135,3 +135,18 @@ func TestInstrumentDatabase(t *testing.T) { } } } + +func TestTraceDatabase(t *testing.T) { + dbFile := path.Join(t.TempDir(), "test_db") + db, err := gorm.Open(sqlite.Open(dbFile)) + if err != nil { + t.Fatalf("Failed to open DB: %s", err) + } + + InstrumentDatabase(db, "test_app_name") + db = TraceDatabase(context.Background(), db) + + if sp := db.Statement.Context.Value(parentSpanGormKey); sp == nil { + t.Error("Parent span not set on statement") + } +}