Skip to content

Commit

Permalink
feat: Inject gorm.DB in the context using the context/database package
Browse files Browse the repository at this point in the history
  • Loading branch information
Edoardo Rossi committed Sep 10, 2020
1 parent 2984490 commit 66ef7ba
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 13 deletions.
38 changes: 38 additions & 0 deletions pkg/context/database/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package database

import (
"context"
"fmt"

"github.com/jinzhu/gorm"
)

type ctxDatabaseMarker struct{}

type ctxDatabase struct {
database *gorm.DB
}

var (
ctxDatabaseKey = &ctxDatabaseMarker{}
)

// Extract takes the gorm.DB database from the context.
// If the ctxDatabase wasn't used, an error is returned.
func Extract(ctx context.Context) (*gorm.DB, error) {
d, ok := ctx.Value(ctxDatabaseKey).(*ctxDatabase)
if !ok || d == nil {
return nil, fmt.Errorf("Unable to get the database")
}

return d.database, nil
}

// ToContext adds the gorm.DB database to the context for extraction later.
// Returning the new context that has been created.
func ToContext(ctx context.Context, db *gorm.DB) context.Context {
d := &ctxDatabase{
database: db,
}
return context.WithValue(ctx, ctxDatabaseKey, d)
}
9 changes: 4 additions & 5 deletions pkg/middleware/database.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package middleware

import (
"context"
"net/http"

"git.lo/microservices/sdk/go-sdk/pkg/contextkeys"
"git.lo/microservices/sdk/go-sdk/pkg/instrumentation"
sdkdatabasecontext "git.lo/microservices/sdk/go-sdk/pkg/context/database"
sdkinstrumentation "git.lo/microservices/sdk/go-sdk/pkg/instrumentation"

"github.com/jinzhu/gorm"
)
Expand All @@ -28,8 +27,8 @@ func NewDatabaseMiddleware(d *gorm.DB) DatabaseMiddleware {
// connection pool to the request context.
func (dm DatabaseMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
db := instrumentation.TraceDatabase(r.Context(), dm.Database)
ctx := context.WithValue(r.Context(), contextkeys.Database, db)
db := sdkinstrumentation.TraceDatabase(r.Context(), dm.Database)
ctx := sdkdatabasecontext.ToContext(r.Context(), db)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
12 changes: 4 additions & 8 deletions pkg/middleware/database_logging.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package middleware

import (
"context"
"database/sql/driver"
"fmt"
"net/http"
Expand All @@ -10,12 +9,9 @@ import (
"time"
"unicode"

"git.lo/microservices/sdk/go-sdk/pkg/contextkeys"

sdkdatabasecontext "git.lo/microservices/sdk/go-sdk/pkg/context/database"
sdkloggercontext "git.lo/microservices/sdk/go-sdk/pkg/context/logger"
sdklogger "git.lo/microservices/sdk/go-sdk/pkg/logger"

"github.com/jinzhu/gorm"
)

// DatabaseLoggingMiddleware wraps an instantiated sdk.Logger that will be injected
Expand All @@ -35,8 +31,8 @@ func NewDatabaseLoggingMiddleware() DatabaseLoggingMiddleware {
// meta-information using the logger.
func (dlm DatabaseLoggingMiddleware) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
d, ok := r.Context().Value(contextkeys.Database).(*gorm.DB)
if !ok {
d, err := sdkdatabasecontext.Extract(r.Context())
if err != nil {
http.Error(w, "Unable to get DB connection", http.StatusInternalServerError)
return
}
Expand All @@ -50,7 +46,7 @@ func (dlm DatabaseLoggingMiddleware) Handler(next http.Handler) http.Handler {
d.LogMode(true)
d.SetLogger(newGormLogger(l))

ctx := context.WithValue(r.Context(), contextkeys.Database, d)
ctx := sdkdatabasecontext.ToContext(r.Context(), d)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
Expand Down

0 comments on commit 66ef7ba

Please sign in to comment.