From 6b76c7f6c55cf6a46d0a87094006b51d4704078b Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Wed, 20 Mar 2024 16:15:37 -0700 Subject: [PATCH] Add dynamism support to `aten.embedding` and `aten.split_with_sizes` (#6781) Co-authored-by: Siyuan Liu --- test/stablehlo/test_export_fx_passes.py | 39 ++++++++ test/stablehlo/test_unbounded_dynamism.py | 69 +++++++++---- .../experimental/unbounded_dynamism_export.py | 97 ++++++++++++++++++- 3 files changed, 182 insertions(+), 23 deletions(-) diff --git a/test/stablehlo/test_export_fx_passes.py b/test/stablehlo/test_export_fx_passes.py index 6a30475122c..9827bc56188 100644 --- a/test/stablehlo/test_export_fx_passes.py +++ b/test/stablehlo/test_export_fx_passes.py @@ -32,6 +32,45 @@ def test_decompose_dynamic_shape_select(self): out2 = ep.module()(*args) self.assertTrue(torch.allclose(out1, out2)) + def test_decompose_dynamic_split_with_sizes(self): + + class M(torch.nn.Module): + + def forward(self, x): + res = torch.ops.aten.split_with_sizes.default(x, [1, 2, 3], -1) + return res[0], res[1] + + args = (torch.rand((3, 10, 6)),) + dynamic_shapes = ({0: Dim("dim")},) + m = M() + ep = export(m, args, dynamic_shapes=dynamic_shapes) + out1 = ep.module()(*args) + decompose_split_with_sizes(ep.graph_module) + ep.graph_module.recompile() + self.assertTrue('split_with_sizes' in ep.graph_module.code) + out2 = ep.module()(*args) + self.assertTrue(torch.allclose(out1[0], out2[0])) + self.assertTrue(torch.allclose(out1[1], out2[1])) + + def test_embedding_indices_flatten(self): + args = (torch.rand((20, 768)), torch.randint(0, 15, + (3, 10)).to(torch.int64)) + dynamic_shapes = ([None, {0: Dim("bs")}],) + m = wrap_func_as_nn_module(torch.ops.aten.embedding.default) + ep = export(m, args, dynamic_shapes=dynamic_shapes) + print(ep) + out1 = ep.module()(*args) + flatten_embedding_indices_tensor(ep.graph_module) + ep.graph_module.recompile() + print(ep) + self.assertTrue('aten.view' in ep.graph_module.code) + replace_dynamic_view_with_xla_op(ep.graph_module) + ep.graph_module.recompile() + self.assertTrue('aten.view' not in ep.graph_module.code) + self.assertTrue('xla.dynamic_view' in ep.graph_module.code) + out2 = ep.module()(*args) + self.assertTrue(torch.allclose(out1, out2)) + def test_no_op_slice_removal(self): class M(torch.nn.Module): diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py index 11e1b87c06d..8c346f1cfbb 100644 --- a/test/stablehlo/test_unbounded_dynamism.py +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -293,25 +293,30 @@ def test_gelu(self): self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) compare_exported_program_and_saved_model_result(ep, tempdir, args) - @unittest.skip("Unbounded dynamism is not supported yet.") def test_embedding(self): class M(torch.nn.Module): def forward(self, x, y): res = torch.ops.aten.embedding.default(x, y) - return res[0], res[1] + return res - args = (torch.rand((1, 768)), torch.randint(0, 15, (3, 10)).to(torch.int64)) + args = (torch.rand((20, 768)), torch.randint(0, 15, + (3, 10)).to(torch.int64)) dynamic_shapes = (None, {0: Dim("dim")}) - # dynamic_shapes = None m = M() ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) shlo_text = shlo_module.get_stablehlo_text() self.assertTrue( - re.search(r"%arg.: tensor<\?x5xf32>.*->.*tensor<\?x5xi32>", shlo_text) - is not None) + re.search(r"%arg.: tensor<\?x10xi64>.*->.*tensor<\?x10x768xf32>", + shlo_text) is not None) + if has_tf_package(): + with tempfile.TemporaryDirectory() as tempdir: + save_torch_module_as_tf_saved_model( + m, args, tempdir, dynamic_shapes=dynamic_shapes) + self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) + compare_exported_program_and_saved_model_result(ep, tempdir, args) def test_mean(self): @@ -372,7 +377,7 @@ def test_mul_scalar(self): self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) compare_exported_program_and_saved_model_result(ep, tempdir, args) - @unittest.skip("Implicit broadcasting logic is broken.") + @unittest.skip("Unbounded dynamism is not supported.") def test_ne_scalar(self): class M(torch.nn.Module): @@ -382,7 +387,6 @@ def forward(self, x): args = (torch.rand((3, 5)).to(torch.int64),) dynamic_shapes = ({0: Dim("dim")},) - # dynamic_shapes = None m = M() ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) @@ -390,12 +394,12 @@ def forward(self, x): self.assertTrue( re.search(r"%arg.: tensor<\?x5xf32>.*->.*tensor<\?x5xi32>", shlo_text) is not None) - # if has_tf_package(): - # with tempfile.TemporaryDirectory() as tempdir: - # save_torch_module_as_tf_saved_model( - # m, args, tempdir, dynamic_shapes=dynamic_shapes) - # self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) - # compare_exported_program_and_saved_model_result(ep, tempdir, args) + if has_tf_package(): + with tempfile.TemporaryDirectory() as tempdir: + save_torch_module_as_tf_saved_model( + m, args, tempdir, dynamic_shapes=dynamic_shapes) + self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) + compare_exported_program_and_saved_model_result(ep, tempdir, args) def test_var(self): @@ -534,6 +538,24 @@ def test_slice(self): self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) compare_exported_program_and_saved_model_result(ep, tempdir, args) + def test_slice_2(self): + args = (torch.rand((10, 3, 224, 224)), 1, 0, 2) + dynamic_shapes = ([{0: Dim("dim")}, None, None, None],) + m = wrap_func_as_nn_module(torch.ops.aten.slice.Tensor) + ep = export(m, args=args, dynamic_shapes=dynamic_shapes) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text() + self.assertTrue( + re.search( + r"%arg.: tensor<\?x3x224x224xf32>.*->.*tensor<\?x2x224x224xf32>", + shlo_text) is not None) + if has_tf_package(): + with tempfile.TemporaryDirectory() as tempdir: + save_torch_module_as_tf_saved_model( + m, args, tempdir, dynamic_shapes=dynamic_shapes) + self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) + compare_exported_program_and_saved_model_result(ep, tempdir, args) + def test_softmax(self): args = (torch.rand((10, 12, 197, 197)), -1, False) dynamic_shapes = ([{0: Dim("dim")}, None, None],) @@ -624,25 +646,30 @@ def test_sub_scalar(self): self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) compare_exported_program_and_saved_model_result(ep, tempdir, args) - @unittest.skip("Unbounded dynamism is not supported yet.") def test_split_with_sizes(self): class M(torch.nn.Module): def forward(self, x): - res = torch.ops.aten.split_with_sizes.default(x, [1, 1], -1) - return res[0], res[1] + res = torch.ops.aten.split_with_sizes.default(x, [1, 2, 3], -1) + return res[0], res[1], res[2] - args = (torch.rand((3, 10, 2)),) + args = (torch.rand((3, 10, 6)),) dynamic_shapes = ({0: Dim("dim")},) - # dynamic_shapes = None m = M() ep = export(m, args=args, dynamic_shapes=dynamic_shapes) shlo_module = exported_program_to_stablehlo(ep) shlo_text = shlo_module.get_stablehlo_text() self.assertTrue( - re.search(r"%arg.: tensor<\?x5xf32>.*->.*tensor<\?x5xi32>", shlo_text) - is not None) + re.search( + r"%arg.: tensor<\?x10x6xf32>.*->.*tensor<\?x10x1xf32>.*tensor<\?x10x2xf32>.*tensor<\?x10x3xf32>", + shlo_text) is not None) + if has_tf_package(): + with tempfile.TemporaryDirectory() as tempdir: + save_torch_module_as_tf_saved_model( + m, args, tempdir, dynamic_shapes=dynamic_shapes) + self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb'))) + compare_exported_program_and_saved_model_result(ep, tempdir, args) def test_transpose_on_dynamic_dim(self): args = (torch.rand((1, 8, 3, 256)),) diff --git a/torch_xla/experimental/unbounded_dynamism_export.py b/torch_xla/experimental/unbounded_dynamism_export.py index d9fdb7a7165..042ba98d2c0 100644 --- a/torch_xla/experimental/unbounded_dynamism_export.py +++ b/torch_xla/experimental/unbounded_dynamism_export.py @@ -117,6 +117,94 @@ def decompose_dynamic_shape_select(gm: GraphModule): graph.erase_node(n) +def decompose_split_with_sizes(gm: GraphModule): + ''' + Decompose `split_with_sizes` with symbolic input shape into `aten.slice` ops. + + This is to reuse the unbounded dynamism support added to `aten.slice`. + ''' + graph = gm.graph + for n in graph.nodes: + if n.op == "call_function" and n.target == aten.split_with_sizes.default: + src_node = n.args[0] + src_shape = src_node.meta['val'].size() + symbolic_dims = [ + i for i, x in enumerate(src_shape) if not isinstance(x, int) + ] + if len(symbolic_dims) == 0: + continue + assert len(symbolic_dims) == 1, "Only 1 dimention can be symbolic." + split_sizes = n.args[1] + split_dim = n.args[2] + assert symbolic_dims[ + 0] != split_dim, "Split along symbolic dim is not supported." + with graph.inserting_before(n): + start_idx = 0 + decomposed_slices = [] + for size in split_sizes: + slice_args = (src_node, split_dim, start_idx, start_idx + size) + slice_node = graph.call_function(torch.ops.aten.slice.Tensor, + slice_args) + start_idx += size + decomposed_slices.append(slice_node) + consumers = n.users + for consumer in consumers: + assert n.op == "call_function" and consumer.target.__name__ == "getitem" + slice_idx = consumer.args[1] + consumer.replace_all_uses_with(decomposed_slices[slice_idx]) + + +def flatten_embedding_indices_tensor(gm: GraphModule): + ''' + Flatten the indices tensor of `aten.embedding.default` to avoid `view` op with + symbolic input during LTC tracing. + + The indices tensor will be flattened, the embedding output shape will be recovered. + + The symbolic shape in `aten.view` will be handled by `aten.view` -> `xla.dynamic_view` pass. + ''' + graph = gm.graph + for n in graph.nodes: + if n.op == "call_function" and n.target == aten.embedding.default: + select_src_shape = n.args[1].meta['val'].shape + symbolic_dims = [ + i for i, x in enumerate(select_src_shape) if not isinstance(x, int) + ] + if len(symbolic_dims) > 0: + assert len(symbolic_dims) == 1, "Only 1 dimention can be symbolic." + with graph.inserting_before(n): + indices_node = n.args[1] + indices_shape = indices_node.meta['val'].size() + flatten_mul_scale = 1 + get_dim_size_node = None + recover_view_shape = [] + for dim, size in enumerate(indices_shape): + if not isinstance(size, int): + get_dim_size_args = (indices_node, dim) + get_dim_size_node = graph.call_function(aten.sym_size.int, + get_dim_size_args) + recover_view_shape.append(get_dim_size_node) + else: + flatten_mul_scale *= size + recover_view_shape.append(size) + weight_shape = n.args[0].meta['val'].size() + recover_view_shape.append(weight_shape[-1]) + + mul_args = (get_dim_size_node, flatten_mul_scale) + flatten_size_node = graph.call_function(aten.mul.Scalar, mul_args) + view_args = (indices_node, [flatten_size_node]) + view_node = graph.call_function(aten.view, view_args) + new_embedding_args = n.args[0:1] + (view_node,) + if len(n.args) > 2: + new_embedding_args += n.args[2:] + n.args = new_embedding_args + with graph.inserting_after(n): + recover_view_args = (n, recover_view_shape) + recover_view_node = graph.call_function(aten.view, recover_view_args) + n.replace_all_uses_with(recover_view_node) + recover_view_node.update_arg(0, n) + + def _is_no_op_slice(n): assert n.op == "call_function" and n.target == aten.slice.Tensor return n.args[2] == 0 and n.args[3] == torch.iinfo(torch.int64).max @@ -202,8 +290,10 @@ def replace_dynamic_view_with_xla_op(gm: GraphModule): mul_scaler = 1 sym_size_node = dynamic_src mul_node = None - if hasattr(dynamic_src.target, - "__name__") and dynamic_src.target.__name__ == "mul": + if hasattr( + dynamic_src.target, + "__name__") and (dynamic_src.target.__name__ == "mul" or + dynamic_src.target.__name__ == "mul.Scalar"): assert isinstance(dynamic_src.args[0], int) or isinstance( dynamic_src.args[1], int) mul_node = dynamic_src @@ -236,6 +326,7 @@ def dynamic_unsqueeze_to_view(gm: GraphModule): ] if len(symbolic_dims) == 0: continue + assert len(symbolic_dims) == 1, "Only 1 dimention can be symbolic." view_args = list(src_shape) with graph.inserting_before(n): for dim in symbolic_dims: @@ -266,10 +357,12 @@ def exported_program_has_symbolic_input_shape(ep): def process_exported_program_with_symbolic_input(ep): passes = [ decompose_dynamic_shape_select, + decompose_split_with_sizes, remove_no_op_slice, decompose_dynamic_native_group_norm, decompose_dynamic_native_layer_norm, dynamic_unsqueeze_to_view, + flatten_embedding_indices_tensor, replace_dynamic_expand_with_xla_op, replace_dynamic_view_with_xla_op, ]