From 66ef7baf99cc0338c817fabe283dfc75456ddfa9 Mon Sep 17 00:00:00 2001 From: Edoardo Rossi Date: Wed, 9 Sep 2020 16:07:59 +0200 Subject: [PATCH] feat: Inject gorm.DB in the context using the context/database package --- pkg/context/database/context.go | 38 ++++++++++++++++++++++++++++++ pkg/middleware/database.go | 9 ++++--- pkg/middleware/database_logging.go | 12 ++++------ 3 files changed, 46 insertions(+), 13 deletions(-) create mode 100644 pkg/context/database/context.go diff --git a/pkg/context/database/context.go b/pkg/context/database/context.go new file mode 100644 index 00000000..30c467c9 --- /dev/null +++ b/pkg/context/database/context.go @@ -0,0 +1,38 @@ +package database + +import ( + "context" + "fmt" + + "github.com/jinzhu/gorm" +) + +type ctxDatabaseMarker struct{} + +type ctxDatabase struct { + database *gorm.DB +} + +var ( + ctxDatabaseKey = &ctxDatabaseMarker{} +) + +// Extract takes the gorm.DB database from the context. +// If the ctxDatabase wasn't used, an error is returned. +func Extract(ctx context.Context) (*gorm.DB, error) { + d, ok := ctx.Value(ctxDatabaseKey).(*ctxDatabase) + if !ok || d == nil { + return nil, fmt.Errorf("Unable to get the database") + } + + return d.database, nil +} + +// ToContext adds the gorm.DB database to the context for extraction later. +// Returning the new context that has been created. +func ToContext(ctx context.Context, db *gorm.DB) context.Context { + d := &ctxDatabase{ + database: db, + } + return context.WithValue(ctx, ctxDatabaseKey, d) +} diff --git a/pkg/middleware/database.go b/pkg/middleware/database.go index b3e17dfa..f100e363 100644 --- a/pkg/middleware/database.go +++ b/pkg/middleware/database.go @@ -1,11 +1,10 @@ package middleware import ( - "context" "net/http" - "git.lo/microservices/sdk/go-sdk/pkg/contextkeys" - "git.lo/microservices/sdk/go-sdk/pkg/instrumentation" + sdkdatabasecontext "git.lo/microservices/sdk/go-sdk/pkg/context/database" + sdkinstrumentation "git.lo/microservices/sdk/go-sdk/pkg/instrumentation" "github.com/jinzhu/gorm" ) @@ -28,8 +27,8 @@ func NewDatabaseMiddleware(d *gorm.DB) DatabaseMiddleware { // connection pool to the request context. func (dm DatabaseMiddleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - db := instrumentation.TraceDatabase(r.Context(), dm.Database) - ctx := context.WithValue(r.Context(), contextkeys.Database, db) + db := sdkinstrumentation.TraceDatabase(r.Context(), dm.Database) + ctx := sdkdatabasecontext.ToContext(r.Context(), db) next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/pkg/middleware/database_logging.go b/pkg/middleware/database_logging.go index aba0a25e..79854360 100644 --- a/pkg/middleware/database_logging.go +++ b/pkg/middleware/database_logging.go @@ -1,7 +1,6 @@ package middleware import ( - "context" "database/sql/driver" "fmt" "net/http" @@ -10,12 +9,9 @@ import ( "time" "unicode" - "git.lo/microservices/sdk/go-sdk/pkg/contextkeys" - + sdkdatabasecontext "git.lo/microservices/sdk/go-sdk/pkg/context/database" sdkloggercontext "git.lo/microservices/sdk/go-sdk/pkg/context/logger" sdklogger "git.lo/microservices/sdk/go-sdk/pkg/logger" - - "github.com/jinzhu/gorm" ) // DatabaseLoggingMiddleware wraps an instantiated sdk.Logger that will be injected @@ -35,8 +31,8 @@ func NewDatabaseLoggingMiddleware() DatabaseLoggingMiddleware { // meta-information using the logger. func (dlm DatabaseLoggingMiddleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - d, ok := r.Context().Value(contextkeys.Database).(*gorm.DB) - if !ok { + d, err := sdkdatabasecontext.Extract(r.Context()) + if err != nil { http.Error(w, "Unable to get DB connection", http.StatusInternalServerError) return } @@ -50,7 +46,7 @@ func (dlm DatabaseLoggingMiddleware) Handler(next http.Handler) http.Handler { d.LogMode(true) d.SetLogger(newGormLogger(l)) - ctx := context.WithValue(r.Context(), contextkeys.Database, d) + ctx := sdkdatabasecontext.ToContext(r.Context(), d) next.ServeHTTP(w, r.WithContext(ctx)) }) }