diff --git a/primaryHandler.go b/primaryHandler.go index 3d8bae3..05266d3 100644 --- a/primaryHandler.go +++ b/primaryHandler.go @@ -392,7 +392,7 @@ func NewPrimaryHandler(logger *zap.Logger, v *viper.Viper, registry xmetrics.Reg otelmux.WithPropagators(tracing.Propagator()), otelmux.WithTracerProvider(tracing.TracerProvider()), } - router.Use(otelmux.Middleware("mainSpan", otelMuxOptions...), candlelight.EchoFirstTraceNodeInfo(tracing.Propagator(), true), ValidateWRP()) + router.Use(otelmux.Middleware("mainSpan", otelMuxOptions...), candlelight.EchoFirstTraceNodeInfo(tracing.Propagator(), true), ValidateWRP(logger)) router.NotFoundHandler = http.HandlerFunc(func(response http.ResponseWriter, _ *http.Request) { xhttp.WriteError(response, http.StatusBadRequest, "Invalid endpoint") @@ -563,25 +563,39 @@ func validateDeviceID() alice.Chain { }) } -func ValidateWRP() func(http.Handler) http.Handler { +func ValidateWRP(logger *zap.Logger) func(http.Handler) http.Handler { return func(delegate http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if msg, ok := wrpcontext.GetMessage(r.Context()); ok { - validators := wrp.SpecValidators() var err error + var failureError error + var warningErrors error + + validators := wrp.SpecValidators() for _, v := range validators { - err = multierr.Append(err, v.Validate(*msg)) + err = v.Validate(*msg) + if errors.Is(err, wrp.ErrorInvalidMessageEncoding.Err) || errors.Is(err, wrp.ErrorInvalidMessageType.Err) { + failureError = multierr.Append(failureError, err) + } else if errors.Is(err, wrp.ErrorInvalidDestination.Err) || errors.Is(err, wrp.ErrorInvalidSource.Err) { + warningErrors = multierr.Append(warningErrors, err) + } } - if err != nil { + + if warningErrors != nil { + logger.Warn("WRP message validation warnings found", zap.Error(warningErrors)) + } + + if failureError != nil { + logger.Error("WRP message validation failures found", zap.Error(failureError)) + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) fmt.Fprintf( w, `{"code": %d, "message": "%s"}`, http.StatusBadRequest, - fmt.Sprintf("failed to validate WRP message: %s", err), - ) + fmt.Sprintf("failed to validate WRP message: %s", err)) return } }