Skip to content

Commit

Permalink
Add dynamism support to aten.embedding and aten.split_with_sizes (#…
Browse files Browse the repository at this point in the history
…6781)

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
  • Loading branch information
lsy323 and Siyuan Liu authored Mar 20, 2024
1 parent fcf24b6 commit 6b76c7f
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 23 deletions.
39 changes: 39 additions & 0 deletions test/stablehlo/test_export_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,45 @@ def test_decompose_dynamic_shape_select(self):
out2 = ep.module()(*args)
self.assertTrue(torch.allclose(out1, out2))

def test_decompose_dynamic_split_with_sizes(self):

class M(torch.nn.Module):

def forward(self, x):
res = torch.ops.aten.split_with_sizes.default(x, [1, 2, 3], -1)
return res[0], res[1]

args = (torch.rand((3, 10, 6)),)
dynamic_shapes = ({0: Dim("dim")},)
m = M()
ep = export(m, args, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(*args)
decompose_split_with_sizes(ep.graph_module)
ep.graph_module.recompile()
self.assertTrue('split_with_sizes' in ep.graph_module.code)
out2 = ep.module()(*args)
self.assertTrue(torch.allclose(out1[0], out2[0]))
self.assertTrue(torch.allclose(out1[1], out2[1]))

def test_embedding_indices_flatten(self):
args = (torch.rand((20, 768)), torch.randint(0, 15,
(3, 10)).to(torch.int64))
dynamic_shapes = ([None, {0: Dim("bs")}],)
m = wrap_func_as_nn_module(torch.ops.aten.embedding.default)
ep = export(m, args, dynamic_shapes=dynamic_shapes)
print(ep)
out1 = ep.module()(*args)
flatten_embedding_indices_tensor(ep.graph_module)
ep.graph_module.recompile()
print(ep)
self.assertTrue('aten.view' in ep.graph_module.code)
replace_dynamic_view_with_xla_op(ep.graph_module)
ep.graph_module.recompile()
self.assertTrue('aten.view' not in ep.graph_module.code)
self.assertTrue('xla.dynamic_view' in ep.graph_module.code)
out2 = ep.module()(*args)
self.assertTrue(torch.allclose(out1, out2))

def test_no_op_slice_removal(self):

class M(torch.nn.Module):
Expand Down
69 changes: 48 additions & 21 deletions test/stablehlo/test_unbounded_dynamism.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,25 +293,30 @@ def test_gelu(self):
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

@unittest.skip("Unbounded dynamism is not supported yet.")
def test_embedding(self):

class M(torch.nn.Module):

def forward(self, x, y):
res = torch.ops.aten.embedding.default(x, y)
return res[0], res[1]
return res

args = (torch.rand((1, 768)), torch.randint(0, 15, (3, 10)).to(torch.int64))
args = (torch.rand((20, 768)), torch.randint(0, 15,
(3, 10)).to(torch.int64))
dynamic_shapes = (None, {0: Dim("dim")})
# dynamic_shapes = None
m = M()
ep = export(m, args=args, dynamic_shapes=dynamic_shapes)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(r"%arg.: tensor<\?x5xf32>.*->.*tensor<\?x5xi32>", shlo_text)
is not None)
re.search(r"%arg.: tensor<\?x10xi64>.*->.*tensor<\?x10x768xf32>",
shlo_text) is not None)
if has_tf_package():
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(
m, args, tempdir, dynamic_shapes=dynamic_shapes)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

def test_mean(self):

Expand Down Expand Up @@ -372,7 +377,7 @@ def test_mul_scalar(self):
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

@unittest.skip("Implicit broadcasting logic is broken.")
@unittest.skip("Unbounded dynamism is not supported.")
def test_ne_scalar(self):

class M(torch.nn.Module):
Expand All @@ -382,20 +387,19 @@ def forward(self, x):

args = (torch.rand((3, 5)).to(torch.int64),)
dynamic_shapes = ({0: Dim("dim")},)
# dynamic_shapes = None
m = M()
ep = export(m, args=args, dynamic_shapes=dynamic_shapes)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(r"%arg.: tensor<\?x5xf32>.*->.*tensor<\?x5xi32>", shlo_text)
is not None)
# if has_tf_package():
# with tempfile.TemporaryDirectory() as tempdir:
# save_torch_module_as_tf_saved_model(
# m, args, tempdir, dynamic_shapes=dynamic_shapes)
# self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
# compare_exported_program_and_saved_model_result(ep, tempdir, args)
if has_tf_package():
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(
m, args, tempdir, dynamic_shapes=dynamic_shapes)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

def test_var(self):

Expand Down Expand Up @@ -534,6 +538,24 @@ def test_slice(self):
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

def test_slice_2(self):
args = (torch.rand((10, 3, 224, 224)), 1, 0, 2)
dynamic_shapes = ([{0: Dim("dim")}, None, None, None],)
m = wrap_func_as_nn_module(torch.ops.aten.slice.Tensor)
ep = export(m, args=args, dynamic_shapes=dynamic_shapes)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(
r"%arg.: tensor<\?x3x224x224xf32>.*->.*tensor<\?x2x224x224xf32>",
shlo_text) is not None)
if has_tf_package():
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(
m, args, tempdir, dynamic_shapes=dynamic_shapes)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

def test_softmax(self):
args = (torch.rand((10, 12, 197, 197)), -1, False)
dynamic_shapes = ([{0: Dim("dim")}, None, None],)
Expand Down Expand Up @@ -624,25 +646,30 @@ def test_sub_scalar(self):
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

@unittest.skip("Unbounded dynamism is not supported yet.")
def test_split_with_sizes(self):

class M(torch.nn.Module):

def forward(self, x):
res = torch.ops.aten.split_with_sizes.default(x, [1, 1], -1)
return res[0], res[1]
res = torch.ops.aten.split_with_sizes.default(x, [1, 2, 3], -1)
return res[0], res[1], res[2]

args = (torch.rand((3, 10, 2)),)
args = (torch.rand((3, 10, 6)),)
dynamic_shapes = ({0: Dim("dim")},)
# dynamic_shapes = None
m = M()
ep = export(m, args=args, dynamic_shapes=dynamic_shapes)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text()
self.assertTrue(
re.search(r"%arg.: tensor<\?x5xf32>.*->.*tensor<\?x5xi32>", shlo_text)
is not None)
re.search(
r"%arg.: tensor<\?x10x6xf32>.*->.*tensor<\?x10x1xf32>.*tensor<\?x10x2xf32>.*tensor<\?x10x3xf32>",
shlo_text) is not None)
if has_tf_package():
with tempfile.TemporaryDirectory() as tempdir:
save_torch_module_as_tf_saved_model(
m, args, tempdir, dynamic_shapes=dynamic_shapes)
self.assertTrue(os.path.exists(os.path.join(tempdir, 'saved_model.pb')))
compare_exported_program_and_saved_model_result(ep, tempdir, args)

def test_transpose_on_dynamic_dim(self):
args = (torch.rand((1, 8, 3, 256)),)
Expand Down
97 changes: 95 additions & 2 deletions torch_xla/experimental/unbounded_dynamism_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,94 @@ def decompose_dynamic_shape_select(gm: GraphModule):
graph.erase_node(n)


def decompose_split_with_sizes(gm: GraphModule):
'''
Decompose `split_with_sizes` with symbolic input shape into `aten.slice` ops.
This is to reuse the unbounded dynamism support added to `aten.slice`.
'''
graph = gm.graph
for n in graph.nodes:
if n.op == "call_function" and n.target == aten.split_with_sizes.default:
src_node = n.args[0]
src_shape = src_node.meta['val'].size()
symbolic_dims = [
i for i, x in enumerate(src_shape) if not isinstance(x, int)
]
if len(symbolic_dims) == 0:
continue
assert len(symbolic_dims) == 1, "Only 1 dimention can be symbolic."
split_sizes = n.args[1]
split_dim = n.args[2]
assert symbolic_dims[
0] != split_dim, "Split along symbolic dim is not supported."
with graph.inserting_before(n):
start_idx = 0
decomposed_slices = []
for size in split_sizes:
slice_args = (src_node, split_dim, start_idx, start_idx + size)
slice_node = graph.call_function(torch.ops.aten.slice.Tensor,
slice_args)
start_idx += size
decomposed_slices.append(slice_node)
consumers = n.users
for consumer in consumers:
assert n.op == "call_function" and consumer.target.__name__ == "getitem"
slice_idx = consumer.args[1]
consumer.replace_all_uses_with(decomposed_slices[slice_idx])


def flatten_embedding_indices_tensor(gm: GraphModule):
'''
Flatten the indices tensor of `aten.embedding.default` to avoid `view` op with
symbolic input during LTC tracing.
The indices tensor will be flattened, the embedding output shape will be recovered.
The symbolic shape in `aten.view` will be handled by `aten.view` -> `xla.dynamic_view` pass.
'''
graph = gm.graph
for n in graph.nodes:
if n.op == "call_function" and n.target == aten.embedding.default:
select_src_shape = n.args[1].meta['val'].shape
symbolic_dims = [
i for i, x in enumerate(select_src_shape) if not isinstance(x, int)
]
if len(symbolic_dims) > 0:
assert len(symbolic_dims) == 1, "Only 1 dimention can be symbolic."
with graph.inserting_before(n):
indices_node = n.args[1]
indices_shape = indices_node.meta['val'].size()
flatten_mul_scale = 1
get_dim_size_node = None
recover_view_shape = []
for dim, size in enumerate(indices_shape):
if not isinstance(size, int):
get_dim_size_args = (indices_node, dim)
get_dim_size_node = graph.call_function(aten.sym_size.int,
get_dim_size_args)
recover_view_shape.append(get_dim_size_node)
else:
flatten_mul_scale *= size
recover_view_shape.append(size)
weight_shape = n.args[0].meta['val'].size()
recover_view_shape.append(weight_shape[-1])

mul_args = (get_dim_size_node, flatten_mul_scale)
flatten_size_node = graph.call_function(aten.mul.Scalar, mul_args)
view_args = (indices_node, [flatten_size_node])
view_node = graph.call_function(aten.view, view_args)
new_embedding_args = n.args[0:1] + (view_node,)
if len(n.args) > 2:
new_embedding_args += n.args[2:]
n.args = new_embedding_args
with graph.inserting_after(n):
recover_view_args = (n, recover_view_shape)
recover_view_node = graph.call_function(aten.view, recover_view_args)
n.replace_all_uses_with(recover_view_node)
recover_view_node.update_arg(0, n)


def _is_no_op_slice(n):
assert n.op == "call_function" and n.target == aten.slice.Tensor
return n.args[2] == 0 and n.args[3] == torch.iinfo(torch.int64).max
Expand Down Expand Up @@ -202,8 +290,10 @@ def replace_dynamic_view_with_xla_op(gm: GraphModule):
mul_scaler = 1
sym_size_node = dynamic_src
mul_node = None
if hasattr(dynamic_src.target,
"__name__") and dynamic_src.target.__name__ == "mul":
if hasattr(
dynamic_src.target,
"__name__") and (dynamic_src.target.__name__ == "mul" or
dynamic_src.target.__name__ == "mul.Scalar"):
assert isinstance(dynamic_src.args[0], int) or isinstance(
dynamic_src.args[1], int)
mul_node = dynamic_src
Expand Down Expand Up @@ -236,6 +326,7 @@ def dynamic_unsqueeze_to_view(gm: GraphModule):
]
if len(symbolic_dims) == 0:
continue
assert len(symbolic_dims) == 1, "Only 1 dimention can be symbolic."
view_args = list(src_shape)
with graph.inserting_before(n):
for dim in symbolic_dims:
Expand Down Expand Up @@ -266,10 +357,12 @@ def exported_program_has_symbolic_input_shape(ep):
def process_exported_program_with_symbolic_input(ep):
passes = [
decompose_dynamic_shape_select,
decompose_split_with_sizes,
remove_no_op_slice,
decompose_dynamic_native_group_norm,
decompose_dynamic_native_layer_norm,
dynamic_unsqueeze_to_view,
flatten_embedding_indices_tensor,
replace_dynamic_expand_with_xla_op,
replace_dynamic_view_with_xla_op,
]
Expand Down

0 comments on commit 6b76c7f

Please sign in to comment.