diff --git a/pkg/instrumentation/database.go b/pkg/instrumentation/database.go index 19922c2f..73f72822 100644 --- a/pkg/instrumentation/database.go +++ b/pkg/instrumentation/database.go @@ -10,11 +10,13 @@ import ( "gorm.io/gorm" ) -const ( +type spanContextKey string + +var ( // ParentSpanGormKey is the name of the parent span key - ParentSpanGormKey = "tracingParentSpan" + parentSpanGormKey = spanContextKey("trancingParentSpan") // SpanGormKey is the name of the span key - SpanGormKey = "tracingSpan" + spanGormKey = spanContextKey("tracingSpan") ) // TraceDatabase sets span to gorm settings, returns cloned DB @@ -22,8 +24,12 @@ func TraceDatabase(ctx context.Context, db *gorm.DB) *gorm.DB { if ctx == nil { return db } + parentSpan, _ := ddtrace.SpanFromContext(ctx) - return db.Set(ParentSpanGormKey, parentSpan) + + return db.Session(&gorm.Session{ + Context: context.WithValue(db.Statement.Context, parentSpanGormKey, parentSpan), + }) } // InstrumentDatabase adds callbacks for tracing, call TraceDatabase to make it work @@ -58,12 +64,11 @@ func (c *callbacks) afterDelete(db *gorm.DB) { c.after(db) } func (c *callbacks) beforeRow(db *gorm.DB) { c.before(db, "", c.serviceName) } func (c *callbacks) afterRow(db *gorm.DB) { c.after(db) } func (c *callbacks) before(db *gorm.DB, operationName string, serviceName string) { - val, ok := db.Get(ParentSpanGormKey) + parentSpan, ok := db.Statement.Context.Value(parentSpanGormKey).(ddtrace.Span) if !ok { return } - parentSpan := val.(ddtrace.Span) spanOpts := []ddtrace.StartSpanOption{ ddtrace.ChildOf(parentSpan.Context()), ddtrace.SpanType(ext.SpanTypeSQL), @@ -73,16 +78,15 @@ func (c *callbacks) before(db *gorm.DB, operationName string, serviceName string operationName = strings.Split(db.Statement.SQL.String(), " ")[0] } sp := ddtrace.StartSpan(operationName, spanOpts...) - db.Set(SpanGormKey, sp) + db.Statement.Context = context.WithValue(db.Statement.Context, spanGormKey, sp) } func (c *callbacks) after(db *gorm.DB) { - val, ok := db.Get(SpanGormKey) + sp, ok := db.Statement.Context.Value(spanGormKey).(ddtrace.Span) if !ok { return } - sp := val.(ddtrace.Span) sp.SetTag(ext.ResourceName, strings.ToUpper(db.Statement.SQL.String())) sp.SetTag("db.table", db.Statement.Table) sp.SetTag("db.query", strings.ToUpper(db.Statement.SQL.String()))