diff --git a/pkg/transport/sqs/request_response_funcs.go b/pkg/transport/sqs/request_response_funcs.go index 2808a41..216d21e 100644 --- a/pkg/transport/sqs/request_response_funcs.go +++ b/pkg/transport/sqs/request_response_funcs.go @@ -5,7 +5,9 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "gorm.io/gorm" + sdkdatabasecontext "github.com/scribd/go-sdk/pkg/context/database" sdkloggercontext "github.com/scribd/go-sdk/pkg/context/logger" sdkmetricscontext "github.com/scribd/go-sdk/pkg/context/metrics" sdkrequestidcontext "github.com/scribd/go-sdk/pkg/context/requestid" @@ -108,3 +110,33 @@ func SetSubscriberMetrics(m sdkmetrics.Metrics) SubscriberRequestFunc { return sdkmetricscontext.ToContext(ctx, m) } } + +// SetSubscriberDatabase returns SubscriberRequestFunc that sets the GORM database to the request context. +func SetSubscriberDatabase(db *gorm.DB) SubscriberRequestFunc { + return func(ctx context.Context, cancel context.CancelFunc, message types.Message) context.Context { + return sdkdatabasecontext.ToContext(ctx, db) + } +} + +// SetSubscriberDatabaseLogging returns SubscriberRequestFunc that sets the SDL Logger to GORM database +// and sets new gorm DB to the request context. +func SetSubscriberDatabaseLogging() SubscriberRequestFunc { + return func(ctx context.Context, cancel context.CancelFunc, message types.Message) context.Context { + db, err := sdkdatabasecontext.Extract(ctx) + if err != nil { + return ctx + } + + l, err := sdkloggercontext.Extract(ctx) + if err != nil { + return ctx + } + + newDB := db.Session(&gorm.Session{ + Logger: sdklogger.NewGormLogger(l), + NewDB: true, + }) + + return sdkdatabasecontext.ToContext(ctx, newDB) + } +}