Skip to content

Commit

Permalink
Add hyper parameter api.
Browse files Browse the repository at this point in the history
  • Loading branch information
ShenYuhan authored May 21, 2021
1 parent 92eb7c1 commit d946813
Show file tree
Hide file tree
Showing 11 changed files with 530 additions and 11 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ requests
shellcheck-py
six >= 1.14.0
matplotlib
pandas
45 changes: 45 additions & 0 deletions visualdl/component/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,51 @@ def histogram(tag, hist, bin_edges, step, walltime):
])


def hparam(name, hparam_dict, metric_list, walltime):
"""Package data to one histogram.
Args:
name (str): Name of hparam.
hparam_dict (dictionary): Each key-value pair in the dictionary is the
name of the hyper parameter and it's corresponding value. The type of the value
can be one of `bool`, `string`, `float`, `int`, or `None`.
metric_list (list): Name of all metrics.
walltime (int): Wall time of hparam.
Return:
Package with format of record_pb2.Record
"""

hm = Record.HParam()
hm.name = name
for k, v in hparam_dict.items():
if v is None:
continue
hparamInfo = Record.HParam.HparamInfo()
hparamInfo.name = k
if isinstance(v, int):
hparamInfo.int_value = v
hm.hparamInfos.append(hparamInfo)
elif isinstance(v, float):
hparamInfo.float_value = v
hm.hparamInfos.append(hparamInfo)
elif isinstance(v, str):
hparamInfo.string_value = v
hm.hparamInfos.append(hparamInfo)
else:
print("The value of %s must be int, float or str, not %s" % (k, str(type(v))))
for metric in metric_list:
metricInfo = Record.HParam.HparamInfo()
metricInfo.name = metric
metricInfo.float_value = 0
hm.metricInfos.append(metricInfo)

return Record(values=[
Record.Value(
id=1, tag="hparam", timestamp=walltime, hparam=hm)
])


def compute_curve(labels, predictions, num_thresholds=None, weights=None):
""" Compute precision-recall curve data by labels and predictions.
Expand Down
15 changes: 15 additions & 0 deletions visualdl/proto/record.proto
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ message Record {
repeated double fpr = 6;
}

message HParam {
message HparamInfo {
oneof type {
int64 int_value = 1;
double float_value = 2;
string string_value = 3;
};
string name = 4;
}
repeated HparamInfo hparamInfos = 1;
repeated HparamInfo metricInfos = 2;
string name = 3;
}

message MetaData {
string display_name = 1;
}
Expand All @@ -75,6 +89,7 @@ message Record {
MetaData meta_data = 10;
ROC_Curve roc_curve = 11;
Text text = 12;
HParam hparam = 13;
}
}

Expand Down
152 changes: 145 additions & 7 deletions visualdl/proto/record_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions visualdl/reader/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def parse_from_bin(self, record_bin):
component = "meta_data"
elif "text" == value_type:
component = "text"
elif "hparam" == value_type:
component = "hyper_parameters"
else:
raise TypeError("Invalid value type `%s`." % value_type)
self._tags[path] = component
Expand Down
29 changes: 28 additions & 1 deletion visualdl/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,28 @@ def pr_curve_tags(self):
def roc_curve_tags(self):
return self._get_with_retry('data/plugin/roc_curves/tags', lib.get_roc_curve_tags)

@result()
def hparam_importance(self):
return self._get_with_retry('data/plugin/hparams/importance', lib.get_hparam_importance)

@result()
def hparam_indicator(self):
return self._get_with_retry('data/plugin/hparams/indicators', lib.get_hparam_indicator)

@result()
def hparam_list(self):
return self._get_with_retry('data/plugin/hparams/list', lib.get_hparam_list)

@result()
def hparam_metric(self, run, metric):
key = os.path.join('data/plugin/hparams/metric', run, metric)
return self._get_with_retry(key, lib.get_hparam_metric, run, metric)

@result('text/csv')
def hparam_data(self, type='tsv'):
key = os.path.join('data/plugin/hparams/data', type)
return self._get_with_retry(key, lib.get_hparam_data, type)

@result()
def scalar_list(self, run, tag):
key = os.path.join('data/plugin/scalars/scalars', run, tag)
Expand Down Expand Up @@ -254,7 +276,12 @@ def create_api_call(logdir, model, cache_timeout):
'pr-curve/list': (api.pr_curves_pr_curve, ['run', 'tag']),
'roc-curve/list': (api.roc_curves_roc_curve, ['run', 'tag']),
'pr-curve/steps': (api.pr_curves_steps, ['run']),
'roc-curve/steps': (api.roc_curves_steps, ['run'])
'roc-curve/steps': (api.roc_curves_steps, ['run']),
'hparams/importance': (api.hparam_importance, []),
'hparams/data': (api.hparam_data, ['type']),
'hparams/indicators': (api.hparam_indicator, []),
'hparams/list': (api.hparam_list, []),
'hparams/metric': (api.hparam_metric, ['run', 'metric'])
}

def call(path: str, args):
Expand Down
7 changes: 5 additions & 2 deletions visualdl/server/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
"pr_curve": 300,
"roc_curve": 300,
"meta_data": 100,
"text": 10
"text": 10,
"hyper_parameters": 10000
}


Expand Down Expand Up @@ -353,7 +354,9 @@ def __init__(self):
"meta_data":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["meta_data"]),
"text":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["text"])
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["text"]),
"hyper_parameters":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["hyper_parameters"])
}
self._mutex = threading.Lock()

Expand Down
Loading

0 comments on commit d946813

Please sign in to comment.