Skip to content

Commit

Permalink
feat: update interceptor implementation
Browse files Browse the repository at this point in the history
made fast fail and logger as optional args
addition to that instead of providing values
dynamically at the time of initialization made
it more dynamic
  • Loading branch information
rohanraj7316 committed Mar 23, 2023
1 parent 3006ed6 commit 035f67b
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 43 deletions.
69 changes: 69 additions & 0 deletions interceptors/validator/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package validator

import "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"

var (
defaultOptions = &options{
logger: DefaultLoggerMethod,
shouldFailFast: DefaultDeciderMethod,
}
)

type options struct {
logger Logger
shouldFailFast Decider
}

// Option
type Option func(*options)

func evaluateServerOpt(opts []Option) *options {
optCopy := &options{}
*optCopy = *defaultOptions
for _, o := range opts {
o(optCopy)
}
return optCopy
}

func evaluateClientOpt(opts []Option) *options {
optCopy := &options{}
*optCopy = *defaultOptions
for _, o := range opts {
o(optCopy)
}
return optCopy
}

// Logger
type Logger func() (logging.Level, logging.Logger)

// DefaultLoggerMethod
func DefaultLoggerMethod() (logging.Level, logging.Logger) {
return "", nil
}

// WithLogger
func WithLogger(logger Logger) Option {
return func(o *options) {
o.logger = logger
}
}

// Decision
type Decision bool

// Decider function defines rules for suppressing any interceptor logs.
type Decider func() Decision

// DefaultDeciderMethod
func DefaultDeciderMethod() Decision {
return false
}

// WithFailFast
func WithFailFast(d Decider) Option {
return func(o *options) {
o.shouldFailFast = d
}
}
68 changes: 27 additions & 41 deletions interceptors/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,31 @@ type validatorLegacy interface {
Validate() error
}

func validate(req interface{}, all bool, l logging.Logger) error {
if all {
func log(level logging.Level, logger logging.Logger, msg string) {
if logger != nil {
logger.Log(level, msg)
}
}

func validate(req interface{}, d Decider, l Logger) error {
isFailFast := bool(d())
level, logger := l()
if isFailFast {
switch v := req.(type) {
case validateAller:
if err := v.ValidateAll(); err != nil {
l.Log(logging.ERROR, err.Error())
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
case validator:
if err := v.Validate(true); err != nil {
l.Log(logging.ERROR, err.Error())
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
case validatorLegacy:
// Fallback to legacy validator
if err := v.Validate(); err != nil {
l.Log(logging.ERROR, err.Error())
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
}
Expand All @@ -55,82 +63,60 @@ func validate(req interface{}, all bool, l logging.Logger) error {
switch v := req.(type) {
case validatorLegacy:
if err := v.Validate(); err != nil {
l.Log(logging.ERROR, err.Error())
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
case validator:
if err := v.Validate(false); err != nil {
l.Log(logging.ERROR, err.Error())
log(level, logger, err.Error())
return status.Error(codes.InvalidArgument, err.Error())
}
}
return nil
}

// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
//
// Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers.
// If `all` is false, the interceptor returns first validation error. Otherwise, the interceptor
// returns ALL validation error as a wrapped multi-error.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
func UnaryServerInterceptor(all bool, logger logging.Logger) grpc.UnaryServerInterceptor {
func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
o := evaluateServerOpt(opts)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if err := validate(req, all, logger); err != nil {
if err := validate(req, o.shouldFailFast, o.logger); err != nil {
return nil, err
}
return handler(ctx, req)
}
}

// UnaryClientInterceptor returns a new unary client interceptor that validates outgoing messages.
//
// Invalid messages will be rejected with `InvalidArgument` before sending the request to server.
// If `all` is false, the interceptor returns first validation error. Otherwise, the interceptor
// returns ALL validation error as a wrapped multi-error.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
func UnaryClientInterceptor(all bool, logger logging.Logger) grpc.UnaryClientInterceptor {
func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {
o := evaluateClientOpt(opts)
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if err := validate(req, all, logger); err != nil {
if err := validate(req, o.shouldFailFast, o.logger); err != nil {
return err
}
return invoker(ctx, method, req, reply, cc, opts...)
}
}

// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
//
// If `all` is false, the interceptor returns first validation error. Otherwise, the interceptor
// returns ALL validation error as a wrapped multi-error.
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
// The stage at which invalid messages will be rejected with `InvalidArgument` varies based on the
// type of the RPC. For `ServerStream` (1:m) requests, it will happen before reaching any userspace
// handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on
// calls to `stream.Recv()`.
func StreamServerInterceptor(all bool, logger logging.Logger) grpc.StreamServerInterceptor {
func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor {
o := evaluateServerOpt(opts)
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
wrapper := &recvWrapper{
all: all,
options: o,
ServerStream: stream,
Logger: logger,
}

return handler(srv, wrapper)
}
}

type recvWrapper struct {
all bool
*options
grpc.ServerStream
logging.Logger
}

func (s *recvWrapper) RecvMsg(m any) error {
if err := s.ServerStream.RecvMsg(m); err != nil {
return err
}
if err := validate(m, s.all, s.Logger); err != nil {
if err := validate(m, s.shouldFailFast, s.logger); err != nil {
return err
}
return nil
Expand Down
3 changes: 1 addition & 2 deletions interceptors/validator/validator_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
// Copyright (c) The go-grpc-middleware Authors.
// Licensed under the Apache License 2.0.

package validator
package validator_test

import (
"io"
"testing"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
Expand Down

0 comments on commit 035f67b

Please sign in to comment.