Skip to content

Commit

Permalink
Use stablehlo.composite instead of stablehlo.custom_call (#6789)
Browse files Browse the repository at this point in the history
  • Loading branch information
mlevesquedion authored Mar 21, 2024
1 parent d70b6c4 commit 3eeb15d
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 86 deletions.
130 changes: 72 additions & 58 deletions test/stablehlo/test_mark_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def forward(self, input_pos, k_val, v_val):
return self.update_cache_with_hlfb(input_pos, k_val, v_val)

def update_cache_with_hlfb(self, input_pos, k_val, v_val):
builder = StableHLOCompositeBuilder('update_kv_cache')
builder = StableHLOCompositeBuilder('test.update_kv_cache')
k_cache, v_cache, input_pos, k_val, v_val = builder.mark_inputs(
self.k_cache, self.v_cache, input_pos, k_val, v_val)
updated_k = k_cache.index_copy_(1, input_pos, k_val)
Expand Down Expand Up @@ -82,15 +82,15 @@ def test_basic(self):

def f(x):
x = x + 1
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "test.p", 0, "0", True)
x = x + 2
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", False)
x = torch.ops.xla.mark_tensor(x, "test.p", 0, "0", False)
return x

input_args = (torch.randn(5),)
stablehlo = self.run_func_get_stablehlo(f, input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 1)
self.assertTrue('{attributes = {}, name = "p"}' in stablehlo)
self.assertEqual(stablehlo.count("stablehlo.composite \"test.p\""), 1)
self.assertEqual(stablehlo.count('{decomposition = @test.p.impl}'), 1)

def test_sdpa_pattern(self):
import torch.nn.functional as F
Expand All @@ -99,25 +99,31 @@ class M(torch.nn.Module):

def forward(self, x, y):
q, k, v = x.split(128, dim=-2)
q = torch.ops.xla.mark_tensor(q, "sdpa", pos=0, id="0", is_input=True)
k = torch.ops.xla.mark_tensor(k, "sdpa", pos=1, id="0", is_input=True)
v = torch.ops.xla.mark_tensor(v, "sdpa", pos=2, id="0", is_input=True)
q = torch.ops.xla.mark_tensor(
q, "test.sdpa", pos=0, id="0", is_input=True)
k = torch.ops.xla.mark_tensor(
k, "test.sdpa", pos=1, id="0", is_input=True)
v = torch.ops.xla.mark_tensor(
v, "test.sdpa", pos=2, id="0", is_input=True)
attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
attn_out = torch.ops.xla.mark_tensor(
attn_out,
"sdpa",
"test.sdpa",
pos=0,
id="0",
is_input=False,
attr=xla_marker.serialize_composite_attr({"scale": 0.25}))
q, k, v = y.split(128, dim=-2)
q = torch.ops.xla.mark_tensor(q, "sdpa", pos=0, id="1", is_input=True)
k = torch.ops.xla.mark_tensor(k, "sdpa", pos=1, id="1", is_input=True)
v = torch.ops.xla.mark_tensor(v, "sdpa", pos=2, id="1", is_input=True)
q = torch.ops.xla.mark_tensor(
q, "test.sdpa", pos=0, id="1", is_input=True)
k = torch.ops.xla.mark_tensor(
k, "test.sdpa", pos=1, id="1", is_input=True)
v = torch.ops.xla.mark_tensor(
v, "test.sdpa", pos=2, id="1", is_input=True)
attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4)
attn_out2 = torch.ops.xla.mark_tensor(
attn_out2,
"sdpa",
"test.sdpa",
pos=0,
id="1",
is_input=False,
Expand All @@ -126,25 +132,26 @@ def forward(self, x, y):

input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64)))
stablehlo = self.run_func_get_stablehlo(M(), input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 2)
self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2)
self.assertTrue(
'{attributes = {scale = 2.500000e-01 : f32}, name = "sdpa"}}' in
stablehlo)
'{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}'
in stablehlo)
self.assertTrue(
'{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo)
'{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}'
in stablehlo)

def test_composite_builder_sdpa_pattern(self):

class M(torch.nn.Module):

def forward(self, x, y):
b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25})
b = StableHLOCompositeBuilder("test.sdpa", {"scale": 0.25})
q, k, v = x.split(128, dim=-2)
q, k, v = b.mark_inputs(q, k, v)
attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
attn_out = b.mark_outputs(attn_out)

b2 = StableHLOCompositeBuilder("sdpa", {"scale": 2})
b2 = StableHLOCompositeBuilder("test.sdpa", {"scale": 2})
q, k, v = y.split(128, dim=-2)
q, k, v = b2.mark_inputs(q, k, v)
attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4)
Expand All @@ -153,21 +160,22 @@ def forward(self, x, y):

input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64)))
stablehlo = self.run_func_get_stablehlo(M(), input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 2)
self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2)
self.assertTrue(
'{attributes = {scale = 2.500000e-01 : f32}, name = "sdpa"}}' in
stablehlo)
'{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}'
in stablehlo)
self.assertTrue(
'{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo)
'{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}'
in stablehlo)

def test_composite_builder_export_sdpa_pattern(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()
self.b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25})
self.b2 = StableHLOCompositeBuilder("sdpa", {"scale": 2})
self.b = StableHLOCompositeBuilder("test.sdpa", {"scale": 0.25})
self.b2 = StableHLOCompositeBuilder("test.sdpa", {"scale": 2})

def forward(self, x, y):
q, k, v = x.split(128, dim=-2)
Expand All @@ -185,12 +193,13 @@ def forward(self, x, y):
tmp_path = tempfile.mkdtemp() if has_tf_package() else None
stablehlo_gm = self.export_func(M(), input_args, tmp_path)
stablehlo = stablehlo_gm.get_stablehlo_text()
self.assertEqual(stablehlo.count("@stablehlo.composite"), 2)
self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2)
self.assertTrue(
'{attributes = {scale = 2.500000e-01 : f32}, name = "sdpa"}}' in
stablehlo)
'{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}'
in stablehlo)
self.assertTrue(
'{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo)
'{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}'
in stablehlo)
if has_tf_package():
self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb')))

Expand All @@ -199,13 +208,13 @@ def test_inlined_composite_builder_export_sdpa_pattern(self):
class M(torch.nn.Module):

def forward(self, x, y):
b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25})
b = StableHLOCompositeBuilder("test.sdpa", {"scale": 0.25})
q, k, v = x.split(128, dim=-2)
q, k, v = b.mark_inputs(q, k, v)
attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
attn_out = b.mark_outputs(attn_out)

b2 = StableHLOCompositeBuilder("sdpa", {"scale": 2})
b2 = StableHLOCompositeBuilder("test.sdpa", {"scale": 2})
q, k, v = y.split(128, dim=-2)
q, k, v = b2.mark_inputs(q, k, v)
attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4)
Expand All @@ -216,12 +225,13 @@ def forward(self, x, y):
tmp_path = tempfile.mkdtemp() if has_tf_package() else None
stablehlo_gm = self.export_func(M(), input_args, tmp_path)
stablehlo = stablehlo_gm.get_stablehlo_text()
self.assertEqual(stablehlo.count("@stablehlo.composite"), 2)
self.assertEqual(stablehlo.count("stablehlo.composite \"test.sdpa\""), 2)
self.assertTrue(
'{attributes = {scale = 2.500000e-01 : f32}, name = "sdpa"}}' in
stablehlo)
'{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0}'
in stablehlo)
self.assertTrue(
'{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo)
'{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl}'
in stablehlo)
if has_tf_package():
self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb')))

Expand All @@ -230,7 +240,7 @@ def test_composite_builder_multiple_outputs(self):
class M(torch.nn.Module):

def forward(self, x, y):
builder = StableHLOCompositeBuilder("sample_composite")
builder = StableHLOCompositeBuilder("test.sample_composite")
x, y = builder.mark_inputs(x, y)
a = x + y
b = x - y
Expand All @@ -240,15 +250,16 @@ def forward(self, x, y):

input_args = (torch.randn((5, 5)), torch.randn((5, 5)))
stablehlo = self.run_func_get_stablehlo(M(), input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 1)
self.assertEqual(
stablehlo.count("stablehlo.composite \"test.sample_composite\""), 1)

def test_composite_builder_mix_attr_value_types(self):

class M(torch.nn.Module):

def forward(self, x, y):
builder = StableHLOCompositeBuilder(
"sample_composite", {
"test.sample_composite", {
"int_attr": 1,
"float_attr": 2.3,
"bool_attr": True,
Expand All @@ -261,7 +272,8 @@ def forward(self, x, y):

input_args = (torch.randn((5, 5)), torch.randn((5, 5)))
stablehlo = self.run_func_get_stablehlo(M(), input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 1)
self.assertEqual(
stablehlo.count("stablehlo.composite \"test.sample_composite\""), 1)
self.assertEqual(stablehlo.count('int_attr = 1 : i64'), 1)
self.assertEqual(stablehlo.count('float_attr = 2.300000e+00 : f32'), 1)
self.assertEqual(stablehlo.count('bool_attr = true'), 1)
Expand All @@ -270,44 +282,45 @@ def forward(self, x, y):
def test_multiple_inputs(self):

def f(x, y):
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla.mark_tensor(y, "p", 1, "0", True)
x = torch.ops.xla.mark_tensor(x, "test.p", 0, "0", True)
y = torch.ops.xla.mark_tensor(y, "test.p", 1, "0", True)
out = x + y
out = out * x * y
out = torch.ops.xla.mark_tensor(out, "p", 0, "0", False)
out = torch.ops.xla.mark_tensor(out, "test.p", 0, "0", False)
return out

input_args = (torch.ones(5), torch.ones(5))
stablehlo = self.run_func_get_stablehlo(f, input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 1)
self.assertTrue('{attributes = {}, name = "p"}' in stablehlo)
self.assertEqual(stablehlo.count("stablehlo.composite \"test.p\""), 1)
self.assertEqual(stablehlo.count('{decomposition = @test.p.impl}'), 1)

def test_multiple_outputs(self):

def f(x, y):
x = torch.ops.xla.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla.mark_tensor(y, "p", 1, "0", True)
x = torch.ops.xla.mark_tensor(x, "test.p", 0, "0", True)
y = torch.ops.xla.mark_tensor(y, "test.p", 1, "0", True)
out1 = x + y
out2 = x * y
out1 = torch.ops.xla.mark_tensor(out1, "p", 0, "0", False)
out2 = torch.ops.xla.mark_tensor(out2, "p", 1, "0", False)
out1 = torch.ops.xla.mark_tensor(out1, "test.p", 0, "0", False)
out2 = torch.ops.xla.mark_tensor(out2, "test.p", 1, "0", False)
return out1, out2

input_args = (torch.ones(5), torch.ones(5))
stablehlo = self.run_func_get_stablehlo(f, input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 1)
self.assertEqual(stablehlo.count("stablehlo.composite \"test.p\""), 1)
self.assertEqual(stablehlo.count('{decomposition = @test.p.impl}'), 1)

@unittest.skip("Nested pattern is not supported now.")
def test_nested_pattern(self):

def f(x):
x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "test.p_outter", 0, "0", True)
x = x + 1
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "test.p_inner", 0, "0", True)
x = x + 1
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False)
x = torch.ops.xla.mark_tensor(x, "test.p_inner", 0, "0", False)
x = x * 2
x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", False)
x = torch.ops.xla.mark_tensor(x, "test.p_outter", 0, "0", False)

input_args = (torch.ones(5),)
stablehlo = self.run_func_get_stablehlo(f, input_args)
Expand All @@ -316,13 +329,13 @@ def f(x):
def test_tangent_output(self):
# Special case of nested pattern, outputs don't have dependencies.
def f(x):
x = torch.ops.xla.mark_tensor(x, "p_outter", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "test.p_outter", 0, "0", True)
x = x + 1
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", True)
x = torch.ops.xla.mark_tensor(x, "test.p_inner", 0, "0", True)
x = x + 1
y = x - 1
x = torch.ops.xla.mark_tensor(x, "p_inner", 0, "0", False)
y = torch.ops.xla.mark_tensor(y, "p_outter", 0, "0", False)
x = torch.ops.xla.mark_tensor(x, "test.p_inner", 0, "0", False)
y = torch.ops.xla.mark_tensor(y, "test.p_outter", 0, "0", False)

input_args = (torch.ones(5),)
stablehlo = self.run_func_get_stablehlo(f, input_args)
Expand All @@ -340,7 +353,8 @@ def test_update_kv_cache(self):
exported = torch.export.export(model, (input_pos, k_val, v_val))
shlo = stablehlo.exported_program_to_stablehlo(exported)
shlo_text = shlo.get_stablehlo_text()
self.assertEqual(shlo_text.count("@stablehlo.composite"), 1)
self.assertEqual(
shlo_text.count("stablehlo.composite \"test.update_kv_cache\""), 1)


if __name__ == '__main__':
Expand Down
32 changes: 4 additions & 28 deletions torch_xla/csrc/runtime/stablehlo_composite_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,37 +437,13 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
"JSON into composite attributes.";
}

builder.setInsertionPointAfter(boundary_output_op);
llvm::SmallVector<mlir::NamedAttribute> call_attrs{
{
builder.getStringAttr("call_target_name"),
builder.getStringAttr("stablehlo.composite"),
},
{
builder.getStringAttr("called_computations"),
builder.getArrayAttr(mlir::FlatSymbolRefAttr::get(
builder.getContext(), impl_func.getSymName())),
},
{
builder.getStringAttr("composite.backend_config"),
builder.getDictionaryAttr(llvm::SmallVector<mlir::NamedAttribute>{
{
builder.getStringAttr("attributes"),
*attributes_or,
},
{
builder.getStringAttr("name"),
builder.getStringAttr(metadata.name),
},
}),
},
};

// Creates and inserts composite call op.
builder.setInsertionPointAfter(boundary_output_op);
mlir::Operation* composite_op =
builder.create<mlir::stablehlo::CustomCallOp>(
builder.create<mlir::stablehlo::CompositeOp>(
boundary_output_op->getLoc(),
impl_func.getFunctionType().getResults(), args, call_attrs);
impl_func.getFunctionType().getResults(), args, metadata.name,
*attributes_or, impl_func.getSymName());
return composite_op;
}
};
Expand Down

0 comments on commit 3eeb15d

Please sign in to comment.