From cfed5b594db2380a68f940496214e36ff9ed6e25 Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 13 Jun 2022 16:12:09 +0800 Subject: [PATCH 1/9] add dygraph component --- .pre-commit-config.yaml | 13 ++ visualdl/component/graph | 25 ++++ visualdl/reader/graph_reader.py | 206 ++++++++++++++++++++++++++ visualdl/server/api.py | 211 +++++++++++++++++++++------ visualdl/server/app.py | 70 +++++---- visualdl/server/client_manager.py | 31 ++++ visualdl/server/lib.py | 231 ++++++++++++++++++++---------- visualdl/writer/writer.py | 120 +++++++++++----- 8 files changed, 726 insertions(+), 181 deletions(-) create mode 100644 visualdl/component/graph create mode 100644 visualdl/reader/graph_reader.py create mode 100644 visualdl/server/client_manager.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1ae4436d7..ef7474c75 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,3 +33,16 @@ repos: hooks: - id: check-hooks-apply - id: check-useless-excludes + - repo: https://github.com/asottile/reorder_python_imports + rev: v2.4.0 + hooks: + - id: reorder-python-imports + + - repo: local + hooks: + - id: yapf + name: yapf + entry: yapf + language: system + args: [-i, --style .style.yapf] + files: \.py$ diff --git a/visualdl/component/graph b/visualdl/component/graph new file mode 100644 index 000000000..cb904a0b4 --- /dev/null +++ b/visualdl/component/graph @@ -0,0 +1,25 @@ +from collections import deque + +import paddle.nn as nn +from paddle.fluid.framework import name_scope +from paddle.fluid.core import AttrType + +_name_scope_stack = deque() + +def _opname_creation_prehook(layer, inputs): + global _name_scope_stack + _name_scope_stack.append(name_scope(layer.full_name())) + _name_scope_stack[-1].__enter__() + +def _opname_creation_posthook(layer, inputs, outputs): + global _name_scope_stack + name_scope_manager = _name_scope_stack.pop() + name_scope_manager.__exit__(None, None, None) + + +def create_opname_scope(layer: nn.Layer): + layer.register_forward_pre_hook(_opname_creation_prehook) + for name, sublayer in layer.named_children(): + sublayer._full_name = '{}[{}]'.format(sublayer.__class__.__name__, name) + create_opname_scope(sublayer) + layer.register_forward_post_hook(_opname_creation_posthook) diff --git a/visualdl/reader/graph_reader.py b/visualdl/reader/graph_reader.py new file mode 100644 index 000000000..7e4c77582 --- /dev/null +++ b/visualdl/reader/graph_reader.py @@ -0,0 +1,206 @@ +# Copyright (c) 2022 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +import json +import os +import tempfile + +from visualdl.component.graph.graph_component import analyse_model +from visualdl.component.graph.netron_graph import Model +from visualdl.io import bfile + + +def is_VDLGraph_file(path): + """Determine whether it is a VDL graph file according to the file name. + + File name of a VDL graph file must contain `vdlgraph`. + + Args: + path: File name to determine. + Returns: + True if the file is a VDL graph file, otherwise false. + """ + if "vdlgraph" not in path and 'pdmodel' not in path: + return False + return True + + +class GraphReader(object): + """Graph reader to read vdl graph files, support for frontend api in lib.py. + """ + + def __init__(self, logdir=''): + """Instance of GraphReader + + Args: + logdir: The dir include vdl graph files, multiple subfolders allowed. + """ + if isinstance(logdir, str): + self.dir = [logdir] + else: + self.dir = logdir + + self.walks = {} + self.displayname2runs = {} + self.runs2displayname = {} + self.graph_buffer = {} + self.walks_buffer = {} + self.tempfile = None + + @property + def logdir(self): + return self.dir + + def get_all_walk(self): + flush_walks = {} + if 'manual_input_model' in self.walks: + flush_walks['manual_input_model'] = [ + self.walks['manual_input_model'] + ] + for dir in self.dir: + for root, dirs, files in bfile.walk(dir): + flush_walks.update({root: files}) + return flush_walks + + def graphs(self, update=False): + """Get graph files. + + Every dir(means `run` in vdl) has only one graph file(means `actual log file`). + + Returns: + walks: A dict like {"exp1": "vdlgraph.1587375595.log", + "exp2": "vdlgraph.1587375685.log"} + """ + if not self.walks or update is True: + flush_walks = self.get_all_walk() + + walks_temp = {} + for run, filenames in flush_walks.items(): + tags_temp = [ + filename for filename in filenames + if is_VDLGraph_file(filename) + ] + tags_temp.sort(reverse=True) + if len(tags_temp) > 0: + walks_temp.update({run: tags_temp[0]}) + self.walks = walks_temp + return self.walks + + def runs(self, update=True): + self.graphs(update=update) + return list(self.walks.keys()) + + def get_graph(self, + run, + nodeid=None, + expand=False, + keep_state=False, + expand_all=False, + refresh=False): + if run in self.walks: + if run in self.walks_buffer: + if self.walks[run] == self.walks_buffer[run]: + graph_model = self.graph_buffer[run] + if nodeid is not None: + graph_model.adjust_visible(nodeid, expand, keep_state) + return graph_model.make_graph( + refresh=refresh, expand_all=expand_all) + + data = bfile.BFile(bfile.join(run, self.walks[run]), 'rb').read() + graph_model = Model(json.loads(data.decode())) + self.graph_buffer[run] = graph_model + self.walks_buffer[run] = self.walks[run] + if nodeid is not None: + graph_model.adjust_visible(nodeid, expand, keep_state) + return graph_model.make_graph( + refresh=refresh, expand_all=expand_all) + + def search_graph_node(self, run, nodeid, keep_state=False, is_node=True): + if run in self.walks: + if run in self.walks_buffer: + if self.walks[run] == self.walks_buffer[run]: + graph_model = self.graph_buffer[run] + graph_model.adjust_search_node_visible( + nodeid, keep_state=keep_state, is_node=is_node) + return graph_model.make_graph( + refresh=False, expand_all=False) + + data = bfile.BFile(bfile.join(run, self.walks[run]), 'rb').read() + graph_model = Model(json.loads(data.decode())) + self.graph_buffer[run] = graph_model + self.walks_buffer[run] = self.walks[run] + graph_model.adjust_search_node_visible( + nodeid, keep_state=keep_state, is_node=is_node) + return graph_model.make_graph(refresh=False, expand_all=False) + + def get_all_nodes(self, run): + if run in self.walks: + if run in self.walks_buffer: + if self.walks[run] == self.walks_buffer[run]: + graph_model = self.graph_buffer[run] + return graph_model.get_all_leaf_nodes() + + data = bfile.BFile(bfile.join(run, self.walks[run]), 'rb').read() + graph_model = Model(json.loads(data.decode())) + self.graph_buffer[run] = graph_model + self.walks_buffer[run] = self.walks[run] + return graph_model.get_all_leaf_nodes() + + def set_displayname(self, log_reader): + self.displayname2runs = log_reader.name2tags + self.runs2displayname = log_reader.tags2name + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def __del__(self): + if self.tempfile: + os.unlink(self.tempfile.name) + + def set_input_graph(self, content, file_type='pdmodel'): + if isinstance(content, str): + if not is_VDLGraph_file(content): + return + if 'pdmodel' in content: + file_type = 'pdmodel' + else: + file_type = 'vdlgraph' + content = bfile.BFile(content, 'rb').read() + + if file_type == 'pdmodel': + data = analyse_model(content) + self.graph_buffer['manual_input_model'] = Model(data) + temp = tempfile.NamedTemporaryFile(suffix='.pdmodel', delete=False) + temp.write(json.dumps(data).encode()) + temp.close() + + elif file_type == 'vdlgraph': + self.graph_buffer['manual_input_model'] = Model( + json.loads(content.decode())) + temp = tempfile.NamedTemporaryFile( + suffix='.log', prefix='vdlgraph.', delete=False) + temp.write(content) + temp.close() + + else: + return + + if self.tempfile: + os.unlink(self.tempfile.name) + self.tempfile = temp + self.walks['manual_input_model'] = temp.name + self.walks_buffer['manual_input_model'] = temp.name diff --git a/visualdl/server/api.py b/visualdl/server/api.py index d81863a58..41034d560 100644 --- a/visualdl/server/api.py +++ b/visualdl/server/api.py @@ -1,5 +1,4 @@ #!/user/bin/env python - # Copyright (c) 2017 VisualDL Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,47 +13,55 @@ # See the License for the specific language governing permissions and # limitations under the License. # ======================================================================= - import functools import json import os +from flask import request + from visualdl import LogReader +from visualdl.python.cache import MemCache +from visualdl.reader.graph_reader import GraphReader from visualdl.server import lib +from visualdl.server.client_manager import ClientManager from visualdl.server.log import logger -from visualdl.python.cache import MemCache - error_retry_times = 3 error_sleep_time = 2 # seconds def gen_result(data=None, status=0, msg=''): - return { - 'status': status, - 'msg': msg, - 'data': data - } + return {'status': status, 'msg': msg, 'data': data} def result(mimetype='application/json', headers=None): def decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): - data = func(self, *args, **kwargs) + data = None + status = 0 + msg = '' + try: + data = func(self, *args, **kwargs) + except Exception as e: + msg = '{}'.format(e) + status = -1 if mimetype == 'application/json': - data = json.dumps(gen_result(data)) + data = json.dumps(gen_result(data, status, msg)) if callable(headers): headers_output = headers(self) else: headers_output = headers return data, mimetype, headers_output + return wrapper + return decorator def try_call(function, *args, **kwargs): - res = lib.retry(error_retry_times, function, error_sleep_time, *args, **kwargs) + res = lib.retry(error_retry_times, function, error_sleep_time, *args, + **kwargs) if not res: logger.error("Internal server error. Retry later.") return res @@ -63,12 +70,15 @@ def try_call(function, *args, **kwargs): class Api(object): def __init__(self, logdir, model, cache_timeout): self._reader = LogReader(logdir) + self._graph_reader = GraphReader(logdir) + self._graph_reader.set_displayname(self._reader) if model: self._reader.model = model + self._graph_reader.set_input_graph(model) self.model_name = os.path.basename(model) else: self.model_name = '' - + self.graph_reader_client_manager = ClientManager(self._graph_reader) # use a memory cache to reduce disk reading frequency. cache = MemCache(timeout=cache_timeout) self._cache = lib.cache_get(cache) @@ -79,6 +89,9 @@ def _get(self, key, func, *args, **kwargs): def _get_with_retry(self, key, func, *args, **kwargs): return self._cache(key, try_call, func, self._reader, *args, **kwargs) + def _get_with_reader(self, key, func, reader, *args, **kwargs): + return self._cache(key, func, reader, *args, **kwargs) + @result() def components(self): return self._get('data/components', lib.get_components) @@ -87,6 +100,11 @@ def components(self): def runs(self): return self._get('data/runs', lib.get_runs) + @result() + def graph_runs(self): + return self._get_with_reader('data/graph_runs', lib.get_graph_runs, + self._graph_reader) + @result() def tags(self): return self._get('data/tags', lib.get_tags) @@ -97,11 +115,13 @@ def logs(self): @result() def scalar_tags(self): - return self._get_with_retry('data/plugin/scalars/tags', lib.get_scalar_tags) + return self._get_with_retry('data/plugin/scalars/tags', + lib.get_scalar_tags) @result() def image_tags(self): - return self._get_with_retry('data/plugin/images/tags', lib.get_image_tags) + return self._get_with_retry('data/plugin/images/tags', + lib.get_image_tags) @result() def text_tags(self): @@ -109,31 +129,38 @@ def text_tags(self): @result() def audio_tags(self): - return self._get_with_retry('data/plugin/audio/tags', lib.get_audio_tags) + return self._get_with_retry('data/plugin/audio/tags', + lib.get_audio_tags) @result() def embedding_tags(self): - return self._get_with_retry('data/plugin/embeddings/tags', lib.get_embeddings_tags) + return self._get_with_retry('data/plugin/embeddings/tags', + lib.get_embeddings_tags) @result() def pr_curve_tags(self): - return self._get_with_retry('data/plugin/pr_curves/tags', lib.get_pr_curve_tags) + return self._get_with_retry('data/plugin/pr_curves/tags', + lib.get_pr_curve_tags) @result() def roc_curve_tags(self): - return self._get_with_retry('data/plugin/roc_curves/tags', lib.get_roc_curve_tags) + return self._get_with_retry('data/plugin/roc_curves/tags', + lib.get_roc_curve_tags) @result() def hparam_importance(self): - return self._get_with_retry('data/plugin/hparams/importance', lib.get_hparam_importance) + return self._get_with_retry('data/plugin/hparams/importance', + lib.get_hparam_importance) @result() def hparam_indicator(self): - return self._get_with_retry('data/plugin/hparams/indicators', lib.get_hparam_indicator) + return self._get_with_retry('data/plugin/hparams/indicators', + lib.get_hparam_indicator) @result() def hparam_list(self): - return self._get_with_retry('data/plugin/hparams/list', lib.get_hparam_list) + return self._get_with_retry('data/plugin/hparams/list', + lib.get_hparam_list) @result() def hparam_metric(self, run, metric): @@ -163,8 +190,10 @@ def image_list(self, mode, tag): @result('image/png') def image_image(self, mode, tag, index=0): index = int(index) - key = os.path.join('data/plugin/images/individualImage', mode, tag, str(index)) - return self._get_with_retry(key, lib.get_individual_image, mode, tag, index) + key = os.path.join('data/plugin/images/individualImage', mode, tag, + str(index)) + return self._get_with_retry(key, lib.get_individual_image, mode, tag, + index) @result() def text_list(self, mode, tag): @@ -174,8 +203,10 @@ def text_list(self, mode, tag): @result('text/plain') def text_text(self, mode, tag, index=0): index = int(index) - key = os.path.join('data/plugin/text/individualText', mode, tag, str(index)) - return self._get_with_retry(key, lib.get_individual_text, mode, tag, index) + key = os.path.join('data/plugin/text/individualText', mode, tag, + str(index)) + return self._get_with_retry(key, lib.get_individual_text, mode, tag, + index) @result() def audio_list(self, run, tag): @@ -185,18 +216,27 @@ def audio_list(self, run, tag): @result('audio/wav') def audio_audio(self, run, tag, index=0): index = int(index) - key = os.path.join('data/plugin/audio/individualAudio', run, tag, str(index)) - return self._get_with_retry(key, lib.get_individual_audio, run, tag, index) + key = os.path.join('data/plugin/audio/individualAudio', run, tag, + str(index)) + return self._get_with_retry(key, lib.get_individual_audio, run, tag, + index) @result() - def embedding_embedding(self, run, tag='default', reduction='pca', dimension=2): + def embedding_embedding(self, + run, + tag='default', + reduction='pca', + dimension=2): dimension = int(dimension) - key = os.path.join('data/plugin/embeddings/embeddings', run, str(dimension), reduction) - return self._get_with_retry(key, lib.get_embeddings, run, tag, reduction, dimension) + key = os.path.join('data/plugin/embeddings/embeddings', run, + str(dimension), reduction) + return self._get_with_retry(key, lib.get_embeddings, run, tag, + reduction, dimension) @result() def embedding_list(self): - return self._get_with_retry('data/plugin/embeddings/list', lib.get_embeddings_list) + return self._get_with_retry('data/plugin/embeddings/list', + lib.get_embeddings_list) @result('text/tab-separated-values') def embedding_metadata(self, name): @@ -210,7 +250,8 @@ def embedding_tensor(self, name): @result() def histogram_tags(self): - return self._get_with_retry('data/plugin/histogram/tags', lib.get_histogram_tags) + return self._get_with_retry('data/plugin/histogram/tags', + lib.get_histogram_tags) @result() def histogram_list(self, run, tag): @@ -237,13 +278,88 @@ def roc_curves_steps(self, run): key = os.path.join('data/plugin/roc_curves/steps', run) return self._get_with_retry(key, lib.get_roc_curve_step, run) - @result( - 'application/octet-stream', - lambda s: {"Content-Disposition": 'attachment; filename="%s"' % s.model_name} if len(s.model_name) else None - ) - def graph_graph(self): - key = os.path.join('data/plugin/graphs/graph') - return self._get_with_retry(key, lib.get_graph) + @result() + def graph_graph(self, run, expand_all, refresh): + client_ip = request.remote_addr + graph_reader = self.graph_reader_client_manager.get_data(client_ip) + if expand_all is not None: + if (expand_all.lower() == 'true'): + expand_all = True + else: + expand_all = False + else: + expand_all = False + if refresh is not None: + if (refresh.lower() == 'true'): + refresh = True + else: + refresh = False + else: + refresh = True + return lib.get_graph( + graph_reader, run, expand_all=expand_all, refresh=refresh) + + @result() + def graph_upload(self): + client_ip = request.remote_addr + graph_reader = self.graph_reader_client_manager.get_data(client_ip) + files = request.files + if 'file' in files: + file_handle = request.files['file'] + if 'pdmodel' in file_handle.filename: + graph_reader.set_input_graph(file_handle.stream.read(), + 'pdmodel') + elif 'vdlgraph' in file_handle.filename: + graph_reader.set_input_graph(file_handle.stream.read(), + 'vdlgraph') + + @result() + def graph_manipulate(self, run, nodeid, expand, keep_state): + client_ip = request.remote_addr + graph_reader = self.graph_reader_client_manager.get_data(client_ip) + if expand is not None: + if (expand.lower() == 'true'): + expand = True + else: + expand = False + else: + expand = False + if keep_state is not None: + if (keep_state.lower() == 'true'): + keep_state = True + else: + keep_state = False + else: + keep_state = False + return lib.get_graph(graph_reader, run, nodeid, expand, keep_state) + + @result() + def graph_search(self, run, nodeid, keep_state, is_node): + client_ip = request.remote_addr + graph_reader = self.graph_reader_client_manager.get_data(client_ip) + if keep_state is not None: + if (keep_state.lower() == 'true'): + keep_state = True + else: + keep_state = False + else: + keep_state = False + + if is_node is not None: + if (is_node.lower() == 'true'): + is_node = True + else: + is_node = False + else: + is_node = False + return lib.get_graph_search(graph_reader, run, nodeid, keep_state, + is_node) + + @result() + def graph_get_all_nodes(self, run): + client_ip = request.remote_addr + graph_reader = self.graph_reader_client_manager.get_data(client_ip) + return lib.get_graph_all_nodes(graph_reader, run) def create_api_call(logdir, model, cache_timeout): @@ -251,6 +367,7 @@ def create_api_call(logdir, model, cache_timeout): routes = { 'components': (api.components, []), 'runs': (api.runs, []), + 'graph_runs': (api.graph_runs, []), 'tags': (api.tags, []), 'logs': (api.logs, []), 'scalar/tags': (api.scalar_tags, []), @@ -269,12 +386,19 @@ def create_api_call(logdir, model, cache_timeout): 'text/text': (api.text_text, ['run', 'tag', 'index']), 'audio/list': (api.audio_list, ['run', 'tag']), 'audio/audio': (api.audio_audio, ['run', 'tag', 'index']), - 'embedding/embedding': (api.embedding_embedding, ['run', 'tag', 'reduction', 'dimension']), + 'embedding/embedding': (api.embedding_embedding, + ['run', 'tag', 'reduction', 'dimension']), 'embedding/list': (api.embedding_list, []), 'embedding/tensor': (api.embedding_tensor, ['name']), 'embedding/metadata': (api.embedding_metadata, ['name']), 'histogram/list': (api.histogram_list, ['run', 'tag']), - 'graph/graph': (api.graph_graph, []), + 'graph/graph': (api.graph_graph, ['run', 'expand_all', 'refresh']), + 'graph/upload': (api.graph_upload, []), + 'graph/search': (api.graph_search, + ['run', 'nodeid', 'keep_state', 'is_node']), + 'graph/get_all_nodes': (api.graph_get_all_nodes, ['run']), + 'graph/manipulate': (api.graph_manipulate, + ['run', 'nodeid', 'expand', 'keep_state']), 'pr-curve/list': (api.pr_curves_pr_curve, ['run', 'tag']), 'roc-curve/list': (api.roc_curves_roc_curve, ['run', 'tag']), 'pr-curve/steps': (api.pr_curves_steps, ['run']), @@ -289,7 +413,8 @@ def create_api_call(logdir, model, cache_timeout): def call(path: str, args): route = routes.get(path) if not route: - return json.dumps(gen_result(status=1, msg='api not found')), 'application/json', None + return json.dumps(gen_result( + status=1, msg='api not found')), 'application/json', None method, call_arg_names = route call_args = [args.get(name) for name in call_arg_names] return method(*call_args) diff --git a/visualdl/server/app.py b/visualdl/server/app.py index cf10b5b33..ea6731e94 100644 --- a/visualdl/server/app.py +++ b/visualdl/server/app.py @@ -1,5 +1,4 @@ #!/user/bin/env python - # Copyright (c) 2017 VisualDL Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,28 +13,32 @@ # See the License for the specific language governing permissions and # limitations under the License. # ======================================================================= - +import multiprocessing import os -import time +import re import sys -import multiprocessing import threading -import re +import time import webbrowser -import requests -from visualdl import __version__ -from visualdl.utils import update_util - -from flask import (Flask, Response, redirect, request, send_file, make_response) +import requests +from flask import Flask +from flask import make_response +from flask import redirect +from flask import request +from flask import Response +from flask import send_file from flask_babel import Babel import visualdl.server +from visualdl import __version__ from visualdl.server.api import create_api_call -from visualdl.server.serve import upload_to_dev -from visualdl.server.args import (ParseArgs, parse_args) +from visualdl.server.args import parse_args +from visualdl.server.args import ParseArgs from visualdl.server.log import info +from visualdl.server.serve import upload_to_dev from visualdl.server.template import Template +from visualdl.utils import update_util SERVER_DIR = os.path.join(visualdl.ROOT, 'server') @@ -90,9 +93,9 @@ def get_locale(): PUBLIC_PATH=public_path, BASE_URI=public_path, API_URL=api_path, - TELEMETRY_ID='63a600296f8a71f576c4806376a9245b' if args.telemetry else '', - THEME='' if args.theme is None else args.theme - ) + TELEMETRY_ID='63a600296f8a71f576c4806376a9245b' + if args.telemetry else '', + THEME='' if args.theme is None else args.theme) @app.route('/') def base(): @@ -107,20 +110,29 @@ def favicon(): @app.route(public_path + '/') def index(): - return redirect(append_query_string(public_path + '/index'), code=302) + return redirect( + append_query_string(public_path + '/index'), code=302) @app.route(public_path + '/') def serve_static(filename): is_not_page_request = re.search(r'\..+$', filename) - response = template.render(filename if is_not_page_request else 'index.html') + response = template.render( + filename if is_not_page_request else 'index.html') if not is_not_page_request: - response.set_cookie('vdl_lng', get_locale(), path='/', samesite='Strict', secure=False, httponly=False) + response.set_cookie( + 'vdl_lng', + get_locale(), + path='/', + samesite='Strict', + secure=False, + httponly=False) return response - @app.route(api_path + '/') + @app.route(api_path + '/', methods=["GET", "POST"]) def serve_api(method): data, mimetype, headers = api_call(method, request.args) - return make_response(Response(data, mimetype=mimetype, headers=headers)) + return make_response( + Response(data, mimetype=mimetype, headers=headers)) @app.route(check_live_path) def check_live(): @@ -134,13 +146,17 @@ def wait_until_live(args: ParseArgs): while True: try: requests.get(url + check_live_path) - info('Running VisualDL at http://%s:%s/ (Press CTRL+C to quit)', args.host, args.port) + info('Running VisualDL at http://%s:%s/ (Press CTRL+C to quit)', + args.host, args.port) if args.host == 'localhost': - info('Serving VisualDL on localhost; to expose to the network, use a proxy or pass --host 0.0.0.0') + info( + 'Serving VisualDL on localhost; to expose to the network, use a proxy or pass --host 0.0.0.0' + ) if args.api_only: - info('Running in API mode, only %s/* will be served.', args.public_path + '/api') + info('Running in API mode, only %s/* will be served.', + args.public_path + '/api') break except Exception: @@ -154,16 +170,14 @@ def _run(args): os.system('') info('\033[1;33mVisualDL %s\033[0m', __version__) app = create_app(args) - threading.Thread(target=wait_until_live, args=(args,)).start() + threading.Thread(target=wait_until_live, args=(args, )).start() app.run(debug=False, host=args.host, port=args.port, threaded=False) def run(logdir=None, **options): - args = { - 'logdir': logdir - } + args = {'logdir': logdir} args.update(options) - p = multiprocessing.Process(target=_run, args=(args,)) + p = multiprocessing.Process(target=_run, args=(args, )) p.start() return p.pid diff --git a/visualdl/server/client_manager.py b/visualdl/server/client_manager.py new file mode 100644 index 000000000..25fd964cc --- /dev/null +++ b/visualdl/server/client_manager.py @@ -0,0 +1,31 @@ +# Copyright (c) 2022 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +import copy + + +class ClientManager: + ''' + This class manages data with status like graph. For data with status but managed by backend, + we should prevent data for different clients interfere with each other. + ''' + + def __init__(self, data): + self._proto_data = data + self.ip_data_map = {} + + def get_data(self, ip): + if ip not in self.ip_data_map: + self.ip_data_map[ip] = copy.deepcopy(self._proto_data) + return self.ip_data_map[ip] diff --git a/visualdl/server/lib.py b/visualdl/server/lib.py index 1e03e3bc7..09762a6cf 100644 --- a/visualdl/server/lib.py +++ b/visualdl/server/lib.py @@ -12,23 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ======================================================================= - from __future__ import absolute_import -from functools import partial # noqa: F401 -import sys -import time -import os -import io + import csv +import io import math +import os +import sys +import time +from functools import partial # noqa: F401 + import numpy as np -from visualdl.server.log import logger + +from visualdl.component import components from visualdl.io import bfile -from visualdl.utils.string_util import encode_tag, decode_tag +from visualdl.server.log import logger from visualdl.utils.importance import calc_all_hyper_param_importance from visualdl.utils.list_util import duplicate_removal -from visualdl.component import components - +from visualdl.utils.string_util import decode_tag +from visualdl.utils.string_util import encode_tag MODIFY_PREFIX = {} MODIFIED_RUNS = [] @@ -61,6 +63,16 @@ def get_runs(log_reader): return runs +def get_graph_runs(graph_reader): + runs = [] + for item in graph_reader.runs(): + if item in graph_reader.runs2displayname: + runs.append(graph_reader.runs2displayname[item]) + else: + runs.append(item) + return runs + + def get_tags(log_reader): return log_reader.tags() @@ -103,7 +115,7 @@ def get_logs(log_reader, component): continue index = key.find(run_prefix) if index != -1: - temp_key = key[index+len(run_prefix):] + temp_key = key[index + len(run_prefix):] log_reader.name2tags.pop(key) log_reader.name2tags.update({temp_key: value}) @@ -168,8 +180,12 @@ def get_hparam_data(log_reader, type='tsv'): def get_hparam_importance(log_reader): indicator = get_hparam_indicator(log_reader) - hparams = [item for item in indicator['hparams'] if (item['type'] != 'string')] - metrics = [item for item in indicator['metrics'] if (item['type'] != 'string')] + hparams = [ + item for item in indicator['hparams'] if (item['type'] != 'string') + ] + metrics = [ + item for item in indicator['metrics'] if (item['type'] != 'string') + ] result = calc_all_hyper_param_importance(hparams, metrics) @@ -186,8 +202,8 @@ def get_hparam_indicator(log_reader): for run in runs: run = log_reader.name2tags[run] if run in log_reader.name2tags else run log_reader.load_new_data() - records = log_reader.data_manager.get_reservoir("hyper_parameters").get_items( - run, decode_tag('hparam')) + records = log_reader.data_manager.get_reservoir( + "hyper_parameters").get_items(run, decode_tag('hparam')) records_list.append([records, run]) records_list.sort(key=lambda x: x[0][0].timestamp) runs = [run for r, run in records_list] @@ -196,46 +212,67 @@ def get_hparam_indicator(log_reader): type = hparamInfo.WhichOneof("type") if "float_value" == type: if hparamInfo.name not in hparams.keys(): - hparams[hparamInfo.name] = {'name': hparamInfo.name, - 'type': 'continuous', - 'values': [hparamInfo.float_value]} - elif hparamInfo.float_value not in hparams[hparamInfo.name]['values']: - hparams[hparamInfo.name]['values'].append(hparamInfo.float_value) + hparams[hparamInfo.name] = { + 'name': hparamInfo.name, + 'type': 'continuous', + 'values': [hparamInfo.float_value] + } + elif hparamInfo.float_value not in hparams[ + hparamInfo.name]['values']: + hparams[hparamInfo.name]['values'].append( + hparamInfo.float_value) elif "string_value" == type: if hparamInfo.name not in hparams.keys(): - hparams[hparamInfo.name] = {'name': hparamInfo.name, - 'type': 'string', - 'values': [hparamInfo.string_value]} - elif hparamInfo.string_value not in hparams[hparamInfo.name]['values']: - hparams[hparamInfo.name]['values'].append(hparamInfo.string_value) + hparams[hparamInfo.name] = { + 'name': hparamInfo.name, + 'type': 'string', + 'values': [hparamInfo.string_value] + } + elif hparamInfo.string_value not in hparams[ + hparamInfo.name]['values']: + hparams[hparamInfo.name]['values'].append( + hparamInfo.string_value) elif "int_value" == type: if hparamInfo.name not in hparams.keys(): - hparams[hparamInfo.name] = {'name': hparamInfo.name, - 'type': 'numeric', - 'values': [hparamInfo.int_value]} - elif hparamInfo.int_value not in hparams[hparamInfo.name]['values']: - hparams[hparamInfo.name]['values'].append(hparamInfo.int_value) + hparams[hparamInfo.name] = { + 'name': hparamInfo.name, + 'type': 'numeric', + 'values': [hparamInfo.int_value] + } + elif hparamInfo.int_value not in hparams[ + hparamInfo.name]['values']: + hparams[hparamInfo.name]['values'].append( + hparamInfo.int_value) else: - raise TypeError("Invalid hparams param value type `%s`." % type) + raise TypeError( + "Invalid hparams param value type `%s`." % type) for metricInfo in records[0].hparam.metricInfos: - metrics[metricInfo.name] = {'name': metricInfo.name, - 'type': 'continuous', - 'values': []} + metrics[metricInfo.name] = { + 'name': metricInfo.name, + 'type': 'continuous', + 'values': [] + } for run in runs: try: - metrics_data = get_hparam_metric(log_reader, run, metricInfo.name) - metrics[metricInfo.name]['values'].append(metrics_data[-1][-1]) + metrics_data = get_hparam_metric(log_reader, run, + metricInfo.name) + metrics[metricInfo.name]['values'].append( + metrics_data[-1][-1]) break except: - logger.error('Missing data of metrics! Please make sure use add_scalar to log metrics data.') + logger.error( + 'Missing data of metrics! Please make sure use add_scalar to log metrics data.' + ) if len(metrics[metricInfo.name]['values']) == 0: metrics.pop(metricInfo.name) else: metrics[metricInfo.name].pop('values') - results = {'hparams': [value for key, value in hparams.items()], - 'metrics': [value for key, value in metrics.items()]} + results = { + 'hparams': [value for key, value in hparams.items()], + 'metrics': [value for key, value in metrics.items()] + } return results @@ -245,7 +282,10 @@ def get_hparam_metric(log_reader, run, tag): log_reader.load_new_data() records = log_reader.data_manager.get_reservoir("scalar").get_items( run, decode_tag(tag)) - results = [[s2ms(item.timestamp), item.id, transfer_abnomal_scalar_value(item.value)] for item in records] + results = [[ + s2ms(item.timestamp), item.id, + transfer_abnomal_scalar_value(item.value) + ] for item in records] return results @@ -258,8 +298,8 @@ def get_hparam_list(log_reader): for run in runs: run = log_reader.name2tags[run] if run in log_reader.name2tags else run log_reader.load_new_data() - records = log_reader.data_manager.get_reservoir("hyper_parameters").get_items( - run, decode_tag('hparam')) + records = log_reader.data_manager.get_reservoir( + "hyper_parameters").get_items(run, decode_tag('hparam')) records_list.append([records, run]) records_list.sort(key=lambda x: x[0][0].timestamp) for records, run in records_list: @@ -273,20 +313,22 @@ def get_hparam_list(log_reader): elif "int_value" == hparam_type: hparams[hparamInfo.name] = hparamInfo.int_value else: - raise TypeError("Invalid hparams param value type `%s`." % hparam_type) + raise TypeError( + "Invalid hparams param value type `%s`." % hparam_type) metrics = {} for metricInfo in records[0].hparam.metricInfos: try: - metrics_data = get_hparam_metric(log_reader, run, metricInfo.name) + metrics_data = get_hparam_metric(log_reader, run, + metricInfo.name) metrics[metricInfo.name] = metrics_data[-1][-1] except: - logger.error('Missing data of metrics! Please make sure use add_scalar to log metrics data.') + logger.error( + 'Missing data of metrics! Please make sure use add_scalar to log metrics data.' + ) metrics[metricInfo.name] = None - results.append({'name': run, - 'hparams': hparams, - 'metrics': metrics}) + results.append({'name': run, 'hparams': hparams, 'metrics': metrics}) return results @@ -295,7 +337,10 @@ def get_scalar(log_reader, run, tag): log_reader.load_new_data() records = log_reader.data_manager.get_reservoir("scalar").get_items( run, decode_tag(tag)) - results = [[s2ms(item.timestamp), item.id, transfer_abnomal_scalar_value(item.value)] for item in records] + results = [[ + s2ms(item.timestamp), item.id, + transfer_abnomal_scalar_value(item.value) + ] for item in records] return results @@ -383,15 +428,15 @@ def get_pr_curve(log_reader, run, tag): pr_curve = item.pr_curve length = len(pr_curve.precision) num_thresholds = [float(v) / length for v in range(1, length + 1)] - results.append([s2ms(item.timestamp), - item.id, - list(pr_curve.precision), - list(pr_curve.recall), - list(pr_curve.TP), - list(pr_curve.FP), - list(pr_curve.TN), - list(pr_curve.FN), - num_thresholds]) + results.append([ + s2ms(item.timestamp), item.id, + list(pr_curve.precision), + list(pr_curve.recall), + list(pr_curve.TP), + list(pr_curve.FP), + list(pr_curve.TN), + list(pr_curve.FN), num_thresholds + ]) return results @@ -405,15 +450,15 @@ def get_roc_curve(log_reader, run, tag): roc_curve = item.roc_curve length = len(roc_curve.tpr) num_thresholds = [float(v) / length for v in range(1, length + 1)] - results.append([s2ms(item.timestamp), - item.id, - list(roc_curve.tpr), - list(roc_curve.fpr), - list(roc_curve.TP), - list(roc_curve.FP), - list(roc_curve.TN), - list(roc_curve.FN), - num_thresholds]) + results.append([ + s2ms(item.timestamp), item.id, + list(roc_curve.tpr), + list(roc_curve.fpr), + list(roc_curve.TP), + list(roc_curve.FP), + list(roc_curve.TN), + list(roc_curve.FN), num_thresholds + ]) return results @@ -451,12 +496,16 @@ def get_embeddings_list(log_reader): if name in EMBEDDING_NAME: return embedding_names EMBEDDING_NAME.update({name: {'run': run, 'tag': tag}}) - records = log_reader.data_manager.get_reservoir("embeddings").get_items( - run, decode_tag(tag)) + records = log_reader.data_manager.get_reservoir( + "embeddings").get_items(run, decode_tag(tag)) row_len = len(records[0].embeddings.embeddings) col_len = len(records[0].embeddings.embeddings[0].vectors) shape = [row_len, col_len] - embedding_names.append({'name': name, 'shape': shape, 'path': path}) + embedding_names.append({ + 'name': name, + 'shape': shape, + 'path': path + }) return embedding_names @@ -509,17 +558,45 @@ def get_histogram(log_reader, run, tag): bin_edges = histogram.bin_edges histogram_data = [] for index in range(len(hist)): - histogram_data.append([bin_edges[index], bin_edges[index+1], hist[index]]) + histogram_data.append( + [bin_edges[index], bin_edges[index + 1], hist[index]]) results.append([s2ms(item.timestamp), item.id, histogram_data]) return results -def get_graph(log_reader): - result = b"" - if log_reader.model: - with bfile.BFile(log_reader.model, 'rb') as bfp: - result = bfp.read_file(log_reader.model) +def get_graph(graph_reader, + run, + nodeid=None, + expand=False, + keep_state=False, + expand_all=False, + refresh=True): + result = "" + run = graph_reader.displayname2runs[ + run] if run in graph_reader.displayname2runs else run + if nodeid is not None: + refresh = False + result = graph_reader.get_graph(run, nodeid, expand, keep_state, + expand_all, refresh) + return result + + +def get_graph_search(graph_reader, run, nodeid, keep_state=False, + is_node=True): + result = "" + run = graph_reader.displayname2runs[ + run] if run in graph_reader.displayname2runs else run + result = graph_reader.search_graph_node( + run, nodeid, keep_state=keep_state, is_node=is_node) + return result + + +def get_graph_all_nodes(graph_reader, run): + result = "" + run = graph_reader.displayname2runs[ + run] if run in graph_reader.displayname2runs else run + result = graph_reader.get_all_nodes(run) return result @@ -532,7 +609,7 @@ def retry(ntimes, function, time2sleep, *args, **kwargs): try: return function(*args, **kwargs) except Exception: - if i < ntimes-1: + if i < ntimes - 1: error_info = '\n'.join(map(str, sys.exc_info())) logger.error("Unexpected error: %s" % error_info) time.sleep(time2sleep) diff --git a/visualdl/writer/writer.py b/visualdl/writer/writer.py index b7a6b0045..73b65eebc 100644 --- a/visualdl/writer/writer.py +++ b/visualdl/writer/writer.py @@ -12,17 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ======================================================================= - import os import time + import numpy as np -from visualdl.writer.record_writer import RecordFileWriter + +from visualdl.component.base_component import audio +from visualdl.component.base_component import embedding +from visualdl.component.base_component import histogram +from visualdl.component.base_component import hparam +from visualdl.component.base_component import image +from visualdl.component.base_component import meta_data +from visualdl.component.base_component import pr_curve +from visualdl.component.base_component import roc_curve +from visualdl.component.base_component import scalar +from visualdl.component.base_component import text +from visualdl.component.graph import translate_graph +from visualdl.io import bfile from visualdl.server.log import logger -from visualdl.utils.img_util import merge_images from visualdl.utils.figure_util import figure_to_image +from visualdl.utils.img_util import merge_images from visualdl.utils.md5_util import md5 -from visualdl.component.base_component import scalar, image, embedding, audio, \ - histogram, pr_curve, roc_curve, meta_data, text, hparam +from visualdl.writer.record_writer import RecordFileWriter class DummyFileWriter(object): @@ -133,7 +144,11 @@ def _get_file_writer(self): def file_name(self): return self._file_writer.get_filename() - def add_meta(self, tag='meta_data_tag', display_name='', step=0, walltime=None): + def add_meta(self, + tag='meta_data_tag', + display_name='', + step=0, + walltime=None): """Add a meta to vdl record file. Args: @@ -146,8 +161,11 @@ def add_meta(self, tag='meta_data_tag', display_name='', step=0, walltime=None): raise RuntimeError("% can't appear in tag!") walltime = round(time.time() * 1000) if walltime is None else walltime self._get_file_writer().add_record( - meta_data(tag=tag, display_name=display_name, step=step, - walltime=walltime)) + meta_data( + tag=tag, + display_name=display_name, + step=step, + walltime=walltime)) def add_scalar(self, tag, value, step, walltime=None): """Add a scalar to vdl record file. @@ -191,8 +209,12 @@ def add_image(self, tag, img, step, walltime=None, dataformats="HWC"): raise RuntimeError("% can't appear in tag!") walltime = round(time.time() * 1000) if walltime is None else walltime self._get_file_writer().add_record( - image(tag=tag, image_array=img, step=step, walltime=walltime, - dataformats=dataformats)) + image( + tag=tag, + image_array=img, + step=step, + walltime=walltime, + dataformats=dataformats)) def add_figure(self, tag, figure, step, walltime=None): """Add an figure to vdl record file. @@ -234,14 +256,20 @@ def add_text(self, tag, text_string, step=None, walltime=None): """ if '%' in tag: raise RuntimeError("% can't appear in tag!") - walltime = round( - time.time() * 1000) if walltime is None else walltime + walltime = round(time.time() * 1000) if walltime is None else walltime self._get_file_writer().add_record( text( tag=tag, text_string=text_string, step=step, walltime=walltime)) - def add_image_matrix(self, tag, imgs, step, rows=-1, scale=1.0, walltime=None, dataformats="HWC"): + def add_image_matrix(self, + tag, + imgs, + step, + rows=-1, + scale=1.0, + walltime=None, + dataformats="HWC"): """Add an image to vdl record file. Args: @@ -264,16 +292,24 @@ def add_image_matrix(self, tag, imgs, step, rows=-1, scale=1.0, walltime=None, d if '%' in tag: raise RuntimeError("% can't appear in tag!") walltime = round(time.time() * 1000) if walltime is None else walltime - img = merge_images(imgs=imgs, dataformats=dataformats, scale=scale, rows=rows) - self.add_image(tag=tag, - img=img, - step=step, - walltime=walltime, - dataformats=dataformats) - - def add_embeddings(self, tag, mat=None, metadata=None, - metadata_header=None, walltime=None, labels=None, - hot_vectors=None, labels_meta=None): + img = merge_images( + imgs=imgs, dataformats=dataformats, scale=scale, rows=rows) + self.add_image( + tag=tag, + img=img, + step=step, + walltime=walltime, + dataformats=dataformats) + + def add_embeddings(self, + tag, + mat=None, + metadata=None, + metadata_header=None, + walltime=None, + labels=None, + hot_vectors=None, + labels_meta=None): """Add embeddings to vdl record file. Args: @@ -407,12 +443,7 @@ def add_audio(self, step=step, walltime=walltime)) - def add_histogram(self, - tag, - values, - step, - walltime=None, - buckets=10): + def add_histogram(self, tag, values, step, walltime=None, buckets=10): """Add an histogram to vdl record file. Args: @@ -522,8 +553,7 @@ def add_pr_curve(self, step=step, walltime=walltime, num_thresholds=num_thresholds, - weights=weights - )) + weights=weights)) def add_roc_curve(self, tag, @@ -564,8 +594,32 @@ def add_roc_curve(self, step=step, walltime=walltime, num_thresholds=num_thresholds, - weights=weights - )) + weights=weights)) + + def add_graph(self, model, input_spec, verbose=False): + """ + Add a model graph to vdl graph file. + Args: + model (paddle.nn.Layer): Model to draw. + input_spec (list[paddle.static.InputSpec]): Describes the input of the saved model's forward arguments. + verbose (bool): Whether to print graph structure in console. + Example: + with LogWriter(logdir="./log/graph_test/train") as writer: + writer.add_graph(model=net, + input_spec=[[1,3,256,256]], + verbose=True) + """ + try: + result = translate_graph(model, input_spec, verbose) + except Exception as e: + print("Failed to save model graph, error: {}".format(e)) + return + graph_file_name = bfile.join( + self.logdir, + "vdlgraph.%010d.log%s" % (time.time(), self._filename_suffix)) + writer = bfile.BFile(graph_file_name, "w") + writer.write(result) + writer.close() def flush(self): """Flush all data in cache to disk. From 88667ca647b48feadacd00827aecd865c215358e Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 13 Jun 2022 16:40:22 +0800 Subject: [PATCH 2/9] fix some styles --- visualdl/component/graph/__init__.py | 19 ++ visualdl/component/graph/exporter.py | 36 +++ visualdl/component/graph/graph_component.py | 318 +++++++++++++++++++ visualdl/component/graph/netron_graph.py | 240 ++++++++++++++ visualdl/component/{graph => graph/utils.py} | 8 +- visualdl/reader/graph_reader.py | 4 +- 6 files changed, 620 insertions(+), 5 deletions(-) create mode 100644 visualdl/component/graph/__init__.py create mode 100644 visualdl/component/graph/exporter.py create mode 100644 visualdl/component/graph/graph_component.py create mode 100644 visualdl/component/graph/netron_graph.py rename visualdl/component/{graph => graph/utils.py} (88%) diff --git a/visualdl/component/graph/__init__.py b/visualdl/component/graph/__init__.py new file mode 100644 index 000000000..bb89f7dc2 --- /dev/null +++ b/visualdl/component/graph/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2022 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +from .exporter import translate_graph +from .graph_component import analyse_model +from .netron_graph import Model + +__all__ = ['translate_graph', 'analyse_model', 'Model'] diff --git a/visualdl/component/graph/exporter.py b/visualdl/component/graph/exporter.py new file mode 100644 index 000000000..580763b1a --- /dev/null +++ b/visualdl/component/graph/exporter.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +import json +import os +import tempfile + +from .graph_component import analyse_model +from .utils import create_opname_scope + + +def translate_graph(model, input_spec, verbose=True): + import paddle + with tempfile.TemporaryDirectory() as tmp: + model._full_name = '{}[{}]'.format(model.__class__.__name__, "model") + create_opname_scope(model) + paddle.jit.save(model, os.path.join(tmp, 'temp'), input_spec) + model_data = open(os.path.join(tmp, 'temp.pdmodel'), 'rb').read() + result = analyse_model(model_data) + if verbose: + from paddle.core import ProgramDesc + program_desc = ProgramDesc(model_data) + print(program_desc) + result = json.dumps(result, indent=2) + return result diff --git a/visualdl/component/graph/graph_component.py b/visualdl/component/graph/graph_component.py new file mode 100644 index 000000000..9c392af18 --- /dev/null +++ b/visualdl/component/graph/graph_component.py @@ -0,0 +1,318 @@ +# Copyright (c) 2022 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +import collections +import os.path +import re + +from paddle.core import AttrType + +attr_type_name = { + AttrType.INT: "INT", + AttrType.INTS: "INTS", + AttrType.LONG: "LONG", + AttrType.LONGS: "LONGS", + AttrType.FLOAT: "FLOAT", + AttrType.FLOATS: "FLOATS", + AttrType.STRING: "STRING", + AttrType.STRINGS: "STRINGS", + AttrType.BOOL: "BOOL", + AttrType.BOOLS: "BOOLS", + AttrType.BLOCK: "BLOCK", + AttrType.BLOCKS: "BLOCKS" +} + +_graph_version = '1.0.0' + + +def analyse_model(model_pb): + from paddle.core import ProgramDesc + from paddle.utils.unique_name import generate + program_desc = ProgramDesc(model_pb) + all_ops = {} + all_vars = {} + all_edges = {} + op_inputvars_dict = collections.defaultdict(list) + op_outputvars_dict = collections.defaultdict(list) + for i in range(program_desc.num_blocks()): + block_desc = program_desc.block(i) + # vars info + for i, var_desc in enumerate(block_desc.all_vars()): + try: + var_name = var_desc.name() + all_vars[var_name] = {} + all_vars[var_name]['name'] = var_name + all_vars[var_name]['shape'] = var_desc.shape() + all_vars[var_name]['type'] = str(var_desc.type()) + all_vars[var_name]['dtype'] = str(var_desc.dtype()) + all_vars[var_name]['value'] = [] + all_vars[var_name]['persistable'] = var_desc.persistable() + attr_dict = {} + for attr_name in var_desc.attr_names(): + attr_dict[attr_name] = var_desc.attr(attr_name) + all_vars[var_name]['attrs'] = attr_dict + all_vars[var_name]['from_node'] = '' + all_vars[var_name]['to_nodes'] = [] + + except Exception: + # feed, fetch var + var_name = var_desc.name() + all_vars[var_name] = {} + all_vars[var_name]['name'] = var_name + all_vars[var_name]['shape'] = '' + all_vars[var_name]['type'] = str(var_desc.type()) + all_vars[var_name]['dtype'] = '' + all_vars[var_name]['value'] = [] + all_vars[var_name]['persistable'] = var_desc.persistable() + attr_dict = {} + for attr_name in var_desc.attr_names(): + attr_dict[attr_name] = var_desc.attr(attr_name) + all_vars[var_name]['attrs'] = attr_dict + all_vars[var_name]['from_node'] = '' + all_vars[var_name]['to_nodes'] = [] + + # ops info + for i in range(block_desc.op_size()): + op_desc = block_desc.op(i) + op_name = op_desc.attr('op_namescope') + generate( + str(op_desc.type())) + all_ops[op_name] = {} + all_ops[op_name]['name'] = op_name + all_ops[op_name]['show_name'] = re.sub(r'\[(\w|\.)*\]', '', + op_name) + all_ops[op_name]['type'] = str(op_desc.type()) + all_ops[op_name]['input_vars'] = {} + all_ops[op_name]['is_leaf_node'] = True + for input_name, variable_list in op_desc.inputs().items(): + all_ops[op_name]['input_vars'][input_name] = variable_list + op_inputvars_dict[op_name].extend(variable_list) + # fill var 'to_nodes' + for variable_name in variable_list: + all_vars[variable_name]['to_nodes'].append(op_name) + all_ops[op_name]['output_vars'] = {} + for output_name, variable_list in op_desc.outputs().items(): + all_ops[op_name]['output_vars'][output_name] = variable_list + op_outputvars_dict[op_name].extend(variable_list) + # fill var 'from_node' + for variable_name in variable_list: + all_vars[variable_name]['from_node'] = op_name + + attr_dict = {} + attr_type_dict = {} + for attr_name in op_desc.attr_names(): + attr_dict[attr_name] = op_desc.attr(attr_name) + attr_type_dict[attr_name] = attr_type_name[op_desc.attr_type( + attr_name)] + all_ops[op_name]['attrs'] = attr_dict + all_ops[op_name]['attr_types'] = attr_type_dict + all_ops[op_name]['children_node'] = [] + all_ops[op_name]['input_nodes'] = [] + all_ops[op_name]['output_nodes'] = [] + all_ops[op_name]['edge_input_nodes'] = [] + all_ops[op_name]['edge_output_nodes'] = [] + # second pass, create non-leaf nodes, fill 'parent_node', 'children_nodes' of nodes. + for variable_name in all_vars: + if all_vars[variable_name]['from_node'] == '': + continue + # some variable's input and output node are the same, we should prevent to show this situation as a cycle + from_node_name = all_vars[variable_name]['from_node'] + for to_node_name in all_vars[variable_name]['to_nodes']: + if to_node_name != from_node_name: + all_ops[from_node_name]['output_nodes'].append( + to_node_name) + all_ops[to_node_name]['input_nodes'].append(from_node_name) + + general_children_dict = collections.defaultdict(set) + + def create_non_leaf_nodes(parent_node_name, child_node_name): + if parent_node_name == '/' or parent_node_name == '': # root node + parent_node_name = '/' + if parent_node_name not in all_ops: + all_ops[parent_node_name] = {} + all_ops[parent_node_name]['children_node'] = set() + all_ops[parent_node_name]['name'] = parent_node_name + all_ops[parent_node_name]['show_name'] = os.path.dirname( + all_ops[child_node_name]['show_name']) + all_ops[parent_node_name]['attrs'] = {} + all_ops[parent_node_name]['input_nodes'] = set() + all_ops[parent_node_name]['output_nodes'] = set() + all_ops[parent_node_name]['type'] = os.path.basename( + all_ops[parent_node_name]['show_name']) + all_ops[parent_node_name]['input_vars'] = set() + all_ops[parent_node_name]['output_vars'] = set() + all_ops[parent_node_name]['parent_node'] = '' + all_ops[parent_node_name]['edge_input_nodes'] = [] + all_ops[parent_node_name]['edge_output_nodes'] = [] + all_ops[parent_node_name]['is_leaf_node'] = False + + all_ops[child_node_name]['parent_node'] = parent_node_name + all_ops[parent_node_name]['children_node'].add(child_node_name) + general_children_dict[parent_node_name].add(child_node_name) + general_children_dict[parent_node_name].update( + general_children_dict[child_node_name]) + if parent_node_name == '/': # root node + return + else: + create_non_leaf_nodes( + os.path.dirname(parent_node_name), parent_node_name) + + def construct_edges(var_name): + ''' + Construct path edges from var's from_node to to_nodes. + Algorithm: + 1. Judge if src_node and dst_node have the same parent node, if yes, link them directly + and fill information in all_edges, return. + 2. Find the closest common ancestor, repeat link node and its parent until reach the common ancestor. + Every time construct a new edge, fill information in all_edges. + ''' + from_node = all_vars[var_name]['from_node'] + to_nodes = all_vars[var_name]['to_nodes'] + + def _construct_edge(src_node, dst_node): + if all_ops[src_node]['parent_node'] == all_ops[dst_node][ + 'parent_node']: + if (src_node, dst_node) not in all_edges: + all_edges[(src_node, dst_node)] = { + 'from_node': src_node, + 'to_node': dst_node, + 'vars': {var_name}, + 'label': '' + } + else: + all_edges[(src_node, dst_node)]['vars'].add(var_name) + else: + common_ancestor = os.path.commonpath([src_node, dst_node]) + src_base_node = src_node + while True: + parent_node = all_ops[src_base_node]['parent_node'] + if parent_node == common_ancestor: + break + if (src_base_node, parent_node) not in all_edges: + all_edges[(src_base_node, parent_node)] = { + 'from_node': src_base_node, + 'to_node': parent_node, + 'vars': {var_name}, + 'label': '' + } + else: + all_edges[(src_base_node, + parent_node)]['vars'].add(var_name) + src_base_node = parent_node + dst_base_node = dst_node + while True: + parent_node = all_ops[dst_base_node]['parent_node'] + if parent_node == common_ancestor: + break + if (parent_node, dst_base_node) not in all_edges: + all_edges[(parent_node, dst_base_node)] = { + 'from_node': parent_node, + 'to_node': dst_base_node, + 'vars': {var_name}, + 'label': '' + } + else: + all_edges[(parent_node, + dst_base_node)]['vars'].add(var_name) + dst_base_node = parent_node + if (src_base_node, dst_base_node) not in all_edges: + all_edges[(src_base_node, dst_base_node)] = { + 'from_node': src_base_node, + 'to_node': dst_base_node, + 'vars': {var_name}, + 'label': '' + } + else: + all_edges[(src_base_node, + dst_base_node)]['vars'].add(var_name) + return + + if from_node and to_nodes: + for to_node in to_nodes: + if from_node == to_node: + continue + _construct_edge(from_node, to_node) + + all_op_names = list(all_ops.keys()) + for op_name in all_op_names: + create_non_leaf_nodes(os.path.dirname(op_name), op_name) + + # fill all non-leaf node's 'output_nodes' 'input_nodes' 'output_vars' 'input_vars' + # post-order traverse tree + post_order_results = [] + + def post_order_traverse(root): + for child in all_ops[root]['children_node']: + post_order_traverse(child) + nonlocal post_order_results + post_order_results.append(root) + return + + post_order_traverse('/') + + for op_name in post_order_results: + op = all_ops[op_name] + op['children_node'] = list(op['children_node']) + + if op['children_node']: + for child_op in op['children_node']: + for input_node in all_ops[child_op]['input_nodes']: + if input_node in general_children_dict[op_name]: + continue + else: + op['input_nodes'].add(input_node) + for output_node in all_ops[child_op]['output_nodes']: + if output_node in general_children_dict[op_name]: + continue + else: + op['output_nodes'].add(output_node) + for input_var in op_inputvars_dict[child_op]: + if all_vars[input_var][ + 'from_node'] not in general_children_dict[ + op_name]: + op['input_vars'].add(input_var) + for output_var in op_outputvars_dict[child_op]: + for to_node_name in all_vars[output_var]['to_nodes']: + if to_node_name not in general_children_dict[ + op_name]: + op['output_vars'].add(output_var) + op['input_nodes'] = list(op['input_nodes']) + op['output_nodes'] = list(op['output_nodes']) + op_inputvars_dict[op_name] = list(op['input_vars']) + op_outputvars_dict[op_name] = list(op['output_vars']) + op['input_vars'] = {'X': list(op['input_vars'])} + op['output_vars'] = {'Y': list(op['output_vars'])} + + # Supplement edges and 'edge_input_nodes', 'edge_output_nodes' in op to help draw in frontend + for var_name in all_vars.keys(): + construct_edges(var_name) + + for src_node, to_node in all_edges.keys(): + all_ops[src_node]['edge_output_nodes'].append(to_node) + all_ops[to_node]['edge_input_nodes'].append(src_node) + all_edges[(src_node, to_node)]['vars'] = list( + all_edges[(src_node, to_node)]['vars']) + if len(all_edges[(src_node, to_node)]['vars']) > 1: + all_edges[(src_node, to_node)]['label'] = str( + len(all_edges[(src_node, to_node)]['vars'])) + ' tensors' + elif len(all_edges[(src_node, to_node)]['vars']) == 1: + all_edges[(src_node, to_node)]['label'] = str( + all_vars[all_edges[(src_node, to_node)]['vars'][0]]['shape']) + + final_data = { + 'version': _graph_version, + 'nodes': list(all_ops.values()), + 'vars': list(all_vars.values()), + 'edges': list(all_edges.values()) + } + return final_data diff --git a/visualdl/component/graph/netron_graph.py b/visualdl/component/graph/netron_graph.py new file mode 100644 index 000000000..7f6e3cbfc --- /dev/null +++ b/visualdl/component/graph/netron_graph.py @@ -0,0 +1,240 @@ +from collections import defaultdict +from collections import deque + + +class Model: + def __init__(self, graph_data): + self.name = 'Paddle Graph' + self.version = graph_data['version'] + self.all_nodes = {node['name']: node for node in graph_data['nodes']} + self.all_vars = {var['name']: var for var in graph_data['vars']} + self.all_edges = {(edge['from_node'], edge['to_node']): edge + for edge in graph_data['edges']} + self.visible_maps = { + node['name']: (True if not node['children_node'] else False) + for node in graph_data['nodes'] + } + root_node = self.all_nodes['/'] + for child_name in root_node['children_node']: + self.visible_maps[child_name] = True + + def make_graph(self, refresh=False, expand_all=False): + if refresh is True: + self.visible_maps = { + node['name']: (True if not node['children_node'] else False) + for node in self.all_nodes.values() + } + root_node = self.all_nodes['/'] + for child_name in root_node['children_node']: + self.visible_maps[child_name] = True + if expand_all is True: + self.visible_maps = { + node['name']: (True if not node['children_node'] else False) + for node in self.all_nodes.values() + } + self.current_nodes = { + node_name: self.all_nodes[node_name] + for node_name in self.get_current_visible_nodes() + } + return Graph(self.current_nodes, self.all_vars) + + def get_all_leaf_nodes(self): + return Graph(self.all_nodes, self.all_vars) + + def get_current_visible_nodes(self): + # bfs traversal to get current visible nodes + # if one node is visible now, all its children nodes are invisible + current_visible_nodes = [] + travesal_queue = deque() + visited_map = defaultdict(bool) + travesal_queue.append('/') + visited_map['/'] = True + while travesal_queue: + current_name = travesal_queue.popleft() + current_node = self.all_nodes[current_name] + if self.visible_maps[current_name] is True: + current_visible_nodes.append(current_name) + else: + for child_name in current_node['children_node']: + if visited_map[child_name] is False: + travesal_queue.append(child_name) + visited_map[child_name] = True + return current_visible_nodes + + def adjust_visible(self, node_name, expand=True, keep_state=False): + if (expand): + if self.all_nodes[node_name]['is_leaf_node'] is True: + return + if keep_state: + self.visible_maps[node_name] = False + else: + self.visible_maps[node_name] = False + current_node = self.all_nodes[node_name] + for child_name in current_node['children_node']: + self.visible_maps[child_name] = True + else: + self.visible_maps[node_name] = True + + def adjust_search_node_visible(self, + node_name, + keep_state=False, + is_node=True): + if node_name is None: + return + node_names = [] + if is_node is False: + var = self.all_vars[node_name] + node_names.append(var['from_node']) + node_names.extend(var['to_nodes']) + else: + node_names.append(node_name) + for node_name in node_names: + topmost_parent = None + parent_node_name = self.all_nodes[node_name]['parent_node'] + while (parent_node_name != '/'): + if self.visible_maps[parent_node_name] is True: + topmost_parent = parent_node_name + parent_node_name = self.all_nodes[parent_node_name][ + 'parent_node'] + if topmost_parent is not None: + self.visible_maps[topmost_parent] = False + parent_node_name = self.all_nodes[node_name]['parent_node'] + if (keep_state): + self.visible_maps[node_name] = True + while (parent_node_name != topmost_parent): + self.visible_maps[parent_node_name] = False + parent_node_name = self.all_nodes[parent_node_name][ + 'parent_node'] + else: + for child_name in self.all_nodes[parent_node_name][ + 'children_node']: + self.visible_maps[child_name] = True + self.visible_maps[parent_node_name] = False + key_path_node_name = parent_node_name + while (parent_node_name != topmost_parent): + parent_node_name = self.all_nodes[parent_node_name][ + 'parent_node'] + for child_name in self.all_nodes[parent_node_name][ + 'children_node']: + if child_name != key_path_node_name: + self.visible_maps[child_name] = True + else: + self.visible_maps[child_name] = False + key_path_node_name = parent_node_name + + +class Graph(dict): + def __init__(self, nodes, all_vars): + self.nodes = [] + self.inputs = [] + self.outputs = [] + self.name = 'Paddle Graph' + output_idx = 0 + for op_node in nodes.values(): + if op_node['type'] == 'feed': + for key, value in op_node["output_vars"].items(): + self.inputs.append( + Parameter( + value[0], + [Argument(name, all_vars[name]) + for name in value])) + continue + if op_node['type'] == 'fetch': + for key, value in op_node["input_vars"].items(): + self.outputs.append( + Parameter( + 'Output{}'.format(output_idx), + [Argument(name, all_vars[name]) + for name in value])) + output_idx += 1 + continue + self.nodes.append(Node(op_node, all_vars)) + + super(Graph, self).__init__( + name=self.name, + nodes=self.nodes, + inputs=self.inputs, + outputs=self.outputs) + + +class Node(dict): + def __init__(self, node, all_vars): + self.name = node['name'] + self.show_name = node['show_name'] + self.type = node['type'] + self.attributes = [ + Attribute(key, value, node['attr_types'][key]) + for key, value in node['attrs'].items() + ] + self.inputs = [ + Parameter(key, [Argument(name, all_vars[name]) for name in value]) + for key, value in node["input_vars"].items() + ] + self.outputs = [ + Parameter(key, [Argument(name, all_vars[name]) for name in value]) + for key, value in node["output_vars"].items() + ] + self.chain = [] + self.visible = True + self.is_leaf = node['is_leaf_node'] + super(Node, self).__init__( + name=self.name, + type=self.type, + attributes=self.attributes, + inputs=self.inputs, + outputs=self.outputs, + chain=self.chain, + visible=self.visible, + is_leaf=self.is_leaf, + show_name=self.show_name) + + +class Attribute(dict): + def __init__(self, key, value, attr_type): + self.name = key + self.value = value + self.type = attr_type + self.visible = True if key not in [ + 'use_mkldnn', 'use_cudnn', 'op_callstack', 'op_role', + 'op_role_var', 'op_namescope', 'is_test' + ] else False + super(Attribute, self).__init__( + name=self.name, + value=self.value, + type=self.type, + visible=self.visible) + + +class Parameter(dict): + def __init__(self, name, args): + self.name = name + self.visible = True + self.arguments = args + super(Parameter, self).__init__( + name=self.name, visible=self.visible, arguments=self.arguments) + + +class Argument(dict): + def __init__(self, name, var): + self.name = name + self.type = TensorType(var['dtype'], var['shape']) + self.initializer = None if var['persistable'] is False else self.type + super(Argument, self).__init__( + name=self.name, type=self.type, initializer=self.initializer) + + +class TensorType(dict): + def __init__(self, datatype, shape): + self.dataType = datatype + self.shape = TensorShape(shape) + self.denotation = None + super(TensorType, self).__init__( + dataType=self.dataType, + shape=self.shape, + denotation=self.denotation) + + +class TensorShape(dict): + def __init__(self, dimensions): + self.dimensions = dimensions + super(TensorShape, self).__init__(dimensions=self.dimensions) diff --git a/visualdl/component/graph b/visualdl/component/graph/utils.py similarity index 88% rename from visualdl/component/graph rename to visualdl/component/graph/utils.py index cb904a0b4..7a15e92dd 100644 --- a/visualdl/component/graph +++ b/visualdl/component/graph/utils.py @@ -1,16 +1,17 @@ from collections import deque import paddle.nn as nn -from paddle.fluid.framework import name_scope -from paddle.fluid.core import AttrType +from paddle.static import name_scope _name_scope_stack = deque() + def _opname_creation_prehook(layer, inputs): global _name_scope_stack _name_scope_stack.append(name_scope(layer.full_name())) _name_scope_stack[-1].__enter__() + def _opname_creation_posthook(layer, inputs, outputs): global _name_scope_stack name_scope_manager = _name_scope_stack.pop() @@ -20,6 +21,7 @@ def _opname_creation_posthook(layer, inputs, outputs): def create_opname_scope(layer: nn.Layer): layer.register_forward_pre_hook(_opname_creation_prehook) for name, sublayer in layer.named_children(): - sublayer._full_name = '{}[{}]'.format(sublayer.__class__.__name__, name) + sublayer._full_name = '{}[{}]'.format(sublayer.__class__.__name__, + name) create_opname_scope(sublayer) layer.register_forward_post_hook(_opname_creation_posthook) diff --git a/visualdl/reader/graph_reader.py b/visualdl/reader/graph_reader.py index 7e4c77582..6a59b3285 100644 --- a/visualdl/reader/graph_reader.py +++ b/visualdl/reader/graph_reader.py @@ -16,8 +16,8 @@ import os import tempfile -from visualdl.component.graph.graph_component import analyse_model -from visualdl.component.graph.netron_graph import Model +from visualdl.component.graph import analyse_model +from visualdl.component.graph import Model from visualdl.io import bfile From 4931603c87ab8009903a5396b4386617b0d6c3db Mon Sep 17 00:00:00 2001 From: chenjian Date: Thu, 16 Jun 2022 13:22:24 +0800 Subject: [PATCH 3/9] fix import bug --- visualdl/component/graph/exporter.py | 3 ++- visualdl/component/graph/graph_component.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/visualdl/component/graph/exporter.py b/visualdl/component/graph/exporter.py index 580763b1a..526892bf1 100644 --- a/visualdl/component/graph/exporter.py +++ b/visualdl/component/graph/exporter.py @@ -29,7 +29,8 @@ def translate_graph(model, input_spec, verbose=True): model_data = open(os.path.join(tmp, 'temp.pdmodel'), 'rb').read() result = analyse_model(model_data) if verbose: - from paddle.core import ProgramDesc + from paddle.framework import core + ProgramDesc = core.ProgramDesc program_desc = ProgramDesc(model_data) print(program_desc) result = json.dumps(result, indent=2) diff --git a/visualdl/component/graph/graph_component.py b/visualdl/component/graph/graph_component.py index 9c392af18..338680e1f 100644 --- a/visualdl/component/graph/graph_component.py +++ b/visualdl/component/graph/graph_component.py @@ -16,7 +16,8 @@ import os.path import re -from paddle.core import AttrType +from paddle.framework import core +AttrType = core.AttrType attr_type_name = { AttrType.INT: "INT", @@ -37,7 +38,7 @@ def analyse_model(model_pb): - from paddle.core import ProgramDesc + ProgramDesc = core.ProgramDesc from paddle.utils.unique_name import generate program_desc = ProgramDesc(model_pb) all_ops = {} From 008df42fb279c9053606ca7c299e67ccf63cfce6 Mon Sep 17 00:00:00 2001 From: chenjian Date: Fri, 17 Jun 2022 18:11:55 +0800 Subject: [PATCH 4/9] update --- visualdl/component/graph/exporter.py | 6 +- visualdl/component/graph/graph_component.py | 43 ++++++----- visualdl/component/graph/netron_graph.py | 14 ++++ visualdl/component/graph/utils.py | 82 +++++++++++++++++++++ visualdl/writer/writer.py | 39 ++++++++-- 5 files changed, 155 insertions(+), 29 deletions(-) diff --git a/visualdl/component/graph/exporter.py b/visualdl/component/graph/exporter.py index 526892bf1..9f43d7dcd 100644 --- a/visualdl/component/graph/exporter.py +++ b/visualdl/component/graph/exporter.py @@ -18,6 +18,7 @@ from .graph_component import analyse_model from .utils import create_opname_scope +from .utils import print_model def translate_graph(model, input_spec, verbose=True): @@ -29,9 +30,6 @@ def translate_graph(model, input_spec, verbose=True): model_data = open(os.path.join(tmp, 'temp.pdmodel'), 'rb').read() result = analyse_model(model_data) if verbose: - from paddle.framework import core - ProgramDesc = core.ProgramDesc - program_desc = ProgramDesc(model_data) - print(program_desc) + print_model(result) result = json.dumps(result, indent=2) return result diff --git a/visualdl/component/graph/graph_component.py b/visualdl/component/graph/graph_component.py index 338680e1f..6ea1c4fa0 100644 --- a/visualdl/component/graph/graph_component.py +++ b/visualdl/component/graph/graph_component.py @@ -16,28 +16,33 @@ import os.path import re -from paddle.framework import core -AttrType = core.AttrType - -attr_type_name = { - AttrType.INT: "INT", - AttrType.INTS: "INTS", - AttrType.LONG: "LONG", - AttrType.LONGS: "LONGS", - AttrType.FLOAT: "FLOAT", - AttrType.FLOATS: "FLOATS", - AttrType.STRING: "STRING", - AttrType.STRINGS: "STRINGS", - AttrType.BOOL: "BOOL", - AttrType.BOOLS: "BOOLS", - AttrType.BLOCK: "BLOCK", - AttrType.BLOCKS: "BLOCKS" -} - _graph_version = '1.0.0' -def analyse_model(model_pb): +def analyse_model(model_pb): # noqa: C901 + try: + from paddle.framework import core + except Exception: + print("Paddlepaddle is required to use add_graph interface.\n\ + Please refer to \ + https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\ + to install paddlepaddle.") + + AttrType = core.AttrType + attr_type_name = { + AttrType.INT: "INT", + AttrType.INTS: "INTS", + AttrType.LONG: "LONG", + AttrType.LONGS: "LONGS", + AttrType.FLOAT: "FLOAT", + AttrType.FLOATS: "FLOATS", + AttrType.STRING: "STRING", + AttrType.STRINGS: "STRINGS", + AttrType.BOOL: "BOOL", + AttrType.BOOLS: "BOOLS", + AttrType.BLOCK: "BLOCK", + AttrType.BLOCKS: "BLOCKS" + } ProgramDesc = core.ProgramDesc from paddle.utils.unique_name import generate program_desc = ProgramDesc(model_pb) diff --git a/visualdl/component/graph/netron_graph.py b/visualdl/component/graph/netron_graph.py index 7f6e3cbfc..91a74ea5f 100644 --- a/visualdl/component/graph/netron_graph.py +++ b/visualdl/component/graph/netron_graph.py @@ -1,3 +1,17 @@ +# Copyright (c) 2022 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= from collections import defaultdict from collections import deque diff --git a/visualdl/component/graph/utils.py b/visualdl/component/graph/utils.py index 7a15e92dd..d84b06a96 100644 --- a/visualdl/component/graph/utils.py +++ b/visualdl/component/graph/utils.py @@ -1,3 +1,18 @@ +# Copyright (c) 2022 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +from collections import Counter from collections import deque import paddle.nn as nn @@ -25,3 +40,70 @@ def create_opname_scope(layer: nn.Layer): name) create_opname_scope(sublayer) layer.register_forward_post_hook(_opname_creation_posthook) + + +def print_model(analyse_result): + ''' + Print some information about model for users, we count numbers of ops and layers. + ''' + result = [] + # statistics + op_counter = Counter() + layer_counter = Counter() + nodes = analyse_result['nodes'] + total_ops = 0 + total_layers = 0 + for node in nodes: + if node['name'] == '/': + continue + if not node['children_node']: + op_counter[node['type']] += 1 + total_ops += 1 + else: + layer_counter[node['type']] += 1 + total_layers += 1 + + SPACING_SIZE = 2 + row_format_list = [""] + header_sep_list = [""] + line_length_list = [-SPACING_SIZE] + + def add_title(padding, text): + left_length = padding - len(text) + half = left_length // 2 + return '-' * half + text + '-' * (left_length - half) + + def add_column(padding, text_dir='<'): + row_format_list[0] += '{: ' + text_dir + str(padding) + '}' + ( + ' ' * SPACING_SIZE) + header_sep_list[0] += '-' * padding + (' ' * SPACING_SIZE) + line_length_list[0] += padding + SPACING_SIZE + + def append(s): + result.append(s) + result.append('\n') + + headers = ['Name', 'Type', 'Count'] + column_width = 20 + for _ in headers: + add_column(column_width) + + row_format = row_format_list[0] + header_sep = header_sep_list[0] + line_length = line_length_list[0] + + # construct table string + append(add_title(line_length, "Graph Summary")) + append('total operators: {}\ttotal layers:{}'.format( + total_ops, total_layers)) + append(header_sep) + append(row_format.format(*headers)) + append(header_sep) + for op_type, count in op_counter.items(): + row_values = [op_type, 'operator', count] + append(row_format.format(*row_values)) + for layer_type, count in layer_counter.items(): + row_values = [layer_type, 'layer', count] + append(row_format.format(*row_values)) + append('-' * line_length) + print(''.join(result)) diff --git a/visualdl/writer/writer.py b/visualdl/writer/writer.py index 73b65eebc..77265ad21 100644 --- a/visualdl/writer/writer.py +++ b/visualdl/writer/writer.py @@ -601,13 +601,40 @@ def add_graph(self, model, input_spec, verbose=False): Add a model graph to vdl graph file. Args: model (paddle.nn.Layer): Model to draw. - input_spec (list[paddle.static.InputSpec]): Describes the input of the saved model's forward arguments. - verbose (bool): Whether to print graph structure in console. + input_spec (list[paddle.static.InputSpec|Tensor]): Describes the input \ + of the saved model's forward arguments. + verbose (bool): Whether to print some graph statistic information in console. + Note: + Paddlepaddle is required to use add_graph interface. Example: - with LogWriter(logdir="./log/graph_test/train") as writer: - writer.add_graph(model=net, - input_spec=[[1,3,256,256]], - verbose=True) + import paddle + import paddle.nn as nn + import paddle.nn.functional as F + from visualdl import LogWriter + class MyNet(nn.Layer): + def __init__(self): + super(MyNet, self).__init__() + self.conv1 = nn.Conv2D(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2) + self.max_pool1 = nn.MaxPool2D(kernel_size=2, stride=2) + self.conv2 = nn.Conv2D(in_channels=20, out_channels=20, kernel_size=5, stride=1, padding=2) + self.max_pool2 = nn.MaxPool2D(kernel_size=2, stride=2) + self.fc = nn.Linear(in_features=980, out_features=10) + def forward(self, inputs): + x = self.conv1(inputs) + x = F.relu(x) + x = self.max_pool1(x) + x = self.conv2(x) + x = F.relu(x) + x = self.max_pool2(x) + x = paddle.reshape(x, [x.shape[0], -1]) + x = self.fc(x) + return x + net = MyNet() + with LogWriter(logdir="./log/graph_test/") as writer: + writer.add_graph( + model=net, + input_spec=[paddle.static.InputSpec([-1, 1, 28, 28], 'float32')], + verbose=True) """ try: result = translate_graph(model, input_spec, verbose) From bfc75b0246ab8f15f7832066496e1a648860f0ab Mon Sep 17 00:00:00 2001 From: chenjian Date: Fri, 17 Jun 2022 18:13:57 +0800 Subject: [PATCH 5/9] add demo --- demo/components/graph_test.py | 54 +++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 demo/components/graph_test.py diff --git a/demo/components/graph_test.py b/demo/components/graph_test.py new file mode 100644 index 000000000..1a5aa27c3 --- /dev/null +++ b/demo/components/graph_test.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from visualdl import LogWriter + + +class MyNet(nn.Layer): + def __init__(self): + super(MyNet, self).__init__() + self.conv1 = nn.Conv2D( + in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2) + self.max_pool1 = nn.MaxPool2D(kernel_size=2, stride=2) + self.conv2 = nn.Conv2D( + in_channels=20, + out_channels=20, + kernel_size=5, + stride=1, + padding=2) + self.max_pool2 = nn.MaxPool2D(kernel_size=2, stride=2) + self.fc = nn.Linear(in_features=980, out_features=10) + + def forward(self, inputs): + x = self.conv1(inputs) + x = F.relu(x) + x = self.max_pool1(x) + x = self.conv2(x) + x = F.relu(x) + x = self.max_pool2(x) + x = paddle.reshape(x, [x.shape[0], -1]) + x = self.fc(x) + return x + + +net = MyNet() +with LogWriter(logdir="./log/graph_test/") as writer: + writer.add_graph( + model=net, + input_spec=[paddle.static.InputSpec([-1, 1, 28, 28], 'float32')], + verbose=True) From 718bacaa3e0406b751218da0e3523d387807d8c1 Mon Sep 17 00:00:00 2001 From: chenjian Date: Fri, 17 Jun 2022 18:17:56 +0800 Subject: [PATCH 6/9] fix --- visualdl/server/lib.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/visualdl/server/lib.py b/visualdl/server/lib.py index 09762a6cf..fe6e6c887 100644 --- a/visualdl/server/lib.py +++ b/visualdl/server/lib.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ======================================================================= -from __future__ import absolute_import - import csv import io import math From 38fb847875c4630c0881866473c4d0391e1a0a05 Mon Sep 17 00:00:00 2001 From: chenjian Date: Fri, 17 Jun 2022 18:21:38 +0800 Subject: [PATCH 7/9] fix --- visualdl/component/graph/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/visualdl/component/graph/utils.py b/visualdl/component/graph/utils.py index d84b06a96..4dd33abdf 100644 --- a/visualdl/component/graph/utils.py +++ b/visualdl/component/graph/utils.py @@ -15,13 +15,11 @@ from collections import Counter from collections import deque -import paddle.nn as nn -from paddle.static import name_scope - _name_scope_stack = deque() def _opname_creation_prehook(layer, inputs): + from paddle.static import name_scope global _name_scope_stack _name_scope_stack.append(name_scope(layer.full_name())) _name_scope_stack[-1].__enter__() @@ -33,7 +31,7 @@ def _opname_creation_posthook(layer, inputs, outputs): name_scope_manager.__exit__(None, None, None) -def create_opname_scope(layer: nn.Layer): +def create_opname_scope(layer): layer.register_forward_pre_hook(_opname_creation_prehook) for name, sublayer in layer.named_children(): sublayer._full_name = '{}[{}]'.format(sublayer.__class__.__name__, From 54eeab78436800f75bfbcfbf5da28aaec4146eff Mon Sep 17 00:00:00 2001 From: chenjian Date: Mon, 20 Jun 2022 16:23:53 +0800 Subject: [PATCH 8/9] fix a bug --- visualdl/component/graph/graph_component.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/visualdl/component/graph/graph_component.py b/visualdl/component/graph/graph_component.py index 6ea1c4fa0..ee2b1f151 100644 --- a/visualdl/component/graph/graph_component.py +++ b/visualdl/component/graph/graph_component.py @@ -118,8 +118,10 @@ def analyse_model(model_pb): # noqa: C901 attr_type_dict = {} for attr_name in op_desc.attr_names(): attr_dict[attr_name] = op_desc.attr(attr_name) - attr_type_dict[attr_name] = attr_type_name[op_desc.attr_type( - attr_name)] + attr_type = op_desc.attr_type(attr_name) + attr_type_dict[attr_name] = attr_type_name[ + attr_type] if attr_type in attr_type_name else str( + attr_type).split('.')[1] all_ops[op_name]['attrs'] = attr_dict all_ops[op_name]['attr_types'] = attr_type_dict all_ops[op_name]['children_node'] = [] From b79a6c07ab762567095e69765855b00441242023 Mon Sep 17 00:00:00 2001 From: chenjian Date: Tue, 21 Jun 2022 11:23:01 +0800 Subject: [PATCH 9/9] fix a bug --- visualdl/server/api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/visualdl/server/api.py b/visualdl/server/api.py index 41034d560..c03d68b2b 100644 --- a/visualdl/server/api.py +++ b/visualdl/server/api.py @@ -102,8 +102,10 @@ def runs(self): @result() def graph_runs(self): + client_ip = request.remote_addr + graph_reader = self.graph_reader_client_manager.get_data(client_ip) return self._get_with_reader('data/graph_runs', lib.get_graph_runs, - self._graph_reader) + graph_reader) @result() def tags(self):