diff --git a/diktat-rules/src/main/kotlin/com/saveourtool/diktat/ruleset/rules/chapter4/NullChecksRule.kt b/diktat-rules/src/main/kotlin/com/saveourtool/diktat/ruleset/rules/chapter4/NullChecksRule.kt index 4a19d0afa5..9d40395f1b 100644 --- a/diktat-rules/src/main/kotlin/com/saveourtool/diktat/ruleset/rules/chapter4/NullChecksRule.kt +++ b/diktat-rules/src/main/kotlin/com/saveourtool/diktat/ruleset/rules/chapter4/NullChecksRule.kt @@ -9,6 +9,8 @@ import com.saveourtool.diktat.ruleset.utils.parent import org.jetbrains.kotlin.KtNodeTypes.BINARY_EXPRESSION import org.jetbrains.kotlin.KtNodeTypes.BLOCK import org.jetbrains.kotlin.KtNodeTypes.BREAK +import org.jetbrains.kotlin.KtNodeTypes.CALL_EXPRESSION +import org.jetbrains.kotlin.KtNodeTypes.CLASS_INITIALIZER import org.jetbrains.kotlin.KtNodeTypes.CONDITION import org.jetbrains.kotlin.KtNodeTypes.ELSE import org.jetbrains.kotlin.KtNodeTypes.IF @@ -26,6 +28,9 @@ import org.jetbrains.kotlin.psi.KtBlockExpression import org.jetbrains.kotlin.psi.KtIfExpression import org.jetbrains.kotlin.psi.KtPsiUtil import org.jetbrains.kotlin.psi.psiUtil.blockExpressionsOrSingle +import org.jetbrains.kotlin.psi.psiUtil.parents + +typealias ThenAndElseLines = Pair?, List?> /** * This rule check and fixes explicit null checks (explicit comparison with `null`) @@ -72,35 +77,68 @@ class NullChecksRule(configRules: List) : DiktatRule( private fun conditionInIfStatement(node: ASTNode) { node.findAllDescendantsWithSpecificType(BINARY_EXPRESSION).forEach { binaryExprNode -> val condition = (binaryExprNode.psi as KtBinaryExpression) + if (isNullCheckBinaryExpression(condition)) { - when (condition.operationToken) { + val isEqualToNull = when (condition.operationToken) { // `==` and `===` comparison can be fixed with `?:` operator - KtTokens.EQEQ, KtTokens.EQEQEQ -> - warnAndFixOnNullCheck( - condition, - isFixable(node, true), - "use '.let/.also/?:/e.t.c' instead of ${condition.text}" - ) { - fixNullInIfCondition(node, condition, true) - } - // `!==` and `!==` comparison can be fixed with `.let/also` operators - KtTokens.EXCLEQ, KtTokens.EXCLEQEQEQ -> - warnAndFixOnNullCheck( - condition, - isFixable(node, false), - "use '.let/.also/?:/e.t.c' instead of ${condition.text}" - ) { - fixNullInIfCondition(node, condition, false) - } + KtTokens.EQEQ, KtTokens.EQEQEQ -> true + // `!=` and `!==` comparison can be fixed with `.let/also` operators + KtTokens.EXCLEQ, KtTokens.EXCLEQEQEQ -> false else -> return } + + val (_, elseCodeLines) = getThenAndElseLines(node, isEqualToNull) + val (numberOfStatementsInElseBlock, isAssignmentInNewElseBlock) = getInfoAboutElseBlock(node, isEqualToNull) + + // if `if-else` block inside `init` or 'run', 'with', 'apply' blocks and there is more than one statement inside 'else' block, then + // we don't have to make any fixes, because this leads to kotlin compilation error after adding 'run' block instead of 'else' block + // read https://youtrack.jetbrains.com/issue/KT-64174 for more information + if (shouldBeWarned(node, elseCodeLines, numberOfStatementsInElseBlock, isAssignmentInNewElseBlock)) { + warnAndFixOnNullCheck( + condition, + canBeAutoFixed(node, isEqualToNull), + "use '.let/.also/?:/e.t.c' instead of ${condition.text}" + ) { + fixNullInIfCondition(node, condition, isEqualToNull) + } + } } } } + /** + * Checks whether it is necessary to warn about null-check + */ + private fun shouldBeWarned( + condition: ASTNode, + elseCodeLines: List?, + numberOfStatementsInElseBlock: Int, + isAssignment: Boolean, + ): Boolean = when { + // else { "null"/empty } -> "" + isNullOrEmptyElseBlock(elseCodeLines) -> true + // else { bar() } -> ?: bar() + isOnlyOneNonAssignmentStatementInElseBlock(numberOfStatementsInElseBlock, isAssignment) -> true + // else { ... } -> ?: run { ... } + else -> isNotInsideWrongBlock(condition) + } + + private fun isOnlyOneNonAssignmentStatementInElseBlock(numberOfStatementsInElseBlock: Int, isAssignment: Boolean) = numberOfStatementsInElseBlock == 1 && !isAssignment + + private fun isNullOrEmptyElseBlock(elseCodeLines: List?) = elseCodeLines == null || elseCodeLines.singleOrNull() == "null" + + private fun isNotInsideWrongBlock(condition: ASTNode): Boolean = condition.parents().none { + it.elementType == CLASS_INITIALIZER || + (it.elementType == CALL_EXPRESSION && + it.findChildByType(REFERENCE_EXPRESSION)?.text in listOf("run", "with", "apply")) + } + + /** + * Checks whether null-check can be auto fixed + */ @Suppress("UnsafeCallOnNullableType") - private fun isFixable(condition: ASTNode, - isEqualToNull: Boolean): Boolean { + private fun canBeAutoFixed(condition: ASTNode, + isEqualToNull: Boolean): Boolean { // Handle cases with `break` word in blocks val typePair = if (isEqualToNull) { (ELSE to THEN) @@ -126,6 +164,19 @@ class NullChecksRule(configRules: List) : DiktatRule( isEqualToNull: Boolean ) { val variableName = binaryExpression.left!!.text + val (thenCodeLines, elseCodeLines) = getThenAndElseLines(condition, isEqualToNull) + val (numberOfStatementsInElseBlock, isAssignmentInNewElseBlock) = getInfoAboutElseBlock(condition, isEqualToNull) + + val elseEditedCodeLines = getEditedElseCodeLines(elseCodeLines, numberOfStatementsInElseBlock, isAssignmentInNewElseBlock) + val thenEditedCodeLines = getEditedThenCodeLines(variableName, thenCodeLines, elseEditedCodeLines) + + val newTextForReplacement = "$thenEditedCodeLines $elseEditedCodeLines" + val newNodeForReplacement = KotlinParser().createNode(newTextForReplacement) + val ifNode = condition.treeParent + ifNode.treeParent.replaceChild(ifNode, newNodeForReplacement) + } + + private fun getThenAndElseLines(condition: ASTNode, isEqualToNull: Boolean): ThenAndElseLines { val thenFromExistingCode = condition.extractLinesFromBlock(THEN) val elseFromExistingCode = condition.extractLinesFromBlock(ELSE) @@ -141,7 +192,11 @@ class NullChecksRule(configRules: List) : DiktatRule( elseFromExistingCode } - val (numberOfStatementsInElseBlock, isAssignmentInNewElseBlock) = (condition.treeParent.psi as? KtIfExpression) + return Pair(thenCodeLines, elseCodeLines) + } + + private fun getInfoAboutElseBlock(condition: ASTNode, isEqualToNull: Boolean) = + ((condition.treeParent.psi as? KtIfExpression) ?.let { if (isEqualToNull) { it.then @@ -155,28 +210,20 @@ class NullChecksRule(configRules: List) : DiktatRule( KtPsiUtil.isAssignment(element) } } - ?: Pair(0, false) - - val elseEditedCodeLines = getEditedElseCodeLines(elseCodeLines, numberOfStatementsInElseBlock, isAssignmentInNewElseBlock) - val thenEditedCodeLines = getEditedThenCodeLines(variableName, thenCodeLines, elseEditedCodeLines) - - val text = "$thenEditedCodeLines $elseEditedCodeLines" - val tree = KotlinParser().createNode(text) - val ifNode = condition.treeParent - ifNode.treeParent.replaceChild(ifNode, tree) - } + ?: Pair(0, false)) + @Suppress("UnsafeCallOnNullableType") private fun getEditedElseCodeLines( elseCodeLines: List?, numberOfStatementsInElseBlock: Int, isAssignment: Boolean, ): String = when { // else { "null"/empty } -> "" - elseCodeLines == null || elseCodeLines.singleOrNull() == "null" -> "" + isNullOrEmptyElseBlock(elseCodeLines) -> "" // else { bar() } -> ?: bar() - numberOfStatementsInElseBlock == 1 && !isAssignment -> "?: ${elseCodeLines.joinToString(postfix = "\n", separator = "\n")}" + isOnlyOneNonAssignmentStatementInElseBlock(numberOfStatementsInElseBlock, isAssignment) -> "?: ${elseCodeLines!!.joinToString(postfix = "\n", separator = "\n")}" // else { ... } -> ?: run { ... } - else -> getDefaultCaseElseCodeLines(elseCodeLines) + else -> getDefaultCaseElseCodeLines(elseCodeLines!!) } @Suppress("UnsafeCallOnNullableType") diff --git a/diktat-rules/src/test/kotlin/com/saveourtool/diktat/ruleset/chapter4/NullChecksRuleWarnTest.kt b/diktat-rules/src/test/kotlin/com/saveourtool/diktat/ruleset/chapter4/NullChecksRuleWarnTest.kt index e510491694..43122523ba 100644 --- a/diktat-rules/src/test/kotlin/com/saveourtool/diktat/ruleset/chapter4/NullChecksRuleWarnTest.kt +++ b/diktat-rules/src/test/kotlin/com/saveourtool/diktat/ruleset/chapter4/NullChecksRuleWarnTest.kt @@ -215,4 +215,109 @@ class NullChecksRuleWarnTest : LintTestBase(::NullChecksRule) { """.trimMargin() ) } + + @Test + @Tag(WarningNames.AVOID_NULL_CHECKS) + fun `don't trigger inside 'init' block when more than one statement in 'else' block`() { + lintMethod( + """ + |class Demo { + | val one: Int + | val two: String + | + | init { + | val number = get() + | if (number != null) { + | one = number.toInt() + | two = number + | } else { + | one = 0 + | two = "0" + | } + | } + | + | private fun get(): String? = if (Math.random() > 0.5) { "1" } else { null } + |} + """.trimMargin() + ) + } + + @Test + @Tag(WarningNames.AVOID_NULL_CHECKS) + fun `trigger inside 'init' block when only one statement in 'else' block`() { + lintMethod( + """ + |class Demo { + | val one: Int = 0 + | val two: String = "" + | + | init { + | val number = get() + | if (number != null) { + | print(number + 1) + | } else { + | print(null) + | } + | } + | + | private fun get(): String? = if (Math.random() > 0.5) { "1" } else { null } + |} + """.trimMargin(), + DiktatError(7, 13, ruleId, Warnings.AVOID_NULL_CHECKS.warnText() + + " use '.let/.also/?:/e.t.c' instead of number != null", true), + ) + } + + @Test + @Tag(WarningNames.AVOID_NULL_CHECKS) + fun `trigger inside 'init' block when no 'else' block`() { + lintMethod( + """ + |class Demo { + | val one: Int = 0 + | val two: String = "" + | + | init { + | val number = get() + | if (number != null) { + | print(number) + | } + | } + | + | private fun get(): String? = if (Math.random() > 0.5) { "1" } else { null } + |} + """.trimMargin(), + DiktatError(7, 13, ruleId, Warnings.AVOID_NULL_CHECKS.warnText() + + " use '.let/.also/?:/e.t.c' instead of number != null", true), + ) + } + + @Test + @Tag(WarningNames.AVOID_NULL_CHECKS) + fun `don't trigger inside 'run', 'with', 'apply' scope functions when more than one statement in 'else' block`() { + lintMethod( + """ + |class Demo { + | + | private fun set() { + | val one: Int + | val two: String + | + | run { + | val number: String? = get() + | if (number != null) { + | one = number.toInt() + | two = number + | } else { + | one = 0 + | two = "0" + | } + | } + | } + | + | private fun get(): String? = if (Math.random() > 0.5) { "1" } else { null } + |} + """.trimMargin() + ) + } }