diff --git a/http/mdm.go b/http/mdm.go index 2ef4946..0ef8980 100644 --- a/http/mdm.go +++ b/http/mdm.go @@ -35,13 +35,12 @@ func CheckinHandlerFunc(svc service.Checkin, logger log.Logger) http.HandlerFunc respBytes, err := service.CheckinRequest(svc, mdmReqFromHTTPReq(r), bodyBytes) if err != nil { logger.Info("msg", "check-in request", "err", err) - var decodeError *service.DecodeError - if errors.Is(err, mdm.ErrUnrecognizedMessageType) || errors.As(err, &decodeError) { - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return + httpStatus := http.StatusInternalServerError + var statusErr *service.HTTPStatusError + if errors.As(err, &statusErr) { + httpStatus = statusErr.Status } - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return + http.Error(w, http.StatusText(httpStatus), httpStatus) } w.Write(respBytes) } @@ -59,13 +58,12 @@ func CommandAndReportResultsHandlerFunc(svc service.CommandAndReportResults, log respBytes, err := service.CommandAndReportResultsRequest(svc, mdmReqFromHTTPReq(r), bodyBytes) if err != nil { logger.Info("msg", "command report results", "err", err) - var decodeError *service.DecodeError - if errors.As(err, &decodeError) { - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return + httpStatus := http.StatusInternalServerError + var statusErr *service.HTTPStatusError + if errors.As(err, &statusErr) { + httpStatus = statusErr.Status } - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return + http.Error(w, http.StatusText(httpStatus), httpStatus) } w.Write(respBytes) } diff --git a/service/request.go b/service/request.go index 86d8d81..1325864 100644 --- a/service/request.go +++ b/service/request.go @@ -2,26 +2,34 @@ package service import ( "fmt" + "net/http" "github.com/micromdm/nanomdm/mdm" ) -type DecodeError struct { - Err error +type HTTPStatusError struct { + Status int + Err error } -func (e *DecodeError) Error() string { return "decoding MDM request: " + e.Err.Error() } +func (e *HTTPStatusError) Error() string { + return fmt.Sprintf("HTTP status %d (%s): %v", e.Status, http.StatusText(e.Status), e.Err) +} -func (e *DecodeError) Unwrap() error { return e.Err } +func (e *HTTPStatusError) Unwrap() error { + return e.Err +} -func NewDecodeError(err error) *DecodeError { return &DecodeError{Err: err} } +func NewHTTPStatusError(status int, err error) *HTTPStatusError { + return &HTTPStatusError{Status: status, Err: err} +} // CheckinRequest is a simple adapter that takes the raw check-in bodyBytes // and dispatches to the respective check-in method on svc. func CheckinRequest(svc Checkin, r *mdm.Request, bodyBytes []byte) ([]byte, error) { msg, err := mdm.DecodeCheckin(bodyBytes) if err != nil { - return nil, NewDecodeError(err) + return nil, NewHTTPStatusError(http.StatusBadRequest, fmt.Errorf("decoding check-in: %w", err)) } switch m := msg.(type) { case *mdm.Authenticate: @@ -40,7 +48,7 @@ func CheckinRequest(svc Checkin, r *mdm.Request, bodyBytes []byte) ([]byte, erro err = fmt.Errorf("checkout service: %w", err) } default: - return nil, mdm.ErrUnrecognizedMessageType + return nil, NewHTTPStatusError(http.StatusBadRequest, mdm.ErrUnrecognizedMessageType) } return nil, err } @@ -51,7 +59,7 @@ func CheckinRequest(svc Checkin, r *mdm.Request, bodyBytes []byte) ([]byte, erro func CommandAndReportResultsRequest(svc CommandAndReportResults, r *mdm.Request, bodyBytes []byte) ([]byte, error) { report, err := mdm.DecodeCommandResults(bodyBytes) if err != nil { - return nil, NewDecodeError(err) + return nil, NewHTTPStatusError(http.StatusBadRequest, fmt.Errorf("decoding command results: %w", err)) } cmd, err := svc.CommandAndReportResults(r, report) if err != nil {