forked from autogluon/autogluon
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Neural network based quantile regression models
Tabular: Re-order NN model priority (autogluon#1059) Tabular: Added Adaptive Early Stopping (autogluon#1042) * Tabular: Added AdaptiveES, default adaptive to LightGBM * ag.es -> ag.early_stop * addressed comments Tabular: Upgraded CatBoost to v0.25 (autogluon#1064) Tabular: Added extra_metrics argument to leaderboard (autogluon#1058) * Tabular: Added extra_metrics argument to leaderboard * addressed comments Upgrade psutil and scipy (autogluon#1072) Tabular: Added efficient OOF functionality to RF/XT models (autogluon#1066) * Tabular: Added efficient OOF functionality to RF/XT models * addressed comments, disabled RF/XT use_child_oof by default Tabular: Adjusted per-level stack time (autogluon#1075) * Tabular: Added efficient OOF functionality to RF/XT models * addressed comments, disabled RF/XT use_child_oof by default * Tabular: Adjusted stack time limit allocation Constrained Bayesian optimization (autogluon#1034) * Constrained Bayesian optimization * Comments from Matthias * Fix random_seed keyword * constraint_attribute + other comments * Fix import Co-authored-by: Valerio Perrone <vperrone@amazon.com> Refactoring of FIFOScheduler, HyperbandScheduler: self.time_out bette… (autogluon#1050) * Refactoring of FIFOScheduler, HyperbandScheduler: self.time_out better respected by stopping jobs when they run over * Added an option and warning concerning the changed meaning of 'time_out' * Removed code to add time_this_iter to result in reporter (buggy, and not used) update predict_proba return (autogluon#1044) * update predict_proba return * non-api breaking * bump * update format * update label format and predict_proba * add test * fix d8 * remove squeeze * fix * fix incorrect class mapping, force it align with label column * fix * fix label * fix sorted list * fix * reset labels * fix test * address comments * fix test * fix * label * test for custom label Vision: Limited gluoncv version (autogluon#1081) Tabular: RF/XT Efficient OOB (autogluon#1082) * Tabular: Enabled efficient OOB for RF/XT * Tabular: Removed min_samples_leaf * 300 estimators Tabular: Refactored evaluate/evaluate_predictions (autogluon#1080) * Tabular: Refactored evaluate/evaluate_predictions * minor fix Tabular: Reorder model priority (autogluon#1084) * Tabular: Enabled efficient OOB for RF/XT * Tabular: Removed min_samples_leaf * 300 estimators * Tabular: Reordered model training priority * added memory check before training XGBoost * minor update * fix xgboost Updated to v0.2.0 (autogluon#1086) Restricted sklearn to >=0.23.2 (autogluon#1088) Update to 0.2.1 (autogluon#1087) TextPredictor fails if eval_metric = 'average_precision' (autogluon#1092) * TextPredictor fails if eval_metric = 'average_precision' Fixes autogluon#1085 * TextPredictor fails if eval_metric = 'average_precision' Fixes autogluon#1085 Co-authored-by: Rohit Jain <rohit@thetalake.com> upgrade SHAP notebooks (autogluon#1089) tell users to search closed issues (autogluon#1095) Added tutorial / API reference table to README.md (autogluon#1093) Tabular: Added ImagePredictorModel (autogluon#1041) * Tabular: Added ImagePredictorModel * Added ImagePredictorModel unittest * revert accidental minimum_cat_count change * addressed comments * addressed comments * Updated after ImagePredictor refactor * minor fix * Addressed comments add `tabular_nn_torch.py`
- Loading branch information
1 parent
b257068
commit e8998ed
Showing
79 changed files
with
6,269 additions
and
2,892 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
0.1.1 | ||
0.2.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,30 @@ | ||
|
||
from autogluon.core.utils.early_stopping import AdaptiveES, ES_CLASS_MAP | ||
|
||
|
||
# TODO: Add more strategies | ||
# - Adaptive early stopping: adjust rounds during model training | ||
def get_early_stopping_rounds(num_rows_train, strategy='auto', min_rounds=10, max_rounds=150, min_rows=10000): | ||
def get_early_stopping_rounds(num_rows_train, strategy='auto', min_patience=10, max_patience=150, min_rows=10000): | ||
if isinstance(strategy, (tuple, list)): | ||
strategy = list(strategy) | ||
if isinstance(strategy[0], str): | ||
if strategy[0] in ES_CLASS_MAP: | ||
strategy[0] = ES_CLASS_MAP[strategy[0]] | ||
else: | ||
raise AssertionError(f'unknown early stopping strategy: {strategy}') | ||
return strategy | ||
|
||
"""Gets early stopping rounds""" | ||
if strategy == 'auto': | ||
modifier = 1 if num_rows_train <= min_rows else min_rows / num_rows_train | ||
early_stopping_rounds = max( | ||
round(modifier * max_rounds), | ||
min_rounds, | ||
) | ||
strategy = 'simple' | ||
|
||
modifier = 1 if num_rows_train <= min_rows else min_rows / num_rows_train | ||
simple_early_stopping_rounds = max( | ||
round(modifier * max_patience), | ||
min_patience, | ||
) | ||
if strategy == 'simple': | ||
return simple_early_stopping_rounds | ||
elif strategy == 'adaptive': | ||
return AdaptiveES, dict(adaptive_offset=min_patience, min_patience=simple_early_stopping_rounds) | ||
else: | ||
raise AssertionError(f'unknown early stopping strategy: {strategy}') | ||
return early_stopping_rounds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.