Skip to content

Commit

Permalink
enas move check algorithm settings to Validate (kubeflow#1146)
Browse files Browse the repository at this point in the history
  • Loading branch information
andreyvelich authored and sperlingxx committed Apr 20, 2020
1 parent 9d2ae1d commit 976d8e0
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 63 deletions.
81 changes: 23 additions & 58 deletions pkg/suggestion/v1alpha3/nas/enas/AlgorithmSettings.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
def parseAlgorithmSettings(params_raw, logger):

param_standard = {
"controller_hidden_size": ['value', int, [1, 'inf']],
"controller_temperature": ['value', float, [0, 'inf']],
"controller_tanh_const": ['value', float, [0, 'inf']],
"controller_entropy_weight": ['value', float, [0.0, 'inf']],
"controller_baseline_decay": ['value', float, [0.0, 1.0]],
"controller_learning_rate": ['value', float, [0.0, 1.0]],
"controller_skip_target": ['value', float, [0.0, 1.0]],
"controller_skip_weight": ['value', float, [0.0, 'inf']],
"controller_train_steps": ['value', int, [1, 'inf']],
"controller_log_every_steps": ['value', int, [1, 'inf']],
}
algorithmSettingsValidator = {
"controller_hidden_size": [int, [1, 'inf']],
"controller_temperature": [float, [0, 'inf']],
"controller_tanh_const": [float, [0, 'inf']],
"controller_entropy_weight": [float, [0.0, 'inf']],
"controller_baseline_decay": [float, [0.0, 1.0]],
"controller_learning_rate": [float, [0.0, 1.0]],
"controller_skip_target": [float, [0.0, 1.0]],
"controller_skip_weight": [float, [0.0, 'inf']],
"controller_train_steps": [int, [1, 'inf']],
"controller_log_every_steps": [int, [1, 'inf']],
}


algorithm_settings = {
# TODO: Enable to add None values, e.g in controller_temperature parameter
def parseAlgorithmSettings(settings_raw):

algorithm_settings_default = {
"controller_hidden_size": 64,
"controller_temperature": 5.,
"controller_tanh_const": 2.25,
Expand All @@ -26,48 +29,10 @@ def parseAlgorithmSettings(params_raw, logger):
"controller_log_every_steps": 10,
}

# TODO: Enable to add None values, e.g in controller_temperature parameter
# TODO: Delete it and add to the Validation part
def checktype(param_name, param_value, check_mode, supposed_type, supposed_range=None, logger=None):
correct = True

try:
converted_value = supposed_type(param_value)
except:
correct = False
logger.info("Parameter {} is of wrong type. Set back to default value {}"
.format(param_name, algorithm_settings[param_name]))

if correct and check_mode == 'value':
if (
(supposed_range[0] != '-inf' and
((supposed_type == float and converted_value <= supposed_range[0]) or
converted_value < supposed_range[0])
) or
(supposed_range[1] != 'inf' and converted_value > supposed_range[1])
):
correct = False
logger.info("Parameter {} out of range. Set back to default value {}"
.format(param_name, algorithm_settings[param_name]))

elif correct and check_mode == 'categorical':
if converted_value not in supposed_range:
correct = False
logger.info("Parameter {} out of range. Set back to default value {}"
.format(param_name, algorithm_settings[param_name]))

if correct:
algorithm_settings[param_name] = converted_value

for param in params_raw:
if param.name in algorithm_settings.keys():
checktype(param.name,
param.value,
param_standard[param.name][0], # mode
param_standard[param.name][1], # type
param_standard[param.name][2], # range
logger)
else:
logger.info("Unknown Parameter name: {}".format(param.name))
for setting in settings_raw:
s_name = setting.name
s_value = setting.value
s_type = algorithmSettingsValidator[s_name][0]
algorithm_settings_default[s_name] = s_type(s_value)

return algorithm_settings
return algorithm_settings_default
33 changes: 28 additions & 5 deletions pkg/suggestion/v1alpha3/nas/enas_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pkg.apis.manager.v1alpha3.python import api_pb2_grpc
from pkg.suggestion.v1alpha3.nas.enas.Controller import Controller
from pkg.suggestion.v1alpha3.nas.enas.Operation import SearchSpace
from pkg.suggestion.v1alpha3.nas.enas.AlgorithmSettings import parseAlgorithmSettings
from pkg.suggestion.v1alpha3.nas.enas.AlgorithmSettings import parseAlgorithmSettings, algorithmSettingsValidator
from pkg.suggestion.v1alpha3.base_health_service import HealthServicer


Expand Down Expand Up @@ -66,9 +66,9 @@ def _get_experiment_param(self):

self.print_search_space()

# Get Experiment Parameters
params_raw = self.experiment.spec.algorithm.algorithm_setting
self.algorithm_settings = parseAlgorithmSettings(params_raw, self.logger)
# Get Experiment Algorithm Settings
settings_raw = self.experiment.spec.algorithm.algorithm_setting
self.algorithm_settings = parseAlgorithmSettings(settings_raw)

self.print_algorithm_settings()

Expand Down Expand Up @@ -150,7 +150,6 @@ def ValidateAlgorithmSettings(self, request, context):
self.logger.info("Validate Algorithm Settings start")
graph_config = request.experiment.spec.nas_config.graph_config

# TODO: Refactor this since we validate it in Katib Controller
# Validate GraphConfig
# Check InputSize
if not graph_config.input_sizes:
Expand Down Expand Up @@ -202,6 +201,30 @@ def ValidateAlgorithmSettings(self, request, context):
if parameter.parameter_type == api_pb2.DOUBLE and (not parameter.feasible_space.step or float(parameter.feasible_space.step) <= 0):
return self.SetValidateContextError(context, "Step parameter should be > 0 in ParameterConfig.feasibleSpace:\n{}".format(parameter))

# Validate Algorithm Settings
settings_raw = request.experiment.spec.algorithm.algorithm_setting
for setting in settings_raw:
if setting.name in algorithmSettingsValidator.keys():
setting_type = algorithmSettingsValidator[setting.name][0]
setting_range = algorithmSettingsValidator[setting.name][1]
try:
converted_value = setting_type(setting.value)
except:
return self.SetValidateContextError(context, "Algorithm Setting {} must be {} type".format(setting.name, setting_type.__name__))

if setting_type == float:
if converted_value <= setting_range[0] or (setting_range[1] != 'inf' and converted_value > setting_range[1]):
return self.SetValidateContextError(context, "Algorithm Setting {}: {} with {} type must be in range ({}, {}]".format(
setting.name, converted_value, setting_type.__name__, setting_range[0], setting_range[1]
))

elif converted_value < setting_range[0]:
return self.SetValidateContextError(context, "Algorithm Setting {}: {} with {} type must be in range [{}, {})".format(
setting.name, converted_value, setting_type.__name__, setting_range[0], setting_range[1]
))
else:
return self.SetValidateContextError(context, "Unknown Algorithm Setting name: {}".format(setting.name))

self.logger.info("All Experiment Settings are Valid")
return api_pb2.ValidateAlgorithmSettingsReply()

Expand Down

0 comments on commit 976d8e0

Please sign in to comment.