Skip to content

Commit

Permalink
Elementwise Concat Op's forward function should pack input tensors (#971
Browse files Browse the repository at this point in the history
)

* Modify Elementwise Concat Op definition

Signed-off-by: Sundar Raman <quic_sundarr@quicinc.com>
  • Loading branch information
quic-sundarr authored and quic-bharathr committed Jan 27, 2022
1 parent c839be4 commit 5f97eba
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
# =============================================================================

""" Modules for functional elementwise ops """
from typing import List, Tuple, Union
import torch
import torch.nn

Expand Down Expand Up @@ -88,13 +87,16 @@ def forward(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

class Concat(torch.nn.Module):
""" Concat module for a functional concat"""
def __init__(self, axis: int = 0):
super(Concat, self).__init__()
self._axis = axis

# pylint:disable=arguments-differ
@staticmethod
def forward(x: Union[Tuple[torch.Tensor], List[torch.Tensor]], dim: int = 0) -> torch.Tensor:
def forward(self, *x) -> torch.Tensor:
"""
Forward-pass routine for cat op
"""
return torch.cat(x, dim=dim)
return torch.cat(x, dim=self._axis)


class MatMul(torch.nn.Module):
Expand Down
91 changes: 80 additions & 11 deletions TrainingExtensions/torch/src/python/aimet_torch/model_preparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,67 @@
'subtract' : elementwise_ops.Subtract,
'mul' : elementwise_ops.Multiply,
'div' : elementwise_ops.Divide,
'cat' : elementwise_ops.Concat,
'matmul' : elementwise_ops.MatMul
}

functional_to_module_special_handling_map = {

# Operations that require special transformation
'cat' : elementwise_ops.Concat
}


def concat_create_node(symbolic_traced_model: torch.fx.GraphModule, module_name: str, node: torch.fx.node) \
-> torch.fx.node:
"""
Create the node to be inserted in the graph model.
:param symbolic_traced_model: Symbolically traced model
:param module_name: Qualified module name in symbolic_traced_model hierarchy corresponding to new node
:param node: Current node in the graph after which new node will be inserted
:return: torch.fx.node to be inserted in the graph
"""

with symbolic_traced_model.graph.inserting_after(node):
num_args = len(node.args)
if num_args == 1:
# Handle torch.cat being called with default parameter dim
forward_args = node.args[0]
new_node = symbolic_traced_model.graph.call_module(module_name, args=forward_args)
else:
forward_args = tuple(node.args[0])
new_node = symbolic_traced_model.graph.call_module(module_name, args=forward_args)
return new_node


def concat_create_module(node: torch.fx.node) -> torch.nn.Module:
"""
Create the replacement module.
:param node: Current node in the graph after which new node will be inserted
:return:
"""

num_args = len(node.args)
if num_args == 1:
# Handle torch.cat being called with default parameter dim
kwargs = node.kwargs
module = elementwise_ops.Concat()
else:
module = elementwise_ops.Concat(node.args[1])
kwargs = {'axis': node.args[1]}

for key, value in kwargs.items():
setattr(module, key, value)

return module


special_handler_functions = {
# Special handling functions for creating node and module
'cat': {'node_fn': concat_create_node, 'module_fn': concat_create_module}
}


def prepare_model(model: torch.nn.Module, concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
"""
Expand Down Expand Up @@ -194,7 +251,7 @@ def forward(self, *inputs):
new_nodule_name = 'module_' + node.name
setattr(symbolic_traced_model, new_nodule_name, new_module)
# Create the node for new module in the graph
_create_node_for_new_module(symbolic_traced_model, node, new_nodule_name)
_create_node_for_new_module(symbolic_traced_model, node, new_nodule_name, functional_name)
logger.info("Functional : Adding new module for node: {%s} ", node.name)

# Create new module for reused/duplicate nodes
Expand Down Expand Up @@ -230,17 +287,25 @@ def _verify_symbolic_traced_model(symbolic_traced_model: torch.fx.GraphModule):


def _create_node_for_new_module(symbolic_traced_model: torch.fx.GraphModule, node: torch.fx.node,
module_name: str):
module_name: str, functional_name: str = None):
"""
Insert 'call module' node into graph and replace all the uses of 'node' with newly added node and erase the
the old node from graph.
:param symbolic_traced_model: Symbolically traced model
:param node: Current node in the graph after which new node will be inserted
:param module_name: Qualified module name in symbolic_traced_model hierarchy corresponding to new node
:param functional_name: Original functional name
:return: None
"""
with symbolic_traced_model.graph.inserting_after(node):
new_node = symbolic_traced_model.graph.call_module(module_name, args=node.args)
if functional_name:
if functional_name in functional_to_module_special_handling_map.keys():
new_node = special_handler_functions[functional_name]['node_fn'](symbolic_traced_model, module_name, node)
else:
new_node = symbolic_traced_model.graph.call_module(module_name, args=node.args)
else:
new_node = symbolic_traced_model.graph.call_module(module_name, args=node.args)

node.replace_all_uses_with(new_node)
symbolic_traced_model.graph.erase_node(node)

Expand All @@ -251,8 +316,9 @@ def _find_functional_name_for_node(node: torch.fx.node) -> Union[str, None]:
:param node: torch.fx Node
:return: corresponding functional name if found, else None
"""
for functional_name in functional_to_module_map:

combined_ops_map = {**functional_to_module_map, **functional_to_module_special_handling_map}
for functional_name in combined_ops_map:
# \b boundary character to find the exact match from the functional_to_module lookup
pattern = r"\b" + functional_name + r"\b"
if search(pattern, str(node.target)):
Expand All @@ -271,12 +337,15 @@ def _create_module_for_functional_node(node: torch.fx.node, functional_name: str
kwargs = node.kwargs

# Instantiate new module from lookup
module = functional_to_module_map[functional_name]()

# Set the parameters for module from node.kwargs
for key, value in kwargs.items():
setattr(module, key, value)

if functional_name in functional_to_module_map.keys():
module = functional_to_module_map[functional_name]()
# Set the parameters for module from node.kwargs
for key, value in kwargs.items():
setattr(module, key, value)
elif functional_name in functional_to_module_special_handling_map:
module = special_handler_functions[functional_name]['module_fn'](node)
else:
raise ValueError("Unsupported module: {}".format(functional_name))
return module


Expand Down
45 changes: 41 additions & 4 deletions TrainingExtensions/torch/test/python/test_elementwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,27 @@ def forward(self, input):
return x


class Model3(nn.Module):
def __init__(self, op):
super(Model3, self).__init__()
self.op1 = op

def forward(self, *x):
x = self.op1(*x)
return x


def dummy_forward_pass(model, args):
model.eval()
with torch.no_grad():
output = model(torch.randn((5, 10, 10, 20)))
return output


def forward_pass(model, iterations):
return torch.rand(1)


class TestTrainingExtensionElementwiseOps(unittest.TestCase):
def test_add_op(self):
torch.manual_seed(10)
Expand Down Expand Up @@ -103,7 +121,6 @@ def test_quantsim_export(self):
self.assertTrue(len(data['activation_encodings']) == 3)
self.assertTrue(len(data['param_encodings']) == 2)


def test_subtract_op(self):
torch.manual_seed(10)
model = Model(Subtract())
Expand Down Expand Up @@ -131,15 +148,35 @@ def test_divide_op(self):
out1 = torch.div(input1, input2)
self.assertTrue(np.allclose(out, out1))

def test_concat_op(self):
def test_concat_op_two_input_tensors(self):
torch.manual_seed(10)
model = Model(Concat())
model = Model3(Concat())
input1 = torch.rand((5, 10, 10, 20))
input2 = torch.rand((5, 10, 10, 20))
out = model([input1, input2], 0)
out = model(input1, input2)
out1 = torch.cat((input1, input2), 0)
self.assertTrue(np.allclose(out, out1))

def test_concat_op_four_input_tensors(self):
torch.manual_seed(10)
model = Model3(Concat())
input1 = torch.rand((5, 10, 10, 20))
input2 = torch.rand((5, 10, 10, 20))
input3 = torch.rand((5, 10, 10, 20))
input4 = torch.rand((5, 10, 10, 20))
out = model(input1, input2, input3, input4)
out1 = torch.cat((input1, input2, input3, input4), 0)
self.assertTrue(np.allclose(out, out1))

def test_concat_compute_encodings(self):
torch.manual_seed(10)
model = Model3(Concat())
dummy_input = torch.randn(5, 10, 10, 20)
sim = QuantizationSimModel(model, dummy_input)
sim.compute_encodings(dummy_forward_pass, None)
print(sim)
sim.export(path='./data', filename_prefix='concat_model', dummy_input=dummy_input)

def test_matmul_op(self):
torch.manual_seed(10)
model = Model(MatMul())
Expand Down
13 changes: 13 additions & 0 deletions TrainingExtensions/torch/test/python/test_model_preparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,3 +1076,16 @@ def forward(self, *inputs):

assert torch.allclose(model_transformed(*input_tensor),
model_with_branch_true(*input_tensor))

def test_inception_v3_compute_encodings(self):
model = models.inception_v3().eval()
model_transformed = prepare_model(model)
print(model_transformed)
input_shape = (1, 3, 299, 299)
input_tensor = torch.randn(*input_shape)
assert torch.allclose(model_transformed(input_tensor),
model(input_tensor))
quant_sim = QuantizationSimModel(model_transformed, dummy_input=input_tensor)
quant_sim.compute_encodings(evaluate, input_tensor)
quant_sim.model(input_tensor)

0 comments on commit 5f97eba

Please sign in to comment.