Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce constraints for unnamed enums #3884

Merged
merged 10 commits into from
Nov 14, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,37 @@ data class InfallibleEnumType(
)
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
drganjoo marked this conversation as resolved.
Show resolved Hide resolved
writable {
rustTemplate(
"""
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
fn from(s: T) -> Self {
${context.enumName}(s.as_ref().to_owned())
drganjoo marked this conversation as resolved.
Show resolved Hide resolved
}
}
""",
*preludeScope,
)
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// Add an infallible FromStr implementation for uniformity
rustTemplate(
"""
impl ::std::str::FromStr for ${context.enumName} {
type Err = ::std::convert::Infallible;

fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
#{Ok}(${context.enumName}::from(s))
}
}
""",
*preludeScope,
)
}

override fun additionalEnumImpls(context: EnumGeneratorContext): Writable =
writable {
// `try_parse` isn't needed for unnamed enums
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ abstract class EnumType {
/** Returns a writable that implements `FromStr` for the enum */
abstract fun implFromStr(context: EnumGeneratorContext): Writable

/** Returns a writable that implements `From<&str>` and/or `TryFrom<&str>` for the unnamed enum */
abstract fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable

/** Returns a writable that implements `FromStr` for the unnamed enum */
abstract fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable

/** Optionally adds additional documentation to the `enum` docs */
open fun additionalDocs(context: EnumGeneratorContext): Writable = writable {}

Expand Down Expand Up @@ -237,32 +243,10 @@ open class EnumGenerator(
rust("&self.0")
},
)

// Add an infallible FromStr implementation for uniformity
rustTemplate(
"""
impl ::std::str::FromStr for ${context.enumName} {
type Err = ::std::convert::Infallible;

fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
#{Ok}(${context.enumName}::from(s))
}
}
""",
*preludeScope,
)

rustTemplate(
"""
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
fn from(s: T) -> Self {
${context.enumName}(s.as_ref().to_owned())
}
}

""",
*preludeScope,
)
// impl From<str> for Blah { ... }
enumType.implFromForStrForUnnamedEnum(context)(this)
// impl FromStr for Blah { ... }
enumType.implFromStrForUnnamedEnum(context)(this)
}

private fun RustWriter.renderEnum() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,16 @@ class EnumGeneratorTest {
// intentional no-op
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// intentional no-op
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// intentional no-op
}

override fun additionalEnumMembers(context: EnumGeneratorContext): Writable =
writable {
rust("// additional enum members")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.util.dq

object TestEnumType : EnumType() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, I wonder why we have this class at all i.e. why we don't test directly against InfallibleEnumType. It feels wrong to copy over the implementations from the "real" classes to this class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you on this. I'll see if it can be easily fixed in this PR, otherwise will raise another one for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InfallibleEnumType allows UnknownVariants, whereas TestEnumType panics when an unknown variant is encountered. I tried composing TestEnumType to use InfallibleEnumType internally, but InfallibleEnumType is defined in codegen-client, while TestEnumType is in codegen-core.

Expand Down Expand Up @@ -49,4 +50,35 @@ object TestEnumType : EnumType() {
""",
)
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
rustTemplate(
"""
impl<T> #{From}<T> for ${context.enumName} where T: #{AsRef}<str> {
fn from(s: T) -> Self {
${context.enumName}(s.as_ref().to_owned())
}
}
""",
*preludeScope,
)
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
writable {
// Add an infallible FromStr implementation for uniformity
rustTemplate(
"""
impl ::std::str::FromStr for ${context.enumName} {
type Err = ::std::convert::Infallible;

fn from_str(s: &str) -> #{Result}<Self, <Self as ::std::str::FromStr>::Err> {
#{Ok}(${context.enumName}::from(s))
}
}
""",
*preludeScope,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
package software.amazon.smithy.rust.codegen.server.smithy.generators

import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
Expand Down Expand Up @@ -39,16 +38,14 @@ open class ConstrainedEnum(
}
private val constraintViolationSymbol = constraintViolationSymbolProvider.toSymbol(shape)
private val constraintViolationName = constraintViolationSymbol.name
private val codegenScope =
arrayOf(
"String" to RuntimeType.String,
)

override fun implFromForStr(context: EnumGeneratorContext): Writable =
writable {
withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) {
rustTemplate(
"""
private fun generateConstraintViolation(
context: EnumGeneratorContext,
generateTryFromStrAndString: RustWriter.(EnumGeneratorContext) -> Unit,
) = writable {
withInlineModule(constraintViolationSymbol.module(), codegenContext.moduleDocProvider) {
rustTemplate(
"""
##[derive(Debug, PartialEq)]
pub struct $constraintViolationName(pub(crate) #{String});

Expand All @@ -60,47 +57,86 @@ open class ConstrainedEnum(

impl #{Error} for $constraintViolationName {}
""",
*codegenScope,
"Error" to RuntimeType.StdError,
"Display" to RuntimeType.Display,
)
*preludeScope,
"Error" to RuntimeType.StdError,
"Display" to RuntimeType.Display,
)

if (shape.isReachableFromOperationInput()) {
rustTemplate(
"""
if (shape.isReachableFromOperationInput()) {
rustTemplate(
"""
impl $constraintViolationName {
#{EnumShapeConstraintViolationImplBlock:W}
}
""",
"EnumShapeConstraintViolationImplBlock" to
validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock(
context.enumTrait,
),
)
}
"EnumShapeConstraintViolationImplBlock" to
validationExceptionConversionGenerator.enumShapeConstraintViolationImplBlock(
context.enumTrait,
),
)
}
rustBlock("impl #T<&str> for ${context.enumName}", RuntimeType.TryFrom) {
rust("type Error = #T;", constraintViolationSymbol)
rustBlockTemplate("fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error>", *preludeScope) {
rustBlock("match s") {
context.sortedMembers.forEach { member ->
rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),")
}

generateTryFromStrAndString(context)
}

override fun implFromForStr(context: EnumGeneratorContext): Writable =
generateConstraintViolation(context) {
rustTemplate(
"""
impl #{TryFrom}<&str> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error> {
match s {
#{MatchArms}
_ => Err(#{ConstraintViolation}(s.to_owned()))
}
rust("_ => Err(#T(s.to_owned()))", constraintViolationSymbol)
}
}
}
impl #{TryFrom}<#{String}> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: #{String}) -> #{Result}<Self, <Self as #{TryFrom}<#{String}>>::Error> {
s.as_str().try_into()
}
}
""",
*preludeScope,
"ConstraintViolation" to constraintViolationSymbol,
"MatchArms" to
writable {
context.sortedMembers.forEach { member ->
rust("${member.value.dq()} => Ok(${context.enumName}::${member.derivedName()}),")
}
},
)
}

override fun implFromForStrForUnnamedEnum(context: EnumGeneratorContext): Writable =
generateConstraintViolation(context) {
rustTemplate(
"""
impl #{TryFrom}<&str> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: &str) -> #{Result}<Self, <Self as #{TryFrom}<&str>>::Error> {
s.to_owned().try_into()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is first converting to a heap-allocated String and only then matching on the enum values. So invalid enum values would unnecessarily be heap-allocated. We should make TryFrom<String> and TryFrom<&str> instead delegate to FromStr, which should only heap-allocate when the enum value is valid.

Copy link
Contributor Author

@drganjoo drganjoo Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike named enums, which can use &str for comparison and directly return an enum variant, unnamed enums need to store an owned, heap-allocated String. The TryFrom<String> implementation already receives an owned String as a parameter, so delegating to FromStr would result in an unnecessary heap allocation. Additionally, ConstraintViolation takes ownership of a heap-allocated String. Therefore, both code paths (valid and invalid enum values) require a heap-allocated String. Calling to_owned within TryFrom<&str> only shifts the allocation slightly earlier without adding any additional allocation.

}
}
impl #{TryFrom}<#{String}> for ${context.enumName} {
type Error = #{ConstraintViolation};
fn try_from(s: #{String}) -> #{Result}<Self, <Self as #{TryFrom}<#{String}>>::Error> {
s.as_str().try_into()
match s.as_str() {
#{Values} => Ok(Self(s)),
_ => Err(#{ConstraintViolation}(s))
}
}
}
""",
*preludeScope,
"ConstraintViolation" to constraintViolationSymbol,
"Values" to
writable {
rust(context.sortedMembers.joinToString(" | ") { it.value.dq() })
},
)
}

Expand All @@ -118,6 +154,8 @@ open class ConstrainedEnum(
"ConstraintViolation" to constraintViolationSymbol,
)
}

override fun implFromStrForUnnamedEnum(context: EnumGeneratorContext) = implFromStr(context)
}

class ServerEnumGenerator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.AbstractTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.unitTest
import software.amazon.smithy.rust.codegen.core.util.lookup
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider
import java.io.File

Expand Down Expand Up @@ -219,4 +223,59 @@ class ConstraintsTest {
structWithInnerDefault.canReachConstrainedShape(model, symbolProvider) shouldBe false
primitiveBoolean.isDirectlyConstrained(symbolProvider) shouldBe false
}

@Test
fun `unnamed and named enums should validate and have an associated ConstraintViolation error type`() {
val model =
"""
namespace test
use aws.protocols#restJson1
use smithy.framework#ValidationException

@restJson1
service SampleService {
operations: [SampleOp]
}

@http(uri: "/dailySummary", method: "POST")
operation SampleOp {
input := {
unnamedDay: UnnamedDayOfWeek
namedDay: DayOfWeek
}
errors: [ValidationException]
}
@enum([
{ value: "MONDAY" },
{ value: "TUESDAY" }
])
string UnnamedDayOfWeek
@enum([
{ value: "MONDAY", name: "MONDAY" },
{ value: "TUESDAY", name: "TUESDAY" }
])
string DayOfWeek
""".asSmithyModel(smithyVersion = "2")

serverIntegrationTest(
model,
IntegrationTestParams(
service = "test#SampleService",
),
) { _, crate ->
crate.unitTest("value_should_be_validated") {
rustTemplate(
"""
let x: Result<crate::model::DayOfWeek, crate::model::day_of_week::ConstraintViolation> =
"Friday".try_into();
assert!(x.is_err());

let x: Result<crate::model::UnnamedDayOfWeek, crate::model::unnamed_day_of_week::ConstraintViolation> =
"Friday".try_into();
assert!(x.is_err());
""",
)
}
}
}
}