From 3006ed693ff1e60310dab51eccb7d41b12a7e43d Mon Sep 17 00:00:00 2001 From: Rohan Raj Date: Wed, 22 Mar 2023 01:05:39 +0530 Subject: [PATCH 1/8] feat: add error logging in validator used logging.Logger interface to add error logging in validator interceptor addition: #494 --- interceptors/validator/validator.go | 29 ++++++++----- interceptors/validator/validator_test.go | 53 +++++++++++++++--------- 2 files changed, 52 insertions(+), 30 deletions(-) diff --git a/interceptors/validator/validator.go b/interceptors/validator/validator.go index d19f48b69..9aceef75b 100644 --- a/interceptors/validator/validator.go +++ b/interceptors/validator/validator.go @@ -9,6 +9,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" ) // The validateAller interface at protoc-gen-validate main branch. @@ -28,20 +30,23 @@ type validatorLegacy interface { Validate() error } -func validate(req any, all bool) error { +func validate(req interface{}, all bool, l logging.Logger) error { if all { switch v := req.(type) { case validateAller: if err := v.ValidateAll(); err != nil { + l.Log(logging.ERROR, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } case validator: if err := v.Validate(true); err != nil { + l.Log(logging.ERROR, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } case validatorLegacy: // Fallback to legacy validator if err := v.Validate(); err != nil { + l.Log(logging.ERROR, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } } @@ -50,10 +55,12 @@ func validate(req any, all bool) error { switch v := req.(type) { case validatorLegacy: if err := v.Validate(); err != nil { + l.Log(logging.ERROR, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } case validator: if err := v.Validate(false); err != nil { + l.Log(logging.ERROR, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } } @@ -67,9 +74,9 @@ func validate(req any, all bool) error { // returns ALL validation error as a wrapped multi-error. // Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation // interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. -func UnaryServerInterceptor(all bool) grpc.UnaryServerInterceptor { - return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { - if err := validate(req, all); err != nil { +func UnaryServerInterceptor(all bool, logger logging.Logger) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if err := validate(req, all, logger); err != nil { return nil, err } return handler(ctx, req) @@ -83,9 +90,9 @@ func UnaryServerInterceptor(all bool) grpc.UnaryServerInterceptor { // returns ALL validation error as a wrapped multi-error. // Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation // interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. -func UnaryClientInterceptor(all bool) grpc.UnaryClientInterceptor { - return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - if err := validate(req, all); err != nil { +func UnaryClientInterceptor(all bool, logger logging.Logger) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if err := validate(req, all, logger); err != nil { return err } return invoker(ctx, method, req, reply, cc, opts...) @@ -102,11 +109,12 @@ func UnaryClientInterceptor(all bool) grpc.UnaryClientInterceptor { // type of the RPC. For `ServerStream` (1:m) requests, it will happen before reaching any userspace // handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on // calls to `stream.Recv()`. -func StreamServerInterceptor(all bool) grpc.StreamServerInterceptor { - return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { +func StreamServerInterceptor(all bool, logger logging.Logger) grpc.StreamServerInterceptor { + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { wrapper := &recvWrapper{ all: all, ServerStream: stream, + Logger: logger, } return handler(srv, wrapper) } @@ -115,13 +123,14 @@ func StreamServerInterceptor(all bool) grpc.StreamServerInterceptor { type recvWrapper struct { all bool grpc.ServerStream + logging.Logger } func (s *recvWrapper) RecvMsg(m any) error { if err := s.ServerStream.RecvMsg(m); err != nil { return err } - if err := validate(m, s.all); err != nil { + if err := validate(m, s.all, s.Logger); err != nil { return err } return nil diff --git a/interceptors/validator/validator_test.go b/interceptors/validator/validator_test.go index b5120aae4..df48ae415 100644 --- a/interceptors/validator/validator_test.go +++ b/interceptors/validator/validator_test.go @@ -11,34 +11,47 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb" ) +type Logger struct { +} + +func (l *Logger) Log(lvl logging.Level, msg string) {} + +func (l *Logger) With(fields ...string) logging.Logger { + return &Logger{} +} + func TestValidateWrapper(t *testing.T) { - assert.NoError(t, validate(testpb.GoodPing, false)) - assert.Error(t, validate(testpb.BadPing, false)) - assert.NoError(t, validate(testpb.GoodPing, true)) - assert.Error(t, validate(testpb.BadPing, true)) - - assert.NoError(t, validate(testpb.GoodPingError, false)) - assert.Error(t, validate(testpb.BadPingError, false)) - assert.NoError(t, validate(testpb.GoodPingError, true)) - assert.Error(t, validate(testpb.BadPingError, true)) - - assert.NoError(t, validate(testpb.GoodPingResponse, false)) - assert.NoError(t, validate(testpb.GoodPingResponse, true)) - assert.Error(t, validate(testpb.BadPingResponse, false)) - assert.Error(t, validate(testpb.BadPingResponse, true)) + assert.NoError(t, validate(testpb.GoodPing, false, &Logger{})) + assert.Error(t, validate(testpb.BadPing, false, &Logger{})) + assert.NoError(t, validate(testpb.GoodPing, true, &Logger{})) + assert.Error(t, validate(testpb.BadPing, true, &Logger{})) + + assert.NoError(t, validate(testpb.GoodPingError, false, &Logger{})) + assert.Error(t, validate(testpb.BadPingError, false, &Logger{})) + assert.NoError(t, validate(testpb.GoodPingError, true, &Logger{})) + assert.Error(t, validate(testpb.BadPingError, true, &Logger{})) + + assert.NoError(t, validate(testpb.GoodPingResponse, false, &Logger{})) + assert.NoError(t, validate(testpb.GoodPingResponse, true, &Logger{})) + assert.Error(t, validate(testpb.BadPingResponse, false, &Logger{})) + assert.Error(t, validate(testpb.BadPingResponse, true, &Logger{})) } func TestValidatorTestSuite(t *testing.T) { s := &ValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ServerOpts: []grpc.ServerOption{ - grpc.StreamInterceptor(StreamServerInterceptor(false)), - grpc.UnaryInterceptor(UnaryServerInterceptor(false)), + grpc.StreamInterceptor(StreamServerInterceptor(false, &Logger{})), + grpc.UnaryInterceptor(UnaryServerInterceptor(false, &Logger{})), }, }, } @@ -46,8 +59,8 @@ func TestValidatorTestSuite(t *testing.T) { sAll := &ValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ServerOpts: []grpc.ServerOption{ - grpc.StreamInterceptor(StreamServerInterceptor(true)), - grpc.UnaryInterceptor(UnaryServerInterceptor(true)), + grpc.StreamInterceptor(StreamServerInterceptor(true, &Logger{})), + grpc.UnaryInterceptor(UnaryServerInterceptor(true, &Logger{})), }, }, } @@ -56,7 +69,7 @@ func TestValidatorTestSuite(t *testing.T) { cs := &ClientValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ClientOpts: []grpc.DialOption{ - grpc.WithUnaryInterceptor(UnaryClientInterceptor(false)), + grpc.WithUnaryInterceptor(UnaryClientInterceptor(false, &Logger{})), }, }, } @@ -64,7 +77,7 @@ func TestValidatorTestSuite(t *testing.T) { csAll := &ClientValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ClientOpts: []grpc.DialOption{ - grpc.WithUnaryInterceptor(UnaryClientInterceptor(true)), + grpc.WithUnaryInterceptor(UnaryClientInterceptor(true, &Logger{})), }, }, } From 035f67baeb0998b7a0a152b4b9d41511fe9834e7 Mon Sep 17 00:00:00 2001 From: Rohan Raj Date: Fri, 24 Mar 2023 03:11:14 +0530 Subject: [PATCH 2/8] feat: update interceptor implementation made fast fail and logger as optional args addition to that instead of providing values dynamically at the time of initialization made it more dynamic --- interceptors/validator/options.go | 69 ++++++++++++++++++++++++ interceptors/validator/validator.go | 68 ++++++++++------------- interceptors/validator/validator_test.go | 3 +- 3 files changed, 97 insertions(+), 43 deletions(-) create mode 100644 interceptors/validator/options.go diff --git a/interceptors/validator/options.go b/interceptors/validator/options.go new file mode 100644 index 000000000..eb6a4cfea --- /dev/null +++ b/interceptors/validator/options.go @@ -0,0 +1,69 @@ +package validator + +import "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + +var ( + defaultOptions = &options{ + logger: DefaultLoggerMethod, + shouldFailFast: DefaultDeciderMethod, + } +) + +type options struct { + logger Logger + shouldFailFast Decider +} + +// Option +type Option func(*options) + +func evaluateServerOpt(opts []Option) *options { + optCopy := &options{} + *optCopy = *defaultOptions + for _, o := range opts { + o(optCopy) + } + return optCopy +} + +func evaluateClientOpt(opts []Option) *options { + optCopy := &options{} + *optCopy = *defaultOptions + for _, o := range opts { + o(optCopy) + } + return optCopy +} + +// Logger +type Logger func() (logging.Level, logging.Logger) + +// DefaultLoggerMethod +func DefaultLoggerMethod() (logging.Level, logging.Logger) { + return "", nil +} + +// WithLogger +func WithLogger(logger Logger) Option { + return func(o *options) { + o.logger = logger + } +} + +// Decision +type Decision bool + +// Decider function defines rules for suppressing any interceptor logs. +type Decider func() Decision + +// DefaultDeciderMethod +func DefaultDeciderMethod() Decision { + return false +} + +// WithFailFast +func WithFailFast(d Decider) Option { + return func(o *options) { + o.shouldFailFast = d + } +} diff --git a/interceptors/validator/validator.go b/interceptors/validator/validator.go index 9aceef75b..ae83ba345 100644 --- a/interceptors/validator/validator.go +++ b/interceptors/validator/validator.go @@ -30,23 +30,31 @@ type validatorLegacy interface { Validate() error } -func validate(req interface{}, all bool, l logging.Logger) error { - if all { +func log(level logging.Level, logger logging.Logger, msg string) { + if logger != nil { + logger.Log(level, msg) + } +} + +func validate(req interface{}, d Decider, l Logger) error { + isFailFast := bool(d()) + level, logger := l() + if isFailFast { switch v := req.(type) { case validateAller: if err := v.ValidateAll(); err != nil { - l.Log(logging.ERROR, err.Error()) + log(level, logger, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } case validator: if err := v.Validate(true); err != nil { - l.Log(logging.ERROR, err.Error()) + log(level, logger, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } case validatorLegacy: // Fallback to legacy validator if err := v.Validate(); err != nil { - l.Log(logging.ERROR, err.Error()) + log(level, logger, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } } @@ -55,82 +63,60 @@ func validate(req interface{}, all bool, l logging.Logger) error { switch v := req.(type) { case validatorLegacy: if err := v.Validate(); err != nil { - l.Log(logging.ERROR, err.Error()) + log(level, logger, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } case validator: if err := v.Validate(false); err != nil { - l.Log(logging.ERROR, err.Error()) + log(level, logger, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } } return nil } -// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages. -// -// Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers. -// If `all` is false, the interceptor returns first validation error. Otherwise, the interceptor -// returns ALL validation error as a wrapped multi-error. -// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation -// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. -func UnaryServerInterceptor(all bool, logger logging.Logger) grpc.UnaryServerInterceptor { +func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { + o := evaluateServerOpt(opts) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - if err := validate(req, all, logger); err != nil { + if err := validate(req, o.shouldFailFast, o.logger); err != nil { return nil, err } return handler(ctx, req) } } -// UnaryClientInterceptor returns a new unary client interceptor that validates outgoing messages. -// -// Invalid messages will be rejected with `InvalidArgument` before sending the request to server. -// If `all` is false, the interceptor returns first validation error. Otherwise, the interceptor -// returns ALL validation error as a wrapped multi-error. -// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation -// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. -func UnaryClientInterceptor(all bool, logger logging.Logger) grpc.UnaryClientInterceptor { +func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { + o := evaluateClientOpt(opts) return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - if err := validate(req, all, logger); err != nil { + if err := validate(req, o.shouldFailFast, o.logger); err != nil { return err } return invoker(ctx, method, req, reply, cc, opts...) } } -// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages. -// -// If `all` is false, the interceptor returns first validation error. Otherwise, the interceptor -// returns ALL validation error as a wrapped multi-error. -// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation -// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. -// The stage at which invalid messages will be rejected with `InvalidArgument` varies based on the -// type of the RPC. For `ServerStream` (1:m) requests, it will happen before reaching any userspace -// handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on -// calls to `stream.Recv()`. -func StreamServerInterceptor(all bool, logger logging.Logger) grpc.StreamServerInterceptor { +func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { + o := evaluateServerOpt(opts) return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { wrapper := &recvWrapper{ - all: all, + options: o, ServerStream: stream, - Logger: logger, } + return handler(srv, wrapper) } } type recvWrapper struct { - all bool + *options grpc.ServerStream - logging.Logger } func (s *recvWrapper) RecvMsg(m any) error { if err := s.ServerStream.RecvMsg(m); err != nil { return err } - if err := validate(m, s.all, s.Logger); err != nil { + if err := validate(m, s.shouldFailFast, s.logger); err != nil { return err } return nil diff --git a/interceptors/validator/validator_test.go b/interceptors/validator/validator_test.go index df48ae415..35296337f 100644 --- a/interceptors/validator/validator_test.go +++ b/interceptors/validator/validator_test.go @@ -1,13 +1,12 @@ // Copyright (c) The go-grpc-middleware Authors. // Licensed under the Apache License 2.0. -package validator +package validator_test import ( "io" "testing" - "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" From 233742c080114344407be8038f57dc81bef33efb Mon Sep 17 00:00:00 2001 From: Rohan Raj Date: Fri, 24 Mar 2023 04:25:05 +0530 Subject: [PATCH 3/8] fix: update options args updated args based on review --- interceptors/validator/options.go | 40 +++++++++---------------------- 1 file changed, 11 insertions(+), 29 deletions(-) diff --git a/interceptors/validator/options.go b/interceptors/validator/options.go index eb6a4cfea..fff62b662 100644 --- a/interceptors/validator/options.go +++ b/interceptors/validator/options.go @@ -4,17 +4,18 @@ import "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" var ( defaultOptions = &options{ - logger: DefaultLoggerMethod, - shouldFailFast: DefaultDeciderMethod, + level: "", + logger: nil, + shouldFailFast: false, } ) type options struct { - logger Logger - shouldFailFast Decider + level logging.Level + logger logging.Logger + shouldFailFast bool } -// Option type Option func(*options) func evaluateServerOpt(opts []Option) *options { @@ -35,35 +36,16 @@ func evaluateClientOpt(opts []Option) *options { return optCopy } -// Logger -type Logger func() (logging.Level, logging.Logger) - -// DefaultLoggerMethod -func DefaultLoggerMethod() (logging.Level, logging.Logger) { - return "", nil -} - -// WithLogger -func WithLogger(logger Logger) Option { +// WithLogger tells validator to log all the error. +func WithLogger(level logging.Level, logger logging.Logger) Option { return func(o *options) { o.logger = logger } } -// Decision -type Decision bool - -// Decider function defines rules for suppressing any interceptor logs. -type Decider func() Decision - -// DefaultDeciderMethod -func DefaultDeciderMethod() Decision { - return false -} - -// WithFailFast -func WithFailFast(d Decider) Option { +// WithFailFast tells validator to immediately stop doing further validation after first validation error. +func WithFailFast() Option { return func(o *options) { - o.shouldFailFast = d + o.shouldFailFast = true } } From 4adb59e6e9598352dd95c05835dc5a3bed2bc166 Mon Sep 17 00:00:00 2001 From: Rohan Raj Date: Fri, 24 Mar 2023 04:28:58 +0530 Subject: [PATCH 4/8] refactor: update validate func restructured if statement in-order to make code execution based on shouldFailFast flag more relevant. --- interceptors/validator/validator.go | 81 ++++++----------------------- 1 file changed, 17 insertions(+), 64 deletions(-) diff --git a/interceptors/validator/validator.go b/interceptors/validator/validator.go index ae83ba345..dec237766 100644 --- a/interceptors/validator/validator.go +++ b/interceptors/validator/validator.go @@ -4,9 +4,6 @@ package validator import ( - "context" - - "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -36,88 +33,44 @@ func log(level logging.Level, logger logging.Logger, msg string) { } } -func validate(req interface{}, d Decider, l Logger) error { - isFailFast := bool(d()) - level, logger := l() - if isFailFast { +func validate(req interface{}, shouldFailFast bool, level logging.Level, logger logging.Logger) error { + // shouldFailFast tells validator to immediately stop doing further validation after first validation error. + if shouldFailFast { switch v := req.(type) { - case validateAller: - if err := v.ValidateAll(); err != nil { + case validatorLegacy: + if err := v.Validate(); err != nil { log(level, logger, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } case validator: - if err := v.Validate(true); err != nil { - log(level, logger, err.Error()) - return status.Error(codes.InvalidArgument, err.Error()) - } - case validatorLegacy: - // Fallback to legacy validator - if err := v.Validate(); err != nil { + if err := v.Validate(false); err != nil { log(level, logger, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } } + return nil } + + // shouldNotFailFast tells validator to continue doing further validation even if after a validation error. switch v := req.(type) { - case validatorLegacy: - if err := v.Validate(); err != nil { + case validateAller: + if err := v.ValidateAll(); err != nil { log(level, logger, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } case validator: - if err := v.Validate(false); err != nil { + if err := v.Validate(true); err != nil { log(level, logger, err.Error()) return status.Error(codes.InvalidArgument, err.Error()) } - } - return nil -} - -func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { - o := evaluateServerOpt(opts) - return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - if err := validate(req, o.shouldFailFast, o.logger); err != nil { - return nil, err - } - return handler(ctx, req) - } -} - -func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { - o := evaluateClientOpt(opts) - return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - if err := validate(req, o.shouldFailFast, o.logger); err != nil { - return err - } - return invoker(ctx, method, req, reply, cc, opts...) - } -} - -func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { - o := evaluateServerOpt(opts) - return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { - wrapper := &recvWrapper{ - options: o, - ServerStream: stream, + case validatorLegacy: + // Fallback to legacy validator + if err := v.Validate(); err != nil { + log(level, logger, err.Error()) + return status.Error(codes.InvalidArgument, err.Error()) } - - return handler(srv, wrapper) } -} -type recvWrapper struct { - *options - grpc.ServerStream -} - -func (s *recvWrapper) RecvMsg(m any) error { - if err := s.ServerStream.RecvMsg(m); err != nil { - return err - } - if err := validate(m, s.shouldFailFast, s.logger); err != nil { - return err - } return nil } From eb166eb2f1a59b438b901963955e5831991e5f6c Mon Sep 17 00:00:00 2001 From: Rohan Raj Date: Fri, 24 Mar 2023 04:32:17 +0530 Subject: [PATCH 5/8] refactor: shifted interceptors into new file restructured code in order to separate the concern. ie: in terms of code struct and testcases wise. --- interceptors/validator/interceptors.go | 81 ++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 interceptors/validator/interceptors.go diff --git a/interceptors/validator/interceptors.go b/interceptors/validator/interceptors.go new file mode 100644 index 000000000..cb6603787 --- /dev/null +++ b/interceptors/validator/interceptors.go @@ -0,0 +1,81 @@ +package validator + +import ( + "context" + + "google.golang.org/grpc" +) + +// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages. +// +// Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers. +// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor +// returns ALL validation error as a wrapped multi-error. +// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging. +// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation +// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. +func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { + o := evaluateServerOpt(opts) + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + if err := validate(req, o.shouldFailFast, o.level, o.logger); err != nil { + return nil, err + } + return handler(ctx, req) + } +} + +// UnaryClientInterceptor returns a new unary client interceptor that validates outgoing messages. +// +// Invalid messages will be rejected with `InvalidArgument` before sending the request to server. +// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor +// returns ALL validation error as a wrapped multi-error. +// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging. +// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation +// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. +func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { + o := evaluateClientOpt(opts) + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if err := validate(req, o.shouldFailFast, o.level, o.logger); err != nil { + return err + } + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages. +// +// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor +// returns ALL validation error as a wrapped multi-error. +// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging. +// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation +// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored. +// The stage at which invalid messages will be rejected with `InvalidArgument` varies based on the +// type of the RPC. For `ServerStream` (1:m) requests, it will happen before reaching any userspace +// handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on +// calls to `stream.Recv()`. +func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { + o := evaluateServerOpt(opts) + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + wrapper := &recvWrapper{ + options: o, + ServerStream: stream, + } + + return handler(srv, wrapper) + } +} + +type recvWrapper struct { + *options + grpc.ServerStream +} + +func (s *recvWrapper) RecvMsg(m any) error { + if err := s.ServerStream.RecvMsg(m); err != nil { + return err + } + if err := validate(m, s.shouldFailFast, s.level, s.logger); err != nil { + return err + } + return nil +} From 09bcdb6fff79b9a28fa68032416c77439729fe5b Mon Sep 17 00:00:00 2001 From: Rohan Raj Date: Fri, 24 Mar 2023 04:33:26 +0530 Subject: [PATCH 6/8] test: updated the testcases modified testcases based on the current modifications made in the code base. --- interceptors/validator/interceptors_test.go | 172 ++++++++++++++++++++ interceptors/validator/validator_test.go | 154 +++--------------- 2 files changed, 191 insertions(+), 135 deletions(-) create mode 100644 interceptors/validator/interceptors_test.go diff --git a/interceptors/validator/interceptors_test.go b/interceptors/validator/interceptors_test.go new file mode 100644 index 000000000..3a0c54c1a --- /dev/null +++ b/interceptors/validator/interceptors_test.go @@ -0,0 +1,172 @@ +package validator_test + +import ( + "io" + "testing" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type TestLogger struct{} + +func (l *TestLogger) Log(lvl logging.Level, msg string) {} + +func (l *TestLogger) With(fields ...string) logging.Logger { + return &TestLogger{} +} + +type ValidatorTestSuite struct { + *testpb.InterceptorTestSuite +} + +func (s *ValidatorTestSuite) TestValidPasses_Unary() { + _, err := s.Client.Ping(s.SimpleCtx(), testpb.GoodPing) + assert.NoError(s.T(), err, "no error expected") +} + +func (s *ValidatorTestSuite) TestInvalidErrors_Unary() { + _, err := s.Client.Ping(s.SimpleCtx(), testpb.BadPing) + assert.Error(s.T(), err, "no error expected") + assert.Equal(s.T(), codes.InvalidArgument, status.Code(err), "gRPC status must be InvalidArgument") +} + +func (s *ValidatorTestSuite) TestValidPasses_ServerStream() { + stream, err := s.Client.PingList(s.SimpleCtx(), testpb.GoodPingList) + require.NoError(s.T(), err, "no error on stream establishment expected") + for { + _, err := stream.Recv() + if err == io.EOF { + break + } + assert.NoError(s.T(), err, "no error on messages sent occurred") + } +} + +type ClientValidatorTestSuite struct { + *testpb.InterceptorTestSuite +} + +func (s *ClientValidatorTestSuite) TestValidPasses_Unary() { + _, err := s.Client.Ping(s.SimpleCtx(), testpb.GoodPing) + assert.NoError(s.T(), err, "no error expected") +} + +func (s *ClientValidatorTestSuite) TestInvalidErrors_Unary() { + _, err := s.Client.Ping(s.SimpleCtx(), testpb.BadPing) + assert.Error(s.T(), err, "error expected") + assert.Equal(s.T(), codes.InvalidArgument, status.Code(err), "gRPC status must be InvalidArgument") +} + +func (s *ValidatorTestSuite) TestInvalidErrors_ServerStream() { + stream, err := s.Client.PingList(s.SimpleCtx(), testpb.BadPingList) + require.NoError(s.T(), err, "no error on stream establishment expected") + _, err = stream.Recv() + assert.Error(s.T(), err, "error should be received on first message") + assert.Equal(s.T(), codes.InvalidArgument, status.Code(err), "gRPC status must be InvalidArgument") +} + +func (s *ValidatorTestSuite) TestInvalidErrors_BidiStream() { + stream, err := s.Client.PingStream(s.SimpleCtx()) + require.NoError(s.T(), err, "no error on stream establishment expected") + + require.NoError(s.T(), stream.Send(testpb.GoodPingStream)) + _, err = stream.Recv() + assert.NoError(s.T(), err, "receiving a good ping should return a good pong") + require.NoError(s.T(), stream.Send(testpb.GoodPingStream)) + _, err = stream.Recv() + assert.NoError(s.T(), err, "receiving a good ping should return a good pong") + + require.NoError(s.T(), stream.Send(testpb.BadPingStream)) + _, err = stream.Recv() + assert.Error(s.T(), err, "receiving a bad ping should return a bad pong") + assert.Equal(s.T(), codes.InvalidArgument, status.Code(err), "gRPC status must be InvalidArgument") + + err = stream.CloseSend() + assert.NoError(s.T(), err, "there should be no error closing the stream on send") +} + +func TestValidatorTestSuite(t *testing.T) { + sWithNoArgs := &ValidatorTestSuite{ + InterceptorTestSuite: &testpb.InterceptorTestSuite{ + ServerOpts: []grpc.ServerOption{ + grpc.StreamInterceptor(validator.StreamServerInterceptor()), + grpc.UnaryInterceptor(validator.UnaryServerInterceptor()), + }, + }, + } + suite.Run(t, sWithNoArgs) + + sWithWithFailFastArgs := &ValidatorTestSuite{ + InterceptorTestSuite: &testpb.InterceptorTestSuite{ + ServerOpts: []grpc.ServerOption{ + grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast())), + grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast())), + }, + }, + } + suite.Run(t, sWithWithFailFastArgs) + + sWithWithLoggerArgs := &ValidatorTestSuite{ + InterceptorTestSuite: &testpb.InterceptorTestSuite{ + ServerOpts: []grpc.ServerOption{ + grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithLogger(logging.DEBUG, &TestLogger{}))), + grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithLogger(logging.DEBUG, &TestLogger{}))), + }, + }, + } + suite.Run(t, sWithWithLoggerArgs) + + sAll := &ValidatorTestSuite{ + InterceptorTestSuite: &testpb.InterceptorTestSuite{ + ServerOpts: []grpc.ServerOption{ + grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast(), validator.WithLogger(logging.DEBUG, &TestLogger{}))), + grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast(), validator.WithLogger(logging.DEBUG, &TestLogger{}))), + }, + }, + } + suite.Run(t, sAll) + + csWithNoArgs := &ClientValidatorTestSuite{ + InterceptorTestSuite: &testpb.InterceptorTestSuite{ + ClientOpts: []grpc.DialOption{ + grpc.WithUnaryInterceptor(validator.UnaryClientInterceptor()), + }, + }, + } + suite.Run(t, csWithNoArgs) + + csWithWithFailFastArgs := &ClientValidatorTestSuite{ + InterceptorTestSuite: &testpb.InterceptorTestSuite{ + ServerOpts: []grpc.ServerOption{ + grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast())), + }, + }, + } + suite.Run(t, csWithWithFailFastArgs) + + csWithWithLoggerArgs := &ClientValidatorTestSuite{ + InterceptorTestSuite: &testpb.InterceptorTestSuite{ + ServerOpts: []grpc.ServerOption{ + grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithLogger(logging.DEBUG, &TestLogger{}))), + }, + }, + } + suite.Run(t, csWithWithLoggerArgs) + + csAll := &ClientValidatorTestSuite{ + InterceptorTestSuite: &testpb.InterceptorTestSuite{ + ClientOpts: []grpc.DialOption{ + grpc.WithUnaryInterceptor(validator.UnaryClientInterceptor(validator.WithFailFast())), + }, + }, + } + suite.Run(t, csAll) +} diff --git a/interceptors/validator/validator_test.go b/interceptors/validator/validator_test.go index 35296337f..e0a2e8131 100644 --- a/interceptors/validator/validator_test.go +++ b/interceptors/validator/validator_test.go @@ -1,154 +1,38 @@ // Copyright (c) The go-grpc-middleware Authors. // Licensed under the Apache License 2.0. -package validator_test +package validator import ( - "io" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb" ) -type Logger struct { -} +type TestLogger struct{} -func (l *Logger) Log(lvl logging.Level, msg string) {} +func (l *TestLogger) Log(lvl logging.Level, msg string) {} -func (l *Logger) With(fields ...string) logging.Logger { - return &Logger{} +func (l *TestLogger) With(fields ...string) logging.Logger { + return &TestLogger{} } func TestValidateWrapper(t *testing.T) { - assert.NoError(t, validate(testpb.GoodPing, false, &Logger{})) - assert.Error(t, validate(testpb.BadPing, false, &Logger{})) - assert.NoError(t, validate(testpb.GoodPing, true, &Logger{})) - assert.Error(t, validate(testpb.BadPing, true, &Logger{})) - - assert.NoError(t, validate(testpb.GoodPingError, false, &Logger{})) - assert.Error(t, validate(testpb.BadPingError, false, &Logger{})) - assert.NoError(t, validate(testpb.GoodPingError, true, &Logger{})) - assert.Error(t, validate(testpb.BadPingError, true, &Logger{})) - - assert.NoError(t, validate(testpb.GoodPingResponse, false, &Logger{})) - assert.NoError(t, validate(testpb.GoodPingResponse, true, &Logger{})) - assert.Error(t, validate(testpb.BadPingResponse, false, &Logger{})) - assert.Error(t, validate(testpb.BadPingResponse, true, &Logger{})) -} - -func TestValidatorTestSuite(t *testing.T) { - s := &ValidatorTestSuite{ - InterceptorTestSuite: &testpb.InterceptorTestSuite{ - ServerOpts: []grpc.ServerOption{ - grpc.StreamInterceptor(StreamServerInterceptor(false, &Logger{})), - grpc.UnaryInterceptor(UnaryServerInterceptor(false, &Logger{})), - }, - }, - } - suite.Run(t, s) - sAll := &ValidatorTestSuite{ - InterceptorTestSuite: &testpb.InterceptorTestSuite{ - ServerOpts: []grpc.ServerOption{ - grpc.StreamInterceptor(StreamServerInterceptor(true, &Logger{})), - grpc.UnaryInterceptor(UnaryServerInterceptor(true, &Logger{})), - }, - }, - } - suite.Run(t, sAll) - - cs := &ClientValidatorTestSuite{ - InterceptorTestSuite: &testpb.InterceptorTestSuite{ - ClientOpts: []grpc.DialOption{ - grpc.WithUnaryInterceptor(UnaryClientInterceptor(false, &Logger{})), - }, - }, - } - suite.Run(t, cs) - csAll := &ClientValidatorTestSuite{ - InterceptorTestSuite: &testpb.InterceptorTestSuite{ - ClientOpts: []grpc.DialOption{ - grpc.WithUnaryInterceptor(UnaryClientInterceptor(true, &Logger{})), - }, - }, - } - suite.Run(t, csAll) -} - -type ValidatorTestSuite struct { - *testpb.InterceptorTestSuite -} - -func (s *ValidatorTestSuite) TestValidPasses_Unary() { - _, err := s.Client.Ping(s.SimpleCtx(), testpb.GoodPing) - assert.NoError(s.T(), err, "no error expected") -} - -func (s *ValidatorTestSuite) TestInvalidErrors_Unary() { - _, err := s.Client.Ping(s.SimpleCtx(), testpb.BadPing) - assert.Error(s.T(), err, "no error expected") - assert.Equal(s.T(), codes.InvalidArgument, status.Code(err), "gRPC status must be InvalidArgument") -} - -func (s *ValidatorTestSuite) TestValidPasses_ServerStream() { - stream, err := s.Client.PingList(s.SimpleCtx(), testpb.GoodPingList) - require.NoError(s.T(), err, "no error on stream establishment expected") - for { - _, err := stream.Recv() - if err == io.EOF { - break - } - assert.NoError(s.T(), err, "no error on messages sent occurred") - } -} - -func (s *ValidatorTestSuite) TestInvalidErrors_ServerStream() { - stream, err := s.Client.PingList(s.SimpleCtx(), testpb.BadPingList) - require.NoError(s.T(), err, "no error on stream establishment expected") - _, err = stream.Recv() - assert.Error(s.T(), err, "error should be received on first message") - assert.Equal(s.T(), codes.InvalidArgument, status.Code(err), "gRPC status must be InvalidArgument") -} - -func (s *ValidatorTestSuite) TestInvalidErrors_BidiStream() { - stream, err := s.Client.PingStream(s.SimpleCtx()) - require.NoError(s.T(), err, "no error on stream establishment expected") - - require.NoError(s.T(), stream.Send(testpb.GoodPingStream)) - _, err = stream.Recv() - assert.NoError(s.T(), err, "receiving a good ping should return a good pong") - require.NoError(s.T(), stream.Send(testpb.GoodPingStream)) - _, err = stream.Recv() - assert.NoError(s.T(), err, "receiving a good ping should return a good pong") - - require.NoError(s.T(), stream.Send(testpb.BadPingStream)) - _, err = stream.Recv() - assert.Error(s.T(), err, "receiving a bad ping should return a bad pong") - assert.Equal(s.T(), codes.InvalidArgument, status.Code(err), "gRPC status must be InvalidArgument") - - err = stream.CloseSend() - assert.NoError(s.T(), err, "there should be no error closing the stream on send") -} - -type ClientValidatorTestSuite struct { - *testpb.InterceptorTestSuite -} - -func (s *ClientValidatorTestSuite) TestValidPasses_Unary() { - _, err := s.Client.Ping(s.SimpleCtx(), testpb.GoodPing) - assert.NoError(s.T(), err, "no error expected") -} - -func (s *ClientValidatorTestSuite) TestInvalidErrors_Unary() { - _, err := s.Client.Ping(s.SimpleCtx(), testpb.BadPing) - assert.Error(s.T(), err, "error expected") - assert.Equal(s.T(), codes.InvalidArgument, status.Code(err), "gRPC status must be InvalidArgument") + assert.NoError(t, validate(testpb.GoodPing, false, logging.ERROR, &TestLogger{})) + assert.Error(t, validate(testpb.BadPing, false, logging.ERROR, &TestLogger{})) + assert.NoError(t, validate(testpb.GoodPing, true, logging.ERROR, &TestLogger{})) + assert.Error(t, validate(testpb.BadPing, true, logging.ERROR, &TestLogger{})) + + assert.NoError(t, validate(testpb.GoodPingError, false, logging.ERROR, &TestLogger{})) + assert.Error(t, validate(testpb.BadPingError, false, logging.ERROR, &TestLogger{})) + assert.NoError(t, validate(testpb.GoodPingError, true, logging.ERROR, &TestLogger{})) + assert.Error(t, validate(testpb.BadPingError, true, logging.ERROR, &TestLogger{})) + + assert.NoError(t, validate(testpb.GoodPingResponse, false, logging.ERROR, &TestLogger{})) + assert.NoError(t, validate(testpb.GoodPingResponse, true, logging.ERROR, &TestLogger{})) + assert.Error(t, validate(testpb.BadPingResponse, false, logging.ERROR, &TestLogger{})) + assert.Error(t, validate(testpb.BadPingResponse, true, logging.ERROR, &TestLogger{})) } From 5398fb72b80abf79a35c3bb2ec8c25e2f2c47be5 Mon Sep 17 00:00:00 2001 From: Rohan Raj Date: Fri, 24 Mar 2023 04:47:11 +0530 Subject: [PATCH 7/8] fix: add copyright headers --- interceptors/validator/interceptors.go | 3 +++ interceptors/validator/interceptors_test.go | 3 +++ interceptors/validator/options.go | 3 +++ 3 files changed, 9 insertions(+) diff --git a/interceptors/validator/interceptors.go b/interceptors/validator/interceptors.go index cb6603787..69c8a054a 100644 --- a/interceptors/validator/interceptors.go +++ b/interceptors/validator/interceptors.go @@ -1,3 +1,6 @@ +// Copyright (c) The go-grpc-middleware Authors. +// Licensed under the Apache License 2.0. + package validator import ( diff --git a/interceptors/validator/interceptors_test.go b/interceptors/validator/interceptors_test.go index 3a0c54c1a..f2714a7e3 100644 --- a/interceptors/validator/interceptors_test.go +++ b/interceptors/validator/interceptors_test.go @@ -1,3 +1,6 @@ +// Copyright (c) The go-grpc-middleware Authors. +// Licensed under the Apache License 2.0. + package validator_test import ( diff --git a/interceptors/validator/options.go b/interceptors/validator/options.go index fff62b662..4184bfb15 100644 --- a/interceptors/validator/options.go +++ b/interceptors/validator/options.go @@ -1,3 +1,6 @@ +// Copyright (c) The go-grpc-middleware Authors. +// Licensed under the Apache License 2.0. + package validator import "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" From 5789ad144982a286a1130002e962a44ff15a726d Mon Sep 17 00:00:00 2001 From: Rohan Raj Date: Fri, 24 Mar 2023 04:49:33 +0530 Subject: [PATCH 8/8] fix: update comment and code updated code based on reviews --- interceptors/validator/options.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/interceptors/validator/options.go b/interceptors/validator/options.go index 4184bfb15..c267babfe 100644 --- a/interceptors/validator/options.go +++ b/interceptors/validator/options.go @@ -39,9 +39,10 @@ func evaluateClientOpt(opts []Option) *options { return optCopy } -// WithLogger tells validator to log all the error. +// WithLogger tells validator to log all the validation errors with the given log level. func WithLogger(level logging.Level, logger logging.Logger) Option { return func(o *options) { + o.level = level o.logger = logger } }