diff --git a/examples/BuddyLlama/README.md b/examples/BuddyLlama/README.md index 4416ef3a60..416c92707f 100644 --- a/examples/BuddyLlama/README.md +++ b/examples/BuddyLlama/README.md @@ -1,6 +1,6 @@ # Buddy Compiler LLaMA Example -1. Download LLaMA2 model +1. Download LLaMA model You should download llama model. You can get model from [meta ai](https://ai.meta.com/llama/). @@ -14,16 +14,17 @@ $ cd buddy-mlir $ pip install -r requirements.txt ``` -3. LLaMA2 model convert to HuggingFace format +3. LLaMA model convert to HuggingFace format -You should convert LLaMA2 model which download from meta ai to HuggingFace format. Because we use HuggingFace api to get LLaMA2 model. +You should convert LLaMA model which download from meta ai to HuggingFace format. Because we use HuggingFace api to get LLaMA model. ``` $ cd examples/BuddyLlama -$ python llama2-to-hf.py --input_dir path-to-llama2-model --model_size 7B --output_dir path-to-save-llama-hf-model +$ python llama-to-hf.py --input_dir path-to-llama-model --model_size 7B/8B --output_dir path-to-save-llama-hf-model [--llama_version 3] ``` Such as you have a 7B LLaMA2 model, in your input_dir path-to-llama-model, you should have a tokenizer.model and a directory named "7B". You should put your 7B LLaMA2 model inside the "7B" directory. +If you want to convert llama3 models, command should add option [--llama_version 3], and work enviroment should have torch == 2.4.0, transformers >= 4.42.3, tokenizers >= 0.19.1, sentencepiece >= 0.2.0, tiktoken, blobfile packages. In addition, set an environment variable for the generated LLaMA model. ``` diff --git a/examples/BuddyLlama/import-llama.py b/examples/BuddyLlama/import-llama.py new file mode 100644 index 0000000000..caa31574b6 --- /dev/null +++ b/examples/BuddyLlama/import-llama.py @@ -0,0 +1,80 @@ +# ===- import-llama.py -------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# This is the test of llama model. +# +# ===--------------------------------------------------------------------------- + +import os +import torch +import torch._dynamo as dynamo +from transformers import AutoModelForCausalLM, AutoTokenizer + +from torch._inductor.decomposition import decompositions as inductor_decomp +from torch._decomp import get_decompositions + +import numpy + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa +from buddy.compiler.graph import GraphDriver +from buddy.compiler.graph.transform import simply_fuse + +# Retrieve the LLaMA model path from environment variables. +model_path = os.environ.get("LLAMA_MODEL_PATH") +if model_path is None: + raise EnvironmentError( + "The environment variable 'LLAMA_MODEL_PATH' is not set or is invalid." + ) + +# Initialize the tokenizer and model from the specified model path. +tokenizer = AutoTokenizer.from_pretrained(model_path) +model = AutoModelForCausalLM.from_pretrained(model_path, torchscript=True) +model.config.use_cache = False + +DEFAULT_DECOMPOSITIONS = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, +] + +decomp = get_decompositions(DEFAULT_DECOMPOSITIONS) + +# Initialize Dynamo Compiler with specific configurations as an importer. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition={**inductor_decomp, **decomp}, +) + +# Import the model into MLIR module and parameters. +with torch.no_grad(): + data = torch.tensor([[1 for i in range(40)]], dtype=torch.int64) + graphs = dynamo_compiler.importer(model, data) + +assert len(graphs) == 1 +graph = graphs[0] +params = dynamo_compiler.imported_params[graph] +pattern_list = [simply_fuse] +graphs[0].fuse_ops(pattern_list) +driver = GraphDriver(graphs[0]) +driver.subgraphs[0].lower_to_top_level_ir() +path_prefix = os.path.dirname(os.path.abspath(__file__)) +with open(os.path.join(path_prefix, "subgraph0.mlir"), "w") as module_file: + print(driver.subgraphs[0]._imported_module, file=module_file) +with open(os.path.join(path_prefix, "forward.mlir"), "w") as module_file: + print(driver.construct_main_graph(True), file=module_file) +all_param = numpy.concatenate( + [param.detach().numpy().reshape([-1]) for param in params] +) +all_param.tofile(os.path.join(path_prefix, "arg0.data")) diff --git a/examples/BuddyLlama/llama-to-hf.py b/examples/BuddyLlama/llama-to-hf.py new file mode 100644 index 0000000000..5a05460c9f --- /dev/null +++ b/examples/BuddyLlama/llama-to-hf.py @@ -0,0 +1,23 @@ +# ===- llama-to-hf.py --------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# This is the convert of the llama model to huggingface format. +# +# ===--------------------------------------------------------------------------- + +from transformers.models.llama import convert_llama_weights_to_hf + +convert_llama_weights_to_hf.main() diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index 9d8c80f014..a5bb2b6794 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -124,6 +124,7 @@ def __init__( "mean.dim": MeanOp, "rsqrt.default": RsqrtOp, "mul.Tensor": MulOp, + "mul.Scalar": MulOp, "t.default": TOp, "mm.default": MatmulOp, "transpose.int": TransposeOp, @@ -164,9 +165,11 @@ def __init__( "cos.default": CosOp, "sin.default": SinOp, "argmax.default": ArgMaxOp, - "split.Tensor":SplitOp, - "max.default":MaxOp, - "gt.Scalar":GtOp, + "split.Tensor": SplitOp, + "max.default": MaxOp, + "gt.Scalar": GtOp, + "ge.Scalar": GeOp, + "gt.Tensor": GreaterThanOp, } @property @@ -230,7 +233,9 @@ def _create_node( buddy_node.add_argument(str(input_arg)) buddy_node.add_parent(str(input_arg)) elif isinstance(input_arg, torch.dtype): - buddy_node.add_argument(self._torch_dtype_translate(str(input_arg))) + buddy_node.add_argument( + self._torch_dtype_translate(str(input_arg)) + ) else: buddy_node.add_argument(input_arg) for user in node_users: @@ -305,10 +310,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): elif gm_node.op == "output": buddy_node = self._create_node( - gm_node.op, - gm_node.name, - gm_node.args, - node_users + gm_node.op, gm_node.name, gm_node.args, node_users ) elif gm_node.target is operator.getitem: @@ -438,7 +440,7 @@ def get_lib_extension(): def cast_c_ptr(outdata_ptr, memref_ptr): """ - Casts a C pointer (`outdata_ptr`) to the type of another C pointer + Casts a C pointer (`outdata_ptr`) to the type of another C pointer (`memref_ptr`). Args: @@ -449,14 +451,14 @@ def cast_c_ptr(outdata_ptr, memref_ptr): Returns: ctypes.POINTER - A new C pointer with the type of `memref_ptr`, representing the + A new C pointer with the type of `memref_ptr`, representing the same memory location as `outdata_ptr`. Example: outdata = ctypes.pointer(ctypes.c_int()) memref = ctypes.pointer(ctypes.c_float()) casted_ptr = cast_c_ptr(outdata, memref) - # Now `casted_ptr` points to the same memory location as `outdata`, + # Now `casted_ptr` points to the same memory location as `outdata`, but with the type of `memref`. """ outdata_addr = ctypes.addressof(outdata_ptr.contents) @@ -465,15 +467,15 @@ def cast_c_ptr(outdata_ptr, memref_ptr): def move_c_ptr(outdata_ptr, memref_ptr): """ - Moves a C pointer (`outdata_ptr`) to the next element in memory, - based on the size of the referenced type in another C pointer + Moves a C pointer (`outdata_ptr`) to the next element in memory, + based on the size of the referenced type in another C pointer (`memref_ptr`). Args: outdata_ptr: ctypes.POINTER The C pointer whose position needs to be moved. memref_ptr: ctypes.POINTER - The reference C pointer whose type determines the size of each + The reference C pointer whose type determines the size of each element for the move. Returns: @@ -496,7 +498,7 @@ def exec_buddy_graph(*args): Returns: List[torch.Tensor] - The result of executing the graph, represented as a list of + The result of executing the graph, represented as a list of output tensors. """ # A list of ctypes pointers representing memory references for input @@ -509,13 +511,13 @@ def exec_buddy_graph(*args): ) for tensor in args ] - # A list of ctypes pointers representing memory references for + # A list of ctypes pointers representing memory references for # output tensors. output_memref = [ ctypes.pointer(ctypes.pointer(graph._output_descriptor())) ] args_memref = output_memref + input_memref - # Invoke the graph's function using the provided execution engine + # Invoke the graph's function using the provided execution engine # and memory references ee.invoke(graph._func_name, *args_memref) @@ -532,7 +534,7 @@ def exec_buddy_graph(*args): # Move to the next element in memory based on the size of the # current output type outdata_ptr = move_c_ptr(outdata_ptr, output_ptr[0]) - # Convert each NumPy array to a PyTorch tensor and return the list + # Convert each NumPy array to a PyTorch tensor and return the list # of tensors return [torch.from_numpy(tensor) for tensor in output_tensor] diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py index 0eb31fd961..02b8b9311b 100644 --- a/frontend/Python/graph/operation.py +++ b/frontend/Python/graph/operation.py @@ -534,3 +534,15 @@ class GtOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.ElementwiseType + + +class GeOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ElementwiseType + + +class GreaterThanOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.BroadcastType diff --git a/frontend/Python/ops/func.py b/frontend/Python/ops/func.py index a7dcc5e11b..6b35c03410 100644 --- a/frontend/Python/ops/func.py +++ b/frontend/Python/ops/func.py @@ -106,7 +106,7 @@ def param_extract( TensorDType.Int64: ir.IntegerType.get_signless(64), } memref_element_type = dtype_mapping[node.tensor_meta["dtype"]] - if(len(node.tensor_meta['shape'])== 0): + if len(node.tensor_meta["shape"]) == 0: output_shape = [1] else: output_shape = list(node.tensor_meta["shape"]) diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index b561b3433a..137ae8c033 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -1073,6 +1073,26 @@ def mul_op( element = mlir_element_attr_get(dtype, node.args[1]) attr = ir.DenseElementsAttr.get_splat(tensor_type, element) input2 = arith.ConstantOp(tensor_type, attr).result + + input1_type = ir.RankedTensorType(input1.type) + input2_type = ir.RankedTensorType(input2.type) + if input1_type != mlir_dtype: + input1 = tosa.CastOp( + ir.RankedTensorType.get( + ir.RankedTensorType(input1.type).shape, + mlir_dtype, + ), + input1, + ) + if input2_type != mlir_dtype: + input2 = tosa.CastOp( + ir.RankedTensorType.get( + ir.RankedTensorType(input2.type).shape, + mlir_dtype, + ), + input2, + ) + if input1 is None or input2 is None: return mul_result_tensor_type = ir.RankedTensorType.get(shape, mlir_dtype) @@ -1782,6 +1802,9 @@ def where_op( if input1 is None or input2 is None or input3 is None: return + if isinstance(input2.type, ir.RankedTensorType): + input2, input3 = input3, input2 + output_shape = list(node.tensor_meta["shape"]) dtype = node.tensor_meta["dtype"] mlir_dtype = mlir_element_type_get(dtype) @@ -1818,6 +1841,7 @@ def where_op( * len(output_shape) ), ) + block = ir.Block.create_at_start( op.region, [ @@ -1965,6 +1989,129 @@ def gt_op(node: GtOp, symbol_table): return cmp_op +def ge_op( + node: GeOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + """ + Import the tensor greater equal operation. + From buddy GreaterEqualOp to MLIR arith `constant` operation. + + Note: This op, campare two input nodes, and output bool tensor to represent + compare result. + Args: + node: Containing information from the input graph node. + symbol_table: A dictionary mapping symbols to their corresponding + operations. + + Returns: + op: The operation return the linalg.generic op. + """ + input_tensor = symbol_table.get((str(node.args[0]), 0), node.args[0]) + input_dtype = ir.RankedTensorType(input_tensor.type).element_type + input_shape = ir.RankedTensorType(input_tensor.type).shape + tensor_type = ir.RankedTensorType.get(input_shape, input_dtype) + + scalar = arith.ConstantOp(input_dtype, node.args[1]) + rhs = tensor.SplatOp(tensor_type, scalar) + + if str(input_dtype).find("i") != -1: + cmp_op = arith.CmpIOp(5, input_tensor, rhs) + else: + cmp_op = arith.CmpFOp(3, input_tensor, rhs) + + return cmp_op + +def greater_than_op( + node: GreaterThanOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + """ + Import the tensor greater than operation. + From buddy GreaterThanOp to MLIR arith `constant` operation. + Note: This op, campare two input nodes, and output bool tensor to represent + compare result. + Args: + node: Containing information from the input graph node. + symbol_table: A dictionary mapping symbols to their corresponding + operations. + Returns: + op: The operation return the linalg.generic op. + """ + input1 = symbol_table.get((str(node.args[0]), 0)) + input2 = symbol_table.get((str(node.args[1]), 0)) + output_shape = list(node.tensor_meta["shape"]) + dtype = node.tensor_meta["dtype"] + value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 4) + shp1 = list(ir.RankedTensorType(ir.Value(input1).type).shape) + shp2 = list(ir.RankedTensorType(ir.Value(input2).type).shape) + dtype = mlir_element_type_get(dtype) + tensor_type = ir.RankedTensorType.get(output_shape, dtype) + output = tensor.EmptyOp(output_shape, dtype) + if len(shp1) < len(shp2): + if int(shp1[-1]) > 1 and shp2[-1] == 1: + generic_map = ir.AffineMap.get_permutation( + [i for i in range(len(shp2) + 1)] + ) + op = linalg.GenericOp( + [tensor_type], + [input1, input2], + [output], + ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + i + for i in range( + len(shp2) - len(shp1), len(shp2) + ) + ] + ) + ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(0, len(shp2) - 1)] + + [len(shp2)] + ) + ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(0, len(shp2))] + ) + ), + ] + ), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * len(shp2) + + [ir.Attribute.parse("#linalg.iterator_type")] + ), + ) + block = ir.Block.create_at_start( + op.region, + [ + ir.RankedTensorType(input2.type).element_type, + ir.RankedTensorType(input2.type).element_type, + dtype, + ], + ) + if ( + str(ir.RankedTensorType(input2.type).element_type).find("i") + != -1 + ): + cmpop = arith.CmpIOp( + value, block.arguments[0], block.arguments[1] + ) + else: + cmpop = arith.CmpFOp( + value, block.arguments[0], block.arguments[1] + ) + block.append(cmpop) + block.append(linalg.YieldOp([cmpop.result])) + + return op + ops_registry = { "MatmulOp": matmul_op, @@ -2001,4 +2148,6 @@ def gt_op(node: GtOp, symbol_table): "SplitOp": split_op, "MaxOp": max_op, "GtOp": gt_op, + "GeOp": ge_op, + "GreaterThanOp": greater_than_op, } diff --git a/frontend/Python/ops/math.py b/frontend/Python/ops/math.py index 0e2f8631ef..6ce2e868d5 100644 --- a/frontend/Python/ops/math.py +++ b/frontend/Python/ops/math.py @@ -26,6 +26,7 @@ def erf_op(node, symbol_table): op = math.ErfOp(input_tensor) return op + def sqrt_op(node, symbol_table): input_tensor = symbol_table.get((str(node.args[0]), 0)) op = math.SqrtOp(input_tensor) @@ -42,9 +43,19 @@ def sin_op(node, symbol_table): return op +def cos_op(node, symbol_table): + input_tensor = symbol_table.get((str(node.args[0]), 0)) + return math.CosOp(input_tensor) + + +def sin_op(node, symbol_table): + input_tensor = symbol_table.get((str(node.args[0]), 0)) + return math.SinOp(input_tensor) + + ops_registry = { "ErfOp": erf_op, "SqrtOp": sqrt_op, "CosOp": cos_op, - "SinOp": sin_op + "SinOp": sin_op, } diff --git a/frontend/Python/ops/tosa.py b/frontend/Python/ops/tosa.py index 797fdfd6d2..a0a22477a5 100644 --- a/frontend/Python/ops/tosa.py +++ b/frontend/Python/ops/tosa.py @@ -800,6 +800,7 @@ def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: the result. """ to_expand_tensor = symbol_table.get((str(node.args[0]), 0)) + original_size = to_expand_tensor.type.shape new_size = node.args[1] result_element_type = ir.RankedTensorType( to_expand_tensor.type @@ -813,8 +814,14 @@ def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: element = ir.FloatAttr.get(result_element_type, 0.0) else: raise NotImplementedError("Unsupported element type!") + expanded_size = [] + for dim, size in zip(original_size, new_size): + if size == -1: + expanded_size.append(dim) + else: + expanded_size.append(size) new_size_tensor_type = ir.RankedTensorType.get( - new_size, result_element_type + expanded_size, result_element_type ) new_size_attr = ir.DenseElementsAttr.get_splat( new_size_tensor_type, element @@ -1481,7 +1488,7 @@ def argmax_op(node: ArgMaxOp, symbol_table): ops_registry = { "AddOp": add_op, - "MulOp": mul_op, + # "MulOp": mul_op, "SubOp": sub_op, "SumDimOp": sum_op, "TanhOp": tanh_op, diff --git a/frontend/Python/ops/utils.py b/frontend/Python/ops/utils.py index 337f5a6b49..dad07bd68c 100644 --- a/frontend/Python/ops/utils.py +++ b/frontend/Python/ops/utils.py @@ -53,4 +53,3 @@ def mlir_element_attr_get(type_name, value): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) case TensorDType.Bool: return ir.IntegerAttr.get(ir.IntegerType.get_signless(1), value) - diff --git a/tests/Python/test_cos.py b/tests/Python/test_cos.py new file mode 100644 index 0000000000..78efe70e31 --- /dev/null +++ b/tests/Python/test_cos.py @@ -0,0 +1,35 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import math + + +def foo(x): + return torch.ops.aten.cos(x) + + +x = torch.randn(10, 3, 6) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=math.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = torch.compile(foo, backend=dynamo_compiler) +assert torch.allclose(foo_mlir(x), foo(x), equal_nan=True) + +graphs = dynamo_compiler.importer(foo, x) +graph = graphs[0] +graph.lower_to_top_level_ir() +print(graph._imported_module) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = math.cos +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } diff --git a/tests/Python/test_embedding.py b/tests/Python/test_embedding.py index 484bb617b5..c3ae33672e 100644 --- a/tests/Python/test_embedding.py +++ b/tests/Python/test_embedding.py @@ -70,4 +70,4 @@ def foo(weight, indices): # CHECK: %{{.*}} = tosa.reshape # CHECK: return %{{.*}} # CHECK: } -# CHECK: } \ No newline at end of file +# CHECK: } diff --git a/tests/Python/test_expand.py b/tests/Python/test_expand.py new file mode 100644 index 0000000000..713bea84fb --- /dev/null +++ b/tests/Python/test_expand.py @@ -0,0 +1,37 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp +from torch._functorch.aot_autograd import aot_autograd_decompositions + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, y): + return torch.ops.aten.expand(x, y) + + +in1 = torch.tensor([[1], [2], [3]], dtype=torch.float32) +in2 = [-1, 4] + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=aot_autograd_decompositions, +) + +graphs = dynamo_compiler.importer(foo, in1, in2) +assert len(graphs) == 1 +graph = graphs[0] +graph.lower_to_top_level_ir() +print(graph._imported_module) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = "tosa.const" +# CHECK: %{{.*}} = tosa.add +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } diff --git a/tests/Python/test_ge.py b/tests/Python/test_ge.py new file mode 100644 index 0000000000..95230324c3 --- /dev/null +++ b/tests/Python/test_ge.py @@ -0,0 +1,36 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import linalg + + +def foo(x, y): + return torch.ops.aten.ge(x, y) + + +in1 = torch.ones([13, 5], dtype=torch.int64) +in2 = 0 +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=linalg.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +graphs = dynamo_compiler.importer(foo, in1, in2) +assert len(graphs) == 1 +graph = graphs[0] +graph.lower_to_top_level_ir() +print(graph._imported_module) + +# CHECK: "builtin.module"() ({ +# CHECK-LABEL: "func.func"() <{function_type = ({{.*}} -> {{.*}}, sym_name = "forward"} +# CHECK: %{{.*}} = "arith.constant" +# CHECK: %{{.*}} = "tensor.empty" +# CHECK: %{{.*}} = "linalg.generic" +# CHECK: "func.return"(%{{.*}}) : {{.*}} -> () +# CHECK: }) : () -> () +# CHECK: }) : () -> () diff --git a/tests/Python/test_gt.py b/tests/Python/test_gt.py new file mode 100644 index 0000000000..48d677b9a2 --- /dev/null +++ b/tests/Python/test_gt.py @@ -0,0 +1,35 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import linalg + + +def foo(x, y): + return torch.ops.aten.gt(x, y) + + +in1 = torch.ones([13], dtype=torch.int64) +in2 = torch.ones([13, 1], dtype=torch.int64) +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=linalg.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +graphs = dynamo_compiler.importer(foo, in1, in2) +assert len(graphs) == 1 +graph = graphs[0] +graph.lower_to_top_level_ir() +print(graph._imported_module) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = tensor.empty +# CHECK: %{{.*}} = linalg.generic +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } diff --git a/tests/Python/test_sin.py b/tests/Python/test_sin.py new file mode 100644 index 0000000000..d6cb52eedd --- /dev/null +++ b/tests/Python/test_sin.py @@ -0,0 +1,36 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import math + + +def foo(x): + return torch.ops.aten.sin(x) + + +x = torch.randn(10, 3, 6) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=math.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = torch.compile(foo, backend=dynamo_compiler) +assert torch.allclose(foo_mlir(x), foo(x), equal_nan=True) + +graphs = dynamo_compiler.importer(foo, x) + +graph = graphs[0] +graph.lower_to_top_level_ir() +print(graph._imported_module) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = math.sin +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } diff --git a/tests/Python/test_where.py b/tests/Python/test_where.py index 5266f00b74..7c8013e028 100644 --- a/tests/Python/test_where.py +++ b/tests/Python/test_where.py @@ -14,8 +14,10 @@ def foo(x, y, z): in1 = torch.ones([13, 13], dtype=torch.bool) -in2 = 0 -in3 = torch.ones([13, 13], dtype=torch.float32) +in2 = 0 # or in2 is tensor value, i.e. in2 = torch.zeros([13, 13], dtype=torch.float32) +in3 = torch.ones( + [13, 13], dtype=torch.float32 +) # or in3 is scalar value, i.e. in3 = 1 # Initialize the dynamo compiler. dynamo_compiler = DynamoCompiler( primary_registry=linalg.ops_registry,