Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[Frontend] Add Span filling for frontends to Relay (#9723)" #10072

Merged
merged 1 commit into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,10 @@ class TupleGetItem(ExprWithOp):
index: int
The index.
span: Optional[tvm.relay.Span]
Span that points to original source code
"""

def __init__(self, tuple_value, index, span=None):
self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index, span)
def __init__(self, tuple_value, index):
self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index)


@tvm._ffi.register_object("relay.RefCreate")
Expand Down
53 changes: 0 additions & 53 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from tvm.topi.utils import get_const_tuple

from .. import expr as _expr
from ..expr_functor import ExprMutator
from .. import function as _function
from .. import transform as _transform
from .. import op as _op
Expand Down Expand Up @@ -955,55 +954,3 @@ def try_resolve_var_to_const(x, graph_params):
return _op.const(value, dtype)

return x


def set_span(sym, node_name):
"""Set up the span of relay expression(s) while converting OP"""

class SpanFiller(ExprMutator):
"""SpanFiller"""

def __init__(self, node_name, suffix_str="_PART_"):
ExprMutator.__init__(self)
self.node_name = node_name
self.suffix_str = suffix_str
self.counter = 0
self.distance_from_leaf = -1

def _create_span(self):
if self.distance_from_leaf == 0:
return tvm.relay.Span(tvm.relay.SourceName(self.node_name), 0, 0, 0, 0)
self.distance_from_leaf -= 1
span_str = "{}{}{}".format(self.node_name, self.suffix_str, str(self.counter))
self.counter += 1
return tvm.relay.Span(tvm.relay.SourceName(span_str), 0, 0, 0, 0)

def visit_call(self, call):
if call.span is None:
self.distance_from_leaf += 1
new_args = [self.visit(arg) for arg in call.args]
return _expr.Call(
call.op, new_args, call.attrs, call.type_args, self._create_span()
)
return call

def visit_tuple(self, tup):
if tup.span is None:
self.distance_from_leaf += 1
return _expr.Tuple([self.visit(field) for field in tup.fields], self._create_span())
return tup

def visit_tuple_getitem(self, op):
if op.span is None:
self.distance_from_leaf += 1
return _expr.TupleGetItem(self.visit(op.tuple_value), op.index, self._create_span())
return op

def fill(self, sym):
if isinstance(sym, _expr.TupleWrapper):
return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size)
if isinstance(sym, _expr.RelayExpr):
return self.visit(sym)
return sym

return SpanFiller(node_name).fill(sym)
19 changes: 0 additions & 19 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
from .common import lstm_cell, try_infer_value, unbind
from .common import set_span
from .pytorch_utils import is_version_greater_than

__all__ = ["from_pytorch"]
Expand Down Expand Up @@ -3276,9 +3275,6 @@ def body(*current_vals):

def convert_operators(self, operators, outputs, ret_names):
"""Convert each Torch IR operators to Relay equivalent"""
# an op node might not belong to any of scope in trace info natively
# use a cunter to prevent from messing up its scope in span
empty_counter = 0
for node_name, op_node in operators:
operator = op_node.kind()
inputs = _get_op_inputs(op_node, outputs)
Expand Down Expand Up @@ -3339,9 +3335,6 @@ def _handel_nested_input(inputs):
relay_out = relay_op(
inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype)
)
span_str, empty_counter = self._get_torch_span(op_node, empty_counter)
relay_out = set_span(relay_out, span_str)

self.record_output_type(relay_out)

if isinstance(relay_out, tuple):
Expand All @@ -3355,18 +3348,6 @@ def _handel_nested_input(inputs):

return [_wrap_const(outputs[ret_name]) for ret_name in ret_names]

def _get_torch_span(self, node, empty_counter):
# torch span looks like
# %input.5 : Float(...) = aten::relu_(%input.3), scope: __module.relu # ${torch}/nn file
# the scope part might not exist
if node.scopeName():
scope_name_str = "jit._trace.TopLevelTracedModule: " + node.scopeName()
else:
scope_name_str = "warning: no trace info " + str(empty_counter)
empty_counter += 1
span_str = "C.graph: {}, {}".format(node.kind(), scope_name_str)
return span_str, empty_counter


def _pytorch_result_type(dtypes, non_tensor_inputs):
"""This promotes TVM dtypes like PyTorch would"""
Expand Down
17 changes: 15 additions & 2 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import set_span

from .tensorflow_ops import _convert_map
from .tensorflow_ops import _need_prelude_for_shape_inference
Expand Down Expand Up @@ -1029,10 +1028,24 @@ def _convert_operator(
else:
raise NotImplementedError("Operator {} not implemented.".format(op_name))

sym = set_span(sym, node_name)
sym = self._set_span(sym, node_name)

return sym

@staticmethod
def _set_span(sym, node_name):
span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
if isinstance(sym, _expr.Call) and sym.span is None:
sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
elif isinstance(sym, _expr.TupleWrapper):
tuple_value = sym.tuple_value
if isinstance(tuple_value, _expr.Call) and tuple_value.span is None:
tuple_value = _expr.Call(
tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span
)
sym = _expr.TupleWrapper(tuple_value, sym.size)
return sym

def _licm_construct(self, loop_name, node_name):
"""Construct a node by considering whether it is
loop invariant with the given while loop. If yes, we
Expand Down
17 changes: 16 additions & 1 deletion python/tvm/relay/frontend/tensorflow2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from .. import function as _function
from ..loops import while_loop as _while_loop
from .common import infer_type as _infer_type
from .common import set_span

from .tensorflow_ops import _convert_map as _convert_map_common
from .tensorflow_ops import _get_more_static_shape_rank
Expand All @@ -59,6 +58,22 @@ def _infer_type_with_prelude(val, prelude):
return body.checked_type


def set_span(sym, node_name):
"""set span of symbol"""

span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
if isinstance(sym, _expr.Call):
sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
elif isinstance(sym, _expr.TupleWrapper):
tuple_value = sym.tuple_value
if isinstance(tuple_value, _expr.Call):
tuple_value = _expr.Call(
tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span
)
sym = _expr.TupleWrapper(tuple_value, sym.size)
return sym


def is_tensor_list_constuctor(tf_node):
"""Check whether is tensor list constructor node."""
return tf_node.op == "TensorListReserve"
Expand Down
16 changes: 5 additions & 11 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from .. import qnn as _qnn
from .common import ExprTable
from .common import infer_shape as _infer_shape
from .common import set_span
from .common import to_int_list
from .tflite_flexbuffer import FlexBufferDecoder

Expand Down Expand Up @@ -240,17 +239,12 @@ def convert_op_to_relay(self):

if len(output_tensors) == 1:
tensor_idx = output_tensors[0].tensor_idx
curr_output = get_tensor_name(self.subgraph, tensor_idx)
ret = set_span(ret, "location: {}, output_name: {}".format(op_idx, curr_output))
self.exp_tab.set_expr(curr_output, ret)
self.exp_tab.set_expr(get_tensor_name(self.subgraph, tensor_idx), ret)
else:
out_names = []
for output_tensor in output_tensors:
out_names.append(get_tensor_name(self.subgraph, output_tensor.tensor_idx))
curr_output = ", ".join(out_names)
ret = set_span(ret, "location: {}, output_name: {}".format(op_idx, curr_output))
for idx, out_name in enumerate(out_names):
self.exp_tab.set_expr(out_name, ret[idx])
for idx, output_tensor in enumerate(output_tensors):
self.exp_tab.set_expr(
get_tensor_name(self.subgraph, output_tensor.tensor_idx), ret[idx]
)

def get_op_code_str(self, op):
"""Get TFLite ops string representation"""
Expand Down
23 changes: 6 additions & 17 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,21 +389,12 @@ Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) {
if (op->fields.size() == 1) {
doc << ",";
}
doc << ")";
if (op->span.defined()) {
doc << " /* " << PrintSpan(op->span) << " */";
}
return doc;
return doc << ")";
}

Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) {
Doc doc;
doc << Print(op->tuple) << "." << op->index;

if (op->span.defined()) {
doc << " /* " << PrintSpan(op->span) << " */";
}
return doc;
return doc << Print(op->tuple) << "." << op->index;
}

Doc RelayTextPrinter::VisitExpr_(const IfNode* op) {
Expand Down Expand Up @@ -977,13 +968,11 @@ Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>&
return doc;
}

Doc RelayTextPrinter::PrintSpan(const Span& span, bool include_spans) {
Doc RelayTextPrinter::PrintSpan(const Span& span) {
Doc doc;
if (include_spans) {
const auto* span_node = span.as<SpanNode>();
ICHECK(span_node);
doc << span_node->source_name->name;
}
const auto* span_node = span.as<SpanNode>();
ICHECK(span_node);
doc << span_node->source_name->name;
return doc;
}

Expand Down
2 changes: 1 addition & 1 deletion src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
*/
Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map);

Doc PrintSpan(const Span& span, bool include_spans = true);
Doc PrintSpan(const Span& span);

Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);

Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple,

TVM_REGISTER_NODE_TYPE(TupleGetItemNode);

TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) {
return TupleGetItem(tuple, index, span);
TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) {
return TupleGetItem(tuple, index);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down
47 changes: 0 additions & 47 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,53 +247,6 @@ def visit(op):
torch.cuda.empty_cache()


def verify_span(model_name, input_data=[], custom_convert_map={}):
if isinstance(model_name, str):
baseline_model, baseline_input = load_model(model_name)
elif isinstance(input_data, list):
baseline_model = model_name
baseline_input = input_data
elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0:
baseline_model = model_name
baseline_input = [input_data]
else:
assert False, "Unexpected input format"

trace = torch.jit.trace(baseline_model, [input.clone() for input in baseline_input])
if isinstance(baseline_model, torch.nn.Module):
trace = trace.float().eval()

if torch.cuda.is_available():
trace = trace.cuda()
else:
trace = trace.cpu()

input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)]
input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input]))
mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map)

# collect fail cases for the convenience of further improvement
fail_cases = []
mod_main_start = False
for line in str(mod.__str__).split("\n"):
if "@main" in line:
mod_main_start = True
continue

if mod_main_start == True:
if "}" == line:
break
elif not ("/*" in line and "*/" in line):
fail_cases.append(line)

print(fail_cases)
assert len(fail_cases) == 0


def test_span():
verify_span("resnet18")


# Single operator tests
@tvm.testing.uses_gpu
def test_forward_pixel_shuffle():
Expand Down
54 changes: 0 additions & 54 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,60 +298,6 @@ def is_gpu_available():
return False


def verify_span(mod):
# collect fail cases for the convenience of further improvement
fail_cases = []
mod_main_start = False
for line in str(mod.__str__).split("\n"):
if "@main" in line:
mod_main_start = True
continue

if mod_main_start == True:
if "}" == line:
break
elif not ("/*" in line and "*/" in line):
fail_cases.append(line)

print(fail_cases)
assert len(fail_cases) == 0


def simple_model():
input_node = tf.placeholder(shape=[None, None, 3, 1], dtype=np.float32, name="input")

shape = tf.shape(input_node)
stack = tf.stack([shape[0], 3, 3], axis=0)
output_node = tf.reshape(input_node, stack, name="output")
return output_node


#######################################################################
# Span fill up
# -------
def test_span_complement_simple_model():
with tf.Graph().as_default() as graph:
model_graph = simple_model()
graph_def = graph.as_graph_def()

graph_def = tf_testing.ProcessGraphDefParam(graph_def)

mod, params = relay.frontend.from_tensorflow(graph_def, shape={"input:0", (1, 3, 3, 1)})
verify_span(mod)


def test_span_complement_big_model():
with tf.Graph().as_default() as graph:
graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
# Call the utility to import the graph definition into default graph.
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

mod, params = relay.frontend.from_tensorflow(
graph_def, shape={"input_tensor:0", (128, 224, 224, 3)}
)
verify_span(mod)


#######################################################################
# Pooling
# -------
Expand Down
Loading