diff --git a/matbench/metalearning/utils.py b/matbench/metalearning/utils.py index 43141000..0447366e 100644 --- a/matbench/metalearning/utils.py +++ b/matbench/metalearning/utils.py @@ -170,7 +170,7 @@ def element_category(element): six.string_types): element = Element(element) if element.is_transition_metal: - if element.is_lanthanoid or element.symbol in ("Y", "Sc"): + if element.is_lanthanoid or element.symbol in {"Y", "Sc"}: return 2 elif element.is_actinoid: return 3 @@ -180,7 +180,7 @@ def element_category(element): return 4 elif element.is_alkaline: return 5 - elif element.symbol in ("Al", "Ga", "In", "Tl", "Sn", "Pb", "Bi", "Po"): + elif element.symbol in {"Al", "Ga", "In", "Tl", "Sn", "Pb", "Bi", "Po"}: return 6 elif element.is_metalloid: return 7 @@ -188,7 +188,7 @@ def element_category(element): # return 8 elif element.is_halogen: return 8 - elif element.symbol in ("C", "H", "N", "P", "O", "S", "Se"): + elif element.symbol in {"C", "H", "N", "P", "O", "S", "Se"}: return 9 elif element.is_noble_gas: return 10 diff --git a/matbench/utils/utils.py b/matbench/utils/utils.py index 92324343..1deee9f5 100644 --- a/matbench/utils/utils.py +++ b/matbench/utils/utils.py @@ -44,6 +44,7 @@ def setup_custom_logger(name='matbench_logger', filepath='.', logger.setLevel(level) logger.addHandler(screen_handler) logger.addHandler(handler) + return logger @@ -55,23 +56,37 @@ def is_greater_better(scoring_function): scoring_function (str): the name of the scoring function supported by TPOT and sklearn. Please see below for more information. - Returns (bool): + Returns (bool): Whether the scoring metric should be considered better if + it is larger or better if it is smaller """ - if scoring_function in [ + desired_high_metrics = { 'accuracy', 'adjusted_rand_score', 'average_precision', - 'balanced_accuracy','f1', 'f1_macro', 'f1_micro', 'f1_samples', + 'balanced_accuracy', 'f1', 'f1_macro', 'f1_micro', 'f1_samples', 'f1_weighted', 'precision', 'precision_macro', 'precision_micro', - 'precision_samples','precision_weighted', 'recall', - 'recall_macro', 'recall_micro','recall_samples', - 'recall_weighted', 'roc_auc'] + \ - ['r2', 'neg_median_absolute_error', 'neg_mean_absolute_error', - 'neg_mean_squared_error']: - return True - elif scoring_function in ['median_absolute_error', - 'mean_absolute_error', - 'mean_squared_error']: - return False - else: - warnings.warn('The scoring_function: "{}" not found; continuing assuming' - ' greater score is better'.format(scoring_function)) - return True \ No newline at end of file + 'precision_samples', 'precision_weighted', 'recall', + 'recall_macro', 'recall_micro', 'recall_samples', + 'recall_weighted', 'roc_auc' 'r2', 'neg_median_absolute_error', + 'neg_mean_absolute_error', 'neg_mean_squared_error' + } + + desired_low_metrics = { + 'median_absolute_error', + 'mean_absolute_error', + 'mean_squared_error' + } + + # Check to ensure no metrics are accidentally placed in both sets + if desired_high_metrics.intersection(desired_low_metrics): + raise MatbenchError("Error, there is a metric in both desired" + " high and desired low metrics") + + if scoring_function not in desired_high_metrics \ + and scoring_function not in desired_low_metrics: + + warnings.warn( + 'The scoring_function: "{}" not found; continuing assuming' + ' greater score is better'.format(scoring_function)) + + # True if not in either set or only in desired_high, + # False if in desired_low or both sets + return scoring_function not in desired_low_metrics