From e5121167e6eae0904424dc314455b7625366d5eb Mon Sep 17 00:00:00 2001 From: ShenYuhan Date: Wed, 19 May 2021 16:09:28 +0800 Subject: [PATCH 1/4] add hp be api --- visualdl/component/base_component.py | 58 +++++++++ visualdl/proto/record.proto | 15 +++ visualdl/proto/record_pb2.py | 152 +++++++++++++++++++++-- visualdl/reader/reader.py | 2 + visualdl/server/api.py | 29 ++++- visualdl/server/data_manager.py | 7 +- visualdl/server/lib.py | 178 +++++++++++++++++++++++++++ visualdl/utils/importance.py | 55 +++++++++ visualdl/writer/writer.py | 48 +++++++- 9 files changed, 533 insertions(+), 11 deletions(-) create mode 100644 visualdl/utils/importance.py diff --git a/visualdl/component/base_component.py b/visualdl/component/base_component.py index d0342c385..9ac6d8e2d 100644 --- a/visualdl/component/base_component.py +++ b/visualdl/component/base_component.py @@ -293,6 +293,64 @@ def histogram(tag, hist, bin_edges, step, walltime): ]) +def hparam(name, hparam_dict, metric_dict, walltime): + """Package data to one histogram. + + Args: + name (str): Name of hparam. + hparam_dict (dictionary): Each key-value pair in the dictionary is the + name of the hyper parameter and it's corresponding value. The type of the value + can be one of `bool`, `string`, `float`, `int`, or `None`. + metric_dict (dictionary): Each key-value pair in the dictionary is the + name of the metric and it's corresponding value. + walltime (int): Wall time of hparam. + + Return: + Package with format of record_pb2.Record + """ + + hm = Record.HParam() + hm.name = name + for k, v in hparam_dict.items(): + if v is None: + continue + hparamInfo = Record.HParam.HparamInfo() + hparamInfo.name = k + if isinstance(v, int): + hparamInfo.int_value = v + hm.hparamInfos.append(hparamInfo) + elif isinstance(v, float): + hparamInfo.float_value = v + hm.hparamInfos.append(hparamInfo) + elif isinstance(v, str): + hparamInfo.string_value = v + hm.hparamInfos.append(hparamInfo) + else: + print("The value of %s must be int, float or str, not %s" % (k, str(type(v)))) + + for k, v in metric_dict.items(): + if v is None: + continue + metricInfo = Record.HParam.HparamInfo() + metricInfo.name = k + if isinstance(v, int): + metricInfo.int_value = v + hm.metricInfos.append(metricInfo) + elif isinstance(v, float): + metricInfo.float_value = v + hm.metricInfos.append(metricInfo) + elif isinstance(v, str): + metricInfo.string_value = v + hm.metricInfos.append(metricInfo) + else: + print("The value of %s must be int, float or str, not %s" % (k, str(type(v)))) + + return Record(values=[ + Record.Value( + id=1, tag="hparam", timestamp=walltime, hparam=hm) + ]) + + def compute_curve(labels, predictions, num_thresholds=None, weights=None): """ Compute precision-recall curve data by labels and predictions. diff --git a/visualdl/proto/record.proto b/visualdl/proto/record.proto index 952218b18..680d0adbb 100644 --- a/visualdl/proto/record.proto +++ b/visualdl/proto/record.proto @@ -57,6 +57,20 @@ message Record { repeated double fpr = 6; } + message HParam { + message HparamInfo { + oneof type { + int64 int_value = 1; + double float_value = 2; + string string_value = 3; + }; + string name = 4; + } + repeated HparamInfo hparamInfos = 1; + repeated HparamInfo metricInfos = 2; + string name = 3; + } + message MetaData { string display_name = 1; } @@ -75,6 +89,7 @@ message Record { MetaData meta_data = 10; ROC_Curve roc_curve = 11; Text text = 12; + HParam hparam = 13; } } diff --git a/visualdl/proto/record_pb2.py b/visualdl/proto/record_pb2.py index 223dc5876..1dc3fd22a 100644 --- a/visualdl/proto/record_pb2.py +++ b/visualdl/proto/record_pb2.py @@ -18,7 +18,7 @@ package='visualdl', syntax='proto3', serialized_options=None, - serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xac\t\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a#\n\x04Text\x12\x1b\n\x13\x65ncoded_text_string\x18\x01 \x01(\t\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x03(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1aP\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x12\x12\n\nlabel_meta\x18\x02 \x03(\t\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1al\n\x07PRCurve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x11\n\tprecision\x18\x05 \x03(\x01\x12\x0e\n\x06recall\x18\x06 \x03(\x01\x1a\x65\n\tROC_Curve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x0b\n\x03tpr\x18\x05 \x03(\x01\x12\x0b\n\x03\x66pr\x18\x06 \x03(\x01\x1a \n\x08MetaData\x12\x14\n\x0c\x64isplay_name\x18\x01 \x01(\t\x1a\xbd\x03\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x12,\n\x08pr_curve\x18\t \x01(\x0b\x32\x18.visualdl.Record.PRCurveH\x00\x12.\n\tmeta_data\x18\n \x01(\x0b\x32\x19.visualdl.Record.MetaDataH\x00\x12/\n\troc_curve\x18\x0b \x01(\x0b\x32\x1a.visualdl.Record.ROC_CurveH\x00\x12%\n\x04text\x18\x0c \x01(\x0b\x32\x15.visualdl.Record.TextH\x00\x42\x0b\n\tone_valueb\x06proto3' + serialized_pb=b'\n\x0crecord.proto\x12\x08visualdl\"\xca\x0b\n\x06Record\x12&\n\x06values\x18\x01 \x03(\x0b\x32\x16.visualdl.Record.Value\x1a%\n\x05Image\x12\x1c\n\x14\x65ncoded_image_string\x18\x04 \x01(\x0c\x1a#\n\x04Text\x12\x1b\n\x13\x65ncoded_text_string\x18\x01 \x01(\t\x1a}\n\x05\x41udio\x12\x13\n\x0bsample_rate\x18\x01 \x01(\x02\x12\x14\n\x0cnum_channels\x18\x02 \x01(\x03\x12\x15\n\rlength_frames\x18\x03 \x01(\x03\x12\x1c\n\x14\x65ncoded_audio_string\x18\x04 \x01(\x0c\x12\x14\n\x0c\x63ontent_type\x18\x05 \x01(\t\x1a+\n\tEmbedding\x12\r\n\x05label\x18\x01 \x03(\t\x12\x0f\n\x07vectors\x18\x02 \x03(\x02\x1aP\n\nEmbeddings\x12.\n\nembeddings\x18\x01 \x03(\x0b\x32\x1a.visualdl.Record.Embedding\x12\x12\n\nlabel_meta\x18\x02 \x03(\t\x1a\x43\n\x10\x62ytes_embeddings\x12\x16\n\x0e\x65ncoded_labels\x18\x01 \x01(\x0c\x12\x17\n\x0f\x65ncoded_vectors\x18\x02 \x01(\x0c\x1a\x34\n\tHistogram\x12\x10\n\x04hist\x18\x01 \x03(\x01\x42\x02\x10\x01\x12\x15\n\tbin_edges\x18\x02 \x03(\x01\x42\x02\x10\x01\x1al\n\x07PRCurve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x11\n\tprecision\x18\x05 \x03(\x01\x12\x0e\n\x06recall\x18\x06 \x03(\x01\x1a\x65\n\tROC_Curve\x12\x0e\n\x02TP\x18\x01 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46P\x18\x02 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02TN\x18\x03 \x03(\x03\x42\x02\x10\x01\x12\x0e\n\x02\x46N\x18\x04 \x03(\x03\x42\x02\x10\x01\x12\x0b\n\x03tpr\x18\x05 \x03(\x01\x12\x0b\n\x03\x66pr\x18\x06 \x03(\x01\x1a\xf0\x01\n\x06HParam\x12\x37\n\x0bhparamInfos\x18\x01 \x03(\x0b\x32\".visualdl.Record.HParam.HparamInfo\x12\x37\n\x0bmetricInfos\x18\x02 \x03(\x0b\x32\".visualdl.Record.HParam.HparamInfo\x12\x0c\n\x04name\x18\x03 \x01(\t\x1a\x66\n\nHparamInfo\x12\x13\n\tint_value\x18\x01 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x02 \x01(\x01H\x00\x12\x16\n\x0cstring_value\x18\x03 \x01(\tH\x00\x12\x0c\n\x04name\x18\x04 \x01(\tB\x06\n\x04type\x1a \n\x08MetaData\x12\x14\n\x0c\x64isplay_name\x18\x01 \x01(\t\x1a\xe8\x03\n\x05Value\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x0b\n\x03tag\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x12\x0f\n\x05value\x18\x04 \x01(\x02H\x00\x12\'\n\x05image\x18\x05 \x01(\x0b\x32\x16.visualdl.Record.ImageH\x00\x12\'\n\x05\x61udio\x18\x06 \x01(\x0b\x32\x16.visualdl.Record.AudioH\x00\x12\x31\n\nembeddings\x18\x07 \x01(\x0b\x32\x1b.visualdl.Record.EmbeddingsH\x00\x12/\n\thistogram\x18\x08 \x01(\x0b\x32\x1a.visualdl.Record.HistogramH\x00\x12,\n\x08pr_curve\x18\t \x01(\x0b\x32\x18.visualdl.Record.PRCurveH\x00\x12.\n\tmeta_data\x18\n \x01(\x0b\x32\x19.visualdl.Record.MetaDataH\x00\x12/\n\troc_curve\x18\x0b \x01(\x0b\x32\x1a.visualdl.Record.ROC_CurveH\x00\x12%\n\x04text\x18\x0c \x01(\x0b\x32\x15.visualdl.Record.TextH\x00\x12)\n\x06hparam\x18\r \x01(\x0b\x32\x17.visualdl.Record.HParamH\x00\x42\x0b\n\tone_valueb\x06proto3' ) @@ -420,6 +420,104 @@ serialized_end=741, ) +_RECORD_HPARAM_HPARAMINFO = _descriptor.Descriptor( + name='HparamInfo', + full_name='visualdl.Record.HParam.HparamInfo', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='int_value', full_name='visualdl.Record.HParam.HparamInfo.int_value', index=0, + number=1, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='float_value', full_name='visualdl.Record.HParam.HparamInfo.float_value', index=1, + number=2, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='string_value', full_name='visualdl.Record.HParam.HparamInfo.string_value', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='name', full_name='visualdl.Record.HParam.HparamInfo.name', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='type', full_name='visualdl.Record.HParam.HparamInfo.type', + index=0, containing_type=None, fields=[]), + ], + serialized_start=882, + serialized_end=984, +) + +_RECORD_HPARAM = _descriptor.Descriptor( + name='HParam', + full_name='visualdl.Record.HParam', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='hparamInfos', full_name='visualdl.Record.HParam.hparamInfos', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='metricInfos', full_name='visualdl.Record.HParam.metricInfos', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='name', full_name='visualdl.Record.HParam.name', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_RECORD_HPARAM_HPARAMINFO, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=744, + serialized_end=984, +) + _RECORD_METADATA = _descriptor.Descriptor( name='MetaData', full_name='visualdl.Record.MetaData', @@ -446,8 +544,8 @@ extension_ranges=[], oneofs=[ ], - serialized_start=743, - serialized_end=775, + serialized_start=986, + serialized_end=1018, ) _RECORD_VALUE = _descriptor.Descriptor( @@ -541,6 +639,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='hparam', full_name='visualdl.Record.Value.hparam', index=12, + number=13, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -556,8 +661,8 @@ name='one_value', full_name='visualdl.Record.Value.one_value', index=0, containing_type=None, fields=[]), ], - serialized_start=778, - serialized_end=1223, + serialized_start=1021, + serialized_end=1509, ) _RECORD = _descriptor.Descriptor( @@ -577,7 +682,7 @@ ], extensions=[ ], - nested_types=[_RECORD_IMAGE, _RECORD_TEXT, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_PRCURVE, _RECORD_ROC_CURVE, _RECORD_METADATA, _RECORD_VALUE, ], + nested_types=[_RECORD_IMAGE, _RECORD_TEXT, _RECORD_AUDIO, _RECORD_EMBEDDING, _RECORD_EMBEDDINGS, _RECORD_BYTES_EMBEDDINGS, _RECORD_HISTOGRAM, _RECORD_PRCURVE, _RECORD_ROC_CURVE, _RECORD_HPARAM, _RECORD_METADATA, _RECORD_VALUE, ], enum_types=[ ], serialized_options=None, @@ -587,7 +692,7 @@ oneofs=[ ], serialized_start=27, - serialized_end=1223, + serialized_end=1509, ) _RECORD_IMAGE.containing_type = _RECORD @@ -600,6 +705,19 @@ _RECORD_HISTOGRAM.containing_type = _RECORD _RECORD_PRCURVE.containing_type = _RECORD _RECORD_ROC_CURVE.containing_type = _RECORD +_RECORD_HPARAM_HPARAMINFO.containing_type = _RECORD_HPARAM +_RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'].fields.append( + _RECORD_HPARAM_HPARAMINFO.fields_by_name['int_value']) +_RECORD_HPARAM_HPARAMINFO.fields_by_name['int_value'].containing_oneof = _RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'] +_RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'].fields.append( + _RECORD_HPARAM_HPARAMINFO.fields_by_name['float_value']) +_RECORD_HPARAM_HPARAMINFO.fields_by_name['float_value'].containing_oneof = _RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'] +_RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'].fields.append( + _RECORD_HPARAM_HPARAMINFO.fields_by_name['string_value']) +_RECORD_HPARAM_HPARAMINFO.fields_by_name['string_value'].containing_oneof = _RECORD_HPARAM_HPARAMINFO.oneofs_by_name['type'] +_RECORD_HPARAM.fields_by_name['hparamInfos'].message_type = _RECORD_HPARAM_HPARAMINFO +_RECORD_HPARAM.fields_by_name['metricInfos'].message_type = _RECORD_HPARAM_HPARAMINFO +_RECORD_HPARAM.containing_type = _RECORD _RECORD_METADATA.containing_type = _RECORD _RECORD_VALUE.fields_by_name['image'].message_type = _RECORD_IMAGE _RECORD_VALUE.fields_by_name['audio'].message_type = _RECORD_AUDIO @@ -609,6 +727,7 @@ _RECORD_VALUE.fields_by_name['meta_data'].message_type = _RECORD_METADATA _RECORD_VALUE.fields_by_name['roc_curve'].message_type = _RECORD_ROC_CURVE _RECORD_VALUE.fields_by_name['text'].message_type = _RECORD_TEXT +_RECORD_VALUE.fields_by_name['hparam'].message_type = _RECORD_HPARAM _RECORD_VALUE.containing_type = _RECORD _RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.fields_by_name['value']) @@ -637,6 +756,9 @@ _RECORD_VALUE.oneofs_by_name['one_value'].fields.append( _RECORD_VALUE.fields_by_name['text']) _RECORD_VALUE.fields_by_name['text'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value'] +_RECORD_VALUE.oneofs_by_name['one_value'].fields.append( + _RECORD_VALUE.fields_by_name['hparam']) +_RECORD_VALUE.fields_by_name['hparam'].containing_oneof = _RECORD_VALUE.oneofs_by_name['one_value'] _RECORD.fields_by_name['values'].message_type = _RECORD_VALUE DESCRIPTOR.message_types_by_name['Record'] = _RECORD _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -706,6 +828,20 @@ }) , + 'HParam' : _reflection.GeneratedProtocolMessageType('HParam', (_message.Message,), { + + 'HparamInfo' : _reflection.GeneratedProtocolMessageType('HparamInfo', (_message.Message,), { + 'DESCRIPTOR' : _RECORD_HPARAM_HPARAMINFO, + '__module__' : 'record_pb2' + # @@protoc_insertion_point(class_scope:visualdl.Record.HParam.HparamInfo) + }) + , + 'DESCRIPTOR' : _RECORD_HPARAM, + '__module__' : 'record_pb2' + # @@protoc_insertion_point(class_scope:visualdl.Record.HParam) + }) + , + 'MetaData' : _reflection.GeneratedProtocolMessageType('MetaData', (_message.Message,), { 'DESCRIPTOR' : _RECORD_METADATA, '__module__' : 'record_pb2' @@ -733,6 +869,8 @@ _sym_db.RegisterMessage(Record.Histogram) _sym_db.RegisterMessage(Record.PRCurve) _sym_db.RegisterMessage(Record.ROC_Curve) +_sym_db.RegisterMessage(Record.HParam) +_sym_db.RegisterMessage(Record.HParam.HparamInfo) _sym_db.RegisterMessage(Record.MetaData) _sym_db.RegisterMessage(Record.Value) diff --git a/visualdl/reader/reader.py b/visualdl/reader/reader.py index e3bef98b9..5dfcd9f91 100644 --- a/visualdl/reader/reader.py +++ b/visualdl/reader/reader.py @@ -182,6 +182,8 @@ def parse_from_bin(self, record_bin): component = "meta_data" elif "text" == value_type: component = "text" + elif "hparam" == value_type: + component = "hyper_parameters" else: raise TypeError("Invalid value type `%s`." % value_type) self._tags[path] = component diff --git a/visualdl/server/api.py b/visualdl/server/api.py index 89a388201..7f9c72199 100644 --- a/visualdl/server/api.py +++ b/visualdl/server/api.py @@ -121,6 +121,28 @@ def pr_curve_tags(self): def roc_curve_tags(self): return self._get_with_retry('data/plugin/roc_curves/tags', lib.get_roc_curve_tags) + @result() + def hparam_importance(self): + return self._get_with_retry('data/plugin/hparams/importance', lib.get_hparam_importance) + + @result() + def hparam_indicator(self): + return self._get_with_retry('data/plugin/hparams/indicators', lib.get_hparam_indicator) + + @result() + def hparam_list(self): + return self._get_with_retry('data/plugin/hparams/list', lib.get_hparam_list) + + @result() + def hparam_metric(self, run, metric): + key = os.path.join('data/plugin/hparams/metric', run, metric) + return self._get_with_retry(key, lib.get_hparam_metric, run, metric) + + @result('text/csv') + def hparam_data(self, type='tsv'): + key = os.path.join('data/plugin/hparams/data', type) + return self._get_with_retry(key, lib.get_hparam_data, type) + @result() def scalar_list(self, run, tag): key = os.path.join('data/plugin/scalars/scalars', run, tag) @@ -254,7 +276,12 @@ def create_api_call(logdir, model, cache_timeout): 'pr-curve/list': (api.pr_curves_pr_curve, ['run', 'tag']), 'roc-curve/list': (api.roc_curves_roc_curve, ['run', 'tag']), 'pr-curve/steps': (api.pr_curves_steps, ['run']), - 'roc-curve/steps': (api.roc_curves_steps, ['run']) + 'roc-curve/steps': (api.roc_curves_steps, ['run']), + 'hparams/importance': (api.hparam_importance, []), + 'hparams/data': (api.hparam_data, ['type']), + 'hparams/indicators': (api.hparam_indicator, []), + 'hparams/list': (api.hparam_list, []), + 'hparams/metric': (api.hparam_metric, ['run', 'metric']) } def call(path: str, args): diff --git a/visualdl/server/data_manager.py b/visualdl/server/data_manager.py index 8fe27ccb7..d4272cc77 100644 --- a/visualdl/server/data_manager.py +++ b/visualdl/server/data_manager.py @@ -26,7 +26,8 @@ "pr_curve": 300, "roc_curve": 300, "meta_data": 100, - "text": 10 + "text": 10, + "hyper_parameters": 10000 } @@ -353,7 +354,9 @@ def __init__(self): "meta_data": Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["meta_data"]), "text": - Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["text"]) + Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["text"]), + "hyper_parameters": + Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["hyper_parameters"]) } self._mutex = threading.Lock() diff --git a/visualdl/server/lib.py b/visualdl/server/lib.py index c7266f3cd..f455988d4 100644 --- a/visualdl/server/lib.py +++ b/visualdl/server/lib.py @@ -24,6 +24,7 @@ from visualdl.server.log import logger from visualdl.io import bfile from visualdl.utils.string_util import encode_tag, decode_tag +from visualdl.utils.importance import calc_all_hyper_param_importance from visualdl.component import components @@ -115,6 +116,183 @@ def get_logs(log_reader, component): exec("get_%s_tags=partial(get_logs, component='%s')" % (name, name)) +def get_hparam_data(log_reader, type='tsv'): + result = get_hparam_list(log_reader) + delimeter = '\t' if 'tsv' == type else ',' + header = ['Trial ID'] + hparams_header = [] + metrics_header = [] + for item in result: + hparams_header += item['hparams'].keys() + metrics_header += item['metrics'].keys() + name_set = set() + h_header = [] + for hparam in hparams_header: + if hparam in name_set: + continue + name_set.add(hparam) + h_header.append(hparam) + name_set = set() + m_header = [] + for metric in metrics_header: + if metric in name_set: + continue + name_set.add(metric) + m_header.append(metric) + trans_result = [] + for item in result: + temp = {'Trial ID': item.get('name', '')} + temp.update(item.get('hparams', {})) + temp.update(item.get('metrics', {})) + trans_result.append(temp) + header = header + h_header + m_header + with io.StringIO() as fp: + csv_writer = csv.writer(fp, delimiter=delimeter) + csv_writer.writerow(header) + for item in trans_result: + row = [] + for col_name in header: + row.append(item.get(col_name, '')) + csv_writer.writerow(row) + result = fp.getvalue() + return result + + +def get_hparam_importance(log_reader): + indicator = get_hparam_indicator(log_reader) + hparams = [item for item in indicator['hparams'] if (item['type'] != 'string')] + metrics = [item for item in indicator['metrics'] if (item['type'] != 'string')] + + result = calc_all_hyper_param_importance(hparams, metrics) + + return result + + +# flake8: noqa: C901 +def get_hparam_indicator(log_reader): + run2tag = get_logs(log_reader, 'hyper_parameters') + runs = run2tag['runs'] + hparams = {} + metrics = {} + records_list = [] + for run in runs: + run = log_reader.name2tags[run] if run in log_reader.name2tags else run + log_reader.load_new_data() + records = log_reader.data_manager.get_reservoir("hyper_parameters").get_items( + run, decode_tag('hparam')) + records_list.append(records) + records_list.sort(key=lambda x: x[0].timestamp) + for records in records_list: + for hparamInfo in records[0].hparam.hparamInfos: + type = hparamInfo.WhichOneof("type") + if "float_value" == type: + if hparamInfo.name not in hparams.keys(): + hparams[hparamInfo.name] = {'name': hparamInfo.name, + 'type': 'continuous', + 'values': [hparamInfo.float_value]} + else: + hparams[hparamInfo.name]['values'].append(hparamInfo.float_value) + elif "string_value" == type: + if hparamInfo.name not in hparams.keys(): + hparams[hparamInfo.name] = {'name': hparamInfo.name, + 'type': 'string', + 'values': [hparamInfo.string_value]} + else: + hparams[hparamInfo.name]['values'].append(hparamInfo.string_value) + elif "int_value" == type: + if hparamInfo.name not in hparams.keys(): + hparams[hparamInfo.name] = {'name': hparamInfo.name, + 'type': 'numeric', + 'values': [hparamInfo.int_value]} + else: + hparams[hparamInfo.name]['values'].append(hparamInfo.int_value) + else: + raise TypeError("Invalid hparams param value type `%s`." % type) + + for metricInfo in records[0].hparam.metricInfos: + type = metricInfo.WhichOneof("type") + if "float_value" == type: + if metricInfo.name not in metrics.keys(): + metrics[metricInfo.name] = {'name': metricInfo.name, + 'type': 'continuous', + 'values': [metricInfo.float_value]} + else: + metrics[metricInfo.name]['values'].append(metricInfo.float_value) + elif "string_value" == type: + if metricInfo.name not in metrics.keys(): + metrics[metricInfo.name] = {'name': metricInfo.name, + 'type': 'string', + 'values': [metricInfo.string_value]} + else: + metrics[metricInfo.name]['values'].append(metricInfo.string_value) + elif "int_value" == type: + if metricInfo.name not in metrics.keys(): + metrics[metricInfo.name] = {'name': metricInfo.name, + 'type': 'numeric', + 'values': [metricInfo.int_value]} + else: + metrics[metricInfo.name]['values'].append(metricInfo.int_value) + else: + raise TypeError("Invalid hparams param value type `%s`." % type) + results = {'hparams': [value for key, value in hparams.items()], + 'metrics': [value for key, value in metrics.items()]} + + return results + + +def get_hparam_metric(log_reader, run, tag): + run = log_reader.name2tags[run] if run in log_reader.name2tags else run + log_reader.load_new_data() + records = log_reader.data_manager.get_reservoir("scalar").get_items( + run, decode_tag(tag)) + results = [[s2ms(item.timestamp), item.id, item.value] for item in records] + return results + + +def get_hparam_list(log_reader): + run2tag = get_logs(log_reader, 'hyper_parameters') + runs = run2tag['runs'] + results = [] + + records_list = [] + for run in runs: + run = log_reader.name2tags[run] if run in log_reader.name2tags else run + log_reader.load_new_data() + records = log_reader.data_manager.get_reservoir("hyper_parameters").get_items( + run, decode_tag('hparam')) + records_list.append([records, run]) + records_list.sort(key=lambda x: x[0][0].timestamp) + for records, run in records_list: + hparams = {} + for hparamInfo in records[0].hparam.hparamInfos: + type = hparamInfo.WhichOneof("type") + if "float_value" == type: + hparams[hparamInfo.name] = hparamInfo.float_value + elif "string_value" == type: + hparams[hparamInfo.name] = hparamInfo.string_value + elif "int_value" == type: + hparams[hparamInfo.name] = hparamInfo.int_value + else: + raise TypeError("Invalid hparams param value type `%s`." % type) + + metrics = {} + for metricInfo in records[0].hparam.metricInfos: + type = metricInfo.WhichOneof("type") + if "float_value" == type: + metrics[metricInfo.name] = metricInfo.float_value + elif "string_value" == type: + metrics[metricInfo.name] = metricInfo.string_value + elif "int_value" == type: + metrics[metricInfo.name] = metricInfo.int_value + else: + raise TypeError("Invalid hparams metric value type `%s`." % type) + + results.append({'name': run, + 'hparams': hparams, + 'metrics': metrics}) + return results + + def get_scalar(log_reader, run, tag): run = log_reader.name2tags[run] if run in log_reader.name2tags else run log_reader.load_new_data() diff --git a/visualdl/utils/importance.py b/visualdl/utils/importance.py new file mode 100644 index 000000000..f2e911feb --- /dev/null +++ b/visualdl/utils/importance.py @@ -0,0 +1,55 @@ +# Copyright (c) 2020 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +from functools import reduce + +import numpy as np +import pandas as pd + +from visualdl.server.log import logger + + +def calc_hyper_param_importance(df, hyper_param, target): + new_df = df[[hyper_param, target]] + no_missing_value_df = new_df.dropna() + + # Can not calc pearson correlation coefficient when number of samples is less or equal than 2 + if len(no_missing_value_df) <= 2: + logger.error("Number of samples is less or equal than 2.") + return 0 + + correlation = no_missing_value_df[target].corr(no_missing_value_df[hyper_param]) + if np.isnan(correlation): + logger.warning("Correlation is nan!") + return 0 + + return abs(correlation) + + +def calc_all_hyper_param_importance(hparams, metrics): + results = {} + for metric in metrics: + for hparam in hparams: + flattened_lineage = {hparam['name']: hparam['values'], metric['name']: metric['values']} + result = calc_hyper_param_importance(pd.DataFrame(flattened_lineage), hparam['name'], metric['name']) + # print('%s - %s : result=' % (hparam, metric), result) + if hparam['name'] not in results.keys(): + results[hparam['name']] = result + else: + results[hparam['name']] += result + sum_score = reduce(lambda x, y: x+y, results.values()) + for key, value in results.items(): + results[key] = value/sum_score + result = [{'name': key, 'value': value} for key, value in results.items()] + return result diff --git a/visualdl/writer/writer.py b/visualdl/writer/writer.py index 214708fd8..8f80184d1 100644 --- a/visualdl/writer/writer.py +++ b/visualdl/writer/writer.py @@ -20,8 +20,9 @@ from visualdl.server.log import logger from visualdl.utils.img_util import merge_images from visualdl.utils.figure_util import figure_to_image +from visualdl.utils.md5_util import md5 from visualdl.component.base_component import scalar, image, embedding, audio, \ - histogram, pr_curve, roc_curve, meta_data, text + histogram, pr_curve, roc_curve, meta_data, text, hparam class DummyFileWriter(object): @@ -108,6 +109,8 @@ def __init__(self, self.loggers = {} self.add_meta(display_name=display_name) + self.hparam_write = False + @property def logdir(self): return self._logdir @@ -441,6 +444,49 @@ def add_histogram(self, step=step, walltime=walltime)) + def add_hparams(self, hparam_dict, metric_dict, walltime=None): + """Add an histogram to vdl record file. + + Args: + hparam_dict (dictionary): Each key-value pair in the dictionary is the + name of the hyper parameter and it's corresponding value. The type of the value + can be one of `bool`, `string`, `float`, `int`, or `None`. + metric_dict (dictionary): Each key-value pair in the dictionary is the + name of the metric and it's corresponding value. + walltime (int): Wall time of hparams. + + Examples:: + from visualdl import LogWriter + + with LogWriter('./log/hparams_test/train/run1') as writer: + writer.add_hparams({'lr': 0.1, 'bsize':1, 'opt': 'sgd'}, {'hparam/accuracy':10, 'hparam/loss': 2}) + writer.add_scalar('hparam/accuracy', 10, 0) + writer.add_scalar('hparam/loss', 2, 0) + + with LogWriter('./log/hparams_test/train/run2') as writer: + writer.add_hparams({'lr': 0.2, 'bsize':2, 'opt': 'relu'}, {'hparam/accuracy':12, 'hparam/loss': 3}) + writer.add_scalar('hparam/accuracy', 12, 0) + writer.add_scalar('hparam/loss', 3, 0) + + with LogWriter('./log/hparams_test/train/run3') as writer: + writer.add_hparams({'lr': 0.3, 'bsize':3, 'opt': 'line'}, {'hparam/accuracy':14, 'hparam/loss': -2.3}) + writer.add_scalar('hparam/accuracy', 14, 0) + writer.add_scalar('hparam/loss', -2.3, 0) + """ + if self.hparam_write: + logger.warning('Each log file should have only one hparams info. ' + 'Only last hparams info will be displayed on board.') + if type(hparam_dict) is not dict or type(metric_dict) is not dict: + raise TypeError('hparam_dict and metric_dict should be dictionary.') + walltime = round(time.time() * 1000) if walltime is None else walltime + + self._get_file_writer().add_record( + hparam( + name=md5(self.file_name), + hparam_dict=hparam_dict, + metric_dict=metric_dict, + walltime=walltime)) + def add_pr_curve(self, tag, labels, From 718931c3a853986317bd418886a88f310af6cf59 Mon Sep 17 00:00:00 2001 From: ShenYuhan Date: Thu, 20 May 2021 01:52:35 +0800 Subject: [PATCH 2/4] add duplicate_removal --- visualdl/server/lib.py | 13 +++++++------ visualdl/utils/list_util.py | 25 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) create mode 100644 visualdl/utils/list_util.py diff --git a/visualdl/server/lib.py b/visualdl/server/lib.py index f455988d4..433ca2a9a 100644 --- a/visualdl/server/lib.py +++ b/visualdl/server/lib.py @@ -25,6 +25,7 @@ from visualdl.io import bfile from visualdl.utils.string_util import encode_tag, decode_tag from visualdl.utils.importance import calc_all_hyper_param_importance +from visualdl.utils.list_util import duplicate_removal from visualdl.component import components @@ -190,21 +191,21 @@ def get_hparam_indicator(log_reader): hparams[hparamInfo.name] = {'name': hparamInfo.name, 'type': 'continuous', 'values': [hparamInfo.float_value]} - else: + elif hparamInfo.float_value not in hparams[hparamInfo.name]['values']: hparams[hparamInfo.name]['values'].append(hparamInfo.float_value) elif "string_value" == type: if hparamInfo.name not in hparams.keys(): hparams[hparamInfo.name] = {'name': hparamInfo.name, 'type': 'string', 'values': [hparamInfo.string_value]} - else: + elif hparamInfo.string_value not in hparams[hparamInfo.name]['values']: hparams[hparamInfo.name]['values'].append(hparamInfo.string_value) elif "int_value" == type: if hparamInfo.name not in hparams.keys(): hparams[hparamInfo.name] = {'name': hparamInfo.name, 'type': 'numeric', 'values': [hparamInfo.int_value]} - else: + elif hparamInfo.int_value not in hparams[hparamInfo.name]['values']: hparams[hparamInfo.name]['values'].append(hparamInfo.int_value) else: raise TypeError("Invalid hparams param value type `%s`." % type) @@ -216,21 +217,21 @@ def get_hparam_indicator(log_reader): metrics[metricInfo.name] = {'name': metricInfo.name, 'type': 'continuous', 'values': [metricInfo.float_value]} - else: + elif metricInfo.float_value not in metrics[metricInfo.name]['values']: metrics[metricInfo.name]['values'].append(metricInfo.float_value) elif "string_value" == type: if metricInfo.name not in metrics.keys(): metrics[metricInfo.name] = {'name': metricInfo.name, 'type': 'string', 'values': [metricInfo.string_value]} - else: + elif metricInfo.string_value not in metrics[metricInfo.name]['values']: metrics[metricInfo.name]['values'].append(metricInfo.string_value) elif "int_value" == type: if metricInfo.name not in metrics.keys(): metrics[metricInfo.name] = {'name': metricInfo.name, 'type': 'numeric', 'values': [metricInfo.int_value]} - else: + elif metricInfo.int_value not in metrics[metricInfo.name]['values']: metrics[metricInfo.name]['values'].append(metricInfo.int_value) else: raise TypeError("Invalid hparams param value type `%s`." % type) diff --git a/visualdl/utils/list_util.py b/visualdl/utils/list_util.py new file mode 100644 index 000000000..cff1af876 --- /dev/null +++ b/visualdl/utils/list_util.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= + + +def duplicate_removal(src_list): + name_scope = set() + dest_list = [] + for item in src_list: + if item in name_scope: + continue + name_scope.add(item) + dest_list.append(item) + return dest_list From c6473435041c41859083a11f91b457f325a321b0 Mon Sep 17 00:00:00 2001 From: ShenYuhan Date: Fri, 21 May 2021 21:42:32 +0800 Subject: [PATCH 3/4] change usage of add_hparam --- visualdl/component/base_component.py | 25 +++------- visualdl/server/lib.py | 69 ++++++++++++---------------- visualdl/writer/writer.py | 38 +++++++-------- 3 files changed, 51 insertions(+), 81 deletions(-) diff --git a/visualdl/component/base_component.py b/visualdl/component/base_component.py index 9ac6d8e2d..b562c6173 100644 --- a/visualdl/component/base_component.py +++ b/visualdl/component/base_component.py @@ -293,7 +293,7 @@ def histogram(tag, hist, bin_edges, step, walltime): ]) -def hparam(name, hparam_dict, metric_dict, walltime): +def hparam(name, hparam_dict, metric_list, walltime): """Package data to one histogram. Args: @@ -301,8 +301,7 @@ def hparam(name, hparam_dict, metric_dict, walltime): hparam_dict (dictionary): Each key-value pair in the dictionary is the name of the hyper parameter and it's corresponding value. The type of the value can be one of `bool`, `string`, `float`, `int`, or `None`. - metric_dict (dictionary): Each key-value pair in the dictionary is the - name of the metric and it's corresponding value. + metric_list (list): Name of all metrics. walltime (int): Wall time of hparam. Return: @@ -327,23 +326,11 @@ def hparam(name, hparam_dict, metric_dict, walltime): hm.hparamInfos.append(hparamInfo) else: print("The value of %s must be int, float or str, not %s" % (k, str(type(v)))) - - for k, v in metric_dict.items(): - if v is None: - continue + for metric in metric_list: metricInfo = Record.HParam.HparamInfo() - metricInfo.name = k - if isinstance(v, int): - metricInfo.int_value = v - hm.metricInfos.append(metricInfo) - elif isinstance(v, float): - metricInfo.float_value = v - hm.metricInfos.append(metricInfo) - elif isinstance(v, str): - metricInfo.string_value = v - hm.metricInfos.append(metricInfo) - else: - print("The value of %s must be int, float or str, not %s" % (k, str(type(v)))) + metricInfo.name = metric + metricInfo.float_value = 0 + hm.metricInfos.append(metricInfo) return Record(values=[ Record.Value( diff --git a/visualdl/server/lib.py b/visualdl/server/lib.py index 433ca2a9a..b6791853a 100644 --- a/visualdl/server/lib.py +++ b/visualdl/server/lib.py @@ -181,9 +181,10 @@ def get_hparam_indicator(log_reader): log_reader.load_new_data() records = log_reader.data_manager.get_reservoir("hyper_parameters").get_items( run, decode_tag('hparam')) - records_list.append(records) - records_list.sort(key=lambda x: x[0].timestamp) - for records in records_list: + records_list.append([records, run]) + records_list.sort(key=lambda x: x[0][0].timestamp) + runs = [run for r, run in records_list] + for records, run in records_list: for hparamInfo in records[0].hparam.hparamInfos: type = hparamInfo.WhichOneof("type") if "float_value" == type: @@ -211,30 +212,21 @@ def get_hparam_indicator(log_reader): raise TypeError("Invalid hparams param value type `%s`." % type) for metricInfo in records[0].hparam.metricInfos: - type = metricInfo.WhichOneof("type") - if "float_value" == type: - if metricInfo.name not in metrics.keys(): - metrics[metricInfo.name] = {'name': metricInfo.name, - 'type': 'continuous', - 'values': [metricInfo.float_value]} - elif metricInfo.float_value not in metrics[metricInfo.name]['values']: - metrics[metricInfo.name]['values'].append(metricInfo.float_value) - elif "string_value" == type: - if metricInfo.name not in metrics.keys(): - metrics[metricInfo.name] = {'name': metricInfo.name, - 'type': 'string', - 'values': [metricInfo.string_value]} - elif metricInfo.string_value not in metrics[metricInfo.name]['values']: - metrics[metricInfo.name]['values'].append(metricInfo.string_value) - elif "int_value" == type: - if metricInfo.name not in metrics.keys(): - metrics[metricInfo.name] = {'name': metricInfo.name, - 'type': 'numeric', - 'values': [metricInfo.int_value]} - elif metricInfo.int_value not in metrics[metricInfo.name]['values']: - metrics[metricInfo.name]['values'].append(metricInfo.int_value) + metrics[metricInfo.name] = {'name': metricInfo.name, + 'type': 'continuous', + 'values': []} + for run in runs: + try: + metrics_data = get_hparam_metric(log_reader, run, metricInfo.name) + metrics[metricInfo.name]['values'].append(metrics_data[-1][-1]) + break + except: + logger.error('Missing data of metrics! Please make sure use add_scalar to log metrics data.') + if len(metrics[metricInfo.name]['values']) == 0: + metrics.pop(metricInfo.name) else: - raise TypeError("Invalid hparams param value type `%s`." % type) + metrics[metricInfo.name].pop('values') + results = {'hparams': [value for key, value in hparams.items()], 'metrics': [value for key, value in metrics.items()]} @@ -266,27 +258,24 @@ def get_hparam_list(log_reader): for records, run in records_list: hparams = {} for hparamInfo in records[0].hparam.hparamInfos: - type = hparamInfo.WhichOneof("type") - if "float_value" == type: + hparam_type = hparamInfo.WhichOneof("type") + if "float_value" == hparam_type: hparams[hparamInfo.name] = hparamInfo.float_value - elif "string_value" == type: + elif "string_value" == hparam_type: hparams[hparamInfo.name] = hparamInfo.string_value - elif "int_value" == type: + elif "int_value" == hparam_type: hparams[hparamInfo.name] = hparamInfo.int_value else: - raise TypeError("Invalid hparams param value type `%s`." % type) + raise TypeError("Invalid hparams param value type `%s`." % hparam_type) metrics = {} for metricInfo in records[0].hparam.metricInfos: - type = metricInfo.WhichOneof("type") - if "float_value" == type: - metrics[metricInfo.name] = metricInfo.float_value - elif "string_value" == type: - metrics[metricInfo.name] = metricInfo.string_value - elif "int_value" == type: - metrics[metricInfo.name] = metricInfo.int_value - else: - raise TypeError("Invalid hparams metric value type `%s`." % type) + try: + metrics_data = get_hparam_metric(log_reader, run, metricInfo.name) + metrics[metricInfo.name] = metrics_data[-1][-1] + except: + logger.error('Missing data of metrics! Please make sure use add_scalar to log metrics data.') + metrics[metricInfo.name] = None results.append({'name': run, 'hparams': hparams, diff --git a/visualdl/writer/writer.py b/visualdl/writer/writer.py index 8f80184d1..ccc795566 100644 --- a/visualdl/writer/writer.py +++ b/visualdl/writer/writer.py @@ -109,8 +109,6 @@ def __init__(self, self.loggers = {} self.add_meta(display_name=display_name) - self.hparam_write = False - @property def logdir(self): return self._logdir @@ -444,47 +442,43 @@ def add_histogram(self, step=step, walltime=walltime)) - def add_hparams(self, hparam_dict, metric_dict, walltime=None): + def add_hparams(self, hparam_dict, metric_list, walltime=None): """Add an histogram to vdl record file. Args: hparam_dict (dictionary): Each key-value pair in the dictionary is the name of the hyper parameter and it's corresponding value. The type of the value can be one of `bool`, `string`, `float`, `int`, or `None`. - metric_dict (dictionary): Each key-value pair in the dictionary is the - name of the metric and it's corresponding value. + metric_list (list): Name of all metrics. walltime (int): Wall time of hparams. Examples:: from visualdl import LogWriter + # Remember use add_scalar to log your metrics data! with LogWriter('./log/hparams_test/train/run1') as writer: - writer.add_hparams({'lr': 0.1, 'bsize':1, 'opt': 'sgd'}, {'hparam/accuracy':10, 'hparam/loss': 2}) - writer.add_scalar('hparam/accuracy', 10, 0) - writer.add_scalar('hparam/loss', 2, 0) + writer.add_hparams({'lr': 0.1, 'bsize': 1, 'opt': 'sgd'}, ['hparam/accuracy', 'hparam/loss']) + for i in range(10): + writer.add_scalar('hparam/accuracy', i, i) + writer.add_scalar('hparam/loss', 2*i, i) with LogWriter('./log/hparams_test/train/run2') as writer: - writer.add_hparams({'lr': 0.2, 'bsize':2, 'opt': 'relu'}, {'hparam/accuracy':12, 'hparam/loss': 3}) - writer.add_scalar('hparam/accuracy', 12, 0) - writer.add_scalar('hparam/loss', 3, 0) - - with LogWriter('./log/hparams_test/train/run3') as writer: - writer.add_hparams({'lr': 0.3, 'bsize':3, 'opt': 'line'}, {'hparam/accuracy':14, 'hparam/loss': -2.3}) - writer.add_scalar('hparam/accuracy', 14, 0) - writer.add_scalar('hparam/loss', -2.3, 0) + writer.add_hparams({'lr': 0.2, 'bsize': 2, 'opt': 'relu'}, ['hparam/accuracy', 'hparam/loss']) + for i in range(10): + writer.add_scalar('hparam/accuracy', 1.0/(i+1), i) + writer.add_scalar('hparam/loss', 5*i, i) """ - if self.hparam_write: - logger.warning('Each log file should have only one hparams info. ' - 'Only last hparams info will be displayed on board.') - if type(hparam_dict) is not dict or type(metric_dict) is not dict: - raise TypeError('hparam_dict and metric_dict should be dictionary.') + if type(hparam_dict) is not dict: + raise TypeError('hparam_dict should be dictionary.') + if type(metric_list) is not list: + raise TypeError('metric_list should be list.') walltime = round(time.time() * 1000) if walltime is None else walltime self._get_file_writer().add_record( hparam( name=md5(self.file_name), hparam_dict=hparam_dict, - metric_dict=metric_dict, + metric_list=metric_list, walltime=walltime)) def add_pr_curve(self, From 45d9fccfaebf1483af931f9788c4424de49936d9 Mon Sep 17 00:00:00 2001 From: ShenYuhan Date: Fri, 21 May 2021 21:46:41 +0800 Subject: [PATCH 4/4] add requirement of pandas --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 5a4e55129..17af24d61 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ requests shellcheck-py six >= 1.14.0 matplotlib +pandas