Skip to content

Commit

Permalink
Add fx passes to support exporting unbounded dynamism (#6653)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
  • Loading branch information
lsy323 and Siyuan Liu authored Mar 20, 2024
1 parent 1ccd6a6 commit 7e0d3a5
Show file tree
Hide file tree
Showing 22 changed files with 1,925 additions and 395 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ function run_xla_op_tests2 {
function run_xla_op_tests3 {
# TODO(qihqi): this test require tensorflow to run. need to setup separate
# CI with tf.
run_test "$CDIR/stablehlo/test_export_fx_passes.py"
run_test "$CDIR/stablehlo/test_implicit_broadcasting.py"
run_test "$CDIR/stablehlo/test_mark_pattern.py"
run_test "$CDIR/stablehlo/test_pt2e_qdq.py"
Expand Down
250 changes: 250 additions & 0 deletions test/stablehlo/test_export_fx_passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
import os
import re
import sys
import unittest

import numpy as np
import torch
import torch.utils._pytree as pytree
import torch_xla
import torch_xla.core.xla_model as xm
from torch.export import Dim, export
from torch_xla.experimental.unbounded_dynamism_export import *
from torch_xla.stablehlo import exported_program_to_stablehlo
from torch_xla.utils.stablehlo_test_utils import wrap_func_as_nn_module


class ExportFxPassTest(unittest.TestCase):

def test_decompose_dynamic_shape_select(self):
args = (torch.rand((10, 197, 768)), 1, 0)
dynamic_shapes = ([{0: Dim("bs")}, None, None],)
m = wrap_func_as_nn_module(torch.ops.aten.select.int)
ep = export(m, args, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(*args)
decompose_dynamic_shape_select(ep.graph_module)
ep.graph_module.recompile()
self.assertTrue('aten.view' in ep.graph_module.code)
replace_dynamic_view_with_xla_op(ep.graph_module)
ep.graph_module.recompile()
self.assertTrue('aten.view' not in ep.graph_module.code)
self.assertTrue('xla.dynamic_view' in ep.graph_module.code)
out2 = ep.module()(*args)
self.assertTrue(torch.allclose(out1, out2))

def test_no_op_slice_removal(self):

class M(torch.nn.Module):

def forward(self, x):
x = x * 2
return torch.ops.aten.slice(x, 1, 0, 9223372036854775807)

m = M()
args = (torch.rand((10, 197, 768)),)
dynamic_shapes = ({0: Dim("bs")},)
ep = export(m, args, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(*args)
self.assertTrue('aten.slice' in ep.graph_module.code)
remove_no_op_slice(ep.graph_module)
ep.graph_module.recompile()
self.assertTrue('aten.slice' not in ep.graph_module.code)
out2 = ep.module()(*args)
self.assertTrue(torch.allclose(out1, out2))

def test_dynamic_view(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 5, [16, 16])

def forward(self, x):
x = self.conv(x)
return x.view(x.shape[0], x.shape[1], -1)

m = M()
args = (torch.rand((10, 3, 224, 224)),)
dynamic_shapes = ({0: Dim("bs")},)
ep = export(m, args, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(*args)
replace_dynamic_view_with_xla_op(ep.graph_module)
ep.graph_module.recompile()
self.assertTrue('xla.dynamic_view' in ep.graph_module.code)
out2 = ep.module()(*args)
self.assertTrue(torch.allclose(out1, out2))

def test_dynamic_view_non_bs(self):

class M(torch.nn.Module):

def forward(self, x):
return x.view(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])

m = M()
args = (torch.rand((1, 3, 2, 16)),)
dynamic_shapes = ({1: Dim("bs")},)
ep = export(m, args, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(*args)
replace_dynamic_view_with_xla_op(ep.graph_module)
ep.graph_module.recompile()
self.assertTrue('xla.dynamic_view' in ep.graph_module.code)
out2 = ep.module()(*args)
self.assertTrue(torch.allclose(out1, out2))

def test_dynamic_view_multiplier(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 5, [16, 16])

def forward(self, x):
x = self.conv(x)
return x.view(x.shape[0] * x.shape[1], -1)

m = M()
args = (torch.rand((10, 3, 224, 224)),)
dynamic_shapes = ({0: Dim("bs")},)
ep = export(m, args, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(*args)
replace_dynamic_view_with_xla_op(ep.graph_module)
print(ep)
ep.graph_module.recompile()
print(ep.graph_module.code)
self.assertTrue('xla.dynamic_view' in ep.graph_module.code)
out2 = ep.module()(*args)
self.assertTrue(torch.allclose(out1, out2))

def test_dynamic_expand(self):

class M(torch.nn.Module):

def forward(self, x, image):
return x.expand([image.shape[0], -1, -1])

m = M()
args = (torch.rand((1, 1, 5)), torch.rand((3, 4)))
dynamic_shapes = (
None,
{
0: Dim("bs")
},
)
ep = export(m, args, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(*args)
replace_dynamic_expand_with_xla_op(ep.graph_module)
ep.graph_module.recompile()
self.assertTrue('xla.dynamic_expand' in ep.graph_module.code)
out2 = ep.module()(*args)
self.assertTrue(torch.allclose(out1, out2))

def test_dynamic_expand_2(self):

class M(torch.nn.Module):

def forward(self, x, range):
return x.expand(1, 1, 8, range.shape[0], 256)

m = M()
args = (torch.rand((1, 1, 1, 3, 256)), torch.arange(3))
dynamic_shapes = ({3: Dim("bs")}, {0: Dim("bs")})
ep = export(m, args, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(*args)
print(ep)
replace_dynamic_expand_with_xla_op(ep.graph_module)
print(ep)
ep.graph_module.recompile()
self.assertTrue('xla.dynamic_expand' in ep.graph_module.code)
out2 = ep.module()(*args)
self.assertTrue(torch.allclose(out1, out2))

def test_layer_norm_decomp(self):

class M(torch.nn.Module):

def forward(self, x, dim, weight, bias, eps):
return torch.ops.aten.native_layer_norm.default(x, dim, weight, bias,
eps)[0]

args = (torch.rand(10, 197,
768), [768], torch.rand(768), torch.rand(768), 1e-12)
dynamic_shapes = ({0: Dim("bs")}, [None], None, None, None)
m = M().eval()
before_decomp_out = m(*args)
after_decomp_out = native_layer_norm_impl(*args)
self.assertTrue(
torch.allclose(before_decomp_out, after_decomp_out, atol=1e-6))
ep = export(m, args, dynamic_shapes=dynamic_shapes)
decompose_dynamic_native_layer_norm(ep.graph_module)
ep.graph_module.recompile()
self.assertFalse('aten.native_layer_norm' in ep.graph_module.code)
after_decomp_out_2 = ep.module()(*args)
self.assertTrue(
torch.allclose(before_decomp_out, after_decomp_out_2, atol=1e-6))

def test_group_norm_to_layer_norm(self):

class M(torch.nn.Module):

def forward(self, x, weight, bias, N, C, HxW, group, eps):
return torch.ops.aten.native_group_norm.default(x, weight, bias, N, C,
HxW, group, eps)[0]

class M2(torch.nn.Module):

def __init__(self):
super().__init__()
# self.conv = torch.nn.Conv1d(1, 512, 10, stride=5)
self.layer_norm = torch.nn.GroupNorm(
num_groups=512, num_channels=512, affine=True)

def forward(self, x):
return self.layer_norm(x)[0]

args = (torch.rand(10, 512, 159), torch.rand(512), torch.rand(512), 10, 512,
159, 512, 1e-12)
export_args = (torch.rand(10, 512, 159),)
dynamic_shapes = ({0: Dim("bs")},)
m = M().eval()
before_decomp_out = m(*args)
after_decomp_out = native_group_norm_impl(*args)
self.assertTrue(
torch.allclose(before_decomp_out, after_decomp_out, atol=1e-6))
# Test export path with a different to workaround an export issue.
m2 = M2().eval()
ep = export(m2, export_args, dynamic_shapes=dynamic_shapes)
before_decomp_ep_out = m2(*export_args)
decompose_dynamic_native_group_norm(ep.graph_module)
ep.graph_module.recompile()
self.assertFalse('aten.native_group_norm' in ep.graph_module.code)
after_decomp_ep_out = ep.module()(*export_args)
# print(before_decomp_ep_out - after_decomp_ep_out)
self.assertTrue(
torch.allclose(before_decomp_ep_out, after_decomp_ep_out, atol=1e-6))

def test_dynamic_unsqueeze_to_view(self):

class M(torch.nn.Module):

def forward(self, x):
return torch.ops.aten.unsqueeze.default(x, 2)

args = (torch.rand((1, 1, 3, 256)),)
dynamic_shapes = ({2: Dim("dim")},)
m = M().eval()
ep = export(m, args, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(*args)
dynamic_unsqueeze_to_view(ep.graph_module)
ep.graph_module.recompile()
self.assertFalse('aten.unsqueeze' in ep.graph_module.code)
self.assertTrue('aten.view' in ep.graph_module.code)
out2 = ep.module()(*args)
self.assertTrue(torch.allclose(out1, out2))


if __name__ == "__main__":
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
2 changes: 1 addition & 1 deletion test/stablehlo/test_mark_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch_xla import stablehlo
from torch_xla.experimental import xla_marker
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
from utils import has_tf_package
from torch_xla.utils.stablehlo_test_utils import has_tf_package

try:
from torch_xla.tf_saved_model_integration import \
Expand Down
2 changes: 1 addition & 1 deletion test/stablehlo/test_pt2e_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer, get_symmetric_quantization_config)
from torch_xla import stablehlo
from utils import has_tf_package
from torch_xla.utils.stablehlo_test_utils import has_tf_package

try:
from torch_xla.tf_saved_model_integration import \
Expand Down
5 changes: 3 additions & 2 deletions test/stablehlo/test_saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from torch_xla.tf_saved_model_integration import (
make_tf_function, save_stablehlo_graph_as_tf,
save_torch_module_as_tf_saved_model)
from utils import (compare_exported_program_and_saved_model_result,
has_tf_package, wrap_func_as_nn_module)
from torch_xla.utils.stablehlo_test_utils import (
compare_exported_program_and_saved_model_result, has_tf_package,
wrap_func_as_nn_module)


class StableHLOInferenceTest(unittest.TestCase):
Expand Down
Loading

0 comments on commit 7e0d3a5

Please sign in to comment.