Skip to content

Commit

Permalink
Add more supported metrics to MhaSelector and MultiMhaSelector
Browse files Browse the repository at this point in the history
  • Loading branch information
thieu1995 committed Aug 7, 2023
1 parent a202714 commit 8ff68a5
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions mafese/wrapper/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from mafese.utils.mealpy_util import get_optimizer_by_name, get_all_optimizers, FeatureSelectionProblem, Optimizer
from mafese.utils import transfer
from mafese.utils.data_loader import Data
from mafese.utils.evaluator import get_metrics
from mafese.utils.evaluator import get_metrics, get_all_regression_metrics, get_all_classification_metrics
from permetrics.regression import RegressionMetric
from permetrics.classification import ClassificationMetric
import plotly.express as px
Expand Down Expand Up @@ -102,11 +102,8 @@ class MhaSelector(Selector):
SUPPORT = {
"estimator": ["knn", "svm", "rf", "adaboost", "xgb", "tree", "ann"],
"transfer_func": ["vstf_01", "vstf_02", "vstf_03", "vstf_04", "sstf_01", "sstf_02", "sstf_03", "sstf_04"],
"regression_objective": {"MAE": "min", "MSE": "min", "RMSE": "min", "MRE": "min", "MAPE": "min", "MASE": "min",
"NSE": "max", "NNSE": "max", "WI": "max", "PCC": "max", "R2s": "max", "R2": "max", "AR2": "max",
"CI": "max", "KGE": "max", "VAF": "max", "A10": "max", "A20": "max"},
"classification_objective": {"AS": "max", "PS": "max", "NPV": "max", "RS": "max", "F1S": "max", "F2S": "max",
"FBS": "max", "SS": "max", "MCC": "max", "JSI": "max", "CKS": "max", "ROC-AUC": "max"},
"regression_objective": get_all_regression_metrics(),
"classification_objective": get_all_classification_metrics(),
"optimizer": list(get_all_optimizers().keys())
}

Expand Down Expand Up @@ -243,11 +240,8 @@ class MultiMhaSelector(Selector):
SUPPORT = {
"estimator": ["knn", "svm", "rf", "adaboost", "xgb", "tree", "ann"],
"transfer_func": ["vstf_01", "vstf_02", "vstf_03", "vstf_04", "sstf_01", "sstf_02", "sstf_03", "sstf_04"],
"regression_objective": {"MAE": "min", "MSE": "min", "RMSE": "min", "MRE": "min", "MAPE": "min", "MASE": "min",
"NSE": "max", "NNSE": "max", "WI": "max", "PCC": "max", "R2s": "max", "R2": "max", "AR2": "max",
"CI": "max", "KGE": "max", "VAF": "max", "A10": "max", "A20": "max"},
"classification_objective": {"AS": "max", "PS": "max", "NPV": "max", "RS": "max", "F1S": "max", "F2S": "max",
"FBS": "max", "SS": "max", "MCC": "max", "JSI": "max", "CKS": "max", "ROC-AUC": "max"},
"regression_objective": get_all_regression_metrics(),
"classification_objective": get_all_classification_metrics(),
"optimizer": list(get_all_optimizers().keys())
}

Expand Down

0 comments on commit 8ff68a5

Please sign in to comment.