Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for llama3 #345

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/BuddyLlama/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Buddy Compiler LLaMA Example

1. Download LLaMA2 model
1. Download LLaMA model

You should download llama model. You can get model from [meta ai](https://ai.meta.com/llama/).

Expand All @@ -14,13 +14,13 @@ $ cd buddy-mlir
$ pip install -r requirements.txt
```

3. LLaMA2 model convert to HuggingFace format
3. LLaMA model convert to HuggingFace format

You should convert LLaMA2 model which download from meta ai to HuggingFace format. Because we use HuggingFace api to get LLaMA2 model.
You should convert LLaMA model which download from meta ai to HuggingFace format. Because we use HuggingFace api to get LLaMA model.

```
$ cd examples/BuddyLlama
$ python llama2-to-hf.py --input_dir path-to-llama2-model --model_size 7B --output_dir path-to-save-llama-hf-model
$ python llama-to-hf.py --input_dir path-to-llama-model --model_size 7B/8B --output_dir path-to-save-llama-hf-model
```

Such as you have a 7B LLaMA2 model, in your input_dir path-to-llama-model, you should have a tokenizer.model and a directory named "7B". You should put your 7B LLaMA2 model inside the "7B" directory.
Expand Down
81 changes: 81 additions & 0 deletions examples/BuddyLlama/import-llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
###### import-llama.py
import os
import torch
import torch._dynamo as dynamo
from transformers import AutoModelForCausalLM, AutoTokenizer

from torch._inductor.decomposition import decompositions as inductor_decomp
from torch._decomp import get_decompositions

import numpy

from buddy.compiler.frontend import DynamoCompiler

# ===- import-llama.py --------------------------------------------------------
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===---------------------------------------------------------------------------
#
# This is the test of llama model.
#
# ===---------------------------------------------------------------------------
from buddy.compiler.ops import tosa
from buddy.compiler.graph import GraphDriver
from buddy.compiler.graph.transform import simply_fuse

# Retrieve the LLaMA model path from environment variables.
model_path = os.environ.get("LLAMA_MODEL_PATH")
if model_path is None:
raise EnvironmentError(
"The environment variable 'LLAMA_MODEL_PATH' is not set or is invalid."
)

# Initialize the tokenizer and model from the specified model path.
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torchscript=True)
model.config.use_cache = False

DEFAULT_DECOMPOSITIONS = [
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
]

decomp = get_decompositions(DEFAULT_DECOMPOSITIONS)

# Initialize Dynamo Compiler with specific configurations as an importer.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition={**inductor_decomp, **decomp},
)

# Import the model into MLIR module and parameters.
with torch.no_grad():
data = torch.tensor([[1 for i in range(40)]], dtype=torch.int64)
graphs = dynamo_compiler.importer(model, data)

assert len(graphs) == 1
graph = graphs[0]
params = dynamo_compiler.imported_params[graph]
pattern_list = [simply_fuse]
graphs[0].fuse_ops(pattern_list)
driver = GraphDriver(graphs[0])
driver.subgraphs[0].lower_to_top_level_ir()
path_prefix = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(path_prefix, "subgraph0.mlir"), "w") as module_file:
print(driver.subgraphs[0]._imported_module, file=module_file)
with open(os.path.join(path_prefix, "forward.mlir"), "w") as module_file:
print(driver.construct_main_graph(True), file=module_file)
all_param = numpy.concatenate(
[param.detach().numpy().reshape([-1]) for param in params]
)
all_param.tofile(os.path.join(path_prefix, "arg0.data"))
23 changes: 23 additions & 0 deletions examples/BuddyLlama/llama-to-hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ===- llama-to-hf.py ---------------------------------------------------------
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===---------------------------------------------------------------------------
#
# This is the convert of the llama model to huggingface format.
#
# ===---------------------------------------------------------------------------

from transformers.models.llama import convert_llama_weights_to_hf

convert_llama_weights_to_hf.main()
34 changes: 19 additions & 15 deletions frontend/Python/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
"mean.dim": MeanOp,
"rsqrt.default": RsqrtOp,
"mul.Tensor": MulOp,
"mul.Scalar": MulOp,
"t.default": TOp,
"mm.default": MatmulOp,
"transpose.int": TransposeOp,
Expand Down Expand Up @@ -160,6 +161,10 @@ def __init__(
"reciprocal.default": ReciprocalOp,
"clamp_min.default": ClampMinOp,
"clamp_max.default": ClampMaxOp,
"ge.Scalar": GreaterEqualOp,
"gt.Tensor": GreaterThanOp,
"cos.default": CosOp,
"sin.default": SinOp,
}

@property
Expand Down Expand Up @@ -223,7 +228,9 @@ def _create_node(
buddy_node.add_argument(str(input_arg))
buddy_node.add_parent(str(input_arg))
elif isinstance(input_arg, torch.dtype):
buddy_node.add_argument(self._torch_dtype_translate(str(input_arg)))
buddy_node.add_argument(
self._torch_dtype_translate(str(input_arg))
)
else:
buddy_node.add_argument(input_arg)
for user in node_users:
Expand Down Expand Up @@ -297,10 +304,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):

elif gm_node.op == "output":
buddy_node = self._create_node(
gm_node.op,
gm_node.name,
gm_node.args,
node_users
gm_node.op, gm_node.name, gm_node.args, node_users
)

elif gm_node.target is operator.getitem:
Expand Down Expand Up @@ -429,7 +433,7 @@ def get_lib_extension():

def cast_c_ptr(outdata_ptr, memref_ptr):
"""
Casts a C pointer (`outdata_ptr`) to the type of another C pointer
Casts a C pointer (`outdata_ptr`) to the type of another C pointer
(`memref_ptr`).

Args:
Expand All @@ -440,14 +444,14 @@ def cast_c_ptr(outdata_ptr, memref_ptr):

Returns:
ctypes.POINTER
A new C pointer with the type of `memref_ptr`, representing the
A new C pointer with the type of `memref_ptr`, representing the
same memory location as `outdata_ptr`.

Example:
outdata = ctypes.pointer(ctypes.c_int())
memref = ctypes.pointer(ctypes.c_float())
casted_ptr = cast_c_ptr(outdata, memref)
# Now `casted_ptr` points to the same memory location as `outdata`,
# Now `casted_ptr` points to the same memory location as `outdata`,
but with the type of `memref`.
"""
outdata_addr = ctypes.addressof(outdata_ptr.contents)
Expand All @@ -456,15 +460,15 @@ def cast_c_ptr(outdata_ptr, memref_ptr):

def move_c_ptr(outdata_ptr, memref_ptr):
"""
Moves a C pointer (`outdata_ptr`) to the next element in memory,
based on the size of the referenced type in another C pointer
Moves a C pointer (`outdata_ptr`) to the next element in memory,
based on the size of the referenced type in another C pointer
(`memref_ptr`).

Args:
outdata_ptr: ctypes.POINTER
The C pointer whose position needs to be moved.
memref_ptr: ctypes.POINTER
The reference C pointer whose type determines the size of each
The reference C pointer whose type determines the size of each
element for the move.

Returns:
Expand All @@ -487,7 +491,7 @@ def exec_buddy_graph(*args):

Returns:
List[torch.Tensor]
The result of executing the graph, represented as a list of
The result of executing the graph, represented as a list of
output tensors.
"""
# A list of ctypes pointers representing memory references for input
Expand All @@ -500,13 +504,13 @@ def exec_buddy_graph(*args):
)
for tensor in args
]
# A list of ctypes pointers representing memory references for
# A list of ctypes pointers representing memory references for
# output tensors.
output_memref = [
ctypes.pointer(ctypes.pointer(graph._output_descriptor()))
]
args_memref = output_memref + input_memref
# Invoke the graph's function using the provided execution engine
# Invoke the graph's function using the provided execution engine
# and memory references
ee.invoke(graph._func_name, *args_memref)

Expand All @@ -523,7 +527,7 @@ def exec_buddy_graph(*args):
# Move to the next element in memory based on the size of the
# current output type
outdata_ptr = move_c_ptr(outdata_ptr, output_ptr[0])
# Convert each NumPy array to a PyTorch tensor and return the list
# Convert each NumPy array to a PyTorch tensor and return the list
# of tensors
return [torch.from_numpy(tensor) for tensor in output_tensor]

Expand Down
37 changes: 37 additions & 0 deletions frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def kwargs(self):
@property
def name(self):
return self._name

@name.setter
def name(self, new_name):
self._name = new_name
Expand Down Expand Up @@ -404,37 +405,44 @@ def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class Conv2dOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReduceType
self._layout = "NCHW_FCHW"


class ReluOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class SigmoidOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class IotaOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.PlaceholderType


class ScalarTensorOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.PlaceholderType


class WhereOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class MaxPool2dWithIndicesOp(Op):
def __init__(self) -> None:
super().__init__()
Expand All @@ -448,17 +456,20 @@ def __init__(self) -> None:
self._op_type = OpType.ReduceType
self._layout = "NCHW"


class CallOp(Op):
def __init__(self) -> None:
super().__init__()
self.call_func_name = ""
self._op_type = OpType.Unfusable


class FuncOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.Unfusable


class ReciprocalOp(Op):
def __init__(self) -> None:
super().__init__()
Expand All @@ -470,12 +481,38 @@ def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class ClampMinOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class ClampMaxOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class GreaterEqualOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.BroadcastType


class GreaterThanOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.BroadcastType


class CosOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class SinOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType
Loading