Skip to content

Commit

Permalink
Merge pull request #76 from Doppe1g4nger/master
Browse files Browse the repository at this point in the history
Cleanup is_greater_better function
  • Loading branch information
Doppe1g4nger authored Sep 24, 2018
2 parents 48bc86c + 63848d3 commit 05af98e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 20 deletions.
6 changes: 3 additions & 3 deletions matbench/metalearning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def element_category(element):
six.string_types):
element = Element(element)
if element.is_transition_metal:
if element.is_lanthanoid or element.symbol in ("Y", "Sc"):
if element.is_lanthanoid or element.symbol in {"Y", "Sc"}:
return 2
elif element.is_actinoid:
return 3
Expand All @@ -180,15 +180,15 @@ def element_category(element):
return 4
elif element.is_alkaline:
return 5
elif element.symbol in ("Al", "Ga", "In", "Tl", "Sn", "Pb", "Bi", "Po"):
elif element.symbol in {"Al", "Ga", "In", "Tl", "Sn", "Pb", "Bi", "Po"}:
return 6
elif element.is_metalloid:
return 7
# elif element.is_chalcogen:
# return 8
elif element.is_halogen:
return 8
elif element.symbol in ("C", "H", "N", "P", "O", "S", "Se"):
elif element.symbol in {"C", "H", "N", "P", "O", "S", "Se"}:
return 9
elif element.is_noble_gas:
return 10
Expand Down
49 changes: 32 additions & 17 deletions matbench/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def setup_custom_logger(name='matbench_logger', filepath='.',
logger.setLevel(level)
logger.addHandler(screen_handler)
logger.addHandler(handler)

return logger


Expand All @@ -55,23 +56,37 @@ def is_greater_better(scoring_function):
scoring_function (str): the name of the scoring function supported by
TPOT and sklearn. Please see below for more information.
Returns (bool):
Returns (bool): Whether the scoring metric should be considered better if
it is larger or better if it is smaller
"""
if scoring_function in [
desired_high_metrics = {
'accuracy', 'adjusted_rand_score', 'average_precision',
'balanced_accuracy','f1', 'f1_macro', 'f1_micro', 'f1_samples',
'balanced_accuracy', 'f1', 'f1_macro', 'f1_micro', 'f1_samples',
'f1_weighted', 'precision', 'precision_macro', 'precision_micro',
'precision_samples','precision_weighted', 'recall',
'recall_macro', 'recall_micro','recall_samples',
'recall_weighted', 'roc_auc'] + \
['r2', 'neg_median_absolute_error', 'neg_mean_absolute_error',
'neg_mean_squared_error']:
return True
elif scoring_function in ['median_absolute_error',
'mean_absolute_error',
'mean_squared_error']:
return False
else:
warnings.warn('The scoring_function: "{}" not found; continuing assuming'
' greater score is better'.format(scoring_function))
return True
'precision_samples', 'precision_weighted', 'recall',
'recall_macro', 'recall_micro', 'recall_samples',
'recall_weighted', 'roc_auc' 'r2', 'neg_median_absolute_error',
'neg_mean_absolute_error', 'neg_mean_squared_error'
}

desired_low_metrics = {
'median_absolute_error',
'mean_absolute_error',
'mean_squared_error'
}

# Check to ensure no metrics are accidentally placed in both sets
if desired_high_metrics.intersection(desired_low_metrics):
raise MatbenchError("Error, there is a metric in both desired"
" high and desired low metrics")

if scoring_function not in desired_high_metrics \
and scoring_function not in desired_low_metrics:

warnings.warn(
'The scoring_function: "{}" not found; continuing assuming'
' greater score is better'.format(scoring_function))

# True if not in either set or only in desired_high,
# False if in desired_low or both sets
return scoring_function not in desired_low_metrics

0 comments on commit 05af98e

Please sign in to comment.