Skip to content

Commit

Permalink
Fix GetBatchResponseById error deserialization (#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-millin authored Feb 5, 2024
1 parent 4bd16ac commit b17ae64
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 11 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

## [1.0.2] - 2023-12-01

### Changed

- Fixed a bug where GetBatchResponseById failed to deserialize error response bodies.

## [1.0.1] - 2023-11-24

### Changed
Expand Down
50 changes: 50 additions & 0 deletions batch_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,56 @@ func TestGetResponseByIdFailedRequest(t *testing.T) {
assert.Equal(t, "The server returned an unexpected status code and no error factory is registered for this code: 401", err.Error())
}

func TestGetErrorResponseBodyById(t *testing.T) {
var jsonBlob = `{
"responses": [{
"id": "3",
"status": 400,
"headers": {
"Content-Type": "application/json"
},
"body": {
"error": {
"code": "ExtensionError",
"message": "Exception: [Status Code: BadRequest; Reason: Boom]",
"innerError": {
"request-id": "123"
}
}
}
}]
}`

errorMapping := abstractions.ErrorMappings{
"4XX": internal.CreateSampleErrorFromDiscriminatorValue,
"5XX": internal.CreateSampleErrorFromDiscriminatorValue,
}
err := RegisterError("Userable", errorMapping)
assert.NoError(t, err)

mockServer := makeMockRequest(200, jsonBlob)
defer mockServer.Close()

mockPath := mockServer.URL + "/$batch"
reqAdapter.SetBaseUrl(mockPath)

reqInfo := getRequestInfo()
batch := NewBatchRequest(reqAdapter)
_, err = batch.AddBatchRequestStep(*reqInfo)
require.NoError(t, err)

resp, err := batch.Send(context.Background(), reqAdapter)
require.NoError(t, err)

_, err = GetBatchResponseById[Userable](resp, "3", CreateUser)
serr := &internal.SampleError{}
assert.ErrorAs(t, err, &serr)
assert.Equal(t, "Exception: [Status Code: BadRequest; Reason: Boom]", serr.Message)

err = DeRegisterError("Userable")
require.NoError(t, err)
}

func TestGetResponseByIdFailedRequestWithFactory(t *testing.T) {
mockServer := makeMockRequest(200, getDummyJSON())
defer mockServer.Close()
Expand Down
29 changes: 20 additions & 9 deletions batch_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/gob"
"encoding/json"
"errors"
nethttplibrary "github.com/microsoft/kiota-http-go"
"net/url"
"reflect"
"strconv"
Expand All @@ -17,6 +16,7 @@ import (
abstractions "github.com/microsoft/kiota-abstractions-go"
"github.com/microsoft/kiota-abstractions-go/serialization"
absser "github.com/microsoft/kiota-abstractions-go/serialization"
nethttplibrary "github.com/microsoft/kiota-http-go"
)

const BatchRequestErrorRegistryKey = "BATCH_REQUEST_ERROR_REGISTRY_KEY"
Expand Down Expand Up @@ -210,13 +210,24 @@ func getRootParseNode(responseItem BatchItem) (absser.ParseNode, error) {
if contentType == "" {
return nil, nil
}
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(responseItem.GetBody())
if err != nil {
return nil, err

var (
content []byte
err error
)
if contentType == jsonContentType {
if content, err = json.Marshal(responseItem.GetBody()); err != nil {
return nil, err
}
} else {
var buf bytes.Buffer
if err = gob.NewEncoder(&buf).Encode(responseItem.GetBody()); err != nil {
return nil, err
}
content = buf.Bytes()
}
return serialization.DefaultParseNodeFactoryInstance.GetRootParseNode(contentType, buf.Bytes())

return serialization.DefaultParseNodeFactoryInstance.GetRootParseNode(contentType, content)
}

func throwErrors(responseItem BatchItem, typeName string) error {
Expand Down Expand Up @@ -266,7 +277,7 @@ func GetBatchResponseById[T serialization.Parsable](resp BatchResponse, itemId s
item := resp.GetResponseById(itemId)

if *item.GetStatus() >= 400 {
return res, throwErrors(item, reflect.TypeOf(new(T)).Name())
return res, throwErrors(item, reflect.TypeOf(new(T)).Elem().Name())
}

jsonStr, err := json.Marshal(item.GetBody())
Expand All @@ -282,7 +293,7 @@ func GetBatchResponseById[T serialization.Parsable](resp BatchResponse, itemId s
}

result, err := parseNode.GetObjectValue(constructor)
return result.(T), nil
return result.(T), err
}

func getErrorMapper(key string) abstractions.ErrorMappings {
Expand Down
17 changes: 15 additions & 2 deletions internal/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,21 @@ func (s SampleError) Serialize(writer serialization.SerializationWriter) error {
return nil
}

func (s SampleError) GetFieldDeserializers() map[string]func(serialization.ParseNode) error {
return make(map[string]func(serialization.ParseNode) error)
func (s *SampleError) GetFieldDeserializers() map[string]func(serialization.ParseNode) error {
res := make(map[string]func(serialization.ParseNode) error)
res["error"] = func(n serialization.ParseNode) error {
v, err := n.GetRawValue()
if err != nil {
return err
}
if vm, ok := v.(map[string]interface{}); ok {
if msg, ok := vm["message"]; ok && msg != nil {
s.Message = *msg.(*string)
}
}
return nil
}
return res
}

func CreateSampleErrorFromDiscriminatorValue(parseNode serialization.ParseNode) (serialization.Parsable, error) {
Expand Down

0 comments on commit b17ae64

Please sign in to comment.