Skip to content

Commit

Permalink
Merge pull request #585 from EliverQ/0.2.x
Browse files Browse the repository at this point in the history
FEA:add param_dict in hyper_tuning.py
  • Loading branch information
EliverQ authored Dec 16, 2020
2 parents 333a4e3 + b6d9e5e commit 8311b8b
Showing 1 changed file with 36 additions and 2 deletions.
38 changes: 36 additions & 2 deletions recbole/trainer/hyper_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class HyperTuning(object):
https://github.com/hyperopt/hyperopt/issues/200
"""

def __init__(self, objective_function, space=None, params_file=None, fixed_config_file_list=None,
def __init__(self, objective_function, space=None, params_file=None, params_dict=None, fixed_config_file_list=None,
algo='exhaustive', max_evals=100):
self.best_score = None
self.best_params = None
Expand All @@ -150,8 +150,10 @@ def __init__(self, objective_function, space=None, params_file=None, fixed_confi
self.space = space
elif params_file:
self.space = self._build_space_from_file(params_file)
elif params_dict:
self.space = self._build_space_from_dict(params_dict)
else:
raise ValueError('at least one of `space` and `params_file` is provided')
raise ValueError('at least one of `space`, `params_file` and `params_dict` is provided')
if isinstance(algo, str):
if algo == 'exhaustive':
self.algo = partial(exhaustive_search, nbMaxSucessiveFailures=1000)
Expand Down Expand Up @@ -187,6 +189,38 @@ def _build_space_from_file(file):
raise ValueError('Illegal param type [{}]'.format(para_type))
return space

@staticmethod
def _build_space_from_dict(config_dict):
from hyperopt import hp
space = {}
for para_type in config_dict:
if para_type == 'choice':
for para_name in config_dict['choice']:
para_value = config_dict['choice'][para_name]
space[para_name] = hp.choice(para_name, para_value)
elif para_type == 'uniform':
for para_name in config_dict['uniform']:
para_value = config_dict['uniform'][para_name]
low = para_value[0]
high = para_value[1]
space[para_name] = hp.uniform(para_name, float(low), float(high))
elif para_type == 'quniform':
for para_name in config_dict['quniform']:
para_value = config_dict['quniform'][para_name]
low = para_value[0]
high = para_value[1]
q = para_value[2]
space[para_name] = hp.quniform(para_name, float(low), float(high), float(q))
elif para_type =='loguniform':
for para_name in config_dict['loguniform']:
para_value = config_dict['loguniform'][para_name]
low = para_value[0]
high = para_value[1]
space[para_name] = hp.loguniform(para_name, float(low), float(high))
else:
raise ValueError('Illegal param type [{}]'.format(para_type))
return space

@staticmethod
def params2str(params):
r""" convert dict to str
Expand Down

0 comments on commit 8311b8b

Please sign in to comment.