Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup is_greater_better function #76

Merged
merged 5 commits into from
Sep 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions matbench/metalearning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def element_category(element):
six.string_types):
element = Element(element)
if element.is_transition_metal:
if element.is_lanthanoid or element.symbol in ("Y", "Sc"):
if element.is_lanthanoid or element.symbol in {"Y", "Sc"}:
return 2
elif element.is_actinoid:
return 3
Expand All @@ -180,15 +180,15 @@ def element_category(element):
return 4
elif element.is_alkaline:
return 5
elif element.symbol in ("Al", "Ga", "In", "Tl", "Sn", "Pb", "Bi", "Po"):
elif element.symbol in {"Al", "Ga", "In", "Tl", "Sn", "Pb", "Bi", "Po"}:
return 6
elif element.is_metalloid:
return 7
# elif element.is_chalcogen:
# return 8
elif element.is_halogen:
return 8
elif element.symbol in ("C", "H", "N", "P", "O", "S", "Se"):
elif element.symbol in {"C", "H", "N", "P", "O", "S", "Se"}:
return 9
elif element.is_noble_gas:
return 10
Expand Down
49 changes: 32 additions & 17 deletions matbench/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def setup_custom_logger(name='matbench_logger', filepath='.',
logger.setLevel(level)
logger.addHandler(screen_handler)
logger.addHandler(handler)

return logger


Expand All @@ -55,23 +56,37 @@ def is_greater_better(scoring_function):
scoring_function (str): the name of the scoring function supported by
TPOT and sklearn. Please see below for more information.

Returns (bool):
Returns (bool): Whether the scoring metric should be considered better if
it is larger or better if it is smaller
"""
if scoring_function in [
desired_high_metrics = {
'accuracy', 'adjusted_rand_score', 'average_precision',
'balanced_accuracy','f1', 'f1_macro', 'f1_micro', 'f1_samples',
'balanced_accuracy', 'f1', 'f1_macro', 'f1_micro', 'f1_samples',
'f1_weighted', 'precision', 'precision_macro', 'precision_micro',
'precision_samples','precision_weighted', 'recall',
'recall_macro', 'recall_micro','recall_samples',
'recall_weighted', 'roc_auc'] + \
['r2', 'neg_median_absolute_error', 'neg_mean_absolute_error',
'neg_mean_squared_error']:
return True
elif scoring_function in ['median_absolute_error',
'mean_absolute_error',
'mean_squared_error']:
return False
else:
warnings.warn('The scoring_function: "{}" not found; continuing assuming'
' greater score is better'.format(scoring_function))
return True
'precision_samples', 'precision_weighted', 'recall',
'recall_macro', 'recall_micro', 'recall_samples',
'recall_weighted', 'roc_auc' 'r2', 'neg_median_absolute_error',
'neg_mean_absolute_error', 'neg_mean_squared_error'
}

desired_low_metrics = {
'median_absolute_error',
'mean_absolute_error',
'mean_squared_error'
}

# Check to ensure no metrics are accidentally placed in both sets
if desired_high_metrics.intersection(desired_low_metrics):
raise MatbenchError("Error, there is a metric in both desired"
" high and desired low metrics")

if scoring_function not in desired_high_metrics \
and scoring_function not in desired_low_metrics:

warnings.warn(
'The scoring_function: "{}" not found; continuing assuming'
' greater score is better'.format(scoring_function))

# True if not in either set or only in desired_high,
# False if in desired_low or both sets
return scoring_function not in desired_low_metrics