Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for decoding event streams #74

Merged
merged 2 commits into from
Jul 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 63 additions & 23 deletions Sources/SotoCodeGeneratorLib/AwsService+shapes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,15 @@ extension AwsService {
}
// check streaming traits
if let payloadMember = payloadMember, let payload = model.shape(for: payloadMember.value.target) {
if payload is BlobShape {
shapeOptions.append("rawPayload")
if payload.hasTrait(type: StreamingTrait.self) {
if isOutput {
if payload is BlobShape || payload.hasTrait(type: StreamingTrait.self) {
shapeOptions.append("rawPayload")
}
} else if isInput {
// currently only support request streaming of blobs
if payload is BlobShape,
payload.hasTrait(type: StreamingTrait.self)
{
shapeOptions.append("allowStreaming")
if !payload.hasTrait(type: RequiresLengthTrait.self),
let operationShape = operationShape,
Expand Down Expand Up @@ -187,11 +193,13 @@ extension AwsService {
if isOutput {
let isResponse = shape.hasTrait(type: SotoResponseShapeTrait.self)
let hasCustomDecode = contexts.members.first { $0.decoding.fromCodable == nil } != nil
let hasNonDecodableElements = contexts.members.first {
$0.decoding.fromHeader != nil || $0.decoding.fromStatusCode != nil || $0.decoding.fromRawPayload == true || $0.decoding.fromEventStream == true
} != nil
decodeContext = .init(
requiresResponse: contexts.members.first {
$0.decoding.fromHeader != nil || $0.decoding.fromStatusCode != nil || $0.decoding.fromRawPayload == true
} != nil,
requiresDecodeInit: isResponse && hasCustomDecode
requiresResponse: hasNonDecodableElements && isResponse,
requiresEvent: hasNonDecodableElements && !isResponse,
requiresDecodeInit: hasCustomDecode
)
}
return StructureContext(
Expand Down Expand Up @@ -230,11 +238,24 @@ extension AwsService {
let isInputShape = shape.hasTrait(type: SotoInputShapeTrait.self)
let sortedMembers = members.map { $0 }.sorted { $0.key.lowercased() < $1.key.lowercased() }
for member in sortedMembers {
guard let targetShape = self.model.shape(for: member.value.target) else { continue }
// member context
let memberContext = self.generateMemberContext(member.value, name: member.key, shapeName: shapeName, typeIsUnion: typeIsUnion, isOutputShape: isOutputShape)
let memberContext = self.generateMemberContext(
member.value,
targetShape: targetShape,
name: member.key,
shapeName: shapeName,
typeIsUnion: typeIsUnion,
isOutputShape: isOutputShape
)
contexts.members.append(memberContext)
// coding key context
if let codingKeyContext = generateCodingKeyContext(member.value, name: member.key, isOutputShape: isOutputShape) {
if let codingKeyContext = generateCodingKeyContext(
member.value,
targetShape: targetShape,
name: member.key,
isOutputShape: isOutputShape
) {
contexts.codingKeys.append(codingKeyContext)
}
// member encoding context. We don't need this for response objects as a custom init(from:) as setup for these
Expand All @@ -260,9 +281,16 @@ extension AwsService {
return contexts
}

func generateMemberContext(_ member: MemberShape, name: String, shapeName: String, typeIsUnion: Bool, isOutputShape: Bool) -> MemberContext {
func generateMemberContext(
_ member: MemberShape,
targetShape: Shape,
name: String,
shapeName: String,
typeIsUnion: Bool,
isOutputShape: Bool
) -> MemberContext {
var required = member.hasTrait(type: RequiredTrait.self) ||
(member.hasTrait(type: HttpPayloadTrait.self) && isOutputShape)
((member.hasTrait(type: HttpPayloadTrait.self) || member.hasTrait(type: EventPayloadTrait.self)) && isOutputShape)
let idempotencyToken = member.hasTrait(type: IdempotencyTokenTrait.self)
let deprecated = member.hasTrait(type: DeprecatedTrait.self)
precondition((required && deprecated) == false, "Member cannot be required and deprecated")
Expand All @@ -281,14 +309,13 @@ extension AwsService {
case .number(let d):
defaultValue = String(format: "%g", d)
case .string(let s):
let shape = self.model.shape(for: member.target)
if let enumShape = shape as? EnumShape {
if let enumShape = targetShape as? EnumShape {
guard let enumCase = self.getEnumCaseFromRawValue(enumShape: enumShape, value: .string(s)) else {
preconditionFailure("Default enum value does not exist")
}
defaultValue = ".\(enumCase.toSwiftEnumCase())"
} else if shape is BlobShape {
if member.hasTrait(type: HttpPayloadTrait.self) == true {
} else if targetShape is BlobShape {
if member.hasTrait(type: HttpPayloadTrait.self) || member.hasTrait(type: EventPayloadTrait.self) {
defaultValue = ".init(string: \"\(s)\")"
} else {
defaultValue = ".data(\"\(s)\".utf8)"
Expand All @@ -297,10 +324,9 @@ extension AwsService {
defaultValue = "\"\(s)\""
}
case .empty:
let shape = self.model.shape(for: member.target)
if shape is ListShape {
if targetShape is ListShape {
defaultValue = "[]"
} else if shape is MapShape {
} else if targetShape is MapShape {
defaultValue = "[:]"
} else {
defaultValue = nil
Expand All @@ -325,8 +351,14 @@ extension AwsService {
memberDecodeContext = .init(fromHeader: headerTrait.value, decodeType: type)
} else if member.hasTrait(type: HttpResponseCodeTrait.self) {
memberDecodeContext = .init(fromStatusCode: true, decodeType: type)
} else if member.hasTrait(type: HttpPayloadTrait.self) {
if model.shape(for: member.target) is BlobShape {
} else if targetShape.hasTrait(type: StreamingTrait.self) {
if targetShape is BlobShape {
memberDecodeContext = .init(fromRawPayload: true, decodeType: type)
} else {
memberDecodeContext = .init(fromEventStream: true, decodeType: type)
}
} else if member.hasTrait(type: HttpPayloadTrait.self) || member.hasTrait(type: EventPayloadTrait.self) {
if targetShape is BlobShape {
memberDecodeContext = .init(fromRawPayload: true, decodeType: type)
} else {
memberDecodeContext = .init(fromPayload: true, decodeType: type)
Expand Down Expand Up @@ -373,7 +405,7 @@ extension AwsService {
let name = isPropertyWrapper ? "_\(name.toSwiftLabelCase())" : name.toSwiftLabelCase()
memberEncoding.append(.init(name: name, location: ".statusCode"))
// if payload and not a blob or shape is an output shape
} else if member.hasTrait(type: HttpPayloadTrait.self),
} else if member.hasTrait(type: HttpPayloadTrait.self) || member.hasTrait(type: EventPayloadTrait.self),
!(model.shape(for: member.target) is BlobShape) || isOutputShape
{
let aliasTrait = member.traits?.first(where: { $0 is AliasTrait }) as? AliasTrait
Expand All @@ -393,9 +425,16 @@ extension AwsService {
return memberEncoding
}

func generateCodingKeyContext(_ member: MemberShape, name: String, isOutputShape: Bool) -> CodingKeysContext? {
func generateCodingKeyContext(
_ member: MemberShape,
targetShape: Shape,
name: String,
isOutputShape: Bool
) -> CodingKeysContext? {
guard isMemberInBody(member, isOutputShape: isOutputShape),
!(member.hasTrait(type: HttpPayloadTrait.self))
!member.hasTrait(type: HttpPayloadTrait.self),
!member.hasTrait(type: EventPayloadTrait.self),
!targetShape.hasTrait(type: StreamingTrait.self)
else {
return nil
}
Expand Down Expand Up @@ -490,6 +529,7 @@ extension AwsService {
guard !alreadyProcessed.contains(shapeId) else { return nil }
guard let shape = model.shape(for: shapeId) else { return nil }
guard !shape.hasTrait(type: EnumTrait.self) else { return nil }
guard !shape.hasTrait(type: StreamingTrait.self) else { return nil }

var requirements: [String: Any] = [:]
if !(shape is EnumShape) {
Expand Down
2 changes: 2 additions & 0 deletions Sources/SotoCodeGeneratorLib/AwsService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ extension AwsService {
var fromHeader: String?
var fromPayload: Bool?
var fromRawPayload: Bool?
var fromEventStream: Bool?
var fromCodable: Bool?
var fromStatusCode: Bool?
var decodeType: String
Expand Down Expand Up @@ -777,6 +778,7 @@ extension AwsService {

struct DecodeContext {
let requiresResponse: Bool
let requiresEvent: Bool
let requiresDecodeInit: Bool
}

Expand Down
7 changes: 6 additions & 1 deletion Sources/SotoCodeGeneratorLib/Smithy+CodeGeneration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,14 @@ extension MemberShape {
return "String"
} else if memberShape is BlobShape {
if self.hasTrait(type: HttpPayloadTrait.self) { return "AWSHTTPBody" }
else if self.hasTrait(type: EventPayloadTrait.self) { return "ByteBuffer" }
return "AWSBase64Data"
} else if memberShape is CollectionShape {
return self.target.shapeName.toSwiftClassCase()
if memberShape.hasTrait(type: StreamingTrait.self) {
return "AWSEventStream<\(self.target.shapeName.toSwiftClassCase())>"
} else {
return self.target.shapeName.toSwiftClassCase()
}
} else if let listShape = memberShape as? ListShape {
return "[\(listShape.member.output(model))]"
} else if let setShape = memberShape as? SetShape {
Expand Down
6 changes: 5 additions & 1 deletion Sources/SotoCodeGeneratorLib/Templates/struct.swift
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,17 @@ extension Templates {
{{#decode.requiresResponse}}
let response = decoder.userInfo[.awsResponse]! as! ResponseDecodingContainer
{{/decode.requiresResponse}}
{{#decode.requiresEvent}}
let response = decoder.userInfo[.awsEvent]! as! EventDecodingContainer
{{/decode.requiresEvent}}
{{^empty(codingKeys)}}
let container = try decoder.container(keyedBy: CodingKeys.self)
{{/empty(codingKeys)}}
{{#members}}{{#decoding}}{{#fromCodable}}
self.{{variable}} = try container.decode{{^propertyWrapper}}{{^required}}IfPresent{{/required}}{{/propertyWrapper}}({{decodeType}}.self, forKey: .{{variable}}){{#propertyWrapper}}.wrappedValue{{/propertyWrapper}}{{/fromCodable}}{{#fromHeader}}
self.{{variable}} = try response.decode{{^required}}IfPresent{{/required}}({{decodeType}}.self, forHeader: "{{.}}"){{/fromHeader}}{{#fromRawPayload}}
self.{{variable}} = response.decodePayload(){{/fromRawPayload}}{{#fromPayload}}
self.{{variable}} = response.decodePayload(){{/fromRawPayload}}{{#fromEventStream}}
self.{{variable}} = response.decodeEventStream(){{/fromEventStream}}{{#fromPayload}}
self.{{variable}} = try .init(from: decoder){{/fromPayload}}{{#fromStatusCode}}
self.{{variable}} = response.decodeStatus(){{/fromStatusCode}}
{{/decoding}}{{/members}}
Expand Down