Skip to content

Commit

Permalink
feat: add directed codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewFossAWS authored and syall committed Mar 27, 2024
1 parent e42afe5 commit 6349ceb
Show file tree
Hide file tree
Showing 21 changed files with 549 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ class CodegenVisitor(context: PluginContext) : ShapeVisitor.Default<Void>() {
}
}

baseGenerationContext = GenerationContext(model, symbolProvider, settings, protocolGenerator, integrations)
baseGenerationContext = GenerationContext(
model, symbolProvider, settings, fileManifest, protocolGenerator, integrations
)

protocolContext = protocolGenerator?.let { ProtocolGenerator.GenerationContext(settings, model, service, symbolProvider, integrations, it.protocol, writers) }
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

package software.amazon.smithy.swift.codegen

import software.amazon.smithy.build.PluginContext
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.codegen.core.directed.CreateContextDirective
import software.amazon.smithy.codegen.core.directed.CreateSymbolProviderDirective
import software.amazon.smithy.codegen.core.directed.DirectedCodegen
import software.amazon.smithy.codegen.core.directed.GenerateEnumDirective
import software.amazon.smithy.codegen.core.directed.GenerateErrorDirective
import software.amazon.smithy.codegen.core.directed.GenerateIntEnumDirective
import software.amazon.smithy.codegen.core.directed.GenerateServiceDirective
import software.amazon.smithy.codegen.core.directed.GenerateStructureDirective
import software.amazon.smithy.codegen.core.directed.GenerateUnionDirective
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.ServiceIndex
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.traits.SensitiveTrait
import software.amazon.smithy.swift.codegen.core.GenerationContext
import software.amazon.smithy.swift.codegen.integration.CustomDebugStringConvertibleGenerator
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.SwiftIntegration
import software.amazon.smithy.swift.codegen.model.hasTrait
import java.util.logging.Logger

class DirectedSwiftCodegen(val context: PluginContext) :
DirectedCodegen<GenerationContext, SwiftSettings, SwiftIntegration> {
private val LOGGER = Logger.getLogger(javaClass.name)

override fun createSymbolProvider(directive: CreateSymbolProviderDirective<SwiftSettings>): SymbolProvider {
return SwiftSymbolProvider(directive.model(), directive.settings())
}

override fun createContext(directive: CreateContextDirective<SwiftSettings, SwiftIntegration>): GenerationContext {
val model = directive.model()
val service = directive.service()
val settings = directive.settings()
val integrations = directive.integrations()

val protocolGenerator = resolveProtocolGenerator(integrations, model, service, settings)

for (integration in integrations) {
integration.serviceErrorProtocolSymbol()?.let {
protocolGenerator?.serviceErrorProtocolSymbol = it
}
}

return GenerationContext(
directive.model(),
directive.symbolProvider(),
directive.settings(),
directive.fileManifest(),
protocolGenerator,
directive.integrations()
)
}

override fun generateService(directive: GenerateServiceDirective<GenerationContext, SwiftSettings>) {
val service = directive.service()
val settings = directive.settings()
val symbolProvider = directive.symbolProvider()
val context = directive.context()
val model = directive.model()
val integrations = context.integrations
val fileManifest = context.fileManifest
val writers = context.writerDelegator()

LOGGER.info("Generating Swift client for service ${directive.settings().service}")

var shouldGenerateTestTarget = false
context.protocolGenerator?.apply {
val ctx = ProtocolGenerator.GenerationContext(settings, model, service, symbolProvider, integrations, this.protocol, writers)
LOGGER.info("[${service.id}] Generating serde for protocol ${this.protocol}")
generateSerializers(ctx)
generateDeserializers(ctx)
generateMessageMarshallable(ctx)
generateMessageUnmarshallable(ctx)
generateCodableConformanceForNestedTypes(ctx)

initializeMiddleware(ctx)

LOGGER.info("[${service.id}] Generating unit tests for protocol ${this.protocol}")
val numProtocolUnitTestsGenerated = generateProtocolUnitTests(ctx)
shouldGenerateTestTarget = (numProtocolUnitTestsGenerated > 0)

LOGGER.info("[${service.id}] Generating service client for protocol ${this.protocol}")

generateProtocolClient(ctx)

integrations.forEach { it.writeAdditionalFiles(context, ctx, writers) }
}

println("Flushing swift writers")
val dependencies = writers.dependencies
writers.flushWriters()

println("Generating package manifest file")
writePackageManifest(settings, fileManifest, dependencies, shouldGenerateTestTarget)
}

override fun generateStructure(directive: GenerateStructureDirective<GenerationContext, SwiftSettings>) {
val shape = directive.shape()
val context = directive.context()
val model = directive.model()
val settings = directive.settings()
val symbolProvider = directive.symbolProvider()
val protocolGenerator = context.protocolGenerator
val writers = context.writerDelegator()

writers.useShapeWriter(shape) { writer: SwiftWriter ->
StructureGenerator(model, symbolProvider, writer, shape, settings, protocolGenerator?.serviceErrorProtocolSymbol).render()
}

if (shape.hasTrait<SensitiveTrait>() || shape.members().any { it.hasTrait<SensitiveTrait>() || model.expectShape(it.target).hasTrait<SensitiveTrait>() }) {
writers.useShapeExtensionWriter(shape, "CustomDebugStringConvertible") { writer: SwiftWriter ->
CustomDebugStringConvertibleGenerator(symbolProvider, writer, shape, model).render()
}
}
}

override fun generateError(directive: GenerateErrorDirective<GenerationContext, SwiftSettings>) {
val shape = directive.shape()
val context = directive.context()
val model = directive.model()
val settings = directive.settings()
val symbolProvider = directive.symbolProvider()
val protocolGenerator = context.protocolGenerator
val writers = context.writerDelegator()

writers.useShapeWriter(shape) { writer: SwiftWriter ->
StructureGenerator(model, symbolProvider, writer, shape, settings, protocolGenerator?.serviceErrorProtocolSymbol).renderErrors()
}
}

override fun generateUnion(directive: GenerateUnionDirective<GenerationContext, SwiftSettings>) {
val shape = directive.shape()
val context = directive.context()
val model = directive.model()
val settings = directive.settings()
val symbolProvider = directive.symbolProvider()
val writers = context.writerDelegator()
writers.useShapeWriter(shape) { writer: SwiftWriter -> UnionGenerator(model, symbolProvider, writer, shape, settings).render() }
}

override fun generateEnumShape(directive: GenerateEnumDirective<GenerationContext, SwiftSettings>) {
val shape = directive.shape()
val context = directive.context()
val model = directive.model()
val settings = directive.settings()
val symbolProvider = directive.symbolProvider()
val writers = context.writerDelegator()
writers.useShapeWriter(shape) { writer: SwiftWriter -> EnumGenerator(model, symbolProvider, writer, shape, settings).render() }
}

override fun generateIntEnumShape(directive: GenerateIntEnumDirective<GenerationContext, SwiftSettings>) {
val shape = directive.shape()
val context = directive.context()
val model = directive.model()
val settings = directive.settings()
val symbolProvider = directive.symbolProvider()
val writers = context.writerDelegator()
writers.useShapeWriter(shape) { writer: SwiftWriter -> IntEnumGenerator(model, symbolProvider, writer, shape.asIntEnumShape().get(), settings).render() }
}

private fun resolveProtocolGenerator(
integrations: List<SwiftIntegration>,
model: Model,
service: ServiceShape,
settings: SwiftSettings
): ProtocolGenerator? {
val generators = integrations.flatMap { it.protocolGenerators }.associateBy { it.protocol }
val serviceIndex = ServiceIndex.of(model)

try {
val protocolTrait = settings.resolveServiceProtocol(serviceIndex, service, generators.keys)
return generators[protocolTrait]
} catch (ex: UnresolvableProtocolException) {
LOGGER.warning("Unable to find protocol generator for ${service.id}: ${ex.message}")
}
return null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ package software.amazon.smithy.swift.codegen
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.traits.EnumDefinition
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.swift.codegen.customtraits.NestedTrait
Expand Down Expand Up @@ -107,7 +107,7 @@ class EnumGenerator(
private val model: Model,
private val symbolProvider: SymbolProvider,
private val writer: SwiftWriter,
private val shape: StringShape,
private val shape: Shape,
private val settings: SwiftSettings
) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

package software.amazon.smithy.swift.codegen

class ImportDeclarations {
import software.amazon.smithy.codegen.core.ImportContainer
import software.amazon.smithy.codegen.core.Symbol

class ImportDeclarations : ImportContainer {
private val imports = mutableSetOf<ImportStatement>()

fun addImport(
Expand Down Expand Up @@ -41,6 +44,10 @@ class ImportDeclarations {
.map(ImportStatement::statement)
.joinToString(separator = "\n")
}

override fun importSymbol(symbol: Symbol?, alias: String?) {
TODO("Not yet implemented")
}
}

private data class ImportStatement(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class StructureGenerator(
writer.removeContext("struct.name")
}

fun renderErrors() {
writer.putContext("struct.name", structSymbol.name.toUpperCamelCase())
renderErrorStructure()
writer.removeContext("struct.name")
}

/**
* Generates an appropriate Swift type for a Smithy Structure shape without error trait.
* If the structure is a recursive nested type it will generate a boxed member Box<T>.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,27 @@ package software.amazon.smithy.swift.codegen
import software.amazon.smithy.build.PluginContext
import software.amazon.smithy.build.SmithyBuildPlugin
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.codegen.core.directed.CodegenDirector
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.swift.codegen.core.GenerationContext
import software.amazon.smithy.swift.codegen.integration.SwiftIntegration
import software.amazon.smithy.swift.codegen.model.AddOperationShapes
import software.amazon.smithy.swift.codegen.model.NestedShapeTransformer
import software.amazon.smithy.swift.codegen.model.RecursiveShapeBoxer
import software.amazon.smithy.swift.codegen.model.UnionIndirectivizer
import java.util.ServiceLoader
import java.util.logging.Logger

/**
* Plugin to trigger Swift code generation.
*/
class SwiftCodegenPlugin : SmithyBuildPlugin {

companion object {

private val LOGGER = Logger.getLogger(SwiftCodegenPlugin::class.java.getName())

/**
* Creates a Kotlin symbol provider.
* @param model The model to generate symbols for
Expand All @@ -24,14 +37,56 @@ class SwiftCodegenPlugin : SmithyBuildPlugin {
* @param sdkId name to use to represent client type. e.g. an sdkId of "foo" would produce a client type "FooClient".
* @return Returns the created provider
*/
fun createSymbolProvider(model: Model, swiftSettings: SwiftSettings): SymbolProvider = SymbolVisitor(model, swiftSettings)
fun createSymbolProvider(model: Model, swiftSettings: SwiftSettings): SymbolProvider = SwiftSymbolProvider(model, swiftSettings)
}

override fun getName(): String = "swift-codegen"

override fun execute(context: PluginContext) {
println("executing swift codegen")

CodegenVisitor(context).execute()
val codegenDirector = CodegenDirector<SwiftWriter, SwiftIntegration, GenerationContext, SwiftSettings>()

val swiftSettings = SwiftSettings.from(context.model, context.settings)

codegenDirector.directedCodegen(DirectedSwiftCodegen(context))

codegenDirector.settings(swiftSettings)

codegenDirector.integrationClass(SwiftIntegration::class.java)

codegenDirector.fileManifest(context.fileManifest)

val enabledIntegrations = ServiceLoader.load(SwiftIntegration::class.java, CodegenDirector::class.java.getClassLoader())
.also { integration -> LOGGER.info("Loaded SwiftIntegration: ${integration.javaClass.name}") }
.filter { integration -> integration.enabledForService(context.model, swiftSettings) }
.also { integration -> LOGGER.info("Enabled SwiftIntegration: ${integration.javaClass.name}") }
.sortedBy(SwiftIntegration::order)
.toList()

val resolvedModel = preprocessModel(context.model, swiftSettings, enabledIntegrations)

codegenDirector.model(resolvedModel)

codegenDirector.integrationFinder { enabledIntegrations.asIterable() }

codegenDirector.service(swiftSettings.getService(resolvedModel).id)

codegenDirector.run()
}

private fun preprocessModel(model: Model, settings: SwiftSettings, integrations: List<SwiftIntegration>): Model {
var resolvedModel = model

for (integration in integrations) {
resolvedModel = integration.preprocessModel(resolvedModel, settings)
}

resolvedModel = ModelTransformer.create().flattenAndRemoveMixins(resolvedModel)
resolvedModel = AddOperationShapes.execute(resolvedModel, settings.getService(resolvedModel), settings.moduleName)
resolvedModel = RecursiveShapeBoxer.transform(resolvedModel)
resolvedModel = NestedShapeTransformer.transform(resolvedModel, settings.getService(resolvedModel))
resolvedModel = UnionIndirectivizer.transform(resolvedModel)
return resolvedModel
}
}
Loading

0 comments on commit 6349ceb

Please sign in to comment.