Skip to content

Commit

Permalink
Merge pull request #3 from cse0001/develop
Browse files Browse the repository at this point in the history
Fix paddle dependency
  • Loading branch information
cse0001 authored Oct 14, 2024
2 parents 2a4d48e + 78f3739 commit 8f78c74
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
8 changes: 7 additions & 1 deletion visualdl/component/graph/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import json
import os
import tempfile
import paddle

from .graph_component import analyse_model
from .graph_component import analyse_pir
Expand All @@ -24,6 +23,13 @@


def translate_graph(model, input_spec, verbose=True, **kwargs):
try:
import paddle
except Exception:
print("Paddlepaddle is required to use add_graph interface.\n\
Please refer to \
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\
to install paddlepaddle.")
is_pir = kwargs.get('is_pir', False)
with tempfile.TemporaryDirectory() as tmp:
if (not is_pir):
Expand Down
18 changes: 15 additions & 3 deletions visualdl/component/graph/graph_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import os.path
import pathlib
import re
import paddle

from . import utils

Expand Down Expand Up @@ -444,9 +443,16 @@ def safe_get_persistable(op):


def get_sub_ops(op, op_name, all_ops, all_vars):
try:
from paddle.utils.unique_name import generate
except Exception:
print("Paddlepaddle is required to use add_graph interface.\n\
Please refer to \
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\
to install paddlepaddle.")
for sub_block in op.blocks():
for sub_op in sub_block.ops:
sub_op_name0 = paddle.utils.unique_name.generate(sub_op.name())
sub_op_name0 = generate(sub_op.name())
sub_op_name = op_name + '/' + sub_op_name0
all_ops[sub_op_name] = {}
all_ops[sub_op_name]['name'] = sub_op_name
Expand Down Expand Up @@ -561,7 +567,13 @@ def update_node_connections(all_vars, all_ops):


def analyse_pir(program):
from paddle.utils.unique_name import generate
try:
from paddle.utils.unique_name import generate
except Exception:
print("Paddlepaddle is required to use add_graph interface.\n\
Please refer to \
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\
to install paddlepaddle.")

all_ops = {}
all_vars = {}
Expand Down
29 changes: 28 additions & 1 deletion visualdl/reader/graph_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from visualdl.component.graph import analyse_pir
from visualdl.component.graph import Model
from visualdl.io import bfile
from paddle.jit import load


def is_VDLGraph_file(path):
Expand Down Expand Up @@ -140,6 +139,13 @@ def get_graph(self,
if 'pdmodel' in self.walks[run]:
graph_model = Model(analyse_model(data))
elif 'json' in self.walks[run]:
try:
from paddle.jit import load
except Exception:
print("Paddlepaddle is required to load json file.\n\
Please refer to \
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\
to install paddlepaddle.")
json_object = json.loads(data)
with tempfile.TemporaryDirectory() as tmp:
with open(os.path.join(tmp, 'temp.json'), 'w') as json_file:
Expand Down Expand Up @@ -174,6 +180,13 @@ def search_graph_node(self, run, nodeid, keep_state=False, is_node=True):
if 'pdmodel' in self.walks[run]:
graph_model = Model(analyse_model(data))
elif 'json' in self.walks[run]:
try:
from paddle.jit import load
except Exception:
print("Paddlepaddle is required to load json file.\n\
Please refer to \
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\
to install paddlepaddle.")
json_object = json.loads(data)
with tempfile.TemporaryDirectory() as tmp:
with open(os.path.join(tmp, 'temp.json'), 'w') as json_file:
Expand Down Expand Up @@ -202,6 +215,13 @@ def get_all_nodes(self, run):
if 'pdmodel' in self.walks[run]:
graph_model = Model(analyse_model(data))
elif 'json' in self.walks[run]:
try:
from paddle.jit import load
except Exception:
print("Paddlepaddle is required to load json file.\n\
Please refer to \
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\
to install paddlepaddle.")
json_object = json.loads(data)
with tempfile.TemporaryDirectory() as tmp:
with open(os.path.join(tmp, 'temp.json'), 'w') as json_file:
Expand Down Expand Up @@ -241,6 +261,13 @@ def set_input_graph(self, content, file_type='pdmodel'):
self.graph_buffer['manual_input_model'] = Model(data)

elif file_type == 'json':
try:
from paddle.jit import load
except Exception:
print("Paddlepaddle is required to load json file.\n\
Please refer to \
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html\
to install paddlepaddle.")
json_object = json.loads(content)
with tempfile.TemporaryDirectory() as tmp:
with open(os.path.join(tmp, 'temp.json'), 'w') as json_file:
Expand Down

0 comments on commit 8f78c74

Please sign in to comment.