Skip to content

Commit

Permalink
chore: Update pkg/interceptors with gorm.v2 functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
laynax committed Oct 12, 2022
1 parent c74f93b commit bb41729
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 60 deletions.
30 changes: 16 additions & 14 deletions pkg/interceptors/database_logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import (
"context"

grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
"google.golang.org/grpc"

sdkdatabasecontext "github.com/scribd/go-sdk/pkg/context/database"
sdkloggercontext "github.com/scribd/go-sdk/pkg/context/logger"
sdklogger "github.com/scribd/go-sdk/pkg/logger"
"google.golang.org/grpc"
"gorm.io/gorm"
)

// DatabaseLoggingUnaryServerInterceptor returns a unary server interceptor.
Expand All @@ -19,10 +19,10 @@ func DatabaseLoggingUnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
_ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
d, err := sdkdatabasecontext.Extract(ctx)
db, err := sdkdatabasecontext.Extract(ctx)
if err != nil {
return nil, err
}
Expand All @@ -32,11 +32,12 @@ func DatabaseLoggingUnaryServerInterceptor() grpc.UnaryServerInterceptor {
return nil, err
}

newDb := d.New()
newDb.LogMode(true)
newDb.SetLogger(sdklogger.NewGormLogger(l))
newDB := db.Session(&gorm.Session{
Logger: sdklogger.NewGormLogger(l),
NewDB: true,
})

newCtx := sdkdatabasecontext.ToContext(ctx, newDb)
newCtx := sdkdatabasecontext.ToContext(ctx, newDB)

return handler(newCtx, req)
}
Expand All @@ -50,10 +51,10 @@ func DatabaseLoggingStreamServerInterceptor() grpc.StreamServerInterceptor {
return func(
srv interface{},
stream grpc.ServerStream,
info *grpc.StreamServerInfo,
_ *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
d, err := sdkdatabasecontext.Extract(stream.Context())
db, err := sdkdatabasecontext.Extract(stream.Context())
if err != nil {
return err
}
Expand All @@ -63,11 +64,12 @@ func DatabaseLoggingStreamServerInterceptor() grpc.StreamServerInterceptor {
return err
}

newDb := d.New()
newDb.LogMode(true)
newDb.SetLogger(sdklogger.NewGormLogger(l))
newDB := db.Session(&gorm.Session{
Logger: sdklogger.NewGormLogger(l),
NewDB: true,
})

newCtx := sdkdatabasecontext.ToContext(stream.Context(), newDb)
newCtx := sdkdatabasecontext.ToContext(stream.Context(), newDB)
wrapped := grpcmiddleware.WrapServerStream(stream)
wrapped.WrappedContext = newCtx

Expand Down
85 changes: 39 additions & 46 deletions pkg/interceptors/database_logging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,20 @@ import (
"io"
golog "log"
"net"
"os"
"path"
"testing"

"github.com/jinzhu/gorm"
"google.golang.org/grpc/credentials/insecure"

"github.com/scribd/go-sdk/pkg/testing/testproto"

_ "github.com/mattn/go-sqlite3"

sdktesting "github.com/scribd/go-sdk/pkg/testing"

"github.com/scribd/go-sdk/pkg/logger"
"github.com/scribd/go-sdk/pkg/testing/testproto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)

type TestRecord struct {
Expand All @@ -36,28 +32,26 @@ func TestDatabaseLoggingUnaryServerInterceptor(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

dbFile := path.Join(t.TempDir(), "test_db")
defer os.Remove(dbFile)
tempDBPath := path.Join(t.TempDir(), "test_db")

db, err := gorm.Open("sqlite3", dbFile)
db, err := gorm.Open(sqlite.Open(tempDBPath))
if err != nil {
t.Fatalf("Failed to open DB: %s", err)
}
defer db.Close()

var (
testRecordOne = TestRecord{ID: 1, Name: "test_name"}
testRecordTwo = TestRecord{ID: 2, Name: "new_test_name"}
)
testRecordOne := TestRecord{ID: 1, Name: "test_name"}
testRecordTwo := TestRecord{ID: 2, Name: "new_test_name"}

errors := db.Begin().
CreateTable(TestRecord{}).
Create(testRecordOne).
Create(testRecordTwo).
Commit().GetErrors()
if err = db.AutoMigrate(TestRecord{}); err != nil {
t.Fatalf("Failed to migrate DB: %s", err)
}

for _, err := range errors {
t.Fatalf("Errors: %v", err)
err = db.Begin().
Create(&testRecordOne).
Create(&testRecordTwo).
Commit().Error
if err != nil {
t.Fatalf("Failed to create record: %s", err)
}

var buffer bytes.Buffer
Expand Down Expand Up @@ -112,28 +106,25 @@ func TestDatabaseLoggingStreamServerInterceptors(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

dbFile := path.Join(t.TempDir(), "test_db")
defer os.Remove(dbFile)

db, err := gorm.Open("sqlite3", dbFile)
tempDBPath := path.Join(t.TempDir(), "test_db")
db, err := gorm.Open(sqlite.Open(tempDBPath))
if err != nil {
t.Fatalf("Failed to open DB: %s", err)
}
defer db.Close()

var (
testRecordOne = TestRecord{ID: 1, Name: "test_name"}
testRecordTwo = TestRecord{ID: 2, Name: "new_test_name"}
)
testRecordOne := TestRecord{ID: 1, Name: "test_name"}
testRecordTwo := TestRecord{ID: 2, Name: "new_test_name"}

errors := db.Begin().
CreateTable(TestRecord{}).
Create(testRecordOne).
Create(testRecordTwo).
Commit().GetErrors()
if err = db.AutoMigrate(TestRecord{}); err != nil {
t.Fatalf("Failed to migrate DB: %s", err)
}

for _, err := range errors {
t.Fatalf("Errors: %v", err)
err = db.Begin().
Create(&testRecordOne).
Create(&testRecordTwo).
Commit().Error
if err != nil {
t.Fatalf("Failed to create record: %s", err)
}

var buffer bytes.Buffer
Expand Down Expand Up @@ -197,12 +188,14 @@ func TestDatabaseLoggingStreamServerInterceptors(t *testing.T) {
}

func checkGormLoggerFields(t *testing.T, fields map[string]interface{}) {
assert.NotEmpty(t, fields["sql"])
assert.NotEmpty(t, fields[logger.GormTraceFieldKey])

dbFields, ok := (fields[logger.GormTraceFieldKey]).(map[string]interface{})
assert.True(t, ok, "%s not found in log fields", logger.GormTraceFieldKey)
assert.NotEmpty(t, dbFields)

var sql = (fields["sql"]).(map[string]interface{})

assert.NotEmpty(t, sql["duration"])
assert.NotEmpty(t, sql["affected_rows"])
assert.NotEmpty(t, sql["file_location"])
assert.NotNil(t, sql["values"])
assert.NotEmpty(t, dbFields["elapsed seconds"])
assert.NotEmpty(t, dbFields["affected rows"])
assert.NotEmpty(t, dbFields["sql"])
}

0 comments on commit bb41729

Please sign in to comment.