Skip to content

Commit

Permalink
Support building StableHLO composite with different attr value types (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
chunnienc authored and amithrm committed Mar 1, 2024
1 parent 62ad221 commit 5c50b95
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 12 deletions.
35 changes: 33 additions & 2 deletions test/stablehlo/test_mark_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch_xla.experimental.xla_marker
from torch.utils import _pytree as pytree
from torch_xla import stablehlo
from torch_xla.experimental import xla_marker
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
from utils import has_tf_package

Expand Down Expand Up @@ -108,14 +109,19 @@ def forward(self, x, y):
pos=0,
id="0",
is_input=False,
attr={"scale": 0.25})
attr=xla_marker.serialize_composite_attr({"scale": 0.25}))
q, k, v = y.split(128, dim=-2)
q = torch.ops.xla.mark_tensor(q, "sdpa", pos=0, id="1", is_input=True)
k = torch.ops.xla.mark_tensor(k, "sdpa", pos=1, id="1", is_input=True)
v = torch.ops.xla.mark_tensor(v, "sdpa", pos=2, id="1", is_input=True)
attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4)
attn_out2 = torch.ops.xla.mark_tensor(
attn_out2, "sdpa", pos=0, id="1", is_input=False, attr={"scale": 2})
attn_out2,
"sdpa",
pos=0,
id="1",
is_input=False,
attr=xla_marker.serialize_composite_attr({"scale": 2}))
return attn_out, attn_out2

input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64)))
Expand Down Expand Up @@ -236,6 +242,31 @@ def forward(self, x, y):
stablehlo = self.run_func_get_stablehlo(M(), input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 1)

def test_composite_builder_mix_attr_value_types(self):

class M(torch.nn.Module):

def forward(self, x, y):
builder = StableHLOCompositeBuilder(
"sample_composite", {
"int_attr": 1,
"float_attr": 2.3,
"bool_attr": True,
"str_attr": "helloworld",
})
x, y = builder.mark_inputs(x, y)
z = x + y
z = builder.mark_outputs(z)
return z

input_args = (torch.randn((5, 5)), torch.randn((5, 5)))
stablehlo = self.run_func_get_stablehlo(M(), input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 1)
self.assertEqual(stablehlo.count('int_attr = 1 : i64'), 1)
self.assertEqual(stablehlo.count('float_attr = 2.300000e+00 : f32'), 1)
self.assertEqual(stablehlo.count('bool_attr = true'), 1)
self.assertEqual(stablehlo.count('str_attr = "helloworld"'), 1)

def test_multiple_inputs(self):

def f(x, y):
Expand Down
7 changes: 5 additions & 2 deletions torch_xla/experimental/mark_pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
import torch._dynamo as torchdynamo
import torch_xla.experimental.xla_marker
from torch_xla.experimental import xla_marker


@torchdynamo.assume_constant_result
Expand Down Expand Up @@ -33,6 +33,9 @@ def __init__(self, name: str, attr: Dict[str, Union[int, float, str]] = None):

def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool):
marked_tensors = []
serialized_attr = xla_marker.serialize_composite_attr(
self.attr) if not is_input else None

for pos, tensor in enumerate(tensors):
if not isinstance(tensor, torch.Tensor):
raise ValueError(f"input must be a torch tensor. Got {type(tensor)}.")
Expand All @@ -43,7 +46,7 @@ def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool):
pos=pos,
id=self.id,
is_input=is_input,
attr=self.attr if not is_input else None,
attr=serialized_attr,
))

if len(marked_tensors) == 1:
Expand Down
34 changes: 26 additions & 8 deletions torch_xla/experimental/xla_marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict

import torch
import torch._dynamo as torchdynamo
from torch.library import impl
import torch_xla
from torch_xla.core.xla_model import XLA_LIB
Expand Down Expand Up @@ -39,9 +40,25 @@ def _assert_valid_composite_attr(attr):
for k, v in attr.items():
if not isinstance(k, str):
raise ValueError("Composite attr name must be a Python str.")
if type(k) not in [str, float, int]:
if type(k) not in (str, float, int, bool):
raise ValueError(
"Composite attr value must be either Python str, float, or int.")
"Composite attr value must be either Python str, float, int, or bool."
)


@torchdynamo.assume_constant_result
def serialize_composite_attr(attr: Dict):
if attr is None:
return None
_assert_valid_composite_attr(attr)
return tuple(attr.items())


@torchdynamo.assume_constant_result
def deserialize_composite_attr(attr) -> Dict:
if attr is None:
return None
return dict(attr)


@impl(XLA_LIB, "mark_tensor", "XLA")
Expand All @@ -50,7 +67,7 @@ def mark_tensor_xla(x: torch.Tensor,
pos: int,
id: str,
is_input: bool,
attr: Dict = None):
attr=None):
"""Attach pattern boundary metadata to a XLA Tensor.
Args:
Expand All @@ -59,9 +76,10 @@ def mark_tensor_xla(x: torch.Tensor,
pos: int - Input/output Position of the annotated tensor in the pattern.
id: str - Unique identifier of the pattern instance.
is_input: bool - If the annotated tensor is the input to the pattern.
attr: dict - Attribute of the pattern, it will be passed down to the attribute field
in the stablehlo composite.
attr - Attribute of the pattern. It must be a value generated by serialize_composite_attr
and will be passed down to the attribute field in the stablehlo composite.
"""
attr = deserialize_composite_attr(attr)
_assert_valid_composite_attr(attr)
pattern_info = BoundaryMetadata(name, pos, id, is_input, attr)
return torch_xla._XLAC._xla_mark_tensor(
Expand All @@ -74,7 +92,7 @@ def mark_tensor(x: torch.Tensor,
pos: int,
id: str,
is_input: bool,
attr: Dict = None):
attr=None):
# Do nothing for non-xla tensor.
return x

Expand All @@ -85,5 +103,5 @@ def mark_tensor_meta(x: torch.Tensor,
pos: int,
id: str,
is_input: bool,
attr: Dict = None):
return torch.empty_like(x)
attr=None):
return torch.empty_like(x)

0 comments on commit 5c50b95

Please sign in to comment.