From 6edae2e1824361256c1f32beb4f40c02f7e3cd16 Mon Sep 17 00:00:00 2001 From: Daniel Bedrood Date: Fri, 14 Oct 2022 13:43:47 +0200 Subject: [PATCH] chore: Update pkg/interceptors with gorm.v2 logger --- pkg/interceptors/database.go | 2 +- pkg/interceptors/database_logging.go | 27 +++---- pkg/interceptors/database_logging_test.go | 86 ++++++++++------------- 3 files changed, 55 insertions(+), 60 deletions(-) diff --git a/pkg/interceptors/database.go b/pkg/interceptors/database.go index 2c8dc03b..51ca31ae 100644 --- a/pkg/interceptors/database.go +++ b/pkg/interceptors/database.go @@ -4,8 +4,8 @@ import ( "context" grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" - "github.com/jinzhu/gorm" "google.golang.org/grpc" + "gorm.io/gorm" sdkcontext "github.com/scribd/go-sdk/pkg/context/database" sdkinstrumentation "github.com/scribd/go-sdk/pkg/instrumentation" diff --git a/pkg/interceptors/database_logging.go b/pkg/interceptors/database_logging.go index 0a0195e7..18273390 100644 --- a/pkg/interceptors/database_logging.go +++ b/pkg/interceptors/database_logging.go @@ -5,6 +5,7 @@ import ( grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" "google.golang.org/grpc" + "gorm.io/gorm" sdkdatabasecontext "github.com/scribd/go-sdk/pkg/context/database" sdkloggercontext "github.com/scribd/go-sdk/pkg/context/logger" @@ -19,10 +20,10 @@ func DatabaseLoggingUnaryServerInterceptor() grpc.UnaryServerInterceptor { return func( ctx context.Context, req interface{}, - info *grpc.UnaryServerInfo, + _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (interface{}, error) { - d, err := sdkdatabasecontext.Extract(ctx) + db, err := sdkdatabasecontext.Extract(ctx) if err != nil { return nil, err } @@ -32,11 +33,12 @@ func DatabaseLoggingUnaryServerInterceptor() grpc.UnaryServerInterceptor { return nil, err } - newDb := d.New() - newDb.LogMode(true) - newDb.SetLogger(sdklogger.NewGormLogger(l)) + newDB := db.Session(&gorm.Session{ + Logger: sdklogger.NewGormLogger(l), + NewDB: true, + }) - newCtx := sdkdatabasecontext.ToContext(ctx, newDb) + newCtx := sdkdatabasecontext.ToContext(ctx, newDB) return handler(newCtx, req) } @@ -50,10 +52,10 @@ func DatabaseLoggingStreamServerInterceptor() grpc.StreamServerInterceptor { return func( srv interface{}, stream grpc.ServerStream, - info *grpc.StreamServerInfo, + _ *grpc.StreamServerInfo, handler grpc.StreamHandler, ) error { - d, err := sdkdatabasecontext.Extract(stream.Context()) + db, err := sdkdatabasecontext.Extract(stream.Context()) if err != nil { return err } @@ -63,11 +65,12 @@ func DatabaseLoggingStreamServerInterceptor() grpc.StreamServerInterceptor { return err } - newDb := d.New() - newDb.LogMode(true) - newDb.SetLogger(sdklogger.NewGormLogger(l)) + newDB := db.Session(&gorm.Session{ + Logger: sdklogger.NewGormLogger(l), + NewDB: true, + }) - newCtx := sdkdatabasecontext.ToContext(stream.Context(), newDb) + newCtx := sdkdatabasecontext.ToContext(stream.Context(), newDB) wrapped := grpcmiddleware.WrapServerStream(stream) wrapped.WrappedContext = newCtx diff --git a/pkg/interceptors/database_logging_test.go b/pkg/interceptors/database_logging_test.go index 645541be..9c27b3a5 100644 --- a/pkg/interceptors/database_logging_test.go +++ b/pkg/interceptors/database_logging_test.go @@ -7,24 +7,20 @@ import ( "io" golog "log" "net" - "os" "path" "testing" - "github.com/jinzhu/gorm" - "google.golang.org/grpc/credentials/insecure" - - "github.com/scribd/go-sdk/pkg/testing/testproto" - - _ "github.com/mattn/go-sqlite3" - - sdktesting "github.com/scribd/go-sdk/pkg/testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + + sdktesting "github.com/scribd/go-sdk/pkg/testing" + "github.com/scribd/go-sdk/pkg/testing/testproto" ) type TestRecord struct { @@ -36,28 +32,26 @@ func TestDatabaseLoggingUnaryServerInterceptor(t *testing.T) { mt := mocktracer.Start() defer mt.Stop() - dbFile := path.Join(t.TempDir(), "test_db") - defer os.Remove(dbFile) + tempDBPath := path.Join(t.TempDir(), "test_db") - db, err := gorm.Open("sqlite3", dbFile) + db, err := gorm.Open(sqlite.Open(tempDBPath)) if err != nil { t.Fatalf("Failed to open DB: %s", err) } - defer db.Close() - var ( - testRecordOne = TestRecord{ID: 1, Name: "test_name"} - testRecordTwo = TestRecord{ID: 2, Name: "new_test_name"} - ) + testRecordOne := TestRecord{ID: 1, Name: "test_name"} + testRecordTwo := TestRecord{ID: 2, Name: "new_test_name"} - errors := db.Begin(). - CreateTable(TestRecord{}). - Create(testRecordOne). - Create(testRecordTwo). - Commit().GetErrors() + if err = db.AutoMigrate(TestRecord{}); err != nil { + t.Fatalf("Failed to migrate DB: %s", err) + } - for _, err := range errors { - t.Fatalf("Errors: %v", err) + err = db.Begin(). + Create(&testRecordOne). + Create(&testRecordTwo). + Commit().Error + if err != nil { + t.Fatalf("Failed to create record: %s", err) } var buffer bytes.Buffer @@ -112,28 +106,25 @@ func TestDatabaseLoggingStreamServerInterceptors(t *testing.T) { mt := mocktracer.Start() defer mt.Stop() - dbFile := path.Join(t.TempDir(), "test_db") - defer os.Remove(dbFile) - - db, err := gorm.Open("sqlite3", dbFile) + tempDBPath := path.Join(t.TempDir(), "test_db") + db, err := gorm.Open(sqlite.Open(tempDBPath)) if err != nil { t.Fatalf("Failed to open DB: %s", err) } - defer db.Close() - var ( - testRecordOne = TestRecord{ID: 1, Name: "test_name"} - testRecordTwo = TestRecord{ID: 2, Name: "new_test_name"} - ) + testRecordOne := TestRecord{ID: 1, Name: "test_name"} + testRecordTwo := TestRecord{ID: 2, Name: "new_test_name"} - errors := db.Begin(). - CreateTable(TestRecord{}). - Create(testRecordOne). - Create(testRecordTwo). - Commit().GetErrors() + if err = db.AutoMigrate(TestRecord{}); err != nil { + t.Fatalf("Failed to migrate DB: %s", err) + } - for _, err := range errors { - t.Fatalf("Errors: %v", err) + err = db.Begin(). + Create(&testRecordOne). + Create(&testRecordTwo). + Commit().Error + if err != nil { + t.Fatalf("Failed to create record: %s", err) } var buffer bytes.Buffer @@ -197,12 +188,13 @@ func TestDatabaseLoggingStreamServerInterceptors(t *testing.T) { } func checkGormLoggerFields(t *testing.T, fields map[string]interface{}) { - assert.NotEmpty(t, fields["sql"]) + assert.NotEmpty(t, fields["trace"]) - var sql = (fields["sql"]).(map[string]interface{}) + dbFields, ok := (fields["trace"]).(map[string]interface{}) + assert.True(t, ok, "%s not found in log fields", "trace") + assert.NotEmpty(t, dbFields) - assert.NotEmpty(t, sql["duration"]) - assert.NotEmpty(t, sql["affected_rows"]) - assert.NotEmpty(t, sql["file_location"]) - assert.NotNil(t, sql["values"]) + assert.NotEmpty(t, dbFields["elapsed_seconds"]) + assert.NotEmpty(t, dbFields["affected_rows"]) + assert.NotEmpty(t, dbFields["sql"]) }