Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add text #917

Merged
merged 1 commit into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions visualdl/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
"image": {
"enabled": False
},
"text": {
"enabled": False
},
"embedding": {
"enabled": False
},
Expand Down
16 changes: 16 additions & 0 deletions visualdl/component/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,22 @@ def audio(tag, audio_array, sample_rate, step, walltime):
])


def text(tag, text_string, step, walltime=None):
"""Package data to one image.
Args:
tag (string): Data identifier
text_string (string): Value of text
step (int): Step of text
walltime (int): Wall time of text
Return:
Package with format of record_pb2.Record
"""
_text = Record.Text(encoded_text_string=text_string)
return Record(values=[
Record.Value(id=step, tag=tag, timestamp=walltime, text=_text)
])


def histogram(tag, hist, bin_edges, step, walltime):
"""Package data to one histogram.

Expand Down
5 changes: 5 additions & 0 deletions visualdl/proto/record.proto
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ message Record {
bytes encoded_image_string = 4;
}

message Text {
string encoded_text_string = 1;
}

message Audio {
float sample_rate = 1;
int64 num_channels = 2;
Expand Down Expand Up @@ -70,6 +74,7 @@ message Record {
PRCurve pr_curve = 9;
MetaData meta_data = 10;
ROC_Curve roc_curve = 11;
Text text = 12;
}
}

Expand Down
92 changes: 71 additions & 21 deletions visualdl/proto/record_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions visualdl/reader/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def parse_from_bin(self, record_bin):
elif "meta_data" == value_type:
self.update_meta_data(record)
component = "meta_data"
elif "text" == value_type:
component = "text"
else:
raise TypeError("Invalid value type `%s`." % value_type)
self._tags[path] = component
Expand Down
18 changes: 18 additions & 0 deletions visualdl/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def scalar_tags(self):
def image_tags(self):
return self._get_with_retry('data/plugin/images/tags', lib.get_image_tags)

@result()
def text_tags(self):
return self._get_with_retry('data/plugin/text/tags', lib.get_text_tags)

@result()
def audio_tags(self):
return self._get_with_retry('data/plugin/audio/tags', lib.get_audio_tags)
Expand Down Expand Up @@ -138,6 +142,17 @@ def image_image(self, mode, tag, index=0):
key = os.path.join('data/plugin/images/individualImage', mode, tag, str(index))
return self._get_with_retry(key, lib.get_individual_image, mode, tag, index)

@result()
def text_list(self, mode, tag):
key = os.path.join('data/plugin/text/text', mode, tag)
return self._get_with_retry(key, lib.get_text_tag_steps, mode, tag)

@result('text/plain')
def text_text(self, mode, tag, index=0):
index = int(index)
key = os.path.join('data/plugin/text/individualText', mode, tag, str(index))
return self._get_with_retry(key, lib.get_individual_text, mode, tag, index)

@result()
def audio_list(self, run, tag):
key = os.path.join('data/plugin/audio/audio', run, tag)
Expand Down Expand Up @@ -216,6 +231,7 @@ def create_api_call(logdir, model, cache_timeout):
'logs': (api.logs, []),
'scalar/tags': (api.scalar_tags, []),
'image/tags': (api.image_tags, []),
'text/tags': (api.text_tags, []),
'audio/tags': (api.audio_tags, []),
'embedding/tags': (api.embedding_tags, []),
'histogram/tags': (api.histogram_tags, []),
Expand All @@ -225,6 +241,8 @@ def create_api_call(logdir, model, cache_timeout):
'scalar/data': (api.scalar_data, ['run', 'tag', 'type']),
'image/list': (api.image_list, ['run', 'tag']),
'image/image': (api.image_image, ['run', 'tag', 'index']),
'text/list': (api.text_list, ['run', 'tag']),
'text/text': (api.text_text, ['run', 'tag', 'index']),
'audio/list': (api.audio_list, ['run', 'tag']),
'audio/audio': (api.audio_audio, ['run', 'tag', 'index']),
'embedding/embedding': (api.embedding_embedding, ['run', 'tag', 'reduction', 'dimension']),
Expand Down
7 changes: 5 additions & 2 deletions visualdl/server/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
"audio": 10,
"pr_curve": 300,
"roc_curve": 300,
"meta_data": 100
"meta_data": 100,
"text": 10
}


Expand Down Expand Up @@ -350,7 +351,9 @@ def __init__(self):
"roc_curve":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["roc_curve"]),
"meta_data":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["meta_data"])
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["meta_data"]),
"text":
Reservoir(max_size=DEFAULT_PLUGIN_MAXSIZE["text"])
}
self._mutex = threading.Lock()

Expand Down
20 changes: 20 additions & 0 deletions visualdl/server/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,26 @@ def get_individual_image(log_reader, run, tag, step_index):
return records[step_index].image.encoded_image_string


def get_text_tag_steps(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("text").get_items(
run, decode_tag(tag))
result = [{
"step": item.id,
"wallTime": s2ms(item.timestamp)
} for item in records]
return result


def get_individual_text(log_reader, run, tag, step_index):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data()
records = log_reader.data_manager.get_reservoir("text").get_items(
run, decode_tag(tag))
return records[step_index].text.encoded_text_string


def get_audio_tag_steps(log_reader, run, tag):
run = log_reader.name2tags[run] if run in log_reader.name2tags else run
log_reader.load_new_data()
Expand Down
Loading