diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_mark_pattern.py index 66148898e1d..e0fcce201f7 100644 --- a/test/stablehlo/test_mark_pattern.py +++ b/test/stablehlo/test_mark_pattern.py @@ -39,9 +39,9 @@ def test_basic(self): def f(x): x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, 0, True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) x = x + 2 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, 0, False) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", False) return x input_args = (torch.randn(5),) @@ -60,24 +60,29 @@ def __init__(self): def forward(self, x, y): q, k, v = x.split(128, dim=-2) q = torch.ops.xla_pattern_marking.mark_tensor( - q, "sdpa", pos=0, id=0, is_input=True) + q, "sdpa", pos=0, id="0", is_input=True) k = torch.ops.xla_pattern_marking.mark_tensor( - k, "sdpa", pos=1, id=0, is_input=True) + k, "sdpa", pos=1, id="0", is_input=True) v = torch.ops.xla_pattern_marking.mark_tensor( - v, "sdpa", pos=2, id=0, is_input=True) + v, "sdpa", pos=2, id="0", is_input=True) attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25) attn_out = torch.ops.xla_pattern_marking.mark_tensor( - attn_out, "sdpa", pos=0, id=0, is_input=False, attr={"scale": 0.25}) + attn_out, + "sdpa", + pos=0, + id="0", + is_input=False, + attr={"scale": 0.25}) q, k, v = y.split(128, dim=-2) q = torch.ops.xla_pattern_marking.mark_tensor( - q, "sdpa", pos=0, id=1, is_input=True) + q, "sdpa", pos=0, id="1", is_input=True) k = torch.ops.xla_pattern_marking.mark_tensor( - k, "sdpa", pos=1, id=1, is_input=True) + k, "sdpa", pos=1, id="1", is_input=True) v = torch.ops.xla_pattern_marking.mark_tensor( - v, "sdpa", pos=2, id=1, is_input=True) + v, "sdpa", pos=2, id="1", is_input=True) attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4) attn_out2 = torch.ops.xla_pattern_marking.mark_tensor( - attn_out2, "sdpa", pos=0, id=1, is_input=False, attr={"scale": 2}) + attn_out2, "sdpa", pos=0, id="1", is_input=False, attr={"scale": 2}) return attn_out, attn_out2 input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64))) @@ -119,15 +124,6 @@ def forward(self, x, y): self.assertTrue( '{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo) - def test_uuid_ser_des(self): - import uuid - from torch_xla.experimental.xla_marker import _get_uuid_tensor_internal, decode_uuid_tensor - id = uuid.uuid4() - id_hex = id.hex - id_tensor = _get_uuid_tensor_internal(id) - decoded = decode_uuid_tensor(id_tensor) - self.assertTrue(decoded, id_hex) - def test_composite_builder_export_sdpa_pattern(self): class M(torch.nn.Module): @@ -161,14 +157,47 @@ def forward(self, x, y): '{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo) self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) + def test_inlined_composite_builder_export_sdpa_pattern(self): + + class M(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, y): + b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25}) + q, k, v = x.split(128, dim=-2) + q, k, v = b.mark_inputs(q, k, v) + attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25) + attn_out = b.mark_outputs(attn_out) + + b2 = StableHLOCompositeBuilder("sdpa", {"scale": 2}) + q, k, v = y.split(128, dim=-2) + q, k, v = b2.mark_inputs(q, k, v) + attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4) + attn_out2 = b2.mark_outputs(attn_out2) + return attn_out, attn_out2 + + input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64))) + tmp_path = tempfile.mkdtemp() + stablehlo_gm = self.export_func(M(), input_args, tmp_path) + stablehlo = stablehlo_gm.get_stablehlo_text() + self.assertEqual(stablehlo.count("@stablehlo.composite"), 2) + self.assertTrue( + '{attributes = {scale = 2.500000e-01 : f32}, name = "sdpa"}}' in + stablehlo) + self.assertTrue( + '{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo) + self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) + def test_multiple_input(self): def f(x, y): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, 0, True) - y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, 0, True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) + y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True) out = x + y out = out * x * y - out = torch.ops.xla_pattern_marking.mark_tensor(out, "p", 0, 0, False) + out = torch.ops.xla_pattern_marking.mark_tensor(out, "p", 0, "0", False) return out input_args = (torch.ones(5), torch.ones(5)) @@ -180,12 +209,12 @@ def f(x, y): def test_multiple_output(self): def f(x, y): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, 0, True) - y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, 0, True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True) + y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True) out1 = x + y out2 = x * y - out1 = torch.ops.xla_pattern_marking.mark_tensor(out1, "p", 0, 0, False) - out2 = torch.ops.xla_pattern_marking.mark_tensor(out2, "p", 1, 0, False) + out1 = torch.ops.xla_pattern_marking.mark_tensor(out1, "p", 0, "0", False) + out2 = torch.ops.xla_pattern_marking.mark_tensor(out2, "p", 1, "0", False) return out1, out2 input_args = (torch.ones(5), torch.ones(5)) @@ -195,14 +224,14 @@ def f(x, y): def test_nested_pattern(self): def f(x): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, 0, True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True) x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, 0, True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True) x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, 0, False) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False) x = x * 2 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, 0, False) - return x + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", + False) input_args = (torch.ones(5),) stablehlo = self.run_func_get_stablehlo(f, input_args) @@ -211,14 +240,14 @@ def f(x): def test_tangent_output(self): # Special case of nested pattern, outputs don't have dependencies. def f(x): - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, 0, True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True) x = x + 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, 0, True) + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True) x = x + 1 y = x - 1 - x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, 0, False) - y = torch.ops.xla_pattern_marking.mark_tensor(y, "p_outter", 0, 0, False) - return x, y + x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False) + y = torch.ops.xla_pattern_marking.mark_tensor(y, "p_outter", 0, "0", + False) input_args = (torch.ones(5),) stablehlo = self.run_func_get_stablehlo(f, input_args) diff --git a/torch_xla/experimental/mark_pattern_utils.py b/torch_xla/experimental/mark_pattern_utils.py index 40330a5b063..4665f9b14e0 100644 --- a/torch_xla/experimental/mark_pattern_utils.py +++ b/torch_xla/experimental/mark_pattern_utils.py @@ -1,8 +1,14 @@ +import uuid from typing import Dict, Union import torch +import torch._dynamo as torchdynamo import torch_xla.experimental.xla_marker -from torch_xla.experimental.xla_marker import get_uuid_tensor + + +@torchdynamo.assume_constant_result +def _get_uuid() -> str: + return uuid.uuid4().hex class StableHLOCompositeBuilder: @@ -21,7 +27,7 @@ def __init__(self, name: str, attr: Dict[str, Union[int, float, str]] = None): self.attr = attr self.name = name - self.id = get_uuid_tensor() + self.id = _get_uuid() self._inputs = [] self._outputs = [] diff --git a/torch_xla/experimental/xla_marker.py b/torch_xla/experimental/xla_marker.py index ab600bb3e2f..2c4a6d97fa2 100644 --- a/torch_xla/experimental/xla_marker.py +++ b/torch_xla/experimental/xla_marker.py @@ -1,8 +1,7 @@ import dataclasses import json from dataclasses import dataclass -from typing import Dict, Union -import uuid +from typing import Dict import torch import torch_xla @@ -11,44 +10,15 @@ xla_pattern_marking_lib = Library("xla_pattern_marking", "DEF") xla_pattern_marking_lib.define( - "mark_tensor(Tensor x, str name, int pos, int id, bool is_input, Any? attr=None) -> Tensor" + "mark_tensor(Tensor x, str name, int pos, str id, bool is_input, Any? attr=None) -> Tensor" ) -xla_pattern_marking_lib.define( - "mark_tensor.tensor(Tensor x, str name, int pos, Tensor id, bool is_input, Any? attr=None) -> Tensor" -) - - -def _get_uuid_tensor_internal(id: uuid.UUID): - int_arr = [] - for i in range(4): - int_arr.append(int(id.int >> (128 - 32 * (i + 1)) & 0xFFFFFFFF)) - # Need to use int64 here to avoid an overflow issue in torch. - return torch.tensor(int_arr, dtype=torch.int64) - - -def get_uuid_tensor(): - id = uuid.uuid4() - return _get_uuid_tensor_internal(id) - - -def decode_uuid_tensor(x): - assert len( - x.shape - ) == 1, f"The uuid tensor is expected to be a 1D tensor. Getting shape : {x.shape}." - assert x.numel( - ) == 4, f"The uuid tensor is expected to have 4 elements. Tensor has {x.numel()} elements." - uuid_int = 0 - for i in range(4): - uuid_int += x.cpu()[i] << (32 * i) - return hex(uuid_int) - @dataclass class BoundaryMetadata: name: str # Name of the Patttern. pos: int # Arg/return position. - id: Union[int, torch.Tensor] # Patten instance id. + id: str # Patten instance id. is_input: bool = True # If the marked tensor is input/output. attr: dict = None # Attribute of the pattern, expected to be attached to output. @@ -57,14 +27,8 @@ class BoundaryMetadataSerializer(json.JSONEncoder): def default(self, obj): if dataclasses.is_dataclass(obj): - if isinstance(obj, BoundaryMetadata): - if isinstance(obj.id, torch.Tensor): - obj.id = decode_uuid_tensor(obj.id) - else: - obj.id = str(obj.id) return dataclasses.asdict(obj) - else: - return super().default(obj) + return super().default(obj) def _assert_valid_composite_attr(attr): @@ -85,7 +49,7 @@ def _assert_valid_composite_attr(attr): def mark_tensor_xla(x: torch.Tensor, name: str, pos: int, - id: int, + id: str, is_input: bool, attr: Dict = None): """Attach pattern boundary metadata to a XLA Tensor. @@ -94,7 +58,7 @@ def mark_tensor_xla(x: torch.Tensor, x: torch.Tensor (On XLA device) - the marked tensor. name: str - The name of the pattern, it will be the name of the stablehlo composite op. pos: int - Input/output Position of the annotated tensor in the pattern. - id: int - Unique identifier of the pattern instance. + id: str - Unique identifier of the pattern instance. is_input: bool - If the annotated tensor is the input to the pattern. attr: dict - Attribute of the pattern, it will be passed down to the attribute field in the stablehlo composite. @@ -109,7 +73,7 @@ def mark_tensor_xla(x: torch.Tensor, def mark_tensor(x: torch.Tensor, name: str, pos: int, - id: int, + id: str, is_input: bool, attr: Dict = None): # Do nothing for non-xla tensor. @@ -120,44 +84,7 @@ def mark_tensor(x: torch.Tensor, def mark_tensor_meta(x: torch.Tensor, name: str, pos: int, - id: int, - is_input: bool, - attr: Dict = None): - return torch.empty_like(x) - - -@impl(xla_pattern_marking_lib, "mark_tensor.tensor", "XLA") -def mark_tensor_xla(x: torch.Tensor, - name: str, - pos: int, - id: torch.Tensor, - is_input: bool, - attr: Dict = None): - """Variant: `id` is a torch.Tensor, which is generated from `get_uuid_tensor`. - """ - _assert_valid_composite_attr(attr) - pattern_info = BoundaryMetadata(name, pos, id, is_input, attr) - return torch_xla._XLAC._xla_mark_tensor( - x, json.dumps(pattern_info, cls=BoundaryMetadataSerializer)) - - -@impl(xla_pattern_marking_lib, "mark_tensor.tensor", - "CompositeExplicitAutograd") -def mark_tensor(x: torch.Tensor, - name: str, - pos: int, - id: torch.Tensor, - is_input: bool, - attr: Dict = None): - # Do nothing for non-xla tensor. - return x - - -@impl(xla_pattern_marking_lib, "mark_tensor.tensor", "Meta") -def mark_tensor_meta(x: torch.Tensor, - name: str, - pos: int, - id: torch.Tensor, + id: str, is_input: bool, attr: Dict = None): - return torch.empty_like(x) + return torch.empty_like(x) \ No newline at end of file