diff --git a/pkg/util/errors.go b/pkg/util/errors.go index a6b890d6ddbe..be82e197eaf2 100644 --- a/pkg/util/errors.go +++ b/pkg/util/errors.go @@ -71,27 +71,13 @@ func (es MultiError) Err() error { // Is tells if all errors are the same as the target error. func (es MultiError) Is(target error) bool { - for _, err := range es { - if !errors.Is(err, target) { - return false - } - } - return true -} - -// IsCancel tells if all errors are either context.Canceled or grpc codes.Canceled. -func (es MultiError) IsCancel() bool { if len(es) == 0 { return false } for _, err := range es { - if errors.Is(err, context.Canceled) { - continue - } - if IsConnCanceled(err) { - continue + if !errors.Is(err, target) { + return false } - return false } return true } diff --git a/pkg/util/server/error.go b/pkg/util/server/error.go index d8040f5ffdaa..d2598e0d27c2 100644 --- a/pkg/util/server/error.go +++ b/pkg/util/server/error.go @@ -56,7 +56,7 @@ func WriteError(err error, w http.ResponseWriter) { ) me, ok := err.(util.MultiError) - if ok && me.IsCancel() { + if ok && me.Is(context.Canceled) { JSONError(w, StatusClientClosedRequest, ErrClientCanceled) return } @@ -68,7 +68,6 @@ func WriteError(err error, w http.ResponseWriter) { s, isRPC := status.FromError(err) switch { case errors.Is(err, context.Canceled) || - (isRPC && s.Code() == codes.Canceled) || (errors.As(err, &promErr) && errors.Is(promErr.Err, context.Canceled)): JSONError(w, StatusClientClosedRequest, ErrClientCanceled) case errors.Is(err, context.DeadlineExceeded) || diff --git a/pkg/util/server/error_test.go b/pkg/util/server/error_test.go index de90b271bf6f..211ae1380374 100644 --- a/pkg/util/server/error_test.go +++ b/pkg/util/server/error_test.go @@ -27,29 +27,57 @@ func Test_writeError(t *testing.T) { name string err error - msg string + expectedMsg string expectedStatus int }{ {"cancelled", context.Canceled, ErrClientCanceled, StatusClientClosedRequest}, {"cancelled multi", util.MultiError{context.Canceled, context.Canceled}, ErrClientCanceled, StatusClientClosedRequest}, - {"rpc cancelled", status.New(codes.Canceled, context.Canceled.Error()).Err(), ErrClientCanceled, StatusClientClosedRequest}, - {"rpc cancelled multi", util.MultiError{status.New(codes.Canceled, context.Canceled.Error()).Err(), status.New(codes.Canceled, context.Canceled.Error()).Err()}, ErrClientCanceled, StatusClientClosedRequest}, - {"mixed context and rpc cancelled", util.MultiError{context.Canceled, status.New(codes.Canceled, context.Canceled.Error()).Err()}, ErrClientCanceled, StatusClientClosedRequest}, - {"mixed context, rpc cancelled and another", util.MultiError{errors.New("standard error"), context.Canceled, status.New(codes.Canceled, context.Canceled.Error()).Err()}, "3 errors: standard error; context canceled; rpc error: code = Canceled desc = context canceled", http.StatusInternalServerError}, + {"rpc cancelled", + status.New(codes.Canceled, context.Canceled.Error()).Err(), + "rpc error: code = Canceled desc = context canceled", + http.StatusInternalServerError}, + {"rpc cancelled multi", + util.MultiError{status.New(codes.Canceled, context.Canceled.Error()).Err(), status.New(codes.Canceled, context.Canceled.Error()).Err()}, + "2 errors: rpc error: code = Canceled desc = context canceled; rpc error: code = Canceled desc = context canceled", + http.StatusInternalServerError}, + {"mixed context and rpc cancelled", + util.MultiError{context.Canceled, status.New(codes.Canceled, context.Canceled.Error()).Err()}, + "2 errors: context canceled; rpc error: code = Canceled desc = context canceled", + http.StatusInternalServerError}, + {"mixed context, rpc cancelled and another", + util.MultiError{errors.New("standard error"), context.Canceled, status.New(codes.Canceled, context.Canceled.Error()).Err()}, + "3 errors: standard error; context canceled; rpc error: code = Canceled desc = context canceled", + http.StatusInternalServerError}, {"cancelled storage", promql.ErrStorage{Err: context.Canceled}, ErrClientCanceled, StatusClientClosedRequest}, {"orgid", user.ErrNoOrgID, user.ErrNoOrgID.Error(), http.StatusBadRequest}, {"deadline", context.DeadlineExceeded, ErrDeadlineExceeded, http.StatusGatewayTimeout}, {"deadline multi", util.MultiError{context.DeadlineExceeded, context.DeadlineExceeded}, ErrDeadlineExceeded, http.StatusGatewayTimeout}, {"rpc deadline", status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err(), ErrDeadlineExceeded, http.StatusGatewayTimeout}, - {"rpc deadline multi", util.MultiError{status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err(), status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err()}, ErrDeadlineExceeded, http.StatusGatewayTimeout}, - {"mixed context and rpc deadline", util.MultiError{context.DeadlineExceeded, status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err()}, ErrDeadlineExceeded, http.StatusGatewayTimeout}, - {"mixed context, rpc deadline and another", util.MultiError{errors.New("standard error"), context.DeadlineExceeded, status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err()}, "3 errors: standard error; context deadline exceeded; rpc error: code = DeadlineExceeded desc = context deadline exceeded", http.StatusInternalServerError}, + {"rpc deadline multi", + util.MultiError{status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err(), + status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err()}, + ErrDeadlineExceeded, + http.StatusGatewayTimeout}, + {"mixed context and rpc deadline", + util.MultiError{context.DeadlineExceeded, status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err()}, + ErrDeadlineExceeded, + http.StatusGatewayTimeout}, + {"mixed context, rpc deadline and another", + util.MultiError{errors.New("standard error"), context.DeadlineExceeded, status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err()}, + "3 errors: standard error; context deadline exceeded; rpc error: code = DeadlineExceeded desc = context deadline exceeded", + http.StatusInternalServerError}, {"parse error", logqlmodel.ParseError{}, "parse error : ", http.StatusBadRequest}, {"httpgrpc", httpgrpc.Errorf(http.StatusBadRequest, errors.New("foo").Error()), "foo", http.StatusBadRequest}, {"internal", errors.New("foo"), "foo", http.StatusInternalServerError}, {"query error", chunk.ErrQueryMustContainMetricName, chunk.ErrQueryMustContainMetricName.Error(), http.StatusBadRequest}, - {"wrapped query error", fmt.Errorf("wrapped: %w", chunk.ErrQueryMustContainMetricName), "wrapped: " + chunk.ErrQueryMustContainMetricName.Error(), http.StatusBadRequest}, - {"multi mixed", util.MultiError{context.Canceled, context.DeadlineExceeded}, "2 errors: context canceled; context deadline exceeded", http.StatusInternalServerError}, + {"wrapped query error", + fmt.Errorf("wrapped: %w", chunk.ErrQueryMustContainMetricName), + "wrapped: " + chunk.ErrQueryMustContainMetricName.Error(), + http.StatusBadRequest}, + {"multi mixed", + util.MultiError{context.Canceled, context.DeadlineExceeded}, + "2 errors: context canceled; context deadline exceeded", + http.StatusInternalServerError}, } { t.Run(tt.name, func(t *testing.T) { rec := httptest.NewRecorder() @@ -58,7 +86,7 @@ func Test_writeError(t *testing.T) { json.NewDecoder(rec.Result().Body).Decode(res) require.Equal(t, tt.expectedStatus, res.Code) require.Equal(t, tt.expectedStatus, rec.Result().StatusCode) - require.Equal(t, tt.msg, res.Message) + require.Equal(t, tt.expectedMsg, res.Message) }) } }