diff --git a/private/protocol/xml/xmlutil/unmarshal.go b/private/protocol/xml/xmlutil/unmarshal.go index 49f291a857b..fabb82f8816 100644 --- a/private/protocol/xml/xmlutil/unmarshal.go +++ b/private/protocol/xml/xmlutil/unmarshal.go @@ -112,7 +112,7 @@ func parseStruct(r reflect.Value, node *XMLNode, tag reflect.StructTag) error { if elems == nil { // try to find the field in attributes for _, a := range node.Attr { - if name == a.Name.Local { + if name == strings.Join([]string{a.Name.Space, a.Name.Local}, ":") { // turn this into a text node for de-serializing elems = []*XMLNode{{Text: a.Value}} } diff --git a/private/protocol/xml/xmlutil/xml_to_struct.go b/private/protocol/xml/xmlutil/xml_to_struct.go index 72c198a9d8d..8e0337bd33a 100644 --- a/private/protocol/xml/xmlutil/xml_to_struct.go +++ b/private/protocol/xml/xmlutil/xml_to_struct.go @@ -63,6 +63,8 @@ func XMLToStruct(d *xml.Decoder, s *xml.StartElement) (*XMLNode, error) { return out, e } node.Name = typed.Name + node.Attr = out.Attr + node = adaptNode(node) slice = append(slice, node) out.Children[name] = slice case xml.EndElement: @@ -74,6 +76,26 @@ func XMLToStruct(d *xml.Decoder, s *xml.StartElement) (*XMLNode, error) { return out, nil } +func adaptNode(node *XMLNode) *XMLNode { + ns := map[string]string{} + for _, a := range node.Attr { + if a.Name.Space == "xmlns" { + ns[a.Value] = a.Name.Local + break + } + } + + for i, a := range node.Attr { + if a.Name.Space == "xmlns" { + continue + } + if v, ok := ns[node.Attr[i].Name.Space]; ok { + node.Attr[i].Name.Space = v + } + } + return node +} + // StructToXML writes an XMLNode to a xml.Encoder as tokens. func StructToXML(e *xml.Encoder, node *XMLNode, sorted bool) error { e.EncodeToken(xml.StartElement{Name: node.Name, Attr: node.Attr}) diff --git a/private/protocol/xml/xmlutil/xml_to_struct_test.go b/private/protocol/xml/xmlutil/xml_to_struct_test.go new file mode 100644 index 00000000000..a64a2386073 --- /dev/null +++ b/private/protocol/xml/xmlutil/xml_to_struct_test.go @@ -0,0 +1,53 @@ +package xmlutil_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/awstesting/unit" + "github.com/aws/aws-sdk-go/service/s3" +) + +func TestUnmarshal(t *testing.T) { + xmlVal := []byte(` + foo-iduserfoo-iduserFULL_CONTROL< + /AccessControlPolicy>`) + + var server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(xmlVal) + })) + + sess := unit.Session + sess.Config.Endpoint = &server.URL + sess.Config.S3ForcePathStyle = aws.Bool(true) + svc := s3.New(sess) + + out, err := svc.GetBucketAcl(&s3.GetBucketAclInput{ + Bucket: aws.String("foo"), + }) + + assert.NoError(t, err) + + expected := &s3.GetBucketAclOutput{ + Grants: []*s3.Grant{ + &s3.Grant{ + Grantee: &s3.Grantee{ + DisplayName: aws.String("user"), + ID: aws.String("foo-id"), + Type: aws.String("type"), + }, + Permission: aws.String("FULL_CONTROL"), + }, + }, + + Owner: &s3.Owner{ + DisplayName: aws.String("user"), + ID: aws.String("foo-id"), + }, + } + assert.Equal(t, expected, out) +}