From de2ec94aa5ca7c122dfeb03a2d915f0425c97a52 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Tue, 21 Jan 2025 19:00:44 +0100 Subject: [PATCH 01/14] add executorch export --- optimum/commands/export/executorch.py | 67 +++ optimum/commands/register/register_export.py | 19 + optimum/executorch/__init__.py | 29 ++ optimum/executorch/modeling.py | 460 ++++++++++++++++++ optimum/exporters/executorch/__init__.py | 50 ++ optimum/exporters/executorch/__main__.py | 160 ++++++ optimum/exporters/executorch/convert.py | 90 ++++ .../exporters/executorch/recipe_registry.py | 68 +++ .../exporters/executorch/recipes/__init__.py | 13 + .../exporters/executorch/recipes/xnnpack.py | 97 ++++ optimum/exporters/executorch/task_registry.py | 68 +++ .../exporters/executorch/tasks/__init__.py | 13 + .../exporters/executorch/tasks/causal_lm.py | 66 +++ tests/export/__init__.py | 14 + tests/export/test_exporters_executorch.py | 115 +++++ tests/runtime/__init__.py | 14 + tests/runtime/test_modeling.py | 70 +++ tests/runtime/test_modeling_gemma.py | 54 ++ tests/runtime/test_modeling_gemma2.py | 56 +++ tests/runtime/test_modeling_llama.py | 83 ++++ tests/runtime/test_modeling_olmo.py | 54 ++ tests/runtime/test_modeling_qwen2.py | 52 ++ 22 files changed, 1712 insertions(+) create mode 100644 optimum/commands/export/executorch.py create mode 100644 optimum/commands/register/register_export.py create mode 100644 optimum/executorch/__init__.py create mode 100644 optimum/executorch/modeling.py create mode 100644 optimum/exporters/executorch/__init__.py create mode 100644 optimum/exporters/executorch/__main__.py create mode 100644 optimum/exporters/executorch/convert.py create mode 100644 optimum/exporters/executorch/recipe_registry.py create mode 100644 optimum/exporters/executorch/recipes/__init__.py create mode 100644 optimum/exporters/executorch/recipes/xnnpack.py create mode 100644 optimum/exporters/executorch/task_registry.py create mode 100644 optimum/exporters/executorch/tasks/__init__.py create mode 100644 optimum/exporters/executorch/tasks/causal_lm.py create mode 100644 tests/export/__init__.py create mode 100644 tests/export/test_exporters_executorch.py create mode 100644 tests/runtime/__init__.py create mode 100644 tests/runtime/test_modeling.py create mode 100644 tests/runtime/test_modeling_gemma.py create mode 100644 tests/runtime/test_modeling_gemma2.py create mode 100644 tests/runtime/test_modeling_llama.py create mode 100644 tests/runtime/test_modeling_olmo.py create mode 100644 tests/runtime/test_modeling_qwen2.py diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py new file mode 100644 index 0000000..2bf2f1d --- /dev/null +++ b/optimum/commands/export/executorch.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""Defines the command line for the export with ExecuTorch.""" + +from pathlib import Path +from typing import TYPE_CHECKING + +from ...exporters import TasksManager +from ..base import BaseOptimumCLICommand + + +if TYPE_CHECKING: + from argparse import ArgumentParser + + +def parse_args_executorch(parser): + required_group = parser.add_argument_group("Required arguments") + required_group.add_argument( + "-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from." + ) + required_group.add_argument( + "-o", + "--output_dir", + type=Path, + help="Path indicating the directory where to store the generated ExecuTorch model.", + ) + required_group.add_argument( + "--task", + type=str, + default="text-generation", + help=( + "The task to export the model for. Available tasks depend on the model, but are among:" + f" {str(TasksManager.get_all_tasks())}." + ), + ) + required_group.add_argument( + "--recipe", + type=str, + default="xnnpack", + help='Pre-defined recipes for export to ExecuTorch. Defaults to "xnnpack".', + ) + + +class ExecuTorchExportCommand(BaseOptimumCLICommand): + @staticmethod + def parse_args(parser: "ArgumentParser"): + return parse_args_executorch(parser) + + def run(self): + from ...exporters.executorch import main_export + + main_export( + model_name_or_path=self.args.model, + task=self.args.task, + recipe=self.args.recipe, + output_dir=self.args.output_dir, + ) diff --git a/optimum/commands/register/register_export.py b/optimum/commands/register/register_export.py new file mode 100644 index 0000000..3959de6 --- /dev/null +++ b/optimum/commands/register/register_export.py @@ -0,0 +1,19 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..export import ExportCommand +from ..export.executorch import ExecuTorchExportCommand + + +REGISTER_COMMANDS = [(ExecuTorchExportCommand, ExportCommand)] diff --git a/optimum/executorch/__init__.py b/optimum/executorch/__init__.py new file mode 100644 index 0000000..cbc9b37 --- /dev/null +++ b/optimum/executorch/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "modeling": [ + "ExecuTorchModelForCausalLM", + ], +} + +if TYPE_CHECKING: + from .modeling import ExecuTorchModelForCausalLM +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py new file mode 100644 index 0000000..b93309f --- /dev/null +++ b/optimum/executorch/modeling.py @@ -0,0 +1,460 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""ExecuTorchModelForXXX classes, allowing to run ExecuTorch Models with ExecuTorch Runtime using the same API as Transformers.""" + +import logging +import os +import warnings +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import List, Optional, Union + +import torch +from executorch.extension.pybindings.portable_lib import ( + ExecuTorchModule, + _load_for_executorch, +) +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from transformers import ( + AutoModelForCausalLM, + PretrainedConfig, + PreTrainedTokenizer, +) + +from ..exporters.executorch import main_export +from ..modeling_base import OptimizedModel + + +logger = logging.getLogger(__name__) + + +class ExecuTorchModelForCausalLM(OptimizedModel): + """ + ExecuTorch model with a causal language modeling head for inference using the ExecuTorch Runtime. + + This class provides an interface for loading, running, and generating outputs from a causal language model + optimized for ExecuTorch Runtime. It includes utilities for exporting and loading pre-trained models + compatible with ExecuTorch runtime. + + Attributes: + auto_model_class (`Type`): + Associated Transformers class, `AutoModelForCausalLM`. + et_model (`ExecuTorchModule`): + The loaded ExecuTorch model. + use_kv_cache (`bool`): + Whether key-value caching is enabled. For performance reasons, the exported model is + optimized to use a static cache. + max_cache_size (`int`): + Maximum sequence length supported by the cache. + max_batch_size (`int`): + Maximum supported batch size. + dtype (`str`): + Data type of the model parameters. + bos_token_id (`int`): + Beginning-of-sequence token ID. + eos_token_id (`int`): + End-of-sequence token ID. + vocab_size (`int`): + Size of the model vocabulary. + """ + + auto_model_class = AutoModelForCausalLM + + def __init__( + self, + model: "ExecuTorchModule", + config: "PretrainedConfig", + ): + super().__init__(model, config) + self.et_model = model + metadata = self.et_model.method_names() + logging.info(f"Load all static methods: {metadata}") + if "use_kv_cache" in metadata: + self.use_kv_cache = self.et_model.run_method("use_kv_cache")[0] + if "get_max_seq_len" in metadata: + self.max_cache_size = self.et_model.run_method("get_max_seq_len")[0] + if "get_max_batch_size" in metadata: + self.max_batch_size = self.et_model.run_method("get_max_batch_size")[0] + if "get_dtype" in metadata: + self.dtype = self.et_model.run_method("get_dtype")[0] + if "get_bos_id" in metadata: + self.bos_token_id = self.et_model.run_method("get_bos_id")[0] + if "get_eos_id" in metadata: + self.eos_token_id = self.et_model.run_method("get_eos_id")[0] + if "get_vocab_size" in metadata: + self.vocab_size = self.et_model.run_method("get_vocab_size")[0] + + def forward( + self, + input_ids: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the model, which is compatible with the ExecuTorch runtime for LLM. + + Args: + input_ids (`torch.Tensor`): Tensor representing current input token id to the model. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + """ + return self.et_model.forward((input_ids, cache_position))[0] + + @classmethod + def from_pretrained( + cls, + model_name_or_path: Union[str, Path], + export: bool = True, + task: str = "", + recipe: str = "", + config: "PretrainedConfig" = None, + subfolder: str = "", + revision: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + **kwargs, + ) -> "ExecuTorchModelForCausalLM": + """ + Load a pre-trained ExecuTorch model. + + Args: + model_name_or_path (`Union[str, Path]`): + Model ID on huggingface.co or path on disk to the model repository to export. Example: `model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder`. + export (`bool`, *optional*, defaults to `True`): + If `True`, the model will be exported from eager to ExecuTorch after fetched from huggingface.co. `model_name_or_path` must be a valid model ID on huggingface.co. + If `False`, the previously exported ExecuTorch model will be loaded from a local path. `model_name_or_path` must be a valid local directory where a `model.pte` is stored. + task (`str`, defaults to `""`): + The task to export the model for, e.g. "text-generation". It is required to specify a task when `export` is `True`. + recipe (`str`, defaults to `""`): + The recipe to use to do the export, e.g. "xnnpack". It is required to specify a task when `export` is `True`. + config (`PretrainedConfig`, *optional*): + Configuration of the pre-trained model. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + **kwargs: + Additional configuration options to tasks and recipes. + + Returns: + `ExecuTorchModelForCausalLM`: An instance of the ExecuTorch model for text generation task. + """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + if export: + # Fetch the model from huggingface.co and export it to ExecuTorch + if task == "": + raise ValueError("Please specify a task to export the model for.") + if recipe == "": + raise ValueError("Please specify a recipe to export the model for.") + return cls._export( + model_id=model_name_or_path, + task=task, + recipe=recipe, + config=config, + **kwargs, + ) + else: + # Load the ExecuTorch model from a local path + return cls._from_pretrained( + model_dir_path=model_name_or_path, + config=config, + ) + + @classmethod + def _from_pretrained( + cls, + model_dir_path: Union[str, Path], + config: PretrainedConfig, + subfolder: str = "", + revision: Optional[str] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + ) -> "ExecuTorchModelForCausalLM": + """ + Load a pre-trained ExecuTorch model from a local directory. + + Args: + model_dir_path (`Union[str, Path]`): + Path to the directory containing the ExecuTorch model file (`model.pte`). + config (`PretrainedConfig`, *optional*): + Configuration of the pre-trained model. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + + Returns: + `ExecuTorchModelForCausalLM`: The initialized ExecuTorch model. + + """ + full_path = os.path.join(f"{model_dir_path}", "model.pte") + model = _load_for_executorch(full_path) + logging.info(f"Loaded model from {full_path}") + logging.debug(f"{model.method_meta('forward')}") + return cls( + model=model, + config=config, + ) + + def _save_pretrained(self, save_directory): + """ + Saves a model weights into a directory, so that it can be re-loaded using the + [`from_pretrained`] class method. + """ + raise NotImplementedError + + @classmethod + def _export( + cls, + model_id: str, + task: str, + recipe: str, + config: PretrainedConfig, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + trust_remote_code: bool = False, + subfolder: str = "", + revision: Optional[str] = None, + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + **kwargs, + ): + """ + Fetch a model from the Hugging Face Hub and export it to ExecuTorch format. + + Args: + model_id (`str`): + Model ID on huggingface.co, for example: `model_name_or_path="meta-llama/Llama-3.2-1B"`. + task (`str`): + The task to export the model for, e.g. "text-generation". + recipe (`str`): + The recipe to use to do the export, e.g. "xnnpack". + config (`PretrainedConfig`, *optional*): + Configuration of the pre-trained model. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories + you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the + model repository. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + **kwargs: + Additional configuration options to tasks and recipes. + + Returns: + `ExecuTorchModelForCausalLM`: The loaded and exported ExecuTorch model. + + """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + + # Export to ExecuTorch and save the pte file to the temporary directory + main_export( + model_name_or_path=model_id, + output_dir=save_dir_path, + task=task, + recipe=recipe, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + return cls._from_pretrained( + model_dir_path=save_dir_path, + config=config, + use_auth_token=use_auth_token, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + ) + + def generate( + self, + prompt_tokens: List[int], + echo: bool = False, + pos_base: int = 0, + max_seq_len: Optional[int] = None, + ) -> List[int]: + """ + Generate tokens from a prompt using the ExecuTorch model. + + Args: + prompt_tokens (List[int]): + List of token IDs representing the prompt. + echo (`bool`, *optional*): + Whether to include prompt tokens in the generated output. Defaults to `False`. + pos_base (`int`, *optional*): + Base position for the prompt tokens. Defaults to 0. + max_seq_len (`int`, *optional*): + Maximum sequence length for the generated output. + Defaults to None and uses the model's `max_cache_size` attribute. + Will be truncated to maximal cache size if larger than `max_cache_size`. + + Returns: + List[int]: List of generated token IDs. + + Note: + Temporarily implemented this method in Python due to limited access to ExecuTorch's c++ LLM runner via pybind. + Expect improvements to the pybind interface in ExecuTorch version 0.4.1. + """ + self.device = torch.device("cpu") + if max_seq_len is None: + # Default to max_cache_size if max_seq_len is not specified + max_seq_len = self.max_cache_size + elif max_seq_len > self.max_cache_size: + logging.warning( + f"max_seq_len={max_seq_len} is larger than max_cache_size={self.max_cache_size}. Generating tokens will be truncated to max_cache_size." + ) + max_seq_len = self.max_cache_size + generated_tokens = [] + + # prefill + for i, prompt_token in enumerate(prompt_tokens): + logits = self.forward( + input_ids=torch.tensor([prompt_token], dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.tensor([i], dtype=torch.long, device=self.device), + ) + + next_token = torch.argmax(logits, dim=-1).item() + generated_tokens = prompt_tokens + [next_token] + + while len(generated_tokens) < max_seq_len: + logits = self.forward( + input_ids=torch.tensor([next_token], dtype=torch.long, device=self.device).unsqueeze(0), + cache_position=torch.tensor( + [pos_base + len(generated_tokens) - 1], + dtype=torch.long, + device=self.device, + ), + ) + next_token = torch.argmax(logits, dim=-1).item() + generated_tokens.append(next_token) + if next_token == self.eos_token_id: + break + + return generated_tokens if echo else generated_tokens[len(prompt_tokens) :] + + def text_generation( + self, + tokenizer: "PreTrainedTokenizer", + prompt: str, + echo: bool = True, + max_seq_len: Optional[int] = None, + ): + """ + Perform text generation task for a given prompt using the ExecuTorch model. + + Args: + tokenizer (`PreTrainedTokenizer`): + The tokenizer used to encode and decode the prompt and output. + prompt (`str`): + The text prompt to complete. + echo (`bool`, *optional*): + Whether to include prompt tokens in the generated output. Defaults to `True`. + max_seq_len (`int`, *optional*): + Maximum sequence length for the generated output. + Defaults to None and uses the model's `max_cache_size` attribute. + Will be truncated to maximal cache size if larger than `max_cache_size`. + """ + self.tokenizer = tokenizer + + # Sanity check + if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.bos_token_id: + raise ValueError( + f"The tokenizer's bos_token_id={self.tokenizer.bos_token_id} must be the same as the model's bos_token_id={self.bos_token_id}." + ) + if self.tokenizer.eos_token_id is not None and self.tokenizer.eos_token_id != self.eos_token_id: + raise ValueError( + f"The tokenizer's eos_token_id={self.tokenizer.eos_token_id} must be the same as the model's eos_token_id={self.eos_token_id}." + ) + + prompt_tokens = self.tokenizer.encode(prompt) + generated_tokens = self.generate( + prompt_tokens=prompt_tokens, + echo=echo, + max_seq_len=max_seq_len, + ) + return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) diff --git a/optimum/exporters/executorch/__init__.py b/optimum/exporters/executorch/__init__.py new file mode 100644 index 0000000..3409e69 --- /dev/null +++ b/optimum/exporters/executorch/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "convert": [ + "export_to_executorch", + ], + "recipe_registry": [ + "discover_recipes", + "register_recipe", + ], + "task_registry": [ + "discover_tasks", + "register_task", + ], + "tasks": [ + "causal_lm", + ], + "recipes": [ + "xnnpack", + ], + "__main__": ["main_export"], +} + +if TYPE_CHECKING: + from .__main__ import main_export + from .convert import export_to_executorch +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) diff --git a/optimum/exporters/executorch/__main__.py b/optimum/exporters/executorch/__main__.py new file mode 100644 index 0000000..e3b561f --- /dev/null +++ b/optimum/exporters/executorch/__main__.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""Entry point to the optimum.exporters.executorch command line.""" + +import argparse +import os +import warnings +from pathlib import Path + +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from transformers.utils import is_torch_available + +from optimum.utils.import_utils import is_transformers_version + +from ...commands.export.executorch import parse_args_executorch +from .convert import export_to_executorch +from .task_registry import discover_tasks, task_registry + + +if is_torch_available(): + pass + +from typing import Optional, Union + + +def main_export( + model_name_or_path: str, + task: str, + recipe: str, + output_dir: Union[str, Path], + cache_dir: str = HUGGINGFACE_HUB_CACHE, + trust_remote_code: bool = False, + pad_token_id: Optional[int] = None, + subfolder: str = "", + revision: str = "main", + force_download: bool = False, + local_files_only: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + **kwargs, +): + """ + Full-suite ExecuTorch export function, exporting **from a model ID on Hugging Face Hub or a local model repository**. + + Args: + model_name_or_path (`str`): + Model ID on huggingface.co or path on disk to the model repository to export. Example: `model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder`. + task (`str`): + The task to export the model for, e.g. "text-generation". + recipe (`str`): + The recipe to use to do the export, e.g. "xnnpack". + output_dir (`Union[str, Path]`): + Path indicating the directory where to store the generated ExecuTorch model. + cache_dir (`Optional[str]`, defaults to `None`): + Path indicating where to store cache. The default Hugging Face cache path will be used by default. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories + you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the + model repository. + pad_token_id (`Optional[int]`, defaults to `None`): + This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it. + subfolder (`str`, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can + specify the folder name here. + revision (`str`, defaults to `"main"`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + force_download (`bool`, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + local_files_only (`Optional[bool]`, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + **kwargs: + Additional configuration options to tasks and recipes. + + Example usage: + ```python + >>> from optimum.exporters.executorch import main_export + + >>> main_export("meta-llama/Llama-3.2-1B", "text-generation", "xnnpack", "meta_llama3_2_1b/") + ``` + """ + + if is_transformers_version("<", "4.46"): + raise ValueError( + "The minimum Transformers version compatible with ExecuTorch is 4.46.0. Please upgrade to Transformers 4.46.0 or later." + ) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + # Dynamically discover and import registered tasks + discover_tasks() + + # Load the model for specific task + try: + task_func = task_registry.get(task) + except KeyError as e: + raise RuntimeError(f"The task '{task}' isn't registered. Detailed error: {e}") + + model = task_func(model_name_or_path, **kwargs) + + if task == "text-generation": + from transformers.integrations.executorch import TorchExportableModuleWithStaticCache + + model = TorchExportableModuleWithStaticCache(model) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + return export_to_executorch( + model=model, + task=task, + recipe=recipe, + output_dir=output_dir, + **kwargs, + ) + + +def main(): + parser = argparse.ArgumentParser("Hugging Face Optimum ExecuTorch exporter") + + parse_args_executorch(parser) + + # Retrieve CLI arguments + args = parser.parse_args() + + main_export( + model_name_or_path=args.model, + output_dir=args.output_dir, + task=args.task, + recipe=args.recipe, + cache_dir=args.cache_dir, + trust_remote_code=args.trust_remote_code, + pad_token_id=args.pad_token_id, + ) + + +if __name__ == "__main__": + main() diff --git a/optimum/exporters/executorch/convert.py b/optimum/exporters/executorch/convert.py new file mode 100644 index 0000000..aceb733 --- /dev/null +++ b/optimum/exporters/executorch/convert.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +"""ExecuTorch model check and export functions.""" + +import logging +import os +from pathlib import Path +from typing import Union + +from transformers.utils import is_torch_available + +from optimum.utils.import_utils import is_transformers_version + +from .recipe_registry import discover_recipes, recipe_registry + + +if is_torch_available(): + from transformers.modeling_utils import PreTrainedModel + +if is_transformers_version(">=", "4.46"): + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + ) + +logger = logging.getLogger(__name__) + + +def export_to_executorch( + model: Union["PreTrainedModel", "TorchExportableModuleWithStaticCache"], + task: str, + recipe: str, + output_dir: Union[str, Path], + **kwargs, +): + """ + Export a pre-trained PyTorch model to the ExecuTorch format using a specified recipe. + + This function facilitates the transformation of a PyTorch model into an optimized ExecuTorch program. + + Args: + model (`Union["PreTrainedModel", "TorchExportableModuleWithStaticCache"]`): + A PyTorch model to be exported. This can be a standard HuggingFace `PreTrainedModel` or a wrapped + module like `TorchExportableModuleWithStaticCache` for text generation task. + task (`str`): + The specific task the exported model will perform, e.g., "text-generation". + recipe (`str`): + The recipe to guide the export process, e.g., "xnnpack". Recipes define the optimization and lowering steps. + Will raise an exception if the specified recipe is not registered in the recipe registry. + output_dir (`Union[str, Path]`): + Path to the directory where the resulting ExecuTorch model will be saved. + **kwargs: + Additional configuration options passed to the recipe. + + Returns: + `ExecuTorchProgram`: + The lowered ExecuTorch program object. + + Notes: + - The function uses a dynamic recipe discovery mechanism to identify and import the specified recipe. + - The exported model is stored in the specified output directory with the fixed filename `model.pte`. + - The resulting ExecuTorch program is serialized and saved to the output directory. + """ + + # Dynamically discover and import registered recipes + discover_recipes() + + # Export and lower the model to ExecuTorch with the recipe + try: + recipe_func = recipe_registry.get(recipe) + except KeyError as e: + raise RuntimeError(f"The recipe '{recipe}' isn't registered. Detailed error: {e}") + + executorch_prog = recipe_func(model, task, **kwargs) + + full_path = os.path.join(f"{output_dir}", "model.pte") + with open(full_path, "wb") as f: + executorch_prog.write_to_file(f) + logging.info(f"Saved exported program to {full_path}") + + return executorch_prog diff --git a/optimum/exporters/executorch/recipe_registry.py b/optimum/exporters/executorch/recipe_registry.py new file mode 100644 index 0000000..2eb728b --- /dev/null +++ b/optimum/exporters/executorch/recipe_registry.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import importlib +import logging +import pkgutil + + +logger = logging.getLogger(__name__) + +recipe_registry = {} + +package_name = "optimum.exporters.executorch.recipes" + + +def register_recipe(recipe_name): + """ + Decorator to register a recipe for exporting and lowering an ExecuTorch model under a specific name. + + Args: + recipe_name (`str`): + The name of the recipe to associate with a callable recipe. + + Returns: + `Callable`: + The original function wrapped as a registered recipe. + + Example: + ```python + @register_recipe("my_new_recipe") + def my_new_recipe(...): + ... + ``` + """ + + def decorator(func): + recipe_registry[recipe_name] = func + return func + + return decorator + + +def discover_recipes(): + """ + Dynamically discovers and imports all recipe modules within the `optimum.exporters.executorch.recipes` package. + + Ensures recipes under `./recipes` directory are dynamically loaded without requiring manual imports. + + Notes: + New recipes **must** be added to the `./recipes` directory to be discovered and used by `main_export`. + Failure to do so will prevent dynamic discovery and registration. Recipes must also use the + `@register_recipe` decorator to be properly registered in the `recipe_registry`. + """ + package = importlib.import_module(package_name) + package_path = package.__path__ + + for _, module_name, _ in pkgutil.iter_modules(package_path): + logger.info(f"Importing {package_name}.{module_name}") + importlib.import_module(f"{package_name}.{module_name}") diff --git a/optimum/exporters/executorch/recipes/__init__.py b/optimum/exporters/executorch/recipes/__init__.py new file mode 100644 index 0000000..a2e21cf --- /dev/null +++ b/optimum/exporters/executorch/recipes/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from . import xnnpack diff --git a/optimum/exporters/executorch/recipes/xnnpack.py b/optimum/exporters/executorch/recipes/xnnpack.py new file mode 100644 index 0000000..d3b3a5d --- /dev/null +++ b/optimum/exporters/executorch/recipes/xnnpack.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from typing import Union + +import torch +import torch.export._trace +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from torch.nn.attention import SDPBackend +from transformers import PreTrainedModel, TorchExportableModuleWithStaticCache + +from ..recipe_registry import register_recipe + + +@register_recipe("xnnpack") +def export_to_executorch_with_xnnpack( + model: Union[PreTrainedModel, TorchExportableModuleWithStaticCache], + task: str, + **kwargs, +): + """ + Export a PyTorch model to ExecuTorch w/ delegation to XNNPACK backend. + + This function also write metadata required by the ExecuTorch runtime to the model. + + Args: + model (Union[PreTrainedModel, TorchExportableModuleWithStaticCache]): + The PyTorch model to be exported to ExecuTorch. + task (str): + The task name to export the model for (e.g., "text-generation"). + **kwargs: + Additional keyword arguments for recipe-specific configurations. + + Returns: + ExecuTorchProgram: + The exported and optimized program for ExecuTorch. + """ + metadata = {} + if task == "text-generation": + example_input_ids = torch.tensor([[1]], dtype=torch.long) + example_cache_position = torch.tensor([0], dtype=torch.long) + + def _get_constant_methods(model: PreTrainedModel): + metadata = { + "get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6, + "get_bos_id": model.config.bos_token_id, + "get_eos_id": model.config.eos_token_id, + "get_head_dim": model.config.hidden_size / model.config.num_attention_heads, + "get_max_batch_size": model.generation_config.cache_config.batch_size, + "get_max_seq_len": model.generation_config.cache_config.max_cache_len, + "get_n_kv_heads": model.config.num_key_value_heads, + "get_n_layers": model.config.num_hidden_layers, + "get_vocab_size": model.config.vocab_size, + "use_kv_cache": model.generation_config.use_cache, + } + return {k: v for k, v in metadata.items() if v is not None} + + metadata = _get_constant_methods(model if isinstance(model, PreTrainedModel) else model.model) + else: + # TODO: Prepare model inputs for other tasks + raise ValueError(f"Unsupported task '{task}'.") + + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + exported_program = torch.export._trace._export( + model, + args=(example_input_ids,), + kwargs={"cache_position": example_cache_position}, + pre_dispatch=False, + strict=True, + ) + + return to_edge_transform_and_lower( + exported_program, + partitioner=[XnnpackPartitioner()], + compile_config=EdgeCompileConfig( + _skip_dim_order=True, + ), + constant_methods=metadata, + ).to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + ), + ) diff --git a/optimum/exporters/executorch/task_registry.py b/optimum/exporters/executorch/task_registry.py new file mode 100644 index 0000000..fdc34f0 --- /dev/null +++ b/optimum/exporters/executorch/task_registry.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import importlib +import logging +import pkgutil + + +logger = logging.getLogger(__name__) + +task_registry = {} + +package_name = "optimum.exporters.executorch.tasks" + + +def register_task(task_name): + """ + Decorator to register a task under a specific name. + + Args: + task_name (`str`): + The name of the task to associate with a callable task. + + Returns: + `Callable`: + The original function wrapped as a registered task. + + Example: + ```python + @register_task("my_new_task") + def my_new_task(...): + ... + ``` + """ + + def decorator(func): + task_registry[task_name] = func + return func + + return decorator + + +def discover_tasks(): + """ + Dynamically discovers and imports all task modules within the `optimum.exporters.executorch.tasks` package. + + Ensures tasks under `./tasks` directory are dynamically loaded without requiring manual imports. + + Notes: + New tasks **must** be added to the `./tasks` directory to be discovered and used by `main_export`. + Failure to do so will prevent dynamic discovery and registration. Tasks must also use the + `@register_task` decorator to be properly registered in the `task_registry`. + """ + package = importlib.import_module(package_name) + package_path = package.__path__ + + for _, module_name, _ in pkgutil.iter_modules(package_path): + logger.info(f"Importing {package_name}.{module_name}") + importlib.import_module(f"{package_name}.{module_name}") diff --git a/optimum/exporters/executorch/tasks/__init__.py b/optimum/exporters/executorch/tasks/__init__.py new file mode 100644 index 0000000..754a824 --- /dev/null +++ b/optimum/exporters/executorch/tasks/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from . import causal_lm diff --git a/optimum/exporters/executorch/tasks/causal_lm.py b/optimum/exporters/executorch/tasks/causal_lm.py new file mode 100644 index 0000000..b02da8b --- /dev/null +++ b/optimum/exporters/executorch/tasks/causal_lm.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +from transformers import AutoModelForCausalLM, GenerationConfig + +from ..task_registry import register_task + + +@register_task("text-generation") +def load_causal_lm_model(model_name_or_path: str, **kwargs): + """ + Loads a causal language model for text generation and registers it under the task + 'text-generation' using Hugging Face's AutoModelForCausalLM. + + Args: + model_name_or_path (str): + Model ID on huggingface.co or path on disk to the model repository to export. For example: + `model_name_or_path="meta-llama/Llama-3.2-1B"` or `mode_name_or_path="/path/to/model_folder` + **kwargs: + Additional configuration options for the model: + - dtype (str, optional): + Data type for model weights (default: "float32"). + Options include "float16" and "bfloat16". + - attn_implementation (str, optional): + Attention mechanism implementation (default: "sdpa"). + - cache_implementation (str, optional): + Cache management strategy (default: "static"). + - max_length (int, optional): + Maximum sequence length for generation (default: 2048). + + Returns: + transformers.PreTrainedModel: + An instance of a model subclass (e.g., Llama, Gemma) with the configuration for exporting + and lowering to ExecuTorch. + """ + device = "cpu" + batch_size = 1 + dtype = kwargs.get("dtype", "float32") + attn_implementation = kwargs.get("attn_implementation", "sdpa") + cache_implementation = kwargs.get("cache_implementation", "static") + max_length = kwargs.get("max_length", 2048) + + return AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_length, + }, + ), + ) diff --git a/tests/export/__init__.py b/tests/export/__init__.py new file mode 100644 index 0000000..fdc0257 --- /dev/null +++ b/tests/export/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/export/test_exporters_executorch.py b/tests/export/test_exporters_executorch.py new file mode 100644 index 0000000..f246710 --- /dev/null +++ b/tests/export/test_exporters_executorch.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import tempfile +import unittest + +import pytest +from transformers.testing_utils import slow + + +class TestExportToExecuTorchCLI(unittest.TestCase): + def test_helps_no_raise(self): + subprocess.run( + "optimum-cli export executorch --help", + shell=True, + check=True, + ) + + @slow + @pytest.mark.run_slow + def test_llama3_2_1b_export_to_executorch(self): + model_id = "NousResearch/Llama-3.2-1B" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_llama3_2_3b_export_to_executorch(self): + model_id = "NousResearch/Hermes-3-Llama-3.2-3B" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_qwen2_5_export_to_executorch(self): + model_id = "Qwen/Qwen2.5-0.5B" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_gemma2_export_to_executorch(self): + model_id = "unsloth/gemma-2-2b-it" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_gemma_export_to_executorch(self): + model_id = "weqweasdas/RM-Gemma-2B" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) + + @slow + @pytest.mark.run_slow + def test_olmo_export_to_executorch(self): + model_id = "allenai/OLMo-1B-hf" + task = "text-generation" + recipe = "xnnpack" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --output_dir {tempdir}/executorch", + shell=True, + check=True, + ) + self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte")) diff --git a/tests/runtime/__init__.py b/tests/runtime/__init__.py new file mode 100644 index 0000000..fdc0257 --- /dev/null +++ b/tests/runtime/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/runtime/test_modeling.py b/tests/runtime/test_modeling.py new file mode 100644 index 0000000..c97b461 --- /dev/null +++ b/tests/runtime/test_modeling.py @@ -0,0 +1,70 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import pytest +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers.testing_utils import ( + slow, +) + +from optimum.executorchruntime import ExecuTorchModelForCausalLM + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_load_model_from_hub(self): + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path="NousResearch/Llama-3.2-1B", + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + @slow + @pytest.mark.run_slow + def test_load_model_from_local_path(self): + from optimum.exporters.executorch import main_export + + model_id = "NousResearch/Llama-3.2-1B" + task = "text-generation" + recipe = "xnnpack" + + with tempfile.TemporaryDirectory() as tempdir: + # Export to a local dir + main_export( + model_name_or_path=model_id, + task=task, + recipe=recipe, + output_dir=tempdir, + ) + self.assertTrue(os.path.exists(f"{tempdir}/model.pte")) + + # Load the exported model from a local dir + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=tempdir, + export=False, + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) diff --git a/tests/runtime/test_modeling_gemma.py b/tests/runtime/test_modeling_gemma.py new file mode 100644 index 0000000..0e4238b --- /dev/null +++ b/tests/runtime/test_modeling_gemma.py @@ -0,0 +1,54 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoTokenizer +from transformers.testing_utils import ( + slow, +) + +from optimum.executorchruntime import ExecuTorchModelForCausalLM + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_gemma_text_generation_with_xnnpack(self): + # TODO: Switch to use google/gemma-2b once https://github.com/huggingface/optimum/issues/2127 is fixed + # model_id = "google/gemma-2b" + model_id = "weqweasdas/RM-Gemma-2B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "Hello I am doing a project for my school and I need to write a report on the history of the United States." + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Hello I am doing a project for my school", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) diff --git a/tests/runtime/test_modeling_gemma2.py b/tests/runtime/test_modeling_gemma2.py new file mode 100644 index 0000000..22fe4ab --- /dev/null +++ b/tests/runtime/test_modeling_gemma2.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoTokenizer +from transformers.testing_utils import ( + slow, +) + +from optimum.executorchruntime import ExecuTorchModelForCausalLM + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_gemma2_text_generation_with_xnnpack(self): + # TODO: Switch to use google/gemma-2-2b once https://github.com/huggingface/optimum/issues/2127 is fixed + # model_id = "google/gemma-2-2b" + model_id = "unsloth/gemma-2-2b-it" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = ( + "Hello I am doing a project for my school and I need to make sure it is a great to be creative and I can!" + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Hello I am doing a project for my school", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) diff --git a/tests/runtime/test_modeling_llama.py b/tests/runtime/test_modeling_llama.py new file mode 100644 index 0000000..fb08a56 --- /dev/null +++ b/tests/runtime/test_modeling_llama.py @@ -0,0 +1,83 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoTokenizer +from transformers.testing_utils import ( + slow, +) + +from optimum.executorchruntime import ExecuTorchModelForCausalLM + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_llama3_2_1b_text_generation_with_xnnpack(self): + # TODO: Switch to use meta-llama/Llama-3.2-1B once https://github.com/huggingface/optimum/issues/2127 is fixed + # model_id = "lama/Llama-3.2-1B" + model_id = "NousResearch/Llama-3.2-1B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "Simply put, the theory of relativity states that the laws of physics are the same in all inertial frames of reference." + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) + + @slow + @pytest.mark.run_slow + @pytest.mark.skip(reason="OOMs with macos-15 CI instances on GH.") + def test_llama3_2_3b_text_generation_with_xnnpack(self): + # TODO: Switch to use meta-llama/Llama-3.2-3B once https://github.com/huggingface/optimum/issues/2127 is fixed + # model_id = "lama/Llama-3.2-3B" + model_id = "NousResearch/Hermes-3-Llama-3.2-3B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = ( + "Simply put, the theory of relativity states that time is relative and can be affected " + "by an object's speed. This theory was developed by Albert Einstein in the early 20th " + "century. The theory has two parts" + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) diff --git a/tests/runtime/test_modeling_olmo.py b/tests/runtime/test_modeling_olmo.py new file mode 100644 index 0000000..aa57496 --- /dev/null +++ b/tests/runtime/test_modeling_olmo.py @@ -0,0 +1,54 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoTokenizer +from transformers.testing_utils import ( + slow, +) + +from optimum.executorchruntime import ExecuTorchModelForCausalLM + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_olmo_text_generation_with_xnnpack(self): + model_id = "allenai/OLMo-1B-hf" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = ( + "Simply put, the theory of relativity states that the speed of light is the same in all directions." + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="Simply put, the theory of relativity states that", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) diff --git a/tests/runtime/test_modeling_qwen2.py b/tests/runtime/test_modeling_qwen2.py new file mode 100644 index 0000000..ef624a7 --- /dev/null +++ b/tests/runtime/test_modeling_qwen2.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest +from executorch.extension.pybindings.portable_lib import ExecuTorchModule +from transformers import AutoTokenizer +from transformers.testing_utils import ( + slow, +) + +from optimum.executorchruntime import ExecuTorchModelForCausalLM + + +class ExecuTorchModelIntegrationTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @slow + @pytest.mark.run_slow + def test_qwen2_5_text_generation_with_xnnpack(self): + model_id = "Qwen/Qwen2.5-0.5B" + model = ExecuTorchModelForCausalLM.from_pretrained( + model_name_or_path=model_id, + export=True, + task="text-generation", + recipe="xnnpack", + ) + self.assertIsInstance(model, ExecuTorchModelForCausalLM) + self.assertIsInstance(model.model, ExecuTorchModule) + + EXPECTED_GENERATED_TEXT = "My favourite condiment is iced tea. I love it with my breakfast, my lunch" + tokenizer = AutoTokenizer.from_pretrained(model_id) + generated_text = model.text_generation( + tokenizer=tokenizer, + prompt="My favourite condiment is ", + max_seq_len=len(tokenizer.encode(EXPECTED_GENERATED_TEXT)), + ) + self.assertEqual(generated_text, EXPECTED_GENERATED_TEXT) From e631d2cb69e407e31d47aa0cf6d65ce0c9d632e3 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 22 Jan 2025 15:30:54 +0100 Subject: [PATCH 02/14] add setup --- optimum/executorch/version.py | 15 ++++++++ setup.py | 68 +++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 optimum/executorch/version.py create mode 100644 setup.py diff --git a/optimum/executorch/version.py b/optimum/executorch/version.py new file mode 100644 index 0000000..6b2bfd7 --- /dev/null +++ b/optimum/executorch/version.py @@ -0,0 +1,15 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "0.0.1" diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..866074d --- /dev/null +++ b/setup.py @@ -0,0 +1,68 @@ +import re + +from setuptools import find_namespace_packages, setup + + +# Ensure we match the version set in optimum/executorch/version.py +filepath = "optimum/executorch/version.py" +try: + with open(filepath) as version_file: + (__version__,) = re.findall('__version__ = "(.*)"', version_file.read()) +except Exception as error: + assert False, "Error: Could not open '%s' due %s\n" % (filepath, error) + +INSTALL_REQUIRE = [ + "optimum~=1.23", + "executorch>=0.4.0", + "transformers>=4.46", +] + +TESTS_REQUIRE = [ + "pytest", + "parameterized", + "sentencepiece", + "datasets", + "safetensors", +] + + +QUALITY_REQUIRE = ["black~=23.1", "ruff==0.4.4"] + + +EXTRAS_REQUIRE = { + "tests": TESTS_REQUIRE, + "quality": QUALITY_REQUIRE, +} + + +setup( + name="optimum-executorch", + version=__version__, + description="Optimum Executorch is an interface between the Hugging Face libraries and ExecuTorch", + long_description=open("README.md", "r", encoding="utf-8").read(), + long_description_content_type="text/markdown", + classifiers=[ + "Development Status :: 2 - Pre-Alpha", + "License :: OSI Approved :: Apache Software License", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + keywords="transformers, quantization, inference, executorch", + url="https://github.com/huggingface/optimum", + author="HuggingFace Inc. Special Ops Team", + author_email="hardware@huggingface.co", + license="Apache", + packages=find_namespace_packages(include=["optimum*"]), + install_requires=INSTALL_REQUIRE, + extras_require=EXTRAS_REQUIRE, + python_requires=">=3.9.0", + include_package_data=True, + zip_safe=False, +) From b704d9401f954522ff6acdbf6c95ce6a4f83f5bb Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 22 Jan 2025 15:51:43 +0100 Subject: [PATCH 03/14] add makefile --- Makefile | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 Makefile diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..fec38f4 --- /dev/null +++ b/Makefile @@ -0,0 +1,46 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +SHELL := /bin/bash +CURRENT_DIR = $(shell pwd) +DEFAULT_CLONE_URL := https://github.com/huggingface/optimum-executorch.git +# If CLONE_URL is empty, revert to DEFAULT_CLONE_URL +REAL_CLONE_URL = $(if $(CLONE_URL),$(CLONE_URL),$(DEFAULT_CLONE_URL)) + +.PHONY: style test + +# Run code quality checks +style_check: + black --check . + ruff check . + +style: + black . + ruff check . --fix + +# Run tests for the library +test: + python -m pytest tests + +# Utilities to release to PyPi +build_dist_install_tools: + pip install build + pip install twine + +build_dist: + rm -rf build + rm -rf dist + python -m build + +pypi_upload: build_dist + python -m twine upload dist/* \ No newline at end of file From 94b7766e8447611277526de00ca295c6a33bcbba Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 22 Jan 2025 15:57:09 +0100 Subject: [PATCH 04/14] add workflows --- .github/workflows/quality.yml | 39 +++++++++++++++++ .github/workflows/test_executorch_export.yml | 35 ++++++++++++++++ .github/workflows/test_executorch_runtime.yml | 42 +++++++++++++++++++ 3 files changed, 116 insertions(+) create mode 100644 .github/workflows/quality.yml create mode 100644 .github/workflows/test_executorch_export.yml create mode 100644 .github/workflows/test_executorch_runtime.yml diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml new file mode 100644 index 0000000..a81a520 --- /dev/null +++ b/.github/workflows/quality.yml @@ -0,0 +1,39 @@ +name: Code Quality +on: + push: + branches: + - main + - v*-release + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + quality: + runs-on: ubuntu-22.04 + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: 3.9 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install .[quality] + + - name: Check style with black + run: | + black --check . + + - name: Check style with ruff + run: | + ruff check . \ No newline at end of file diff --git a/.github/workflows/test_executorch_export.yml b/.github/workflows/test_executorch_export.yml new file mode 100644 index 0000000..5be6001 --- /dev/null +++ b/.github/workflows/test_executorch_export.yml @@ -0,0 +1,35 @@ +name: ExecuTorch Export / Python - Test + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11', '3.12'] + os: [macos-15] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for ExecuTorch + run: | + pip install .[tests] + pip list + - name: Run tests + working-directory: tests + run: | + RUN_SLOW=1 pytest executorch/export/test_*.py -s -vvvv --durations=0 diff --git a/.github/workflows/test_executorch_runtime.yml b/.github/workflows/test_executorch_runtime.yml new file mode 100644 index 0000000..3a1c7d8 --- /dev/null +++ b/.github/workflows/test_executorch_runtime.yml @@ -0,0 +1,42 @@ +name: ExecuTorch Runtime / Python - Test + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11', '3.12'] + os: [macos-15] + test-modeling: + - test_modeling_gemma2.py + - test_modeling_gemma.py + - test_modeling_llama.py + - test_modeling_olmo.py + - test_modeling.py + - test_modeling_qwen2.py + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for ExecuTorch + run: | + pip install .[tests] + pip list + - name: Run tests + working-directory: tests + run: | + RUN_SLOW=1 pytest executorch/runtime/${{ matrix.test-modeling }} -s -vvvv --durations=0 From 1f3e63da04615cd02d0ab19256446120f2c1471c Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 22 Jan 2025 18:34:46 +0100 Subject: [PATCH 05/14] add pyproject --- pyproject.toml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e18dc9f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,22 @@ +[tool.black] +line-length = 119 +target-version = ['py37'] + +[tool.ruff] +# Never enforce `E501` (line length violations). +ignore = ["C901", "E501", "E741", "W605"] +select = ["C", "E", "F", "I", "W"] +line-length = 119 + +# Ignore import violations in all `__init__.py` files. +[tool.ruff.per-file-ignores] +"__init__.py" = ["E402", "F401", "F403", "F811"] + +[tool.ruff.isort] +lines-after-imports = 2 +known-first-party = ["optimum"] + +[tool.pytest.ini_options] +markers = [ + "run_slow", +] \ No newline at end of file From e45437d06f5480d5e216f3c71550079451a62b74 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 22 Jan 2025 18:36:34 +0100 Subject: [PATCH 06/14] fix style --- optimum/executorch/modeling.py | 9 +++++---- optimum/exporters/executorch/recipes/xnnpack.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index b93309f..22bd86d 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -20,10 +20,6 @@ from typing import List, Optional, Union import torch -from executorch.extension.pybindings.portable_lib import ( - ExecuTorchModule, - _load_for_executorch, -) from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from transformers import ( AutoModelForCausalLM, @@ -31,6 +27,11 @@ PreTrainedTokenizer, ) +from executorch.extension.pybindings.portable_lib import ( + ExecuTorchModule, + _load_for_executorch, +) + from ..exporters.executorch import main_export from ..modeling_base import OptimizedModel diff --git a/optimum/exporters/executorch/recipes/xnnpack.py b/optimum/exporters/executorch/recipes/xnnpack.py index d3b3a5d..a3cdf47 100644 --- a/optimum/exporters/executorch/recipes/xnnpack.py +++ b/optimum/exporters/executorch/recipes/xnnpack.py @@ -14,14 +14,15 @@ import torch import torch.export._trace +from torch.nn.attention import SDPBackend +from transformers import PreTrainedModel, TorchExportableModuleWithStaticCache + from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.exir import ( EdgeCompileConfig, ExecutorchBackendConfig, to_edge_transform_and_lower, ) -from torch.nn.attention import SDPBackend -from transformers import PreTrainedModel, TorchExportableModuleWithStaticCache from ..recipe_registry import register_recipe From a8823c79db92bb684284d5fda4de66f3dda1f3ca Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 22 Jan 2025 18:44:36 +0100 Subject: [PATCH 07/14] rename --- .github/workflows/quality.yml | 2 +- .../{test_executorch_runtime.yml => test_executorch_models.yml} | 2 +- .../workflows/{test_executorch_export.yml => test_export.yml} | 2 +- tests/{runtime => models}/__init__.py | 0 tests/{runtime => models}/test_modeling.py | 0 tests/{runtime => models}/test_modeling_gemma.py | 0 tests/{runtime => models}/test_modeling_gemma2.py | 0 tests/{runtime => models}/test_modeling_llama.py | 0 tests/{runtime => models}/test_modeling_olmo.py | 0 tests/{runtime => models}/test_modeling_qwen2.py | 0 10 files changed, 3 insertions(+), 3 deletions(-) rename .github/workflows/{test_executorch_runtime.yml => test_executorch_models.yml} (91%) rename .github/workflows/{test_executorch_export.yml => test_export.yml} (90%) rename tests/{runtime => models}/__init__.py (100%) rename tests/{runtime => models}/test_modeling.py (100%) rename tests/{runtime => models}/test_modeling_gemma.py (100%) rename tests/{runtime => models}/test_modeling_gemma2.py (100%) rename tests/{runtime => models}/test_modeling_llama.py (100%) rename tests/{runtime => models}/test_modeling_olmo.py (100%) rename tests/{runtime => models}/test_modeling_qwen2.py (100%) diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index a81a520..0b970dc 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -28,7 +28,7 @@ jobs: - name: Install dependencies run: | pip install --upgrade pip - pip install .[quality] + pip install "black~=23.1" "ruff==0.4.4" - name: Check style with black run: | diff --git a/.github/workflows/test_executorch_runtime.yml b/.github/workflows/test_executorch_models.yml similarity index 91% rename from .github/workflows/test_executorch_runtime.yml rename to .github/workflows/test_executorch_models.yml index 3a1c7d8..e03c851 100644 --- a/.github/workflows/test_executorch_runtime.yml +++ b/.github/workflows/test_executorch_models.yml @@ -39,4 +39,4 @@ jobs: - name: Run tests working-directory: tests run: | - RUN_SLOW=1 pytest executorch/runtime/${{ matrix.test-modeling }} -s -vvvv --durations=0 + RUN_SLOW=1 pytest runtime/${{ matrix.test-modeling }} -s -vvvv --durations=0 diff --git a/.github/workflows/test_executorch_export.yml b/.github/workflows/test_export.yml similarity index 90% rename from .github/workflows/test_executorch_export.yml rename to .github/workflows/test_export.yml index 5be6001..aa8f7e9 100644 --- a/.github/workflows/test_executorch_export.yml +++ b/.github/workflows/test_export.yml @@ -32,4 +32,4 @@ jobs: - name: Run tests working-directory: tests run: | - RUN_SLOW=1 pytest executorch/export/test_*.py -s -vvvv --durations=0 + RUN_SLOW=1 pytest export/test_*.py -s -vvvv --durations=0 diff --git a/tests/runtime/__init__.py b/tests/models/__init__.py similarity index 100% rename from tests/runtime/__init__.py rename to tests/models/__init__.py diff --git a/tests/runtime/test_modeling.py b/tests/models/test_modeling.py similarity index 100% rename from tests/runtime/test_modeling.py rename to tests/models/test_modeling.py diff --git a/tests/runtime/test_modeling_gemma.py b/tests/models/test_modeling_gemma.py similarity index 100% rename from tests/runtime/test_modeling_gemma.py rename to tests/models/test_modeling_gemma.py diff --git a/tests/runtime/test_modeling_gemma2.py b/tests/models/test_modeling_gemma2.py similarity index 100% rename from tests/runtime/test_modeling_gemma2.py rename to tests/models/test_modeling_gemma2.py diff --git a/tests/runtime/test_modeling_llama.py b/tests/models/test_modeling_llama.py similarity index 100% rename from tests/runtime/test_modeling_llama.py rename to tests/models/test_modeling_llama.py diff --git a/tests/runtime/test_modeling_olmo.py b/tests/models/test_modeling_olmo.py similarity index 100% rename from tests/runtime/test_modeling_olmo.py rename to tests/models/test_modeling_olmo.py diff --git a/tests/runtime/test_modeling_qwen2.py b/tests/models/test_modeling_qwen2.py similarity index 100% rename from tests/runtime/test_modeling_qwen2.py rename to tests/models/test_modeling_qwen2.py From 8635ef5ad58e3466a06aba40dc230f4e5a26436a Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 22 Jan 2025 18:47:37 +0100 Subject: [PATCH 08/14] fix --- .github/workflows/test_executorch_models.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_executorch_models.yml b/.github/workflows/test_executorch_models.yml index e03c851..39eaf49 100644 --- a/.github/workflows/test_executorch_models.yml +++ b/.github/workflows/test_executorch_models.yml @@ -39,4 +39,4 @@ jobs: - name: Run tests working-directory: tests run: | - RUN_SLOW=1 pytest runtime/${{ matrix.test-modeling }} -s -vvvv --durations=0 + RUN_SLOW=1 pytest models/${{ matrix.test-modeling }} -s -vvvv --durations=0 From a8831c404a3b1697bdaa100f227fc847a236a103 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 22 Jan 2025 18:52:53 +0100 Subject: [PATCH 09/14] fix --- .../{test_executorch_models.yml => test_models.yml} | 0 tests/models/test_modeling.py | 6 ++---- tests/models/test_modeling_gemma.py | 6 ++---- tests/models/test_modeling_gemma2.py | 6 ++---- tests/models/test_modeling_llama.py | 6 ++---- tests/models/test_modeling_olmo.py | 6 ++---- tests/models/test_modeling_qwen2.py | 6 ++---- 7 files changed, 12 insertions(+), 24 deletions(-) rename .github/workflows/{test_executorch_models.yml => test_models.yml} (100%) diff --git a/.github/workflows/test_executorch_models.yml b/.github/workflows/test_models.yml similarity index 100% rename from .github/workflows/test_executorch_models.yml rename to .github/workflows/test_models.yml diff --git a/tests/models/test_modeling.py b/tests/models/test_modeling.py index c97b461..90ee504 100644 --- a/tests/models/test_modeling.py +++ b/tests/models/test_modeling.py @@ -19,11 +19,9 @@ import pytest from executorch.extension.pybindings.portable_lib import ExecuTorchModule -from transformers.testing_utils import ( - slow, -) +from transformers.testing_utils import slow -from optimum.executorchruntime import ExecuTorchModelForCausalLM +from optimum.executorch import ExecuTorchModelForCausalLM class ExecuTorchModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/test_modeling_gemma.py b/tests/models/test_modeling_gemma.py index 0e4238b..49a9cf2 100644 --- a/tests/models/test_modeling_gemma.py +++ b/tests/models/test_modeling_gemma.py @@ -18,11 +18,9 @@ import pytest from executorch.extension.pybindings.portable_lib import ExecuTorchModule from transformers import AutoTokenizer -from transformers.testing_utils import ( - slow, -) +from transformers.testing_utils import slow -from optimum.executorchruntime import ExecuTorchModelForCausalLM +from optimum.executorch import ExecuTorchModelForCausalLM class ExecuTorchModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/test_modeling_gemma2.py b/tests/models/test_modeling_gemma2.py index 22fe4ab..4c875bb 100644 --- a/tests/models/test_modeling_gemma2.py +++ b/tests/models/test_modeling_gemma2.py @@ -18,11 +18,9 @@ import pytest from executorch.extension.pybindings.portable_lib import ExecuTorchModule from transformers import AutoTokenizer -from transformers.testing_utils import ( - slow, -) +from transformers.testing_utils import slow -from optimum.executorchruntime import ExecuTorchModelForCausalLM +from optimum.executorch import ExecuTorchModelForCausalLM class ExecuTorchModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/test_modeling_llama.py b/tests/models/test_modeling_llama.py index fb08a56..cf13a32 100644 --- a/tests/models/test_modeling_llama.py +++ b/tests/models/test_modeling_llama.py @@ -18,11 +18,9 @@ import pytest from executorch.extension.pybindings.portable_lib import ExecuTorchModule from transformers import AutoTokenizer -from transformers.testing_utils import ( - slow, -) +from transformers.testing_utils import slow -from optimum.executorchruntime import ExecuTorchModelForCausalLM +from optimum.executorch import ExecuTorchModelForCausalLM class ExecuTorchModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/test_modeling_olmo.py b/tests/models/test_modeling_olmo.py index aa57496..21f4694 100644 --- a/tests/models/test_modeling_olmo.py +++ b/tests/models/test_modeling_olmo.py @@ -18,11 +18,9 @@ import pytest from executorch.extension.pybindings.portable_lib import ExecuTorchModule from transformers import AutoTokenizer -from transformers.testing_utils import ( - slow, -) +from transformers.testing_utils import slow -from optimum.executorchruntime import ExecuTorchModelForCausalLM +from optimum.executorch import ExecuTorchModelForCausalLM class ExecuTorchModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/test_modeling_qwen2.py b/tests/models/test_modeling_qwen2.py index ef624a7..aa57493 100644 --- a/tests/models/test_modeling_qwen2.py +++ b/tests/models/test_modeling_qwen2.py @@ -18,11 +18,9 @@ import pytest from executorch.extension.pybindings.portable_lib import ExecuTorchModule from transformers import AutoTokenizer -from transformers.testing_utils import ( - slow, -) +from transformers.testing_utils import slow -from optimum.executorchruntime import ExecuTorchModelForCausalLM +from optimum.executorch import ExecuTorchModelForCausalLM class ExecuTorchModelIntegrationTest(unittest.TestCase): From 54c2cc374007b10cfbf7bf9fbdf69d7667baaa26 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 22 Jan 2025 18:55:25 +0100 Subject: [PATCH 10/14] optimum from source --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 866074d..20d5327 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,8 @@ assert False, "Error: Could not open '%s' due %s\n" % (filepath, error) INSTALL_REQUIRE = [ - "optimum~=1.23", + # "optimum~=1.24", + "optimum@git+https://github.com/huggingface/optimum.git", "executorch>=0.4.0", "transformers>=4.46", ] From fe10bc26076f745e84c2afe60e818b7e5f4b5db6 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 22 Jan 2025 18:59:17 +0100 Subject: [PATCH 11/14] add accelerate tests requirements --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 20d5327..ce36549 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ ] TESTS_REQUIRE = [ + "accelerate>=0.26.0", "pytest", "parameterized", "sentencepiece", From 74bf56dd8df92bc19c88f9aa96a4075272ab9264 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 22 Jan 2025 19:13:21 +0100 Subject: [PATCH 12/14] add command info --- optimum/commands/export/executorch.py | 4 +++- setup.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py index 2bf2f1d..dbc26c2 100644 --- a/optimum/commands/export/executorch.py +++ b/optimum/commands/export/executorch.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING from ...exporters import TasksManager -from ..base import BaseOptimumCLICommand +from ..base import BaseOptimumCLICommand, CommandInfo if TYPE_CHECKING: @@ -52,6 +52,8 @@ def parse_args_executorch(parser): class ExecuTorchExportCommand(BaseOptimumCLICommand): + COMMAND = CommandInfo(name="executorch", help="Export models to ExecuTorch.") + @staticmethod def parse_args(parser: "ArgumentParser"): return parse_args_executorch(parser) diff --git a/setup.py b/setup.py index ce36549..803f1f0 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ INSTALL_REQUIRE = [ # "optimum~=1.24", - "optimum@git+https://github.com/huggingface/optimum.git", + "optimum@git+https://github.com/huggingface/optimum.git@optimum-executorch", "executorch>=0.4.0", "transformers>=4.46", ] From 2a9b828771e8f3b49d0cb372b68bd7929089bce1 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 23 Jan 2025 10:43:16 +0100 Subject: [PATCH 13/14] update copyright --- optimum/commands/export/executorch.py | 18 ++++++++++-------- optimum/executorch/__init__.py | 18 ++++++++++-------- optimum/executorch/modeling.py | 18 ++++++++++-------- optimum/exporters/executorch/__init__.py | 18 ++++++++++-------- optimum/exporters/executorch/__main__.py | 18 ++++++++++-------- optimum/exporters/executorch/convert.py | 18 ++++++++++-------- .../exporters/executorch/recipe_registry.py | 18 ++++++++++-------- .../exporters/executorch/recipes/__init__.py | 18 ++++++++++-------- .../exporters/executorch/recipes/xnnpack.py | 18 ++++++++++-------- optimum/exporters/executorch/task_registry.py | 18 ++++++++++-------- optimum/exporters/executorch/tasks/__init__.py | 18 ++++++++++-------- .../exporters/executorch/tasks/causal_lm.py | 18 ++++++++++-------- 12 files changed, 120 insertions(+), 96 deletions(-) diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py index dbc26c2..62ca5b0 100644 --- a/optimum/commands/export/executorch.py +++ b/optimum/commands/export/executorch.py @@ -1,14 +1,16 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Defines the command line for the export with ExecuTorch.""" diff --git a/optimum/executorch/__init__.py b/optimum/executorch/__init__.py index cbc9b37..ed88421 100644 --- a/optimum/executorch/__init__.py +++ b/optimum/executorch/__init__.py @@ -1,14 +1,16 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import TYPE_CHECKING diff --git a/optimum/executorch/modeling.py b/optimum/executorch/modeling.py index 22bd86d..d33c6c1 100644 --- a/optimum/executorch/modeling.py +++ b/optimum/executorch/modeling.py @@ -1,14 +1,16 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ExecuTorchModelForXXX classes, allowing to run ExecuTorch Models with ExecuTorch Runtime using the same API as Transformers.""" diff --git a/optimum/exporters/executorch/__init__.py b/optimum/exporters/executorch/__init__.py index 3409e69..3636354 100644 --- a/optimum/exporters/executorch/__init__.py +++ b/optimum/exporters/executorch/__init__.py @@ -1,14 +1,16 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import TYPE_CHECKING diff --git a/optimum/exporters/executorch/__main__.py b/optimum/exporters/executorch/__main__.py index e3b561f..bb0b30c 100644 --- a/optimum/exporters/executorch/__main__.py +++ b/optimum/exporters/executorch/__main__.py @@ -1,14 +1,16 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Entry point to the optimum.exporters.executorch command line.""" diff --git a/optimum/exporters/executorch/convert.py b/optimum/exporters/executorch/convert.py index aceb733..c4c0108 100644 --- a/optimum/exporters/executorch/convert.py +++ b/optimum/exporters/executorch/convert.py @@ -1,14 +1,16 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ExecuTorch model check and export functions.""" diff --git a/optimum/exporters/executorch/recipe_registry.py b/optimum/exporters/executorch/recipe_registry.py index 2eb728b..52df5cc 100644 --- a/optimum/exporters/executorch/recipe_registry.py +++ b/optimum/exporters/executorch/recipe_registry.py @@ -1,14 +1,16 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import importlib import logging diff --git a/optimum/exporters/executorch/recipes/__init__.py b/optimum/exporters/executorch/recipes/__init__.py index a2e21cf..833692a 100644 --- a/optimum/exporters/executorch/recipes/__init__.py +++ b/optimum/exporters/executorch/recipes/__init__.py @@ -1,13 +1,15 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from . import xnnpack diff --git a/optimum/exporters/executorch/recipes/xnnpack.py b/optimum/exporters/executorch/recipes/xnnpack.py index a3cdf47..4b581ff 100644 --- a/optimum/exporters/executorch/recipes/xnnpack.py +++ b/optimum/exporters/executorch/recipes/xnnpack.py @@ -1,14 +1,16 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Union diff --git a/optimum/exporters/executorch/task_registry.py b/optimum/exporters/executorch/task_registry.py index fdc34f0..f186f76 100644 --- a/optimum/exporters/executorch/task_registry.py +++ b/optimum/exporters/executorch/task_registry.py @@ -1,14 +1,16 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import importlib import logging diff --git a/optimum/exporters/executorch/tasks/__init__.py b/optimum/exporters/executorch/tasks/__init__.py index 754a824..12155cd 100644 --- a/optimum/exporters/executorch/tasks/__init__.py +++ b/optimum/exporters/executorch/tasks/__init__.py @@ -1,13 +1,15 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from . import causal_lm diff --git a/optimum/exporters/executorch/tasks/causal_lm.py b/optimum/exporters/executorch/tasks/causal_lm.py index b02da8b..4024393 100644 --- a/optimum/exporters/executorch/tasks/causal_lm.py +++ b/optimum/exporters/executorch/tasks/causal_lm.py @@ -1,14 +1,16 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from transformers import AutoModelForCausalLM, GenerationConfig From ed91e2171c428a4de3bba07fe070833c250f59ac Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 23 Jan 2025 15:56:59 +0100 Subject: [PATCH 14/14] branch no longer exist --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 803f1f0..ce36549 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ INSTALL_REQUIRE = [ # "optimum~=1.24", - "optimum@git+https://github.com/huggingface/optimum.git@optimum-executorch", + "optimum@git+https://github.com/huggingface/optimum.git", "executorch>=0.4.0", "transformers>=4.46", ]