Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a null-check to java enum safeValueOf #5904

Merged
merged 12 commits into from
May 28, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ internal object JavaClassNames {
val Map: ClassName = ClassName.get("java.util", "Map")
val MapOfStringToObject = ParameterizedTypeName.get(Map, String, Object)
val JavaOptional = ClassName.get("java.util", "Optional")
val Objects = ClassName.get("java.util", "Objects")

val ObjectBuilderKt = ClassName.get(apolloApiPackageName, "ObjectBuilderKt")
val ObjectMap = ClassName.get(apolloApiPackageName, "ObjectMap")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package com.apollographql.apollo3.compiler.codegen.java.helpers

import com.apollographql.apollo3.compiler.GeneratedMethod
import com.apollographql.apollo3.compiler.GeneratedMethod.*
import com.apollographql.apollo3.compiler.internal.applyIf
import com.apollographql.apollo3.compiler.GeneratedMethod.EQUALS_HASH_CODE
import com.apollographql.apollo3.compiler.GeneratedMethod.TO_STRING
import com.apollographql.apollo3.compiler.codegen.Identifier.__h
import com.apollographql.apollo3.compiler.codegen.java.JavaClassNames
import com.apollographql.apollo3.compiler.codegen.java.L
import com.apollographql.apollo3.compiler.codegen.java.joinToCode
import com.apollographql.apollo3.compiler.internal.applyIf
import com.squareup.javapoet.ClassName
import com.squareup.javapoet.CodeBlock
import com.squareup.javapoet.FieldSpec
Expand All @@ -29,8 +30,8 @@ import javax.lang.model.element.Modifier
internal fun TypeSpec.Builder.makeClassFromParameters(
generateMethods: List<GeneratedMethod>,
parameters: List<ParameterSpec>,
className: ClassName
): TypeSpec.Builder {
className: ClassName,
): TypeSpec.Builder {
addMethod(
MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
Expand All @@ -55,7 +56,7 @@ internal fun TypeSpec.Builder.makeClassFromParameters(

internal fun TypeSpec.Builder.addGeneratedMethods(
className: ClassName,
generateMethods: List<GeneratedMethod> = listOf(EQUALS_HASH_CODE, TO_STRING)
generateMethods: List<GeneratedMethod> = listOf(EQUALS_HASH_CODE, TO_STRING),
): TypeSpec.Builder {
return applyIf(generateMethods.contains(EQUALS_HASH_CODE)) { withEqualsImplementation(className) }
.applyIf(generateMethods.contains(EQUALS_HASH_CODE)) { withHashCodeImplementation() }
Expand All @@ -68,8 +69,8 @@ internal fun TypeSpec.Builder.addGeneratedMethods(
internal fun TypeSpec.Builder.makeClassFromProperties(
generateMethods: List<GeneratedMethod>,
fields: List<FieldSpec>,
className: ClassName
): TypeSpec.Builder {
className: ClassName,
): TypeSpec.Builder {
addMethod(
MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
Expand All @@ -94,42 +95,45 @@ internal fun TypeSpec.Builder.makeClassFromProperties(

internal fun TypeSpec.Builder.withToStringImplementation(className: ClassName): TypeSpec.Builder {
fun printFieldCode(fieldIndex: Int, fieldName: String) =
CodeBlock.builder()
.let { if (fieldIndex > 0) it.add(" + \", \"\n") else it.add("\n") }
.indent()
.add("+ \$S + \$L", "$fieldName=", fieldName)
.unindent()
.build()
CodeBlock.builder()
.let { if (fieldIndex > 0) it.add(" + \", \"\n") else it.add("\n") }
.indent()
.add("+ \$S + \$L", "$fieldName=", fieldName)
.unindent()
.build()

fun methodCode() =
CodeBlock.builder()
.beginControlFlow("if (\$L == null)", MEMOIZED_TO_STRING_VAR)
.add("\$L = \$S", "\$toString", "${className.simpleName()}{")
.add(fieldSpecs
.filter { !it.hasModifier(Modifier.STATIC) }
.filter { !it.hasModifier(Modifier.TRANSIENT) }
.map { it.name }
.mapIndexed(::printFieldCode)
.fold(CodeBlock.builder(), CodeBlock.Builder::add)
.build())
.add(CodeBlock.builder()
.indent()
.add("\n+ \$S;\n", "}")
.unindent()
.build())
.endControlFlow()
.addStatement("return \$L", MEMOIZED_TO_STRING_VAR)
.build()
CodeBlock.builder()
.beginControlFlow("if (\$L == null)", MEMOIZED_TO_STRING_VAR)
.add("\$L = \$S", "\$toString", "${className.simpleName()}{")
.add(fieldSpecs
.filter { !it.hasModifier(Modifier.STATIC) }
.filter { !it.hasModifier(Modifier.TRANSIENT) }
.map { it.name }
.mapIndexed(::printFieldCode)
.fold(CodeBlock.builder(), CodeBlock.Builder::add)
.build())
.add(CodeBlock.builder()
.indent()
.add("\n+ \$S;\n", "}")
.unindent()
.build()
)
.endControlFlow()
.addStatement("return \$L", MEMOIZED_TO_STRING_VAR)
.build()

return addField(FieldSpec.builder(JavaClassNames.String, MEMOIZED_TO_STRING_VAR, Modifier.PRIVATE, Modifier.VOLATILE,
Modifier.TRANSIENT)
.build())
.addMethod(MethodSpec.methodBuilder("toString")
.addAnnotation(JavaClassNames.Override)
.addModifiers(Modifier.PUBLIC)
.returns(JavaClassNames.String)
.addCode(methodCode())
.build())
return addField(
FieldSpec.builder(JavaClassNames.String, MEMOIZED_TO_STRING_VAR, Modifier.PRIVATE, Modifier.VOLATILE, Modifier.TRANSIENT).build()
)
.addMethod(
MethodSpec.methodBuilder("toString")
.addAnnotation(JavaClassNames.Override)
.addModifiers(Modifier.PUBLIC)
.returns(JavaClassNames.String)
.addCode(methodCode())
.build()
)
}

private fun List<FieldSpec>.equalsCode(): CodeBlock = filter { !it.hasModifier(Modifier.STATIC) }
Expand All @@ -138,92 +142,114 @@ private fun List<FieldSpec>.equalsCode(): CodeBlock = filter { !it.hasModifier(M
.joinToCode("\n &&")

private fun FieldSpec.equalsCode() =
CodeBlock.builder()
.let {
if (type.isPrimitive) {
if (type == TypeName.DOUBLE) {
it.add("Double.doubleToLongBits(this.\$L) == Double.doubleToLongBits(that.\$L)",
name, name)
} else {
it.add("this.\$L == that.\$L", name, name)
}
CodeBlock.builder()
.let {
if (type.isPrimitive) {
if (type == TypeName.DOUBLE) {
it.add("Double.doubleToLongBits(this.\$L) == Double.doubleToLongBits(that.\$L)", name, name)
} else {
it.add("((this.\$L == null) ? (that.\$L == null) : this.\$L.equals(that.\$L))", name, name, name, name)
it.add("this.\$L == that.\$L", name, name)
}
} else {
it.add("((this.\$L == null) ? (that.\$L == null) : this.\$L.equals(that.\$L))", name, name, name, name)
}
.build()
}
.build()

internal fun TypeSpec.Builder.withEqualsImplementation(className: ClassName): TypeSpec.Builder {
val hasSuperClass = build().superclass != ClassName.OBJECT
fun methodCode(typeJavaClass: ClassName) =
CodeBlock.builder()
.beginControlFlow("if (o == this)")
.addStatement("return true")
.endControlFlow()
.beginControlFlow("if (o instanceof \$T)", typeJavaClass)
.apply {
if (fieldSpecs.isEmpty()) {
CodeBlock.builder()
.beginControlFlow("if (o == this)")
.addStatement("return true")
.endControlFlow()
.beginControlFlow("if (o instanceof \$T)", typeJavaClass)
.apply {
if (fieldSpecs.isEmpty()) {
if (hasSuperClass) {
add("return super.equals(o);\n")
} else {
add("return true;\n")
}
} else {
addStatement("\$T that = (\$T) o", typeJavaClass, typeJavaClass)
if (hasSuperClass) {
add("return super.equals(o) && $L;\n", fieldSpecs.equalsCode())
} else {
addStatement("\$T that = (\$T) o", typeJavaClass, typeJavaClass)
add("return $L;\n", if (fieldSpecs.isEmpty()) "true" else fieldSpecs.equalsCode())
add("return $L;\n", fieldSpecs.equalsCode())
}
}
.endControlFlow()
.addStatement("return false")
.build()
}
.endControlFlow()
.addStatement("return false")
.build()

return addMethod(MethodSpec.methodBuilder("equals")
.addAnnotation(JavaClassNames.Override)
.addModifiers(Modifier.PUBLIC)
.returns(TypeName.BOOLEAN)
.addParameter(ParameterSpec.builder(TypeName.OBJECT, "o").build())
.addCode(methodCode(className))
.build())
.addAnnotation(JavaClassNames.Override)
.addModifiers(Modifier.PUBLIC)
.returns(TypeName.BOOLEAN)
.addParameter(ParameterSpec.builder(TypeName.OBJECT, "o").build())
.addCode(methodCode(className))
.build()
)
}

internal fun TypeSpec.Builder.withHashCodeImplementation(): TypeSpec.Builder {
val hasSuperClass = build().superclass != ClassName.OBJECT
fun hashFieldCode(field: FieldSpec) =
CodeBlock.builder()
.addStatement("$__h *= 1000003")
.let {
if (field.type.isPrimitive) {
when (field.type.withoutAnnotations()) {
TypeName.DOUBLE -> it.addStatement("$__h ^= Double.valueOf(\$L).hashCode()", field.name)
TypeName.BOOLEAN -> it.addStatement("$__h ^= Boolean.valueOf(\$L).hashCode()", field.name)
else -> it.addStatement("$__h ^= \$L", field.name)
}
} else {
it.addStatement("$__h ^= (\$L == null) ? 0 : \$L.hashCode()", field.name, field.name)
CodeBlock.builder()
.addStatement("$__h *= 1000003")
.let {
if (field.type.isPrimitive) {
when (field.type.withoutAnnotations()) {
TypeName.DOUBLE -> it.addStatement("$__h ^= Double.valueOf(\$L).hashCode()", field.name)
TypeName.BOOLEAN -> it.addStatement("$__h ^= Boolean.valueOf(\$L).hashCode()", field.name)
else -> it.addStatement("$__h ^= \$L", field.name)
}
} else {
it.addStatement("$__h ^= (\$L == null) ? 0 : \$L.hashCode()", field.name, field.name)
}
.build()
}
.build()

fun methodCode() =
CodeBlock.builder()
.beginControlFlow("if (!\$L)", MEMOIZED_HASH_CODE_FLAG_VAR)
.addStatement("int $__h = 1")
.add(fieldSpecs
.filter { !it.hasModifier(Modifier.STATIC) }
.filter { !it.hasModifier(Modifier.TRANSIENT) }
.map(::hashFieldCode)
.fold(CodeBlock.builder(), CodeBlock.Builder::add)
.build())
.addStatement("\$L = $__h", MEMOIZED_HASH_CODE_VAR)
.addStatement("\$L = true", MEMOIZED_HASH_CODE_FLAG_VAR)
.endControlFlow()
.addStatement("return \$L", MEMOIZED_HASH_CODE_VAR)
.build()
CodeBlock.builder()
.beginControlFlow("if (!\$L)", MEMOIZED_HASH_CODE_FLAG_VAR)
.addStatement(
if (hasSuperClass) {
"int $__h = super.hashCode()"
} else {
"int $__h = 1"
}
)
.add(
fieldSpecs
.filter { !it.hasModifier(Modifier.STATIC) }
.filter { !it.hasModifier(Modifier.TRANSIENT) }
.map(::hashFieldCode)
.fold(CodeBlock.builder(), CodeBlock.Builder::add)
.build()
)
.addStatement("\$L = $__h", MEMOIZED_HASH_CODE_VAR)
.addStatement("\$L = true", MEMOIZED_HASH_CODE_FLAG_VAR)
.endControlFlow()
.addStatement("return \$L", MEMOIZED_HASH_CODE_VAR)
.build()

return addField(FieldSpec.builder(TypeName.INT, MEMOIZED_HASH_CODE_VAR, Modifier.PRIVATE, Modifier.VOLATILE,
Modifier.TRANSIENT).build())
.addField(FieldSpec.builder(TypeName.BOOLEAN, MEMOIZED_HASH_CODE_FLAG_VAR, Modifier.PRIVATE,
Modifier.VOLATILE, Modifier.TRANSIENT).build())
.addMethod(MethodSpec.methodBuilder("hashCode")
.addAnnotation(JavaClassNames.Override)
.addModifiers(Modifier.PUBLIC)
.returns(TypeName.INT)
.addCode(methodCode())
.build())
return addField(
FieldSpec.builder(TypeName.INT, MEMOIZED_HASH_CODE_VAR, Modifier.PRIVATE, Modifier.VOLATILE, Modifier.TRANSIENT).build()
)
.addField(
FieldSpec.builder(TypeName.BOOLEAN, MEMOIZED_HASH_CODE_FLAG_VAR, Modifier.PRIVATE, Modifier.VOLATILE, Modifier.TRANSIENT).build()
)
.addMethod(
MethodSpec.methodBuilder("hashCode")
.addAnnotation(JavaClassNames.Override)
.addModifiers(Modifier.PUBLIC)
.returns(TypeName.INT)
.addCode(methodCode())
.build()
)
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ internal class EnumAsClassBuilder(
)
.addMethod(
MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
.addModifiers(Modifier.PRIVATE)
.addParameter(ParameterSpec.builder(JavaClassNames.String, rawValue).build())
.addCode("this.$rawValue = $rawValue;\n")
.build()
Expand All @@ -86,7 +86,7 @@ internal class EnumAsClassBuilder(
.returns(selfClassName)
.addCode(
CodeBlock.builder()
.beginControlFlow("switch($rawValue)")
.beginControlFlow("switch ($T.requireNonNull($rawValue))", JavaClassNames.Objects)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL about Objects.requireNonNull. Should we use that instead of

public static <T> T checkNotNull(T value, String errorMessage) {
?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was almost sure we had something like that but couldn't find it anymore 😅. Ours is a bit nicer since you can pass a message - so I'd keep it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment on ours?

// A version of Objects.requireNonNull that allows a customized message

.apply {
values.forEach {
add("case $S: return $T.$L;\n", it.name, selfClassName, it.targetName.escapeTypeReservedWord()
Expand All @@ -113,7 +113,7 @@ internal class EnumAsClassBuilder(
.addJavadoc(L, "An enum value that wasn't known at compile time.\n")
.addMethod(
MethodSpec.constructorBuilder()
.addModifiers(Modifier.PUBLIC)
.addModifiers(Modifier.PRIVATE)
.addParameter(ParameterSpec.builder(JavaClassNames.String, rawValue).build())
.addCode("super($rawValue);\n")
.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ internal class EnumAsEnumBuilder(
.returns(selfClassName)
.addCode(
CodeBlock.builder()
.beginControlFlow("switch ($rawValue)")
.beginControlFlow("switch ($T.requireNonNull($rawValue))", JavaClassNames.Objects)
.apply {
values.forEach {
add("case $S: return $T.$L;\n", it.name, selfClassName, it.targetName.escapeTypeReservedWord()
Expand Down
Loading
Loading