From b17ae64b5ec7c9db59fc00e4c66f862b075c8f08 Mon Sep 17 00:00:00 2001 From: eric-millin <110051399+eric-millin@users.noreply.github.com> Date: Mon, 5 Feb 2024 06:24:24 -0500 Subject: [PATCH] Fix GetBatchResponseById error deserialization (#249) --- CHANGELOG.md | 6 ++++++ batch_request_test.go | 50 +++++++++++++++++++++++++++++++++++++++++++ batch_requests.go | 29 +++++++++++++++++-------- internal/errors.go | 17 +++++++++++++-- 4 files changed, 91 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fe7a91..f79c7bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +## [1.0.2] - 2023-12-01 + +### Changed + +- Fixed a bug where GetBatchResponseById failed to deserialize error response bodies. + ## [1.0.1] - 2023-11-24 ### Changed diff --git a/batch_request_test.go b/batch_request_test.go index 4f3c5dd..99680c1 100644 --- a/batch_request_test.go +++ b/batch_request_test.go @@ -308,6 +308,56 @@ func TestGetResponseByIdFailedRequest(t *testing.T) { assert.Equal(t, "The server returned an unexpected status code and no error factory is registered for this code: 401", err.Error()) } +func TestGetErrorResponseBodyById(t *testing.T) { + var jsonBlob = `{ + "responses": [{ + "id": "3", + "status": 400, + "headers": { + "Content-Type": "application/json" + }, + "body": { + "error": { + "code": "ExtensionError", + "message": "Exception: [Status Code: BadRequest; Reason: Boom]", + "innerError": { + "request-id": "123" + } + } + } + }] + }` + + errorMapping := abstractions.ErrorMappings{ + "4XX": internal.CreateSampleErrorFromDiscriminatorValue, + "5XX": internal.CreateSampleErrorFromDiscriminatorValue, + } + err := RegisterError("Userable", errorMapping) + assert.NoError(t, err) + + mockServer := makeMockRequest(200, jsonBlob) + defer mockServer.Close() + + mockPath := mockServer.URL + "/$batch" + reqAdapter.SetBaseUrl(mockPath) + + reqInfo := getRequestInfo() + batch := NewBatchRequest(reqAdapter) + _, err = batch.AddBatchRequestStep(*reqInfo) + require.NoError(t, err) + + resp, err := batch.Send(context.Background(), reqAdapter) + require.NoError(t, err) + + _, err = GetBatchResponseById[Userable](resp, "3", CreateUser) + serr := &internal.SampleError{} + assert.ErrorAs(t, err, &serr) + assert.Equal(t, "Exception: [Status Code: BadRequest; Reason: Boom]", serr.Message) + + err = DeRegisterError("Userable") + require.NoError(t, err) +} + func TestGetResponseByIdFailedRequestWithFactory(t *testing.T) { mockServer := makeMockRequest(200, getDummyJSON()) defer mockServer.Close() diff --git a/batch_requests.go b/batch_requests.go index 1fc5a50..2d35bef 100644 --- a/batch_requests.go +++ b/batch_requests.go @@ -6,7 +6,6 @@ import ( "encoding/gob" "encoding/json" "errors" - nethttplibrary "github.com/microsoft/kiota-http-go" "net/url" "reflect" "strconv" @@ -17,6 +16,7 @@ import ( abstractions "github.com/microsoft/kiota-abstractions-go" "github.com/microsoft/kiota-abstractions-go/serialization" absser "github.com/microsoft/kiota-abstractions-go/serialization" + nethttplibrary "github.com/microsoft/kiota-http-go" ) const BatchRequestErrorRegistryKey = "BATCH_REQUEST_ERROR_REGISTRY_KEY" @@ -210,13 +210,24 @@ func getRootParseNode(responseItem BatchItem) (absser.ParseNode, error) { if contentType == "" { return nil, nil } - var buf bytes.Buffer - enc := gob.NewEncoder(&buf) - err := enc.Encode(responseItem.GetBody()) - if err != nil { - return nil, err + + var ( + content []byte + err error + ) + if contentType == jsonContentType { + if content, err = json.Marshal(responseItem.GetBody()); err != nil { + return nil, err + } + } else { + var buf bytes.Buffer + if err = gob.NewEncoder(&buf).Encode(responseItem.GetBody()); err != nil { + return nil, err + } + content = buf.Bytes() } - return serialization.DefaultParseNodeFactoryInstance.GetRootParseNode(contentType, buf.Bytes()) + + return serialization.DefaultParseNodeFactoryInstance.GetRootParseNode(contentType, content) } func throwErrors(responseItem BatchItem, typeName string) error { @@ -266,7 +277,7 @@ func GetBatchResponseById[T serialization.Parsable](resp BatchResponse, itemId s item := resp.GetResponseById(itemId) if *item.GetStatus() >= 400 { - return res, throwErrors(item, reflect.TypeOf(new(T)).Name()) + return res, throwErrors(item, reflect.TypeOf(new(T)).Elem().Name()) } jsonStr, err := json.Marshal(item.GetBody()) @@ -282,7 +293,7 @@ func GetBatchResponseById[T serialization.Parsable](resp BatchResponse, itemId s } result, err := parseNode.GetObjectValue(constructor) - return result.(T), nil + return result.(T), err } func getErrorMapper(key string) abstractions.ErrorMappings { diff --git a/internal/errors.go b/internal/errors.go index ffe2e0d..fc707a3 100644 --- a/internal/errors.go +++ b/internal/errors.go @@ -15,8 +15,21 @@ func (s SampleError) Serialize(writer serialization.SerializationWriter) error { return nil } -func (s SampleError) GetFieldDeserializers() map[string]func(serialization.ParseNode) error { - return make(map[string]func(serialization.ParseNode) error) +func (s *SampleError) GetFieldDeserializers() map[string]func(serialization.ParseNode) error { + res := make(map[string]func(serialization.ParseNode) error) + res["error"] = func(n serialization.ParseNode) error { + v, err := n.GetRawValue() + if err != nil { + return err + } + if vm, ok := v.(map[string]interface{}); ok { + if msg, ok := vm["message"]; ok && msg != nil { + s.Message = *msg.(*string) + } + } + return nil + } + return res } func CreateSampleErrorFromDiscriminatorValue(parseNode serialization.ParseNode) (serialization.Parsable, error) {