Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add hp be api #961

Merged
merged 7 commits into from
May 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ requests
shellcheck-py
six >= 1.14.0
matplotlib
pandas
45 changes: 45 additions & 0 deletions visualdl/component/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,51 @@ def histogram(tag, hist, bin_edges, step, walltime):
])


def hparam(name, hparam_dict, metric_list, walltime):
"""Package data to one histogram.
Args:
name (str): Name of hparam.
hparam_dict (dictionary): Each key-value pair in the dictionary is the
name of the hyper parameter and it's corresponding value. The type of the value
can be one of `bool`, `string`, `float`, `int`, or `None`.
metric_list (list): Name of all metrics.
walltime (int): Wall time of hparam.
Return:
Package with format of record_pb2.Record
"""

hm = Record.HParam()
hm.name = name
for k, v in hparam_dict.items():
if v is None:
continue
hparamInfo = Record.HParam.HparamInfo()
hparamInfo.name = k
if isinstance(v, int):
hparamInfo.int_value = v
hm.hparamInfos.append(hparamInfo)
elif isinstance(v, float):
hparamInfo.float_value = v
hm.hparamInfos.append(hparamInfo)
elif isinstance(v, str):
hparamInfo.string_value = v
hm.hparamInfos.append(hparamInfo)
else:
print("The value of %s must be int, float or str, not %s" % (k, str(type(v))))
for metric in metric_list:
metricInfo = Record.HParam.HparamInfo()
metricInfo.name = metric
metricInfo.float_value = 0
hm.metricInfos.append(metricInfo)

return Record(values=[
Record.Value(
id=1, tag="hparam", timestamp=walltime, hparam=hm)
])


def compute_curve(labels, predictions, num_thresholds=None, weights=None):
""" Compute precision-recall curve data by labels and predictions.
Expand Down
15 changes: 15 additions & 0 deletions visualdl/proto/record.proto
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ message Record {
repeated double fpr = 6;
}

message HParam {
message HparamInfo {
oneof type {
int64 int_value = 1;
double float_value = 2;
string string_value = 3;
};
string name = 4;
}
repeated HparamInfo hparamInfos = 1;
repeated HparamInfo metricInfos = 2;
string name = 3;
}

message MetaData {
string display_name = 1;
}
Expand All @@ -75,6 +89,7 @@ message Record {
MetaData meta_data = 10;
ROC_Curve roc_curve = 11;
Text text = 12;
HParam hparam = 13;
}
}

Expand Down
152 changes: 145 additions & 7 deletions visualdl/proto/record_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions visualdl/reader/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def parse_from_bin(self, record_bin):
component = "meta_data"
elif "text" == value_type:
component = "text"
elif "hparam" == value_type:
component = "hyper_parameters"
else:
raise TypeError("Invalid value type `%s`." % value_type)
self._tags[path] = component
Expand Down
29 changes: 28 additions & 1 deletion visualdl/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,28 @@ def pr_curve_tags(self):
def roc_curve_tags(self):
return self._get_with_retry('data/plugin/roc_curves/tags', lib.get_roc_curve_tags)

@result()
def hparam_importance(self):
return self._get_with_retry('data/plugin/hparams/importance', lib.get_hparam_importance)

@result()
def hparam_indicator(self):
return self._get_with_retry('data/plugin/hparams/indicators', lib.get_hparam_indicator)

@result()
def hparam_list(self):
return self._get_with_retry('data/plugin/hparams/list', lib.get_hparam_list)

@result()
def hparam_metric(self, run, metric):
key = os.path.join('data/plugin/hparams/metric', run, metric)
return self._get_with_retry(key, lib.get_hparam_metric, run, metric)

@result('text/csv')
def hparam_data(self, type='tsv'):
key = os.path.join('data/plugin/hparams/data', type)
return self._get_with_retry(key, lib.get_hparam_data, type)

@result()
def scalar_list(self, run, tag):
key = os.path.join('data/plugin/scalars/scalars', run, tag)
Expand Down Expand Up @@ -254,7 +276,12 @@ def create_api_call(logdir, model, cache_timeout):
'pr-curve/list': (api.pr_curves_pr_curve, ['run', 'tag']),
'roc-curve/list': (api.roc_curves_roc_curve, ['run', 'tag']),
'pr-curve/steps': (api.pr_curves_steps, ['run']),
'roc-curve/steps': (api.roc_curves_steps, ['run'])
'roc-curve/steps': (api.roc_curves_steps, ['run']),
'hparams/importance': (api.hparam_importance, []),
'hparams/data': (api.hparam_data, ['type']),
'hparams/indicators': (api.hparam_indicator, []),
'hparams/list': (api.hparam_list, []),
'hparams/metric': (api.hparam_metric, ['run', 'metric'])
}

def call(path: str, args):
Expand Down
7 changes: 5 additions & 2 deletions visualdl/server/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
"pr_curve": 300,
"roc_curve": 300,
"meta_data": 100,
"text": 10
"text": 10,
"hyper_parameters": 10000
}


Expand Down Expand Up @@ -353,7 +354,9 @@ def __init__(self):
"meta_data":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["meta_data"]),
"text":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["text"])
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["text"]),
"hyper_parameters":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["hyper_parameters"])
}
self._mutex = threading.Lock()

Expand Down
Loading