diff --git a/interop/ksp/src/main/kotlin/com/squareup/kotlinpoet/ksp/KsTypes.kt b/interop/ksp/src/main/kotlin/com/squareup/kotlinpoet/ksp/KsTypes.kt index 8d072201e5..6ffd36749e 100644 --- a/interop/ksp/src/main/kotlin/com/squareup/kotlinpoet/ksp/KsTypes.kt +++ b/interop/ksp/src/main/kotlin/com/squareup/kotlinpoet/ksp/KsTypes.kt @@ -15,6 +15,7 @@ */ package com.squareup.kotlinpoet.ksp +import com.google.devtools.ksp.symbol.KSCallableReference import com.google.devtools.ksp.symbol.KSClassDeclaration import com.google.devtools.ksp.symbol.KSType import com.google.devtools.ksp.symbol.KSTypeAlias @@ -27,6 +28,8 @@ import com.google.devtools.ksp.symbol.Variance.COVARIANT import com.google.devtools.ksp.symbol.Variance.INVARIANT import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.KModifier +import com.squareup.kotlinpoet.LambdaTypeName +import com.squareup.kotlinpoet.ParameterSpec import com.squareup.kotlinpoet.STAR import com.squareup.kotlinpoet.TypeName import com.squareup.kotlinpoet.TypeVariableName @@ -179,5 +182,14 @@ public fun KSTypeArgument.toTypeName( public fun KSTypeReference.toTypeName( typeParamResolver: TypeParameterResolver = TypeParameterResolver.EMPTY, ): TypeName { - return resolve().toTypeName(typeParamResolver, element?.typeArguments.orEmpty()) + return when (val elem = element) { + is KSCallableReference -> { + LambdaTypeName.get( + receiver = elem.receiverType?.toTypeName(typeParamResolver), + parameters = elem.functionParameters.map { ParameterSpec.unnamed(it.type.toTypeName(typeParamResolver)) }, + returnType = elem.returnType.toTypeName(typeParamResolver), + ) + } + else -> resolve().toTypeName(typeParamResolver, element?.typeArguments.orEmpty()) + } } diff --git a/interop/ksp/test-processor/src/main/kotlin/com/squareup/kotlinpoet/ksp/test/processor/TestProcessor.kt b/interop/ksp/test-processor/src/main/kotlin/com/squareup/kotlinpoet/ksp/test/processor/TestProcessor.kt index 82064c8548..3d9db1e1cf 100644 --- a/interop/ksp/test-processor/src/main/kotlin/com/squareup/kotlinpoet/ksp/test/processor/TestProcessor.kt +++ b/interop/ksp/test-processor/src/main/kotlin/com/squareup/kotlinpoet/ksp/test/processor/TestProcessor.kt @@ -171,7 +171,14 @@ class TestProcessor(private val env: SymbolProcessorEnvironment) : SymbolProcess ) .addParameters( function.parameters.map { parameter -> - val parameterType = parameter.type.toValidatedTypeName(functionTypeParams).let { + // Function references can't be obtained from a resolved KSType because it resolves to FunctionN<> which + // loses the necessary context, skip validation in these cases as we know they won't match. + val typeName = if (parameter.type.resolve().isFunctionType) { + parameter.type.toTypeName(functionTypeParams) + } else { + parameter.type.toValidatedTypeName(functionTypeParams) + } + val parameterType = typeName.let { if (unwrapTypeAliases) { it.unwrapTypeAlias() } else { diff --git a/interop/ksp/test-processor/src/test/kotlin/com/squareup/kotlinpoet/ksp/test/processor/TestProcessorTest.kt b/interop/ksp/test-processor/src/test/kotlin/com/squareup/kotlinpoet/ksp/test/processor/TestProcessorTest.kt index 1aaa1edb5e..e731852b9f 100644 --- a/interop/ksp/test-processor/src/test/kotlin/com/squareup/kotlinpoet/ksp/test/processor/TestProcessorTest.kt +++ b/interop/ksp/test-processor/src/test/kotlin/com/squareup/kotlinpoet/ksp/test/processor/TestProcessorTest.kt @@ -129,7 +129,10 @@ class TestProcessorTest { suspend fun functionD( param1: () -> String, param2: (String) -> String, - param3: String.() -> String + param3: String.() -> String, + param4: Function0, + param5: Function1, + param6: (Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int, Int) -> Unit, ) { } @@ -188,6 +191,7 @@ class TestProcessorTest { import kotlin.Int import kotlin.IntArray import kotlin.String + import kotlin.Unit import kotlin.collections.List import kotlin.collections.Map import kotlin.collections.MutableList @@ -261,9 +265,37 @@ class TestProcessorTest { } public suspend fun functionD( - param1: Function0, - param2: Function1, - param3: Function1, + param1: () -> String, + param2: (String) -> String, + param3: String.() -> String, + param4: Function0, + param5: Function1, + param6: ( + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + Int, + ) -> Unit, ) { }