Skip to content

Commit

Permalink
Merge pull request #6280 from VesnaT/rf_max_features
Browse files Browse the repository at this point in the history
Random Forest, Gradient Boosting: Handle deprecated parameters
  • Loading branch information
markotoplak authored Jan 11, 2023
2 parents 4f6d627 + 3a4bc8d commit 4348f1e
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Orange/classification/gb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class GBClassifier(SklLearner, _FeatureScorerMixin):
__returns__ = SklModel

def __init__(self,
loss="deviance",
loss="log_loss",
learning_rate=0.1,
n_estimators=100,
subsample=1.0,
Expand Down
2 changes: 1 addition & 1 deletion Orange/classification/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self,
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.,
max_features="auto",
max_features="sqrt",
max_leaf_nodes=None,
bootstrap=True,
oob_score=False,
Expand Down
2 changes: 0 additions & 2 deletions Orange/classification/xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def __init__(self,
importance_type=importance_type,
gpu_id=gpu_id,
validate_parameters=validate_parameters,
use_label_encoder=False,
preprocessors=preprocessors)


Expand Down Expand Up @@ -146,5 +145,4 @@ def __init__(self,
importance_type=importance_type,
gpu_id=gpu_id,
validate_parameters=validate_parameters,
use_label_encoder=False,
preprocessors=preprocessors)
2 changes: 1 addition & 1 deletion Orange/regression/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self,
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.,
max_features="auto",
max_features=1.0,
max_leaf_nodes=None,
bootstrap=True,
oob_score=False,
Expand Down
25 changes: 25 additions & 0 deletions Orange/tests/test_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,28 @@ def test_get_regression_trees(self):
self.assertEqual(len(model.trees), n)
tree = model.trees[0]
tree(self.housing[0])

def test_max_features_cls(self):
data = Table("heart_disease")
forest_1 = RandomForestLearner(random_state=0, max_features=1)
model_1 = forest_1(data[1:])

forest_2 = RandomForestLearner(random_state=0, max_features=1.)
model_2 = forest_2(data[1:])
diff = np.sum(np.abs(model_1(data[:1], ret=model_2.Probs) -
model_2(data[:1], ret=model_2.Probs)))
self.assertGreaterEqual(diff, 0.2)

def test_max_features_reg(self):
data = self.housing
forest_1 = RandomForestRegressionLearner(random_state=0, max_features=1)
model_1 = forest_1(data[1:])

forest_2 = RandomForestRegressionLearner(random_state=0, max_features=1.)
model_2 = forest_2(data[1:])
diff = np.sum(np.abs(model_1(data[:1]) - model_2(data[:1])))
self.assertGreater(diff, 1.2)


if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions Orange/tests/test_score_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,7 @@ def random_column(d):
scores2 = learner.score_data(data)

np.testing.assert_equal(scores1[0][:-1], scores2[0])


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion Orange/widgets/model/owrandomforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def add_main_layout(self):
alignment=Qt.AlignRight, label="Number of trees: ",
callback=self.settings_changed)
self.max_features_spin = gui.spin(
box, self, "max_features", 2, 50, controlWidth=80,
box, self, "max_features", 1, 50, controlWidth=80,
label="Number of attributes considered at each split: ",
callback=self.settings_changed, checked="use_max_features",
checkCallback=self.settings_changed, alignment=Qt.AlignRight,)
Expand Down
8 changes: 7 additions & 1 deletion Orange/widgets/model/tests/test_owrandomforest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring
import unittest

from AnyQt.QtCore import Qt
from Orange.data import Table
from Orange.widgets.model.owrandomforest import OWRandomForest
Expand Down Expand Up @@ -41,7 +43,7 @@ def test_parameters_unchecked(self):
self.widget.min_samples_split_spin[0].setCheckState(Qt.Unchecked)
self.parameters = self.parameters[:1]
self.parameters.extend([
DefaultParameterMapping("max_features", "auto"),
DefaultParameterMapping("max_features", "sqrt"),
DefaultParameterMapping("random_state", None),
DefaultParameterMapping("max_depth", None),
DefaultParameterMapping("min_samples_split", 2)])
Expand All @@ -56,3 +58,7 @@ def test_class_weights(self):
self.widget.apply_button.button.click()
self.assertEqual(self.widget.model.skl_model.class_weight, "balanced")
self.assertTrue(self.widget.Warning.class_weights_used.is_shown())


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ requirements:
- setuptools >=41.0.0
- numpy >=1.19.5
- scipy >=1.9
- scikit-learn >=1.0.1,<1.2.0
- scikit-learn >=1.1.0,<1.2.0
- bottleneck >=1.3.4
- chardet >=3.0.2
- xlrd >=1.2.0
- xlsxwriter
- anyqt >=0.1.0
- pyqt >=5.12,!=5.15.1,<6.0
- pyqtgraph >=0.12.2,!=0.12.4
- joblib >=0.11
- joblib >=1.0.0
- keyring
- keyrings.alt
- pip >=18.0
Expand Down
4 changes: 2 additions & 2 deletions requirements-core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ numpy>=1.19.5
scipy>=1.9
# scikit-learn version 1.0.0 includes problematic libomp 12 which breaks xgboost
# https://github.com/scikit-learn/scikit-learn/pull/21227
scikit-learn>=1.0.1,<1.2.0
scikit-learn>=1.1.0,<1.2.0
bottleneck>=1.3.4
# Reading Excel files
xlrd>=1.2.0
Expand All @@ -12,7 +12,7 @@ xlsxwriter
# Encoding detection
chardet>=3.0.2
# Multiprocessing abstraction
joblib>=0.11
joblib>=1.0.0
keyring
keyrings.alt # for alternative keyring implementations
setuptools>=41.0.0
Expand Down
2 changes: 1 addition & 1 deletion requirements-opt.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
catboost!=1.0.0 # 1.0.0 segfaults on Macs
catboost>=1.0.1
xgboost>=1.5.0
5 changes: 3 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ deps =
oldest: pip==18.0
oldest: numpy==1.19.5
oldest: scipy==1.9
oldest: scikit-learn==1.0.1
oldest: scikit-learn==1.1.0
oldest: bottleneck==1.3.4
oldest: xlrd==1.2.0
# oldest: xlsxwriter
oldest: chardet==3.0.2
oldest: joblib==0.11
oldest: joblib==1.0.0
# oldest: keyring
# oldest: keyrings.alt
oldest: setuptools==41.0.0
Expand All @@ -70,6 +70,7 @@ deps =
# oldest: openpyxl
oldest: httpx==0.21.0
oldest: xgboost==1.5.0
oldest: catboost==1.0.1

commands_pre =
# Verify installed packages have compatible dependencies
Expand Down

0 comments on commit 4348f1e

Please sign in to comment.