Skip to content

Commit

Permalink
chore: Update pkg/interceptors with gorm.v2 logger
Browse files Browse the repository at this point in the history
  • Loading branch information
laynax committed Oct 14, 2022
1 parent 393c674 commit 6edae2e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 60 deletions.
2 changes: 1 addition & 1 deletion pkg/interceptors/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"context"

grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/jinzhu/gorm"
"google.golang.org/grpc"
"gorm.io/gorm"

sdkcontext "github.com/scribd/go-sdk/pkg/context/database"
sdkinstrumentation "github.com/scribd/go-sdk/pkg/instrumentation"
Expand Down
27 changes: 15 additions & 12 deletions pkg/interceptors/database_logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
"google.golang.org/grpc"
"gorm.io/gorm"

sdkdatabasecontext "github.com/scribd/go-sdk/pkg/context/database"
sdkloggercontext "github.com/scribd/go-sdk/pkg/context/logger"
Expand All @@ -19,10 +20,10 @@ func DatabaseLoggingUnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
_ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
d, err := sdkdatabasecontext.Extract(ctx)
db, err := sdkdatabasecontext.Extract(ctx)
if err != nil {
return nil, err
}
Expand All @@ -32,11 +33,12 @@ func DatabaseLoggingUnaryServerInterceptor() grpc.UnaryServerInterceptor {
return nil, err
}

newDb := d.New()
newDb.LogMode(true)
newDb.SetLogger(sdklogger.NewGormLogger(l))
newDB := db.Session(&gorm.Session{
Logger: sdklogger.NewGormLogger(l),
NewDB: true,
})

newCtx := sdkdatabasecontext.ToContext(ctx, newDb)
newCtx := sdkdatabasecontext.ToContext(ctx, newDB)

return handler(newCtx, req)
}
Expand All @@ -50,10 +52,10 @@ func DatabaseLoggingStreamServerInterceptor() grpc.StreamServerInterceptor {
return func(
srv interface{},
stream grpc.ServerStream,
info *grpc.StreamServerInfo,
_ *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
d, err := sdkdatabasecontext.Extract(stream.Context())
db, err := sdkdatabasecontext.Extract(stream.Context())
if err != nil {
return err
}
Expand All @@ -63,11 +65,12 @@ func DatabaseLoggingStreamServerInterceptor() grpc.StreamServerInterceptor {
return err
}

newDb := d.New()
newDb.LogMode(true)
newDb.SetLogger(sdklogger.NewGormLogger(l))
newDB := db.Session(&gorm.Session{
Logger: sdklogger.NewGormLogger(l),
NewDB: true,
})

newCtx := sdkdatabasecontext.ToContext(stream.Context(), newDb)
newCtx := sdkdatabasecontext.ToContext(stream.Context(), newDB)
wrapped := grpcmiddleware.WrapServerStream(stream)
wrapped.WrappedContext = newCtx

Expand Down
86 changes: 39 additions & 47 deletions pkg/interceptors/database_logging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,20 @@ import (
"io"
golog "log"
"net"
"os"
"path"
"testing"

"github.com/jinzhu/gorm"
"google.golang.org/grpc/credentials/insecure"

"github.com/scribd/go-sdk/pkg/testing/testproto"

_ "github.com/mattn/go-sqlite3"

sdktesting "github.com/scribd/go-sdk/pkg/testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer"
"gorm.io/driver/sqlite"
"gorm.io/gorm"

sdktesting "github.com/scribd/go-sdk/pkg/testing"
"github.com/scribd/go-sdk/pkg/testing/testproto"
)

type TestRecord struct {
Expand All @@ -36,28 +32,26 @@ func TestDatabaseLoggingUnaryServerInterceptor(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

dbFile := path.Join(t.TempDir(), "test_db")
defer os.Remove(dbFile)
tempDBPath := path.Join(t.TempDir(), "test_db")

db, err := gorm.Open("sqlite3", dbFile)
db, err := gorm.Open(sqlite.Open(tempDBPath))
if err != nil {
t.Fatalf("Failed to open DB: %s", err)
}
defer db.Close()

var (
testRecordOne = TestRecord{ID: 1, Name: "test_name"}
testRecordTwo = TestRecord{ID: 2, Name: "new_test_name"}
)
testRecordOne := TestRecord{ID: 1, Name: "test_name"}
testRecordTwo := TestRecord{ID: 2, Name: "new_test_name"}

errors := db.Begin().
CreateTable(TestRecord{}).
Create(testRecordOne).
Create(testRecordTwo).
Commit().GetErrors()
if err = db.AutoMigrate(TestRecord{}); err != nil {
t.Fatalf("Failed to migrate DB: %s", err)
}

for _, err := range errors {
t.Fatalf("Errors: %v", err)
err = db.Begin().
Create(&testRecordOne).
Create(&testRecordTwo).
Commit().Error
if err != nil {
t.Fatalf("Failed to create record: %s", err)
}

var buffer bytes.Buffer
Expand Down Expand Up @@ -112,28 +106,25 @@ func TestDatabaseLoggingStreamServerInterceptors(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

dbFile := path.Join(t.TempDir(), "test_db")
defer os.Remove(dbFile)

db, err := gorm.Open("sqlite3", dbFile)
tempDBPath := path.Join(t.TempDir(), "test_db")
db, err := gorm.Open(sqlite.Open(tempDBPath))
if err != nil {
t.Fatalf("Failed to open DB: %s", err)
}
defer db.Close()

var (
testRecordOne = TestRecord{ID: 1, Name: "test_name"}
testRecordTwo = TestRecord{ID: 2, Name: "new_test_name"}
)
testRecordOne := TestRecord{ID: 1, Name: "test_name"}
testRecordTwo := TestRecord{ID: 2, Name: "new_test_name"}

errors := db.Begin().
CreateTable(TestRecord{}).
Create(testRecordOne).
Create(testRecordTwo).
Commit().GetErrors()
if err = db.AutoMigrate(TestRecord{}); err != nil {
t.Fatalf("Failed to migrate DB: %s", err)
}

for _, err := range errors {
t.Fatalf("Errors: %v", err)
err = db.Begin().
Create(&testRecordOne).
Create(&testRecordTwo).
Commit().Error
if err != nil {
t.Fatalf("Failed to create record: %s", err)
}

var buffer bytes.Buffer
Expand Down Expand Up @@ -197,12 +188,13 @@ func TestDatabaseLoggingStreamServerInterceptors(t *testing.T) {
}

func checkGormLoggerFields(t *testing.T, fields map[string]interface{}) {
assert.NotEmpty(t, fields["sql"])
assert.NotEmpty(t, fields["trace"])

var sql = (fields["sql"]).(map[string]interface{})
dbFields, ok := (fields["trace"]).(map[string]interface{})
assert.True(t, ok, "%s not found in log fields", "trace")
assert.NotEmpty(t, dbFields)

assert.NotEmpty(t, sql["duration"])
assert.NotEmpty(t, sql["affected_rows"])
assert.NotEmpty(t, sql["file_location"])
assert.NotNil(t, sql["values"])
assert.NotEmpty(t, dbFields["elapsed_seconds"])
assert.NotEmpty(t, dbFields["affected_rows"])
assert.NotEmpty(t, dbFields["sql"])
}

0 comments on commit 6edae2e

Please sign in to comment.