Skip to content

Commit

Permalink
Update base for rebase with regression fix on "Skip full model shape …
Browse files Browse the repository at this point in the history
…inference if model > 2GB | feat(optimizer)"

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Apr 3, 2024
2 parents 695a6e2 + e0524fc commit a4a4632
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 30 deletions.
6 changes: 4 additions & 2 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,9 @@ def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto):
# nn.Modules exported by dynamo exporter have unique call sites, their function
# op_type name can serve to form the unique identifier for value info.
# Store inside top level GraphProto.
new_value_info = self.generate_maingraph_value_info_proto()
new_value_info = self.generate_subgraphs_value_info_proto()
# Insert value info for nodes in top level graph.
new_value_info.update(self.generate_maingraph_value_info_proto())
# Do not store input, output or initializer into value_info
for input in onnx_model.graph.input:
new_value_info.pop(input.name, None)
Expand Down Expand Up @@ -908,7 +910,7 @@ def generate_function_value_info_proto(
return named_value_info

@runtime_typing.checked
def generate_subgraphs_value_info_proto(self) -> Mapping[str, onnx.ValueInfoProto]:
def generate_subgraphs_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]:
"""Unique naming strategies for values inside subgraphs, i.e. local functions.
{function_domain::function_op_type}/{value_name}
Expand Down
62 changes: 34 additions & 28 deletions onnxscript/function_libs/torch_lib/graph_building_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,53 +142,41 @@ def test_add_initializer_allows_adding_the_same_tensor_twice_using_same_name(sel
graph.add_initializer("x", x_tensor)


class _MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.fc2 = torch.nn.Linear(hidden_size, output_size)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out


@unittest.skipIf(
IS_WINDOWS and version_utils.torch_older_than("2.3"),
"dynamo_export not supported on Windows in PyTorch<2.3",
)
class TestModelSaving(unittest.TestCase):
def test_save_initializer_to_files_for_large_model(self):
class MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.fc2 = torch.nn.Linear(hidden_size, output_size)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out

# # of model parameters:
# input_size x hidden_size + hidden_size +
# hidden_size x output_size + output_size
# ~= 3GB below
batch_size, input_size, hidden_size, output_size = 1, 4, 50000000, 10
model = MLP(input_size, hidden_size, output_size)
model = _MLP(input_size, hidden_size, output_size)
x = torch.randn(batch_size, input_size)

model_proto = torch.onnx.dynamo_export(model, x).model_proto
# Assert model is larger than 2GB (~=3GB)
self.assertGreater(model_proto.ByteSize(), 2**31)

def test_input_output_and_initializer_are_not_stored_in_value_info(self):
class MLP(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.fc2 = torch.nn.Linear(hidden_size, output_size)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out

batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10
model = MLP(input_size, hidden_size, output_size)
model = _MLP(input_size, hidden_size, output_size)
x = torch.randn(batch_size, input_size)

model_proto = torch.onnx.dynamo_export(model, x).model_proto
Expand All @@ -201,6 +189,24 @@ def forward(self, x):
for i in model_proto.graph.initializer:
self.assertNotIn(i.name, v_names)

def test_experimental_function_value_info_are_stored_in_graph_value_info(self):
batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10
model = _MLP(input_size, hidden_size, output_size)
x = torch.randn(batch_size, input_size)

model_proto = torch.onnx.dynamo_export(model, x).model_proto
v_names = {v.name for v in model_proto.graph.value_info}
torch_functions = [
f for f in model_proto.functions if f.domain.startswith("pkg.torch")
]
self.assertNotEqual(len(torch_functions), 0)
for f in torch_functions:
for n in f.node:
for i in n.input:
self.assertIn(f"{f.domain}::{f.name}/{i}", v_names)
for o in n.output:
self.assertIn(f"{f.domain}::{f.name}/{o}", v_names)


if __name__ == "__main__":
unittest.main()

0 comments on commit a4a4632

Please sign in to comment.