diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1745a4b09e3d4..072651dbf1732 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1093,9 +1093,9 @@ object DecisionTree extends Serializable with Logging { */ if (strategy.categoricalFeaturesInfo.size > 0) { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 - require(numBins >= maxCategoriesForFeatures) + require(numBins > maxCategoriesForFeatures) if (strategy.isMultiClassification) { - require(numBins > math.pow(2, maxCategoriesForFeatures.toInt) - 1) + require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1) } }