From e0524fc54388a837d95ee02ae4efa1845d1fd09d Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Wed, 3 Apr 2024 09:35:29 -0700 Subject: [PATCH] Fix bug regression in experimental value info export (#1341) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #1341 Previous fix didn't export subgraph value info. --- .../function_libs/torch_lib/graph_building.py | 6 +- .../torch_lib/graph_building_test.py | 62 ++++++++++--------- 2 files changed, 38 insertions(+), 30 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index a35a61605..de6de9323 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -822,7 +822,9 @@ def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto): # nn.Modules exported by dynamo exporter have unique call sites, their function # op_type name can serve to form the unique identifier for value info. # Store inside top level GraphProto. - new_value_info = self.generate_maingraph_value_info_proto() + new_value_info = self.generate_subgraphs_value_info_proto() + # Insert value info for nodes in top level graph. + new_value_info.update(self.generate_maingraph_value_info_proto()) # Do not store input, output or initializer into value_info for input in onnx_model.graph.input: new_value_info.pop(input.name, None) @@ -908,7 +910,7 @@ def generate_function_value_info_proto( return named_value_info @runtime_typing.checked - def generate_subgraphs_value_info_proto(self) -> Mapping[str, onnx.ValueInfoProto]: + def generate_subgraphs_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]: """Unique naming strategies for values inside subgraphs, i.e. local functions. {function_domain::function_op_type}/{value_name} diff --git a/onnxscript/function_libs/torch_lib/graph_building_test.py b/onnxscript/function_libs/torch_lib/graph_building_test.py index 3ce29c26a..5e12a7e58 100644 --- a/onnxscript/function_libs/torch_lib/graph_building_test.py +++ b/onnxscript/function_libs/torch_lib/graph_building_test.py @@ -142,31 +142,32 @@ def test_add_initializer_allows_adding_the_same_tensor_twice_using_same_name(sel graph.add_initializer("x", x_tensor) +class _MLP(torch.nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.fc1 = torch.nn.Linear(input_size, hidden_size) + self.fc2 = torch.nn.Linear(hidden_size, output_size) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + @unittest.skipIf( IS_WINDOWS and version_utils.torch_older_than("2.3"), "dynamo_export not supported on Windows in PyTorch<2.3", ) class TestModelSaving(unittest.TestCase): def test_save_initializer_to_files_for_large_model(self): - class MLP(torch.nn.Module): - def __init__(self, input_size, hidden_size, output_size): - super().__init__() - self.fc1 = torch.nn.Linear(input_size, hidden_size) - self.fc2 = torch.nn.Linear(hidden_size, output_size) - self.relu = torch.nn.ReLU() - - def forward(self, x): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return out - # # of model parameters: # input_size x hidden_size + hidden_size + # hidden_size x output_size + output_size # ~= 3GB below batch_size, input_size, hidden_size, output_size = 1, 4, 50000000, 10 - model = MLP(input_size, hidden_size, output_size) + model = _MLP(input_size, hidden_size, output_size) x = torch.randn(batch_size, input_size) model_proto = torch.onnx.dynamo_export(model, x).model_proto @@ -174,21 +175,8 @@ def forward(self, x): self.assertGreater(model_proto.ByteSize(), 2**31) def test_input_output_and_initializer_are_not_stored_in_value_info(self): - class MLP(torch.nn.Module): - def __init__(self, input_size, hidden_size, output_size): - super().__init__() - self.fc1 = torch.nn.Linear(input_size, hidden_size) - self.fc2 = torch.nn.Linear(hidden_size, output_size) - self.relu = torch.nn.ReLU() - - def forward(self, x): - out = self.fc1(x) - out = self.relu(out) - out = self.fc2(out) - return out - batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10 - model = MLP(input_size, hidden_size, output_size) + model = _MLP(input_size, hidden_size, output_size) x = torch.randn(batch_size, input_size) model_proto = torch.onnx.dynamo_export(model, x).model_proto @@ -201,6 +189,24 @@ def forward(self, x): for i in model_proto.graph.initializer: self.assertNotIn(i.name, v_names) + def test_experimental_function_value_info_are_stored_in_graph_value_info(self): + batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10 + model = _MLP(input_size, hidden_size, output_size) + x = torch.randn(batch_size, input_size) + + model_proto = torch.onnx.dynamo_export(model, x).model_proto + v_names = {v.name for v in model_proto.graph.value_info} + torch_functions = [ + f for f in model_proto.functions if f.domain.startswith("pkg.torch") + ] + self.assertNotEqual(len(torch_functions), 0) + for f in torch_functions: + for n in f.node: + for i in n.input: + self.assertIn(f"{f.domain}::{f.name}/{i}", v_names) + for o in n.output: + self.assertIn(f"{f.domain}::{f.name}/{o}", v_names) + if __name__ == "__main__": unittest.main()