From aebd9ac334c138f53e360cc1fb9e2b32335495da Mon Sep 17 00:00:00 2001 From: effrey-liu <2318266514@qq.com> Date: Fri, 12 Jul 2024 23:05:02 +0800 Subject: [PATCH 1/2] support for llama3 --- examples/BuddyLlama/README.md | 8 +- examples/BuddyLlama/import-llama.py | 82 +++++++++++++ examples/BuddyLlama/llama-to-hf.py | 23 ++++ frontend/Python/frontend.py | 34 +++--- frontend/Python/graph/operation.py | 37 ++++++ frontend/Python/ops/linalg.py | 181 +++++++++++++++++++++++++++- frontend/Python/ops/math.py | 13 ++ frontend/Python/ops/tosa.py | 9 +- requirements.txt | 6 +- tests/Python/test_cos.py | 35 ++++++ tests/Python/test_expand.py | 38 ++++++ tests/Python/test_ge.py | 35 ++++++ tests/Python/test_gt.py | 35 ++++++ tests/Python/test_sin.py | 36 ++++++ tests/Python/test_where.py | 6 +- 15 files changed, 552 insertions(+), 26 deletions(-) create mode 100644 examples/BuddyLlama/import-llama.py create mode 100644 examples/BuddyLlama/llama-to-hf.py create mode 100644 tests/Python/test_cos.py create mode 100644 tests/Python/test_expand.py create mode 100644 tests/Python/test_ge.py create mode 100644 tests/Python/test_gt.py create mode 100644 tests/Python/test_sin.py diff --git a/examples/BuddyLlama/README.md b/examples/BuddyLlama/README.md index 4416ef3a6..747b85bd3 100644 --- a/examples/BuddyLlama/README.md +++ b/examples/BuddyLlama/README.md @@ -1,6 +1,6 @@ # Buddy Compiler LLaMA Example -1. Download LLaMA2 model +1. Download LLaMA model You should download llama model. You can get model from [meta ai](https://ai.meta.com/llama/). @@ -14,13 +14,13 @@ $ cd buddy-mlir $ pip install -r requirements.txt ``` -3. LLaMA2 model convert to HuggingFace format +3. LLaMA model convert to HuggingFace format -You should convert LLaMA2 model which download from meta ai to HuggingFace format. Because we use HuggingFace api to get LLaMA2 model. +You should convert LLaMA model which download from meta ai to HuggingFace format. Because we use HuggingFace api to get LLaMA model. ``` $ cd examples/BuddyLlama -$ python llama2-to-hf.py --input_dir path-to-llama2-model --model_size 7B --output_dir path-to-save-llama-hf-model +$ python llama-to-hf.py --input_dir path-to-llama-model --model_size 7B/8B --output_dir path-to-save-llama-hf-model ``` Such as you have a 7B LLaMA2 model, in your input_dir path-to-llama-model, you should have a tokenizer.model and a directory named "7B". You should put your 7B LLaMA2 model inside the "7B" directory. diff --git a/examples/BuddyLlama/import-llama.py b/examples/BuddyLlama/import-llama.py new file mode 100644 index 000000000..41c93f904 --- /dev/null +++ b/examples/BuddyLlama/import-llama.py @@ -0,0 +1,82 @@ +###### import-llama.py +import os +import torch +import torch._dynamo as dynamo +from transformers import AutoModelForCausalLM, AutoTokenizer + +from torch._inductor.decomposition import decompositions as inductor_decomp +from torch._decomp import get_decompositions + +import numpy + +from buddy.compiler.frontend import DynamoCompiler + +# ===- import-llama.py -------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# This is the test of llama model. +# +# ===--------------------------------------------------------------------------- +from buddy.compiler.ops import tosa +from buddy.compiler.graph import GraphDriver +from buddy.compiler.graph.transform import simply_fuse + +# Retrieve the LLaMA model path from environment variables. +# model_path = os.environ.get("LLAMA_MODEL_PATH") +# if model_path is None: +# raise EnvironmentError( +# "The environment variable 'LLAMA_MODEL_PATH' is not set or is invalid." +# ) +model_path = "../../../download/llama3_model/llama3_8B_save/" + +# Initialize the tokenizer and model from the specified model path. +tokenizer = AutoTokenizer.from_pretrained(model_path) +model = AutoModelForCausalLM.from_pretrained(model_path, torchscript=True) +model.config.use_cache = False + +DEFAULT_DECOMPOSITIONS = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, +] + +decomp = get_decompositions(DEFAULT_DECOMPOSITIONS) + +# Initialize Dynamo Compiler with specific configurations as an importer. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition={**inductor_decomp, **decomp}, +) + +# Import the model into MLIR module and parameters. +with torch.no_grad(): + data = torch.tensor([[1 for i in range(40)]], dtype=torch.int64) + graphs = dynamo_compiler.importer(model, data) + +assert len(graphs) == 1 +graph = graphs[0] +params = dynamo_compiler.imported_params[graph] +pattern_list = [simply_fuse] +graphs[0].fuse_ops(pattern_list) +driver = GraphDriver(graphs[0]) +driver.subgraphs[0].lower_to_top_level_ir() +path_prefix = os.path.dirname(os.path.abspath(__file__)) +with open(os.path.join(path_prefix, "subgraph0.mlir"), "w") as module_file: + print(driver.subgraphs[0]._imported_module, file=module_file) +with open(os.path.join(path_prefix, "forward.mlir"), "w") as module_file: + print(driver.construct_main_graph(True), file=module_file) +all_param = numpy.concatenate( + [param.detach().numpy().reshape([-1]) for param in params] +) +all_param.tofile(os.path.join(path_prefix, "arg0.data")) diff --git a/examples/BuddyLlama/llama-to-hf.py b/examples/BuddyLlama/llama-to-hf.py new file mode 100644 index 000000000..5a05460c9 --- /dev/null +++ b/examples/BuddyLlama/llama-to-hf.py @@ -0,0 +1,23 @@ +# ===- llama-to-hf.py --------------------------------------------------------- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===--------------------------------------------------------------------------- +# +# This is the convert of the llama model to huggingface format. +# +# ===--------------------------------------------------------------------------- + +from transformers.models.llama import convert_llama_weights_to_hf + +convert_llama_weights_to_hf.main() diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index f30eb2a28..8b1b954fd 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -124,6 +124,7 @@ def __init__( "mean.dim": MeanOp, "rsqrt.default": RsqrtOp, "mul.Tensor": MulOp, + "mul.Scalar": MulOp, "t.default": TOp, "mm.default": MatmulOp, "transpose.int": TransposeOp, @@ -160,6 +161,10 @@ def __init__( "reciprocal.default": ReciprocalOp, "clamp_min.default": ClampMinOp, "clamp_max.default": ClampMaxOp, + "ge.Scalar": GreaterEqualOp, + "gt.Tensor": GreaterThanOp, + "cos.default": CosOp, + "sin.default": SinOp, } @property @@ -223,7 +228,9 @@ def _create_node( buddy_node.add_argument(str(input_arg)) buddy_node.add_parent(str(input_arg)) elif isinstance(input_arg, torch.dtype): - buddy_node.add_argument(self._torch_dtype_translate(str(input_arg))) + buddy_node.add_argument( + self._torch_dtype_translate(str(input_arg)) + ) else: buddy_node.add_argument(input_arg) for user in node_users: @@ -297,10 +304,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): elif gm_node.op == "output": buddy_node = self._create_node( - gm_node.op, - gm_node.name, - gm_node.args, - node_users + gm_node.op, gm_node.name, gm_node.args, node_users ) elif gm_node.target is operator.getitem: @@ -429,7 +433,7 @@ def get_lib_extension(): def cast_c_ptr(outdata_ptr, memref_ptr): """ - Casts a C pointer (`outdata_ptr`) to the type of another C pointer + Casts a C pointer (`outdata_ptr`) to the type of another C pointer (`memref_ptr`). Args: @@ -440,14 +444,14 @@ def cast_c_ptr(outdata_ptr, memref_ptr): Returns: ctypes.POINTER - A new C pointer with the type of `memref_ptr`, representing the + A new C pointer with the type of `memref_ptr`, representing the same memory location as `outdata_ptr`. Example: outdata = ctypes.pointer(ctypes.c_int()) memref = ctypes.pointer(ctypes.c_float()) casted_ptr = cast_c_ptr(outdata, memref) - # Now `casted_ptr` points to the same memory location as `outdata`, + # Now `casted_ptr` points to the same memory location as `outdata`, but with the type of `memref`. """ outdata_addr = ctypes.addressof(outdata_ptr.contents) @@ -456,15 +460,15 @@ def cast_c_ptr(outdata_ptr, memref_ptr): def move_c_ptr(outdata_ptr, memref_ptr): """ - Moves a C pointer (`outdata_ptr`) to the next element in memory, - based on the size of the referenced type in another C pointer + Moves a C pointer (`outdata_ptr`) to the next element in memory, + based on the size of the referenced type in another C pointer (`memref_ptr`). Args: outdata_ptr: ctypes.POINTER The C pointer whose position needs to be moved. memref_ptr: ctypes.POINTER - The reference C pointer whose type determines the size of each + The reference C pointer whose type determines the size of each element for the move. Returns: @@ -487,7 +491,7 @@ def exec_buddy_graph(*args): Returns: List[torch.Tensor] - The result of executing the graph, represented as a list of + The result of executing the graph, represented as a list of output tensors. """ # A list of ctypes pointers representing memory references for input @@ -500,13 +504,13 @@ def exec_buddy_graph(*args): ) for tensor in args ] - # A list of ctypes pointers representing memory references for + # A list of ctypes pointers representing memory references for # output tensors. output_memref = [ ctypes.pointer(ctypes.pointer(graph._output_descriptor())) ] args_memref = output_memref + input_memref - # Invoke the graph's function using the provided execution engine + # Invoke the graph's function using the provided execution engine # and memory references ee.invoke(graph._func_name, *args_memref) @@ -523,7 +527,7 @@ def exec_buddy_graph(*args): # Move to the next element in memory based on the size of the # current output type outdata_ptr = move_c_ptr(outdata_ptr, output_ptr[0]) - # Convert each NumPy array to a PyTorch tensor and return the list + # Convert each NumPy array to a PyTorch tensor and return the list # of tensors return [torch.from_numpy(tensor) for tensor in output_tensor] diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py index 14bfbf275..a9cd18520 100644 --- a/frontend/Python/graph/operation.py +++ b/frontend/Python/graph/operation.py @@ -128,6 +128,7 @@ def kwargs(self): @property def name(self): return self._name + @name.setter def name(self, new_name): self._name = new_name @@ -404,37 +405,44 @@ def __init__(self) -> None: super().__init__() self._op_type = OpType.ElementwiseType + class Conv2dOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.ReduceType self._layout = "NCHW_FCHW" + class ReluOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.ElementwiseType + class SigmoidOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.ElementwiseType + class IotaOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.PlaceholderType + class ScalarTensorOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.PlaceholderType + class WhereOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.ElementwiseType + class MaxPool2dWithIndicesOp(Op): def __init__(self) -> None: super().__init__() @@ -448,17 +456,20 @@ def __init__(self) -> None: self._op_type = OpType.ReduceType self._layout = "NCHW" + class CallOp(Op): def __init__(self) -> None: super().__init__() self.call_func_name = "" self._op_type = OpType.Unfusable + class FuncOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.Unfusable + class ReciprocalOp(Op): def __init__(self) -> None: super().__init__() @@ -470,12 +481,38 @@ def __init__(self) -> None: super().__init__() self._op_type = OpType.ElementwiseType + class ClampMinOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.ElementwiseType + class ClampMaxOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.ElementwiseType + + +class GreaterEqualOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.BroadcastType + + +class GreaterThanOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.BroadcastType + + +class CosOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ElementwiseType + + +class SinOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ElementwiseType diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index 38e1e68ba..3cdae6c0a 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -268,6 +268,7 @@ def ones_op( return op + def full_op( node: FullOp, symbol_table: Dict[Tuple[str, int], ir.Operation], @@ -1756,6 +1757,7 @@ def silu_op( return op + def where_op( node: WhereOp, symbol_table: Dict[Tuple[str, int], ir.Operation], @@ -1816,11 +1818,18 @@ def where_op( * len(output_shape) ), ) + + def get_element_type(value): + if isinstance(value.type, ir.RankedTensorType): + return value.type.element_type + else: + return value.type + block = ir.Block.create_at_start( op.region, [ ir.RankedTensorType(input1.type).element_type, - ir.RankedTensorType(input3.type).element_type, + get_element_type(input3), ir.RankedTensorType(output.result.type).element_type, ], ) @@ -1830,6 +1839,7 @@ def where_op( return op + def scalar_tensor_op(node: ScalarTensorOp, symbol_table): """ Import the tensor Scalar_Tensor operation. @@ -1842,6 +1852,173 @@ def scalar_tensor_op(node: ScalarTensorOp, symbol_table): return op + +def ge_op( + node: GreaterEqualOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + """ + Import the tensor greater equal operation. + From buddy GreaterEqualOp to MLIR arith `constant` operation. + + Note: This op, campare two input nodes, and output bool tensor to represent + compare result. + Args: + node: Containing information from the input graph node. + symbol_table: A dictionary mapping symbols to their corresponding + operations. + + Returns: + op: The operation return the linalg.generic op. + """ + input1 = symbol_table.get((str(node.args[0]), 0)) + scalar = node.args[1] + value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 5) + output_shape = list(node.tensor_meta["shape"]) + dtype = node.tensor_meta["dtype"] + mlir_dtype = mlir_element_type_get(dtype) + if not isinstance(scalar, str): + scalar = arith.ConstantOp( + mlir_dtype, mlir_element_attr_get(dtype, scalar) + ) + generic_map = ir.AffineMap.get_permutation( + [i for i in range(len(output_shape))] + ) + tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) + output = tensor.EmptyOp(output_shape, mlir_dtype) + op = linalg.GenericOp( + [tensor_type], + [input1], + [output], + ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(len(output_shape))] + ) + ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(len(output_shape))] + ) + ), + ] + ), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * len(output_shape) + ), + ) + block = ir.Block.create_at_start( + op.region, + [ + ir.RankedTensorType(input1.type).element_type, + ir.RankedTensorType(output.result.type).element_type, + ], + ) + if str(ir.RankedTensorType(input1.type).element_type).find("i") != -1: + cmpop = arith.CmpIOp(value, block.arguments[0], block.arguments[1]) + else: + cmpop = arith.CmpFOp(value, block.arguments[0], block.arguments[1]) + block.append(cmpop) + block.append(linalg.YieldOp([cmpop.result])) + + return op + + +def gt_op( + node: GreaterThanOp, + symbol_table: Dict[Tuple[str, int], ir.Operation], +): + """ + Import the tensor greater than operation. + From buddy GreaterThanOp to MLIR arith `constant` operation. + + Note: This op, campare two input nodes, and output bool tensor to represent + compare result. + Args: + node: Containing information from the input graph node. + symbol_table: A dictionary mapping symbols to their corresponding + operations. + + Returns: + op: The operation return the linalg.generic op. + """ + input1 = symbol_table.get((str(node.args[0]), 0)) + input2 = symbol_table.get((str(node.args[1]), 0)) + output_shape = list(node.tensor_meta["shape"]) + dtype = node.tensor_meta["dtype"] + value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 4) + shp1 = list(ir.RankedTensorType(ir.Value(input1).type).shape) + shp2 = list(ir.RankedTensorType(ir.Value(input2).type).shape) + dtype = mlir_element_type_get(dtype) + tensor_type = ir.RankedTensorType.get(output_shape, dtype) + output = tensor.EmptyOp(output_shape, dtype) + if len(shp1) < len(shp2): + if int(shp1[-1]) > 1 and shp2[-1] == 1: + generic_map = ir.AffineMap.get_permutation( + [i for i in range(len(shp2) + 1)] + ) + op = linalg.GenericOp( + [tensor_type], + [input1, input2], + [output], + ir.ArrayAttr.get( + [ + ir.AffineMapAttr.get( + generic_map.get_submap( + [ + i + for i in range( + len(shp2) - len(shp1), len(shp2) + ) + ] + ) + ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(0, len(shp2) - 1)] + + [len(shp2)] + ) + ), + ir.AffineMapAttr.get( + generic_map.get_submap( + [i for i in range(0, len(shp2))] + ) + ), + ] + ), + ir.ArrayAttr.get( + [ir.Attribute.parse("#linalg.iterator_type")] + * len(shp2) + + [ir.Attribute.parse("#linalg.iterator_type")] + ), + ) + block = ir.Block.create_at_start( + op.region, + [ + ir.RankedTensorType(input2.type).element_type, + ir.RankedTensorType(input2.type).element_type, + dtype, + ], + ) + if ( + str(ir.RankedTensorType(input2.type).element_type).find("i") + != -1 + ): + cmpop = arith.CmpIOp( + value, block.arguments[0], block.arguments[1] + ) + else: + cmpop = arith.CmpFOp( + value, block.arguments[0], block.arguments[1] + ) + block.append(cmpop) + block.append(linalg.YieldOp([cmpop.result])) + + return op + + ops_registry = { "MatmulOp": matmul_op, "ArangeOp": arange_op, @@ -1874,4 +2051,6 @@ def scalar_tensor_op(node: ScalarTensorOp, symbol_table): "AddOp": add_op, "WhereOp": where_op, "ScalarTensorOp": scalar_tensor_op, + "GreaterEqualOp": ge_op, + "GreaterThanOp": gt_op, } diff --git a/frontend/Python/ops/math.py b/frontend/Python/ops/math.py index f1afc2161..cc2ab2634 100644 --- a/frontend/Python/ops/math.py +++ b/frontend/Python/ops/math.py @@ -26,13 +26,26 @@ def erf_op(node, symbol_table): op = math.ErfOp(input_tensor) return op + def sqrt_op(node, symbol_table): input_tensor = symbol_table.get((str(node.args[0]), 0)) op = math.SqrtOp(input_tensor) return op +def cos_op(node, symbol_table): + input_tensor = symbol_table.get((str(node.args[0]), 0)) + return math.CosOp(input_tensor) + + +def sin_op(node, symbol_table): + input_tensor = symbol_table.get((str(node.args[0]), 0)) + return math.SinOp(input_tensor) + + ops_registry = { "ErfOp": erf_op, "SqrtOp": sqrt_op, + "CosOp": cos_op, + "SinOp": sin_op, } diff --git a/frontend/Python/ops/tosa.py b/frontend/Python/ops/tosa.py index 5de51ca56..9a3c3db4f 100644 --- a/frontend/Python/ops/tosa.py +++ b/frontend/Python/ops/tosa.py @@ -797,6 +797,7 @@ def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: the result. """ to_expand_tensor = symbol_table.get((str(node.args[0]), 0)) + original_size = to_expand_tensor.type.shape new_size = node.args[1] result_element_type = ir.RankedTensorType( to_expand_tensor.type @@ -807,8 +808,14 @@ def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: element = ir.FloatAttr.get(result_element_type, 0.0) else: raise NotImplementedError("Unsupported element type!") + expanded_size = [] + for dim, size in zip(original_size, new_size): + if size == -1: + expanded_size.append(dim) + else: + expanded_size.append(size) new_size_tensor_type = ir.RankedTensorType.get( - new_size, result_element_type + expanded_size, result_element_type ) new_size_attr = ir.DenseElementsAttr.get_splat( new_size_tensor_type, element diff --git a/requirements.txt b/requirements.txt index 9818b8ec7..c644918ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,9 @@ --pre --extra-index-url https://download.pytorch.org/whl/cpu torch == 2.1.2 numpy < 2 -transformers == 4.33.1 -tokenizers == 0.13.3 -sentencepiece == 0.1.99 +transformers >= 4.42.3 +tokenizers >= 0.19.1 +sentencepiece >= 0.2.0 accelerate protobuf pybind11 == 2.11.1 diff --git a/tests/Python/test_cos.py b/tests/Python/test_cos.py new file mode 100644 index 000000000..78efe70e3 --- /dev/null +++ b/tests/Python/test_cos.py @@ -0,0 +1,35 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import math + + +def foo(x): + return torch.ops.aten.cos(x) + + +x = torch.randn(10, 3, 6) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=math.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = torch.compile(foo, backend=dynamo_compiler) +assert torch.allclose(foo_mlir(x), foo(x), equal_nan=True) + +graphs = dynamo_compiler.importer(foo, x) +graph = graphs[0] +graph.lower_to_top_level_ir() +print(graph._imported_module) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = math.cos +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } diff --git a/tests/Python/test_expand.py b/tests/Python/test_expand.py new file mode 100644 index 000000000..80642b084 --- /dev/null +++ b/tests/Python/test_expand.py @@ -0,0 +1,38 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp +from torch._functorch.aot_autograd import aot_autograd_decompositions + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import tosa + + +def foo(x, y): + return torch.ops.aten.expand(x, y) + + +in1 = torch.tensor([[1], [2], [3]], dtype=torch.float32) +in2 = [-1, 4] + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=tosa.ops_registry, + aot_autograd_decomposition=aot_autograd_decompositions, +) + +graphs = dynamo_compiler.importer(foo, in1, in2) +assert len(graphs) == 1 +graph = graphs[0] +graph.lower_to_top_level_ir() +print(graph._imported_module) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = arith.constant +# CHECK: %{{.*}} = tensor.empty +# CHECK: %{{.*}} = linalg.generic +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } diff --git a/tests/Python/test_ge.py b/tests/Python/test_ge.py new file mode 100644 index 000000000..24e202a18 --- /dev/null +++ b/tests/Python/test_ge.py @@ -0,0 +1,35 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import linalg + + +def foo(x, y): + return torch.ops.aten.ge(x, y) + + +in1 = torch.ones([13, 5], dtype=torch.int64) +in2 = 0 +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=linalg.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +graphs = dynamo_compiler.importer(foo, in1, in2) +assert len(graphs) == 1 +graph = graphs[0] +graph.lower_to_top_level_ir() +print(graph._imported_module) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = tensor.empty +# CHECK: %{{.*}} = linalg.generic +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } diff --git a/tests/Python/test_gt.py b/tests/Python/test_gt.py new file mode 100644 index 000000000..48d677b9a --- /dev/null +++ b/tests/Python/test_gt.py @@ -0,0 +1,35 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import linalg + + +def foo(x, y): + return torch.ops.aten.gt(x, y) + + +in1 = torch.ones([13], dtype=torch.int64) +in2 = torch.ones([13, 1], dtype=torch.int64) +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=linalg.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +graphs = dynamo_compiler.importer(foo, in1, in2) +assert len(graphs) == 1 +graph = graphs[0] +graph.lower_to_top_level_ir() +print(graph._imported_module) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = tensor.empty +# CHECK: %{{.*}} = linalg.generic +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } diff --git a/tests/Python/test_sin.py b/tests/Python/test_sin.py new file mode 100644 index 000000000..d6cb52eed --- /dev/null +++ b/tests/Python/test_sin.py @@ -0,0 +1,36 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +from torch._inductor.decomposition import decompositions as inductor_decomp + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import math + + +def foo(x): + return torch.ops.aten.sin(x) + + +x = torch.randn(10, 3, 6) + +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=math.ops_registry, + aot_autograd_decomposition=inductor_decomp, +) + +foo_mlir = torch.compile(foo, backend=dynamo_compiler) +assert torch.allclose(foo_mlir(x), foo(x), equal_nan=True) + +graphs = dynamo_compiler.importer(foo, x) + +graph = graphs[0] +graph.lower_to_top_level_ir() +print(graph._imported_module) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = math.sin +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: } diff --git a/tests/Python/test_where.py b/tests/Python/test_where.py index 5266f00b7..7c8013e02 100644 --- a/tests/Python/test_where.py +++ b/tests/Python/test_where.py @@ -14,8 +14,10 @@ def foo(x, y, z): in1 = torch.ones([13, 13], dtype=torch.bool) -in2 = 0 -in3 = torch.ones([13, 13], dtype=torch.float32) +in2 = 0 # or in2 is tensor value, i.e. in2 = torch.zeros([13, 13], dtype=torch.float32) +in3 = torch.ones( + [13, 13], dtype=torch.float32 +) # or in3 is scalar value, i.e. in3 = 1 # Initialize the dynamo compiler. dynamo_compiler = DynamoCompiler( primary_registry=linalg.ops_registry, From 7adb1fdadf9e6fe29e3a799762173dfbfb4be330 Mon Sep 17 00:00:00 2001 From: effrey-liu <2318266514@qq.com> Date: Fri, 12 Jul 2024 23:40:01 +0800 Subject: [PATCH 2/2] uncomment model_path --- examples/BuddyLlama/import-llama.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/BuddyLlama/import-llama.py b/examples/BuddyLlama/import-llama.py index 41c93f904..57b2ede11 100644 --- a/examples/BuddyLlama/import-llama.py +++ b/examples/BuddyLlama/import-llama.py @@ -35,12 +35,11 @@ from buddy.compiler.graph.transform import simply_fuse # Retrieve the LLaMA model path from environment variables. -# model_path = os.environ.get("LLAMA_MODEL_PATH") -# if model_path is None: -# raise EnvironmentError( -# "The environment variable 'LLAMA_MODEL_PATH' is not set or is invalid." -# ) -model_path = "../../../download/llama3_model/llama3_8B_save/" +model_path = os.environ.get("LLAMA_MODEL_PATH") +if model_path is None: + raise EnvironmentError( + "The environment variable 'LLAMA_MODEL_PATH' is not set or is invalid." + ) # Initialize the tokenizer and model from the specified model path. tokenizer = AutoTokenizer.from_pretrained(model_path)