diff --git a/visualdl/__init__.py b/visualdl/__init__.py index 14fd8f746..9a7dcd64d 100644 --- a/visualdl/__init__.py +++ b/visualdl/__init__.py @@ -18,6 +18,7 @@ import os from visualdl.writer.writer import LogWriter # noqa +from visualdl.reader.reader import LogReader from visualdl.version import vdl_version as __version__ from visualdl.utils.dir import init_vdl_config diff --git a/visualdl/component/__init__.py b/visualdl/component/__init__.py index 249ab41f6..192c6b49b 100644 --- a/visualdl/component/__init__.py +++ b/visualdl/component/__init__.py @@ -30,5 +30,11 @@ }, "graph": { "enabled": False + }, + "pr_curve": { + "enabled": False + }, + "meta_data": { + "enabled": False } } diff --git a/visualdl/reader/reader.py b/visualdl/reader/reader.py index a801cd092..6c1217bf3 100644 --- a/visualdl/reader/reader.py +++ b/visualdl/reader/reader.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ======================================================================= -import os +import collections +from functools import partial from visualdl.io import bfile from visualdl.component import components from visualdl.reader.record_reader import RecordReader @@ -40,7 +41,7 @@ class LogReader(object): """ - def __init__(self, logdir): + def __init__(self, logdir='', file_name=''): """Instance of LogReader Args: @@ -59,13 +60,23 @@ def __init__(self, logdir): self.tags2name = {} self.file_readers = {} - self._environments = components - self.data_manager = default_data_manager - self.load_new_data(update=True) - self._a_tags = {} - self._model = "" + if file_name: + self._log_data = collections.defaultdict(lambda: collections.defaultdict(list)) + self.get_file_reader(file_name=file_name) + remain = self.get_remain() + self.read_log_data(remain=remain) + components_name = components.keys() + + for name in components_name: + exec("self.get_%s=partial(self.get_data, '%s')" % (name, name)) + elif logdir: + self.data_manager = default_data_manager + self.load_new_data(update=True) + self._a_tags = {} + + self._model = "" @property def model(self): @@ -87,6 +98,19 @@ def model(self, model_path): def logdir(self): return self.dir + def _get_log_tags(self): + component_keys = self._log_data.keys() + log_tags = {} + for key in component_keys: + log_tags[key] = list(self._log_data[key].keys()) + return log_tags + + def get_tags(self): + return self._get_log_tags() + + def get_data(self, component, tag): + return self._log_data[component][tag] + def parse_from_bin(self, record_bin): """Register to self._tags by component type. @@ -134,21 +158,6 @@ def get_all_walk(self): for root, dirs, files in bfile.walk(dir): self.walks.update({root: files}) - def components_listing(self): - """Get available component types. - - Indicates what components are included. - - Returns: - self._environments: A dict like `{"image": False, "scalar": - True}` - """ - keys_enable = self.data_manager.get_keys() - for key in self._environments.keys(): - if key in keys_enable: - self._environments[key].update({"enable": True}) - return self._environments - def logs(self, update=False): """Get logs. @@ -191,7 +200,20 @@ def get_log_reader(self, dir, log): self.reader = self.readers[filepath] return self.reader - def _register_reader(self, path, dir): + def get_file_reader(self, file_name): + """Get file reader for specified vdl log file. + + Get instance of class RecordReader base on BFile. + + Args: + file_name: Vdl log file name. + """ + self._register_reader(file_name) + self.reader = self.readers[file_name] + self.reader.dir = file_name + return self.reader + + def _register_reader(self, path, dir=None): if path not in list(self.readers.keys()): reader = RecordReader(filepath=path, dir=dir) self.readers[path] = reader @@ -215,10 +237,10 @@ def add_remain(self): """ for reader in self.readers.values(): self.reader = reader - remain = self.reader.get_remain() for item in remain: component, dir, tag, record = self.parse_from_bin(item) + self.data_manager.add_item(component, self.reader.dir, tag, record) @@ -229,6 +251,20 @@ def get_remain(self): raise RuntimeError("Please specify log path!") return self.reader.get_remain() + def read_log_data(self, remain): + """Parse data from log file without sampling. + + Args: + remain: Raw data from log file. + """ + for item in remain: + component, dir, tag, record = self.parse_from_bin(item) + self._log_data[component][tag].append(record) + + @property + def log_data(self): + return self._log_data + def runs(self, update=True): self.logs(update=update) return list(self.walks.keys()) @@ -259,3 +295,9 @@ def load_new_data(self, update=True): if self.logdir is not None: self.register_readers(update=update) self.add_remain() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass diff --git a/visualdl/reader/record_reader.py b/visualdl/reader/record_reader.py index 557d051d1..9ea5a623d 100644 --- a/visualdl/reader/record_reader.py +++ b/visualdl/reader/record_reader.py @@ -109,3 +109,7 @@ def get_remain(self, update=False): @property def dir(self): return self._dir + + @dir.setter + def dir(self, value): + self._dir = value diff --git a/visualdl/server/__init__.py b/visualdl/server/__init__.py index 09b6fc7d3..6bab744b7 100644 --- a/visualdl/server/__init__.py +++ b/visualdl/server/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017 VisualDL Authors. All Rights Reserve. +# Copyright (c) 2020 VisualDL Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # ======================================================================= - -from __future__ import absolute_import -from . import (log, app, api) - -__all__ = ['log', 'app', 'api'] diff --git a/visualdl/server/api.py b/visualdl/server/api.py index 13bfd366c..ad2387781 100644 --- a/visualdl/server/api.py +++ b/visualdl/server/api.py @@ -19,7 +19,7 @@ import json import os -from visualdl.reader.reader import LogReader +from visualdl import LogReader from visualdl.server import lib from visualdl.server.log import logger from visualdl.python.cache import MemCache diff --git a/visualdl/server/data_manager.py b/visualdl/server/data_manager.py index 83ee8a0c0..d32cb6cf4 100644 --- a/visualdl/server/data_manager.py +++ b/visualdl/server/data_manager.py @@ -13,7 +13,6 @@ # limitations under the License. # ======================================================================= -from __future__ import absolute_import import threading import random import collections diff --git a/visualdl/server/lib.py b/visualdl/server/lib.py index e812cff79..db402344b 100644 --- a/visualdl/server/lib.py +++ b/visualdl/server/lib.py @@ -17,10 +17,12 @@ import sys import time import os +from functools import partial import numpy as np from visualdl.server.log import logger from visualdl.io import bfile from visualdl.utils.string_util import encode_tag, decode_tag +from visualdl.component import components MODIFY_PREFIX = {} @@ -106,8 +108,8 @@ def get_logs(log_reader, component): return run2tag -def get_scalar_tags(log_reader): - return get_logs(log_reader, "scalar") +for name in components.keys(): + exec("get_%s_tags=partial(get_logs, component='%s')" % (name, name)) def get_scalar(log_reader, run, tag): @@ -119,10 +121,6 @@ def get_scalar(log_reader, run, tag): return results -def get_image_tags(log_reader): - return get_logs(log_reader, "image") - - def get_image_tag_steps(log_reader, run, tag): run = log_reader.name2tags[run] if run in log_reader.name2tags else run log_reader.load_new_data() @@ -143,10 +141,6 @@ def get_individual_image(log_reader, run, tag, step_index): return records[step_index].image.encoded_image_string -def get_audio_tags(log_reader): - return get_logs(log_reader, "audio") - - def get_audio_tag_steps(log_reader, run, tag): run = log_reader.name2tags[run] if run in log_reader.name2tags else run log_reader.load_new_data() @@ -168,18 +162,6 @@ def get_individual_audio(log_reader, run, tag, step_index): return result -def get_embeddings_tags(log_reader): - return get_logs(log_reader, "embeddings") - - -def get_histogram_tags(log_reader): - return get_logs(log_reader, "histogram") - - -def get_pr_curve_tags(log_reader): - return get_logs(log_reader, "pr_curve") - - def get_pr_curve(log_reader, run, tag): run = log_reader.name2tags[run] if run in log_reader.name2tags else run log_reader.load_new_data()