Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix errors for unions with unit target membershape #3547

Merged
merged 5 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"

[[smithy-rs]]
message = """
Stalled stream protection now supports request upload streams. It is currently off by default, but will be enabled by default in a future release. To enable it now, you can do the following:
Expand Down Expand Up @@ -52,3 +51,9 @@ message = "Stalled stream protection on downloads will now only trigger if the u
references = ["smithy-rs#3485"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "jdisanti"

[[smithy-rs]]
message = "Unions with unit target member shape are now fully supported"
references = ["smithy-rs#2546"]
meta = { "breaking" = false, "tada" = false, "bug" = true, "target" = "all"}
author = "drganjoo"
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.transformers.eventStreamE
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.expectTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isTargetUnit
import software.amazon.smithy.rust.codegen.core.util.toPascalCase

fun RustModule.Companion.eventStreamSerdeModule(): RustModule.LeafModule = private("event_stream_serde")
Expand Down Expand Up @@ -189,12 +190,20 @@ class EventStreamUnmarshallerGenerator(
// Don't attempt to parse the payload for an empty struct. The payload can be empty, or if the model was
// updated since the code was generated, it can have content that would not be understood.
empty -> {
rustTemplate(
"Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName(#{UnionStruct}::builder().build())))",
"Output" to unionSymbol,
"UnionStruct" to symbolProvider.toSymbol(unionStruct),
*codegenScope,
)
if (unionMember.isTargetUnit()) {
rustTemplate(
"Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName))",
"Output" to unionSymbol,
*codegenScope,
)
} else {
rustTemplate(
"Ok(#{UnmarshalledMessage}::Event(#{Output}::$unionMemberName(#{UnionStruct}::builder().build())))",
"Output" to unionSymbol,
"UnionStruct" to symbolProvider.toSymbol(unionStruct),
*codegenScope,
)
}
}

payloadOnly -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.eventStre
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isTargetUnit
import software.amazon.smithy.rust.codegen.core.util.toPascalCase

open class EventStreamMarshallerGenerator(
Expand Down Expand Up @@ -107,7 +108,15 @@ open class EventStreamMarshallerGenerator(
rustBlock("let payload = match input") {
for (member in unionShape.members()) {
val eventType = member.memberName // must be the original name, not the Rust-safe name
rustBlock("Self::Input::${symbolProvider.toMemberName(member)}(inner) => ") {
// Union members targeting the Smithy `Unit` type do not have associated data in the
// Rust enum generated for the type.
val mayHaveInner =
if (!member.isTargetUnit()) {
"(inner)"
} else {
""
}
rustBlock("Self::Input::${symbolProvider.toMemberName(member)}$mayHaveInner => ") {
addStringHeader(":event-type", "${eventType.dq()}.into()")
val target = model.expectShape(member.target, StructureShape::class.java)
renderMarshallEvent(member, target)
Expand Down Expand Up @@ -147,7 +156,15 @@ open class EventStreamMarshallerGenerator(
renderMarshallEventPayload("inner.$memberName", payloadMember, target, serializerFn)
} else if (headerMembers.isEmpty()) {
val serializerFn = serializerGenerator.payloadSerializer(unionMember)
renderMarshallEventPayload("inner", unionMember, eventStruct, serializerFn)
// Union members targeting the Smithy `Unit` type do not have associated data in the
// Rust enum generated for the type. For these, we need to pass the `crate::model::Unit` data type.
val inner =
if (unionMember.isTargetUnit()) {
"crate::model::Unit::builder().build()"
} else {
"inner"
}
renderMarshallEventPayload(inner, unionMember, eventStruct, serializerFn)
} else {
rust("Vec::new()")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.makeMaybeConstrained
import software.amazon.smithy.rust.codegen.core.smithy.makeRustBoxed
import software.amazon.smithy.rust.codegen.core.smithy.traits.RustBoxTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.isTargetUnit
import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.toPascalCase
import software.amazon.smithy.rust.codegen.server.smithy.InlineModuleCreator
Expand Down Expand Up @@ -86,10 +87,16 @@ class UnconstrainedUnionGenerator(
""",
) {
sortedMembers.forEach { member ->
rust(
"${unconstrainedShapeSymbolProvider.toMemberName(member)}(#T),",
unconstrainedShapeSymbolProvider.toSymbol(member),
)
if (member.isTargetUnit()) {
rust(
"${unconstrainedShapeSymbolProvider.toMemberName(member)},",
)
} else {
rust(
"${unconstrainedShapeSymbolProvider.toMemberName(member)}(#T),",
unconstrainedShapeSymbolProvider.toSymbol(member),
)
}
}
}

Expand Down Expand Up @@ -198,65 +205,80 @@ class UnconstrainedUnionGenerator(
withBlock("match value {", "}") {
sortedMembers.forEach { member ->
val memberName = unconstrainedShapeSymbolProvider.toMemberName(member)
withBlockTemplate(
"#{UnconstrainedUnion}::$memberName(unconstrained) => Self::$memberName(",
"),",
"UnconstrainedUnion" to symbol,
) {
if (!member.canReachConstrainedShape(model, symbolProvider)) {
rust("unconstrained")
} else {
val targetShape = model.expectShape(member.target)
val resolveToNonPublicConstrainedType =
targetShape !is StructureShape && targetShape !is UnionShape && !targetShape.hasTrait<EnumTrait>() &&
(!publicConstrainedTypes || !targetShape.isDirectlyConstrained(symbolProvider))

val (unconstrainedVar, boxIt) =
if (member.hasTrait<RustBoxTrait>()) {
"(*unconstrained)" to ".map(Box::new)"
} else {
"unconstrained" to ""
}
val boxErr =
if (member.hasTrait<ConstraintViolationRustBoxTrait>()) {
".map_err(Box::new)"
} else {
""
}

if (resolveToNonPublicConstrainedType) {
val constrainedSymbol =
if (!publicConstrainedTypes && targetShape.isDirectlyConstrained(symbolProvider)) {
codegenContext.constrainedShapeSymbolProvider.toSymbol(targetShape)
} else {
pubCrateConstrainedShapeSymbolProvider.toSymbol(targetShape)
}
rustTemplate(
"""
{
let constrained: #{ConstrainedSymbol} = $unconstrainedVar
.try_into()$boxIt$boxErr
.map_err(Self::Error::${ConstraintViolation(member).name()})?;
constrained.into()
}
""",
"ConstrainedSymbol" to constrainedSymbol,
)
if (member.isTargetUnit()) {
// Unit type within Unions do not have associated data.
rustTemplate(
"""
#{UnconstrainedUnion}::$memberName => Self::$memberName,
""",
"UnconstrainedUnion" to symbol,
)
} else {
withBlockTemplate(
"#{UnconstrainedUnion}::$memberName(unconstrained) => Self::$memberName(",
"),",
"UnconstrainedUnion" to symbol,
) {
if (!member.canReachConstrainedShape(model, symbolProvider)) {
rust("unconstrained")
} else {
rust(
"""
$unconstrainedVar
.try_into()
$boxIt
$boxErr
.map_err(Self::Error::${ConstraintViolation(member).name()})?
""",
)
generateTryFromImplForReachableConstrainedShape(member).invoke(this)
}
}
}
}
}
}
}

private fun generateTryFromImplForReachableConstrainedShape(member: MemberShape) =
writable {
val targetShape = model.expectShape(member.target)
val resolveToNonPublicConstrainedType =
targetShape !is StructureShape && targetShape !is UnionShape && !targetShape.hasTrait<EnumTrait>() &&
(!publicConstrainedTypes || !targetShape.isDirectlyConstrained(symbolProvider))

val (unconstrainedVar, boxIt) =
if (member.hasTrait<RustBoxTrait>()) {
"(*unconstrained)" to ".map(Box::new)"
} else {
"unconstrained" to ""
}
val boxErr =
if (member.hasTrait<ConstraintViolationRustBoxTrait>()) {
".map_err(Box::new)"
} else {
""
}

if (resolveToNonPublicConstrainedType) {
val constrainedSymbol =
if (!publicConstrainedTypes && targetShape.isDirectlyConstrained(symbolProvider)) {
codegenContext.constrainedShapeSymbolProvider.toSymbol(targetShape)
} else {
pubCrateConstrainedShapeSymbolProvider.toSymbol(targetShape)
}
rustTemplate(
"""
{
let constrained: #{ConstrainedSymbol} = $unconstrainedVar
.try_into()$boxIt$boxErr
.map_err(Self::Error::${ConstraintViolation(member).name()})?;
constrained.into()
}
""",
"ConstrainedSymbol" to constrainedSymbol,
)
} else {
rust(
"""
$unconstrainedVar
.try_into()
$boxIt
$boxErr
.map_err(Self::Error::${ConstraintViolation(member).name()})?
""",
)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.server.smithy

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest

class UnionWithUnitTest {
@Test
fun `a constrained union that has a unit member should compile`() {
val model =
"""
${'$'}version: "2"
namespace com.example
use aws.protocols#restJson1
use smithy.framework#ValidationException

@restJson1 @title("Test Service")
service TestService {
version: "0.1",
operations: [
TestOperation
TestSimpleUnionWithUnit
]
}

@http(uri: "/testunit", method: "POST")
operation TestSimpleUnionWithUnit {
input := {
@required
request: SomeUnionWithUnit
}
output := {
result : SomeUnionWithUnit
}
errors: [
ValidationException
]
}

@length(min: 13)
string StringRestricted

union SomeUnionWithUnit {
Option1: Unit
Option2: StringRestricted
}

@http(uri: "/test", method: "POST")
operation TestOperation {
input := { payload: String }
output := {
@httpPayload
events: TestEvent
},
errors: [ValidationException]
}

@streaming
union TestEvent {
KeepAlive: Unit,
Response: TestResponseEvent,
}

structure TestResponseEvent {
data: String
}
""".asSmithyModel()

// Ensure the generated SDK compiles.
serverIntegrationTest(model) { _, _ -> }
}
}