diff --git a/optimum/commands/export/neuronx.py b/optimum/commands/export/neuronx.py index 58dba4923..fc1d2c73e 100644 --- a/optimum/commands/export/neuronx.py +++ b/optimum/commands/export/neuronx.py @@ -85,6 +85,11 @@ def parse_args_neuronx(parser: "ArgumentParser"): type=Path, help="Path indicating the directory where to store intermediary files generated by Neuronx compiler.", ) + optional_group.add_argument( + "--disable-weights-neff-inline", + action="store_true", + help="Whether to disable the weights / neff graph inline. You can only replace weights of neuron-compiled models when the weights-neff inlining has been disabled during the compilation.", + ) optional_group.add_argument( "--disable-validation", action="store_true", diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index a4c2eb28c..8db4f4a75 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -363,6 +363,7 @@ def main_export( atol: Optional[float] = None, cache_dir: Optional[str] = None, compiler_workdir: Optional[Union[str, Path]] = None, + inline_weights_to_neff: bool = True, optlevel: str = "2", trust_remote_code: bool = False, subfolder: str = "", @@ -415,6 +416,7 @@ def main_export( models_and_neuron_configs=models_and_neuron_configs, output_dir=output, compiler_workdir=compiler_workdir, + inline_weights_to_neff=inline_weights_to_neff, optlevel=optlevel, output_file_names=output_model_names, compiler_kwargs=compiler_kwargs, @@ -523,6 +525,7 @@ def main(): atol=args.atol, cache_dir=args.cache_dir, compiler_workdir=args.compiler_workdir, + inline_weights_to_neff=not args.disable_weights_neff_inline, optlevel=optlevel, trust_remote_code=args.trust_remote_code, subfolder=args.subfolder, diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py index 08fa1b21b..63e85cd44 100644 --- a/optimum/exporters/neuron/convert.py +++ b/optimum/exporters/neuron/convert.py @@ -273,6 +273,7 @@ def export_models( ], output_dir: Path, compiler_workdir: Optional[Path] = None, + inline_weights_to_neff: bool = True, optlevel: str = "2", output_file_names: Optional[Dict[str, str]] = None, compiler_kwargs: Optional[Dict[str, Any]] = {}, @@ -288,6 +289,8 @@ def export_models( Output directory to store the exported Neuron models. compiler_workdir (`Optional[Path]`, defaults to `None`): The directory to store intermediary outputs of the neuron compiler. + inline_weights_to_neff (`bool`, defaults to `True`): + Whether to inline the weights to the neff graph. If set to False, weights will be seperated from the neff. optlevel (`str`, defaults to `"2"`): The level of optimization the compiler should perform. Can be `"1"`, `"2"` or `"3"`, defaults to "2". 1: enables the core performance optimizations in the compiler, while also minimizing compile time. @@ -334,6 +337,7 @@ def export_models( config=sub_neuron_config, output=output_path, compiler_workdir=compiler_workdir_path, + inline_weights_to_neff=inline_weights_to_neff, optlevel=optlevel, **compiler_kwargs, ) @@ -362,6 +366,7 @@ def export_models( dynamic_batch_size=sub_neuron_config.dynamic_batch_size, compiler_type=NEURON_COMPILER_TYPE, compiler_version=NEURON_COMPILER_VERSION, + inline_weights_to_neff=inline_weights_to_neff, optlevel=optlevel, model_type=getattr(sub_neuron_config, "MODEL_TYPE", None), task=getattr(sub_neuron_config, "task", None), @@ -392,6 +397,7 @@ def export( config: "NeuronDefaultConfig", output: Path, compiler_workdir: Optional[Path] = None, + inline_weights_to_neff: bool = True, optlevel: str = "2", auto_cast: Optional[str] = None, auto_cast_type: str = "bf16", @@ -406,6 +412,7 @@ def export( config=config, output=output, compiler_workdir=compiler_workdir, + inline_weights_to_neff=inline_weights_to_neff, optlevel=optlevel, auto_cast=auto_cast, auto_cast_type=auto_cast_type, @@ -421,6 +428,7 @@ def export_neuronx( config: "NeuronDefaultConfig", output: Path, compiler_workdir: Optional[Path] = None, + inline_weights_to_neff: bool = True, optlevel: str = "2", auto_cast: Optional[str] = None, auto_cast_type: str = "bf16", @@ -437,6 +445,8 @@ def export_neuronx( Directory to store the exported Neuron model. compiler_workdir (`Optional[Path]`, defaults to `None`): The directory used by neuronx-cc, where you can find intermediary outputs (neff, weight, hlo...). + inline_weights_to_neff (`bool`, defaults to `True`): + Whether to inline the weights to the neff graph. If set to False, weights will be seperated from the neff. optlevel (`str`, defaults to `"2"`): The level of optimization the compiler should perform. Can be `"1"`, `"2"` or `"3"`, defaults to "2". 1: enables the core performance optimizations in the compiler, while also minimizing compile time. @@ -504,10 +514,15 @@ def export_neuronx( dummy_inputs_tuple, compiler_args=compiler_args, input_output_aliases=aliases, + inline_weights_to_neff=inline_weights_to_neff, compiler_workdir=compiler_workdir, ) if config.dynamic_batch_size is True: + if not inline_weights_to_neff: + raise ValueError( + "Dynamic batching is not yet compatible with the weights/neff non-inlined model. Please set `dynamic_batch_size=False` or `inline_weights_to_neff=True`." + ) neuron_model = neuronx.dynamic_batch(neuron_model) # diffusers specific diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index 5b0f20786..a08da0826 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -69,7 +69,7 @@ @register_in_tasks_manager("bert", *COMMON_TEXT_TASKS) class BertNeuronConfig(TextEncoderNeuronConfig): NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("bert") - ATOL_FOR_VALIDATION = 1e-4 + ATOL_FOR_VALIDATION = 1e-3 @property def inputs(self) -> List[str]: @@ -83,6 +83,8 @@ class AlbertNeuronConfig(BertNeuronConfig): @register_in_tasks_manager("convbert", *COMMON_TEXT_TASKS) class ConvBertNeuronConfig(BertNeuronConfig): + ATOL_FOR_VALIDATION = 1e-1 # TODO: why accuracy more off than other arch + @property def outputs(self) -> List[str]: if self.task == "feature-extraction": @@ -91,12 +93,16 @@ def outputs(self) -> List[str]: @register_in_tasks_manager("electra", *COMMON_TEXT_TASKS) -class ElectraNeuronConfig(ConvBertNeuronConfig): - pass +class ElectraNeuronConfig(BertNeuronConfig): + @property + def outputs(self) -> List[str]: + if self.task == "feature-extraction": + return ["last_hidden_state"] + return self._TASK_TO_COMMON_OUTPUTS[self.task] @register_in_tasks_manager("flaubert", *COMMON_TEXT_TASKS) -class FlaubertNeuronConfig(ConvBertNeuronConfig): +class FlaubertNeuronConfig(ElectraNeuronConfig): pass @@ -106,18 +112,18 @@ class MobileBertNeuronConfig(BertNeuronConfig): @register_in_tasks_manager("roformer", *COMMON_TEXT_TASKS) -class RoFormerNeuronConfig(ConvBertNeuronConfig): +class RoFormerNeuronConfig(ElectraNeuronConfig): pass @register_in_tasks_manager("xlm", *COMMON_TEXT_TASKS) -class XLMNeuronConfig(ConvBertNeuronConfig): +class XLMNeuronConfig(ElectraNeuronConfig): pass @register_in_tasks_manager("distilbert", *COMMON_TEXT_TASKS) class DistilBertNeuronConfig(BertNeuronConfig): - ATOL_FOR_VALIDATION = 1e-4 + ATOL_FOR_VALIDATION = 1e-3 @property def inputs(self) -> List[str]: @@ -132,7 +138,7 @@ def outputs(self) -> List[str]: @register_in_tasks_manager("camembert", *COMMON_TEXT_TASKS) class CamembertNeuronConfig(BertNeuronConfig): - ATOL_FOR_VALIDATION = 1e-4 + ATOL_FOR_VALIDATION = 1e-3 @property def inputs(self) -> List[str]: @@ -156,8 +162,8 @@ class XLMRobertaNeuronConfig(CamembertNeuronConfig): # https://github.com/aws-neuron/aws-neuron-sdk/issues/642 # Failed only for INF1: 'XSoftmax' -@register_in_tasks_manager("deberta", *COMMON_TEXT_TASKS) -class DebertaNeuronConfig(BertNeuronConfig): +@register_in_tasks_manager("deberta", *([task for task in COMMON_TEXT_TASKS if task != "multiple-choice"])) +class DebertaNeuronConfig(ElectraNeuronConfig): @property def inputs(self) -> List[str]: common_inputs = super().inputs @@ -169,8 +175,8 @@ def inputs(self) -> List[str]: # https://github.com/aws-neuron/aws-neuron-sdk/issues/642 # Failed only for INF1: 'XSoftmax' -@register_in_tasks_manager("deberta-v2", *COMMON_TEXT_TASKS) -class DebertaV2NeuronConfig(DebertaNeuronConfig): +@register_in_tasks_manager("deberta-v2", *([task for task in COMMON_TEXT_TASKS if task != "multiple-choice"])) +class DebertaV2NeuronConfig(ElectraNeuronConfig): pass diff --git a/optimum/neuron/modeling_base.py b/optimum/neuron/modeling_base.py index 1f3f0b108..e26d42afe 100644 --- a/optimum/neuron/modeling_base.py +++ b/optimum/neuron/modeling_base.py @@ -32,7 +32,13 @@ from ..exporters.tasks import TasksManager from ..modeling_base import OptimizedModel from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors -from .utils import NEURON_FILE_NAME, is_neuron_available, store_compilation_config +from .utils import ( + NEURON_FILE_NAME, + check_if_weights_replacable, + is_neuron_available, + replace_weights, + store_compilation_config, +) from .utils.import_utils import is_neuronx_available from .utils.version_utils import check_compiler_compatibility, get_neuroncc_version, get_neuronxcc_version @@ -103,7 +109,13 @@ def load_model(path: Union[str, Path]) -> torch.jit._script.ScriptModule: path = Path(path) if path.is_file(): - return torch.jit.load(path) + model = torch.jit.load(path) + return model + + def replace_weights(self, weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]] = None): + check_if_weights_replacable(self.config, weights) + if weights is not None: + replace_weights(self.model, weights) def _save_pretrained(self, save_directory: Union[str, Path]): """ @@ -216,6 +228,7 @@ def _export( force_download: bool = False, cache_dir: Optional[str] = None, compiler_workdir: Optional[Union[str, Path]] = None, + inline_weights_to_neff: bool = True, optlevel: str = "2", subfolder: str = "", local_files_only: bool = False, @@ -303,6 +316,7 @@ def _export( config=neuron_config, output=save_dir_path / NEURON_FILE_NAME, compiler_workdir=compiler_workdir, + inline_weights_to_neff=inline_weights_to_neff, optlevel=optlevel, **compiler_kwargs, ) @@ -316,6 +330,7 @@ def _export( dynamic_batch_size=dynamic_batch_size, compiler_type=compiler_type, compiler_version=compiler_version, + inline_weights_to_neff=inline_weights_to_neff, optlevel=optlevel, task=task, ) @@ -570,3 +585,10 @@ def remove_padding( ] return outputs + + @property + def is_weights_neff_separated(self) -> bool: + """ + Whether the Neuron model has separated weights and neff graph (by setting `inline_weights_to_neff=False` during the compilation). + """ + return not self.config.neuron.get("inline_weights_to_neff", True) diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 34747a191..e6b9080ee 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -540,6 +540,7 @@ def _export( force_download: bool = True, cache_dir: Optional[str] = None, compiler_workdir: Optional[str] = None, + inline_weights_to_neff: bool = True, optlevel: str = "2", subfolder: str = "", local_files_only: bool = False, @@ -580,6 +581,8 @@ def _export( standard cache should not be used. compiler_workdir (`Optional[str]`, defaults to `None`): Path to a directory in which the neuron compiler will store all intermediary files during the compilation(neff, weight, hlo graph...). + inline_weights_to_neff (`bool`, defaults to `True`): + Whether to inline the weights to the neff graph. If set to False, weights will be seperated from the neff. optlevel (`str`, defaults to `"2"`): The level of optimization the compiler should perform. Can be `"1"`, `"2"` or `"3"`, defaults to "2". 1: enables the core performance optimizations in the compiler, while also minimizing compile time. @@ -640,6 +643,7 @@ def _export( dynamic_batch_size=dynamic_batch_size, cache_dir=cache_dir, compiler_workdir=compiler_workdir, + inline_weights_to_neff=inline_weights_to_neff, optlevel=optlevel, trust_remote_code=trust_remote_code, subfolder=subfolder, diff --git a/optimum/neuron/modeling_seq2seq.py b/optimum/neuron/modeling_seq2seq.py index 048f189b3..c212a3848 100644 --- a/optimum/neuron/modeling_seq2seq.py +++ b/optimum/neuron/modeling_seq2seq.py @@ -260,6 +260,7 @@ def _export( force_download: bool = True, cache_dir: Optional[str] = None, compiler_workdir: Optional[str] = None, + inline_weights_to_neff: bool = True, optlevel: str = "2", subfolder: str = "", local_files_only: bool = False, @@ -302,6 +303,7 @@ def _export( dynamic_batch_size=dynamic_batch_size, cache_dir=cache_dir, compiler_workdir=compiler_workdir, + inline_weights_to_neff=inline_weights_to_neff, optlevel=optlevel, trust_remote_code=trust_remote_code, subfolder=subfolder, diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 15a51ee0b..764426674 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -35,6 +35,7 @@ is_transformers_neuronx_available, ) from .input_generators import DummyBeamValuesGenerator +from .misc import check_if_weights_replacable, replace_weights from .optimization_utils import get_attention_scores_sd, get_attention_scores_sdxl from .patching import DynamicPatch, ModelPatcher, Patcher, patch_everywhere, patch_within_function from .training_utils import ( diff --git a/optimum/neuron/utils/argument_utils.py b/optimum/neuron/utils/argument_utils.py index 9cc7ec68b..208535796 100644 --- a/optimum/neuron/utils/argument_utils.py +++ b/optimum/neuron/utils/argument_utils.py @@ -145,6 +145,7 @@ def store_compilation_config( dynamic_batch_size: bool, compiler_type: str, compiler_version: str, + inline_weights_to_neff: bool, optlevel: str, model_type: Optional[str] = None, task: str = None, @@ -161,6 +162,7 @@ def store_compilation_config( # Add neuron version to the config, so it can be checked at load time config_args["compiler_type"] = compiler_type config_args["compiler_version"] = compiler_version + config_args["inline_weights_to_neff"] = inline_weights_to_neff # Add input shapes during compilation to the config for axis, shape in input_shapes.items(): diff --git a/optimum/neuron/utils/misc.py b/optimum/neuron/utils/misc.py index 21bf56e1e..9b21c4e4a 100644 --- a/optimum/neuron/utils/misc.py +++ b/optimum/neuron/utils/misc.py @@ -18,7 +18,7 @@ import os import re from pathlib import Path -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union import torch from transformers.modeling_utils import _add_variant @@ -42,6 +42,9 @@ from .require_utils import requires_safetensors +if TYPE_CHECKING: + from transformers import PretrainedConfig + logger = logging.get_logger() @@ -508,3 +511,40 @@ def download_checkpoints_in_cache( resolved_archive_file = filenames_to_safetensors_filenames[Path(resolved_archive_file).name] return resolved_archive_file, sharded_metadata + + +def replace_weights( + model: torch.jit._script.RecursiveScriptModule, + weights: Union[Dict[str, torch.Tensor], torch.nn.Module], + prefix: str = "model", +): + """ + Replaces the weights in a Neuron Model with weights from another model, the original neuron model should have separated weights(by setting `inline_weights_to_neff=Talse` during the tracing). + """ + if isinstance(weights, torch.nn.Module): + weights = weights.state_dict() + + # extract module paths from the weights c module + code = model.weights._c.code + start_str = "__parameters__ = [" + end_str = "]\n" + module_paths = code.split(start_str)[1].split(end_str)[0].strip()[:-1:].replace('"', "").split(", ") + module_paths = [module_path for module_path in module_paths if module_path != ""] + + for module_path in module_paths: + if len(re.findall("\w\d+", module_path)) > 0: + continue + else: + model.weights._c.setattr(module_path, weights[module_path.replace(prefix + "->", "").replace("->", ".")]) + + +def check_if_weights_replacable( + config: "PretrainedConfig", weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]] +): + is_weights_neff_separated = ( + not config.neuron.get("inline_weights_to_neff", True) if hasattr(config, "neuron") else False + ) + if weights is not None and not is_weights_neff_separated: + raise RuntimeError( + "Unable to replace weights of the neuron model since its weights and neff are not separated, please set `inline_weights_to_neff=Talse` when converting the model to Neuron format." + ) diff --git a/setup.py b/setup.py index a89b684af..77eea2506 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ "diffusers >= 0.25.0", "safetensors", "sentence-transformers >= 2.2.0", + "sacremoses", ] QUALITY_REQUIRES = [ diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 419d689cd..c373e5588 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -19,8 +19,8 @@ "bert": "hf-internal-testing/tiny-random-BertModel", "camembert": "hf-internal-testing/tiny-random-camembert", "convbert": "hf-internal-testing/tiny-random-ConvBertModel", - # "deberta": "hf-internal-testing/tiny-random-DebertaModel", # Failed for INF1: 'XSoftmax' - # "deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model", # Failed for INF1: 'XSoftmax' + "deberta": "hf-internal-testing/tiny-random-DebertaModel", # Failed for INF1: 'XSoftmax' + "deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model", # Failed for INF1: 'XSoftmax' "distilbert": "hf-internal-testing/tiny-random-DistilBertModel", "electra": "hf-internal-testing/tiny-random-ElectraModel", "flaubert": "flaubert/flaubert_small_cased", @@ -47,4 +47,6 @@ "sentence-transformers-clip": "sentence-transformers/clip-ViT-B-32", } +WEIGHTS_NEFF_SEPARATION_UNSUPPORTED_ARCH = ["camembert", "roberta"] + SEED = 42 diff --git a/tests/exporters/test_export.py b/tests/exporters/test_export.py index d7f24bd22..f59656252 100644 --- a/tests/exporters/test_export.py +++ b/tests/exporters/test_export.py @@ -19,7 +19,7 @@ import unittest from pathlib import Path from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import Dict +from typing import Dict, List, Optional from parameterized import parameterized from transformers import AutoConfig, AutoModelForSeq2SeqLM, set_seed @@ -36,6 +36,7 @@ from optimum.exporters.neuron.__main__ import _get_submodels_and_neuron_configs from optimum.exporters.neuron.model_configs import * # noqa: F403 from optimum.exporters.tasks import TasksManager +from optimum.neuron.utils import is_neuron_available from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx from optimum.utils import DEFAULT_DUMMY_SHAPES, is_diffusers_available, logging from optimum.utils.testing_utils import require_diffusers, require_sentence_transformers @@ -45,6 +46,7 @@ EXPORT_MODELS_TINY, SENTENCE_TRANSFORMERS_MODELS, STABLE_DIFFUSION_MODELS_TINY, + WEIGHTS_NEFF_SEPARATION_UNSUPPORTED_ARCH, ) @@ -56,35 +58,39 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def _get_models_to_test(export_models_dict: Dict): +def _get_models_to_test( + export_models_dict: Dict, + exclude_model_types: Optional[List[str]] = None, +): models_to_test = [] for model_type, model_names_tasks in export_models_dict.items(): model_type = model_type.replace("_", "-") - task_config_mapping = TasksManager.get_supported_tasks_for_model_type(model_type, "neuron") + if exclude_model_types is None or (model_type not in exclude_model_types): + task_config_mapping = TasksManager.get_supported_tasks_for_model_type(model_type, "neuron") - if isinstance(model_names_tasks, str): # test export of all tasks on the same model - tasks = list(task_config_mapping.keys()) - model_tasks = {model_names_tasks: tasks} - else: - n_tested_tasks = sum(len(tasks) for tasks in model_names_tasks.values()) - if n_tested_tasks != len(task_config_mapping): - logger.warning(f"Not all tasks are tested for {model_type}.") - model_tasks = model_names_tasks # possibly, test different tasks on different models - - for model_name, tasks in model_tasks.items(): - for task in tasks: - default_shapes = dict(DEFAULT_DUMMY_SHAPES) - neuron_config_constructor = TasksManager.get_exporter_config_constructor( - model_type=model_type, - exporter="neuron", - task=task, - model_name=model_name, - exporter_config_kwargs={**default_shapes}, - ) - - models_to_test.append( - (f"{model_type}_{task}", model_type, model_name, task, neuron_config_constructor) - ) + if isinstance(model_names_tasks, str): # test export of all tasks on the same model + tasks = list(task_config_mapping.keys()) + model_tasks = {model_names_tasks: tasks} + else: + n_tested_tasks = sum(len(tasks) for tasks in model_names_tasks.values()) + if n_tested_tasks != len(task_config_mapping): + logger.warning(f"Not all tasks are tested for {model_type}.") + model_tasks = model_names_tasks # possibly, test different tasks on different models + + for model_name, tasks in model_tasks.items(): + for task in tasks: + default_shapes = dict(DEFAULT_DUMMY_SHAPES) + neuron_config_constructor = TasksManager.get_exporter_config_constructor( + model_type=model_type, + exporter="neuron", + task=task, + model_name=model_name, + exporter_config_kwargs={**default_shapes}, + ) + + models_to_test.append( + (f"{model_type}_{task}", model_type, model_name, task, neuron_config_constructor) + ) random_pick = os.environ.get("MAX_EXPORT_TEST_COMBINATIONS", None) if random_pick is not None: @@ -98,6 +104,11 @@ class NeuronExportTestCase(unittest.TestCase): Integration tests ensuring supported models are correctly exported. """ + if is_neuron_available(): + # Deberta has 'XSoftmax' unsupported on INF1 + for model in ["deberta", "deberta-v2"]: + EXPORT_MODELS_TINY.pop(model) + def _neuronx_export( self, test_name: str, @@ -106,6 +117,7 @@ def _neuronx_export( task: str, neuron_config_constructor: "NeuronDefaultConfig", dynamic_batch_size: bool = False, + inline_weights_to_neff: bool = True, ): if "sentence-transformers" in model_type: model_class = TasksManager.get_model_class_for_task(task, framework="pt", library="sentence_transformers") @@ -136,6 +148,7 @@ def _neuronx_export( model=model, config=neuron_config, output=Path(output.name), + inline_weights_to_neff=inline_weights_to_neff, ) validate_model_outputs( @@ -153,6 +166,16 @@ def _neuronx_export( def test_export(self, test_name, name, model_name, task, neuron_config_constructor): self._neuronx_export(test_name, name, model_name, task, neuron_config_constructor) + @parameterized.expand( + _get_models_to_test(EXPORT_MODELS_TINY, exclude_model_types=WEIGHTS_NEFF_SEPARATION_UNSUPPORTED_ARCH) + ) + @is_inferentia_test + @requires_neuronx + def test_export_separated_weights(self, test_name, name, model_name, task, neuron_config_constructor): + self._neuronx_export( + test_name, name, model_name, task, neuron_config_constructor, inline_weights_to_neff=False + ) + @parameterized.expand(_get_models_to_test(SENTENCE_TRANSFORMERS_MODELS)) @is_inferentia_test @require_vision diff --git a/tests/inference/inference_utils.py b/tests/inference/inference_utils.py index 46e64bb7b..52678bd92 100644 --- a/tests/inference/inference_utils.py +++ b/tests/inference/inference_utils.py @@ -33,8 +33,8 @@ "bert": "hf-internal-testing/tiny-random-BertModel", "camembert": "hf-internal-testing/tiny-random-camembert", "convbert": "hf-internal-testing/tiny-random-ConvBertModel", - # "deberta": "hf-internal-testing/tiny-random-DebertaModel", # Failed for INF1: 'XSoftmax' - # "deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model", # Failed for INF1: 'XSoftmax' + "deberta": "hf-internal-testing/tiny-random-DebertaModel", # Failed for INF1: 'XSoftmax' + "deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model", # Failed for INF1: 'XSoftmax' "distilbert": "hf-internal-testing/tiny-random-DistilBertModel", "electra": "hf-internal-testing/tiny-random-ElectraModel", "flaubert": "flaubert/flaubert_small_cased", diff --git a/tests/inference/test_modeling.py b/tests/inference/test_modeling.py index 96b8f203b..3884b3517 100644 --- a/tests/inference/test_modeling.py +++ b/tests/inference/test_modeling.py @@ -136,6 +136,26 @@ def test_save_compiler_intermediary_files(self): self.assertTrue(os.path.isdir(save_path)) self.assertTrue(os.path.exists(neff_path)) + @requires_neuronx + def test_decouple_weights_neff_and_replace_weight(self): + with tempfile.TemporaryDirectory() as tempdir: + # compile + save_path = f"{tempdir}/neff" + neuron_model = NeuronModelForSequenceClassification.from_pretrained( + self.MODEL_ID, + export=True, + compiler_workdir=save_path, + inline_weights_to_neff=False, + **self.STATIC_INPUTS_SHAPES, + ) + self.assertFalse(neuron_model.config.neuron.get("inline_weights_to_neff")) + + # replace weights + model = AutoModelForSequenceClassification.from_pretrained(self.MODEL_ID) + neuron_model.replace_weights(weights=model) + + self.assertIsInstance(neuron_model.model, torch.jit._script.ScriptModule) + @is_inferentia_test class NeuronModelForFeatureExtractionIntegrationTest(NeuronModelTestMixin): @@ -149,7 +169,7 @@ class NeuronModelForFeatureExtractionIntegrationTest(NeuronModelTestMixin): "camembert", # "convbert", # accuracy off compared to pytorch: atol=1e-1 # "deberta", # INF2 only - # "deberta_v2", # INF2 only + # "deberta-v2", # INF2 only # "distilbert", # accuracy off compared to pytorch: atol=1e-1 "electra", # "flaubert", # accuracy off compared to pytorch (not due to the padding) @@ -165,16 +185,16 @@ class NeuronModelForFeatureExtractionIntegrationTest(NeuronModelTestMixin): "albert", "bert", "camembert", - # "convbert", # accuracy off compared to pytorch: atol=1e-2 - # "deberta", # INF2 only - # "deberta_v2", # INF2 only + "convbert", + "deberta", + "deberta-v2", "distilbert", "electra", - # "flaubert", # accuracy off compared to pytorch (not due to the padding) + "flaubert", "mobilebert", "roberta", "roformer", - # "xlm", # accuracy off compared to pytorch (not due to the padding) + "xlm", "xlm-roberta", ] else: @@ -210,6 +230,7 @@ def test_compare_to_transformers_dyn_bs(self, model_arch): transformers_outputs = transformers_model(**tokens) # Numeric validation + atol = neuron_model_dyn.neuron_config.ATOL_FOR_VALIDATION or self.ATOL_FOR_VALIDATION neuron_outputs_dyn = neuron_model_dyn(**tokens) self.assertIn("last_hidden_state", neuron_outputs_dyn) self.assertIsInstance(neuron_outputs_dyn.last_hidden_state, torch.Tensor) @@ -217,7 +238,7 @@ def test_compare_to_transformers_dyn_bs(self, model_arch): torch.allclose( neuron_outputs_dyn.last_hidden_state, transformers_outputs.last_hidden_state, - atol=self.ATOL_FOR_VALIDATION, + atol=atol, ) ) @@ -225,7 +246,9 @@ def test_compare_to_transformers_dyn_bs(self, model_arch): self.assertIsInstance(neuron_outputs_dyn.pooler_output, torch.Tensor) self.assertTrue( torch.allclose( - neuron_outputs_dyn.pooler_output, transformers_outputs.pooler_output, atol=self.ATOL_FOR_VALIDATION + neuron_outputs_dyn.pooler_output, + transformers_outputs.pooler_output, + atol=atol, ) ) @@ -258,6 +281,10 @@ def test_compare_to_transformers_non_dyn_bs(self, model_arch): transformers_outputs = transformers_model(**tokens) # Numeric validation + if is_neuron_available(): + atol = self.ATOL_FOR_VALIDATION + else: + atol = neuron_model_non_dyn.neuron_config.ATOL_FOR_VALIDATION or self.ATOL_FOR_VALIDATION neuron_outputs_non_dyn = neuron_model_non_dyn(**tokens) self.assertIn("last_hidden_state", neuron_outputs_non_dyn) self.assertIsInstance(neuron_outputs_non_dyn.last_hidden_state, torch.Tensor) @@ -265,7 +292,7 @@ def test_compare_to_transformers_non_dyn_bs(self, model_arch): torch.allclose( neuron_outputs_non_dyn.last_hidden_state, transformers_outputs.last_hidden_state, - atol=self.ATOL_FOR_VALIDATION, + atol=atol, ) ) @@ -275,7 +302,7 @@ def test_compare_to_transformers_non_dyn_bs(self, model_arch): torch.allclose( neuron_outputs_non_dyn.pooler_output, transformers_outputs.pooler_output, - atol=self.ATOL_FOR_VALIDATION, + atol=atol, ) ) @@ -337,13 +364,14 @@ def test_sentence_transformers_dyn_bs(self, model_arch): neuron_outputs_dyn = neuron_model_dyn(**tokens) # Validate token_embeddings + atol = neuron_model_dyn.neuron_config.ATOL_FOR_VALIDATION or self.ATOL_FOR_VALIDATION self.assertIn("token_embeddings", neuron_outputs_dyn) self.assertIsInstance(neuron_outputs_dyn.token_embeddings, torch.Tensor) self.assertTrue( torch.allclose( neuron_outputs_dyn.token_embeddings, sentence_transformers_outputs.token_embeddings, - atol=self.ATOL_FOR_VALIDATION, + atol=atol, ) ) @@ -354,7 +382,7 @@ def test_sentence_transformers_dyn_bs(self, model_arch): torch.allclose( neuron_outputs_dyn.sentence_embedding, sentence_transformers_outputs.sentence_embedding, - atol=self.ATOL_FOR_VALIDATION, + atol=atol, ) ) @@ -372,8 +400,6 @@ class NeuronModelForMaskedLMIntegrationTest(NeuronModelTestMixin): "bert", "camembert", # "convbert", # accuracy off compared to pytorch: atol=1e-1 - # "deberta", # INF2 only - # "deberta_v2", # INF2 only # "distilbert", # accuracy off compared to pytorch: atol=1e-1 "electra", # "flaubert", # accuracy off compared to pytorch (not due to the padding) @@ -389,16 +415,16 @@ class NeuronModelForMaskedLMIntegrationTest(NeuronModelTestMixin): "albert", "bert", "camembert", - # "convbert", # accuracy off compared to pytorch: atol=1e-2 - # "deberta", # INF2 only - # "deberta_v2", # INF2 only + "convbert", + "deberta", + "deberta-v2", "distilbert", "electra", - # "flaubert", # accuracy off compared to pytorch (not due to the padding) + "flaubert", "mobilebert", "roberta", "roformer", - # "xlm", # accuracy off compared to pytorch (not due to the padding) + "xlm", "xlm-roberta", ] else: @@ -440,11 +466,16 @@ def test_compare_to_transformers_dyn_bs(self, model_arch): transformers_outputs = transformers_model(**tokens) # Numeric validation + atol = neuron_model_dyn.neuron_config.ATOL_FOR_VALIDATION or self.ATOL_FOR_VALIDATION neuron_outputs_dyn = neuron_model_dyn(**tokens) self.assertIn("logits", neuron_outputs_dyn) self.assertIsInstance(neuron_outputs_dyn.logits, torch.Tensor) self.assertTrue( - torch.allclose(neuron_outputs_dyn.logits, transformers_outputs.logits, atol=self.ATOL_FOR_VALIDATION) + torch.allclose( + neuron_outputs_dyn.logits, + transformers_outputs.logits, + atol=atol, + ) ) gc.collect() @@ -476,11 +507,19 @@ def test_compare_to_transformers_non_dyn_bs(self, model_arch): transformers_outputs = transformers_model(**tokens) # Numeric validation + if is_neuron_available(): + atol = self.ATOL_FOR_VALIDATION + else: + atol = neuron_model_non_dyn.neuron_config.ATOL_FOR_VALIDATION or self.ATOL_FOR_VALIDATION neuron_outputs_non_dyn = neuron_model_non_dyn(**tokens) self.assertIn("logits", neuron_outputs_non_dyn) self.assertIsInstance(neuron_outputs_non_dyn.logits, torch.Tensor) self.assertTrue( - torch.allclose(neuron_outputs_non_dyn.logits, transformers_outputs.logits, atol=self.ATOL_FOR_VALIDATION) + torch.allclose( + neuron_outputs_non_dyn.logits, + transformers_outputs.logits, + atol=atol, + ) ) gc.collect() @@ -538,8 +577,6 @@ class NeuronModelForQuestionAnsweringIntegrationTest(NeuronModelTestMixin): "bert", "camembert", # "convbert", # accuracy off compared to pytorch: atol=1e-1 - # "deberta", # INF2 only - # "deberta_v2", # INF2 only # "distilbert", # accuracy off compared to pytorch: atol=1e-1 "electra", # "flaubert", # accuracy off compared to pytorch (not due to the padding) @@ -555,16 +592,16 @@ class NeuronModelForQuestionAnsweringIntegrationTest(NeuronModelTestMixin): "albert", "bert", "camembert", - # "convbert", # accuracy off compared to pytorch: atol=1e-2 - # "deberta", # INF2 only - # "deberta_v2", # INF2 only + "convbert", + "deberta", + "deberta-v2", "distilbert", "electra", - # "flaubert", # accuracy off compared to pytorch (not due to the padding) + "flaubert", "mobilebert", "roberta", "roformer", - # "xlm", # accuracy off compared to pytorch (not due to the padding) + "xlm", "xlm-roberta", ] else: @@ -608,6 +645,7 @@ def test_compare_to_transformers_dyn_bs(self, model_arch): transformers_outputs = transformers_model(**tokens) # Numeric validation + atol = neuron_model_dyn.neuron_config.ATOL_FOR_VALIDATION or self.ATOL_FOR_VALIDATION neuron_outputs_dyn = neuron_model_dyn(**tokens) self.assertIn("start_logits", neuron_outputs_dyn) self.assertIn("end_logits", neuron_outputs_dyn) @@ -619,14 +657,14 @@ def test_compare_to_transformers_dyn_bs(self, model_arch): torch.allclose( torch.Tensor(neuron_outputs_dyn.start_logits), transformers_outputs.start_logits, - atol=self.ATOL_FOR_VALIDATION, + atol=atol, ) ) self.assertTrue( torch.allclose( torch.Tensor(neuron_outputs_dyn.end_logits), transformers_outputs.end_logits, - atol=self.ATOL_FOR_VALIDATION, + atol=atol, ) ) @@ -659,6 +697,10 @@ def test_compare_to_transformers_non_dyn_bs(self, model_arch): transformers_outputs = transformers_model(**tokens) # Numeric validation + if is_neuron_available(): + atol = self.ATOL_FOR_VALIDATION + else: + atol = neuron_model_non_dyn.neuron_config.ATOL_FOR_VALIDATION or self.ATOL_FOR_VALIDATION neuron_outputs_non_dyn = neuron_model_non_dyn(**tokens) self.assertIn("start_logits", neuron_outputs_non_dyn) self.assertIn("end_logits", neuron_outputs_non_dyn) @@ -670,14 +712,14 @@ def test_compare_to_transformers_non_dyn_bs(self, model_arch): torch.allclose( torch.Tensor(neuron_outputs_non_dyn.start_logits), transformers_outputs.start_logits, - atol=self.ATOL_FOR_VALIDATION, + atol=atol, ) ) self.assertTrue( torch.allclose( torch.Tensor(neuron_outputs_non_dyn.end_logits), transformers_outputs.end_logits, - atol=self.ATOL_FOR_VALIDATION, + atol=atol, ) ) @@ -706,7 +748,8 @@ def test_non_dyn_bs_neuron_model_on_false_batch_size(self): self.assertIn("set `dynamic_batch_size=True` during the compilation", str(context.exception)) - @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + # TODO: exclude flaubert, xlm for now as the pipeline seems to pad already input_ids to max, and running tiny test will fail. (ValueError: Unable to pad input_ids with shape: torch.Size([1, 384]) on dimension 1 as input shapes must be inferior than the static shapes used for compilation: torch.Size([1, 32]).) + @parameterized.expand([x for x in SUPPORTED_ARCHITECTURES if x not in ["flaubert", "xlm"]], skip_on_empty=True) def test_pipeline_model(self, model_arch): model_args = {"test_name": model_arch + "_dyn_bs_false", "model_arch": model_arch} self._setup(model_args) @@ -739,7 +782,7 @@ class NeuronModelForSequenceClassificationIntegrationTest(NeuronModelTestMixin): "camembert", # "convbert", # accuracy off compared to pytorch: atol=1e-1 # "deberta", # INF2 only - # "deberta_v2", # INF2 only + # "deberta-v2", # INF2 only # "distilbert", # accuracy off compared to pytorch: atol=1e-1 "electra", # "flaubert", # accuracy off compared to pytorch (not due to the padding) @@ -755,12 +798,12 @@ class NeuronModelForSequenceClassificationIntegrationTest(NeuronModelTestMixin): "albert", "bert", "camembert", - # "convbert", # accuracy off compared to pytorch: atol=1e-2 - # "deberta", # INF2 only - # "deberta_v2", # INF2 only + "convbert", + "deberta", + "deberta-v2", "distilbert", "electra", - # "flaubert", # accuracy off compared to pytorch (not due to the padding) + "flaubert", "mobilebert", "roberta", "roformer", @@ -808,11 +851,16 @@ def test_compare_to_transformers_dyn_bs(self, model_arch): transformers_outputs = transformers_model(**tokens) # Numeric validation + atol = neuron_model_dyn.neuron_config.ATOL_FOR_VALIDATION or self.ATOL_FOR_VALIDATION neuron_outputs_dyn = neuron_model_dyn(**tokens) self.assertIn("logits", neuron_outputs_dyn) self.assertIsInstance(neuron_outputs_dyn.logits, torch.Tensor) self.assertTrue( - torch.allclose(neuron_outputs_dyn.logits, transformers_outputs.logits, atol=self.ATOL_FOR_VALIDATION) + torch.allclose( + neuron_outputs_dyn.logits, + transformers_outputs.logits, + atol=atol, + ) ) gc.collect() @@ -844,11 +892,19 @@ def test_compare_to_transformers_non_dyn_bs(self, model_arch): transformers_outputs = transformers_model(**tokens) # Numeric validation + if is_neuron_available(): + atol = self.ATOL_FOR_VALIDATION + else: + atol = neuron_model_non_dyn.neuron_config.ATOL_FOR_VALIDATION or self.ATOL_FOR_VALIDATION neuron_outputs_non_dyn = neuron_model_non_dyn(**tokens) self.assertIn("logits", neuron_outputs_non_dyn) self.assertIsInstance(neuron_outputs_non_dyn.logits, torch.Tensor) self.assertTrue( - torch.allclose(neuron_outputs_non_dyn.logits, transformers_outputs.logits, atol=self.ATOL_FOR_VALIDATION) + torch.allclose( + neuron_outputs_non_dyn.logits, + transformers_outputs.logits, + atol=atol, + ) ) gc.collect() @@ -908,7 +964,7 @@ class NeuronModelForTokenClassificationIntegrationTest(NeuronModelTestMixin): "camembert", # "convbert", # accuracy off compared to pytorch: atol=1e-1 # "deberta", # INF2 only - # "deberta_v2", # INF2 only + # "deberta-v2", # INF2 only # "distilbert", # accuracy off compared to pytorch: atol=1e-1 "electra", # "flaubert", # accuracy off compared to pytorch (not due to the padding) @@ -924,16 +980,16 @@ class NeuronModelForTokenClassificationIntegrationTest(NeuronModelTestMixin): "albert", "bert", "camembert", - # "convbert", # accuracy off compared to pytorch: atol=1e-2 - # "deberta", # INF2 only - # "deberta_v2", # INF2 only + "convbert", + "deberta", + "deberta-v2", "distilbert", "electra", - # "flaubert", # accuracy off compared to pytorch (not due to the padding) + "flaubert", "mobilebert", "roberta", "roformer", - # "xlm", # accuracy off compared to pytorch (not due to the padding) + "xlm", "xlm-roberta", ] else: @@ -977,11 +1033,16 @@ def test_compare_to_transformers_dyn_bs(self, model_arch): transformers_outputs = transformers_model(**tokens) # Numeric validation + atol = neuron_model_dyn.neuron_config.ATOL_FOR_VALIDATION or self.ATOL_FOR_VALIDATION neuron_outputs_dyn = neuron_model_dyn(**tokens) self.assertIn("logits", neuron_outputs_dyn) self.assertIsInstance(neuron_outputs_dyn.logits, torch.Tensor) self.assertTrue( - torch.allclose(neuron_outputs_dyn.logits, transformers_outputs.logits, atol=self.ATOL_FOR_VALIDATION) + torch.allclose( + neuron_outputs_dyn.logits, + transformers_outputs.logits, + atol=atol, + ) ) gc.collect() @@ -1013,11 +1074,19 @@ def test_compare_to_transformers_non_dyn_bs(self, model_arch): transformers_outputs = transformers_model(**tokens) # Numeric validation + if is_neuron_available(): + atol = self.ATOL_FOR_VALIDATION + else: + atol = neuron_model_non_dyn.neuron_config.ATOL_FOR_VALIDATION or self.ATOL_FOR_VALIDATION neuron_outputs_non_dyn = neuron_model_non_dyn(**tokens) self.assertIn("logits", neuron_outputs_non_dyn) self.assertIsInstance(neuron_outputs_non_dyn.logits, torch.Tensor) self.assertTrue( - torch.allclose(neuron_outputs_non_dyn.logits, transformers_outputs.logits, atol=self.ATOL_FOR_VALIDATION) + torch.allclose( + neuron_outputs_non_dyn.logits, + transformers_outputs.logits, + atol=atol, + ) ) gc.collect() @@ -1077,7 +1146,7 @@ class NeuronModelForMultipleChoiceIntegrationTest(NeuronModelTestMixin): "camembert", # "convbert", # accuracy off compared to pytorch: atol=1e-1 # "deberta", # INF2 only - # "deberta_v2", # INF2 only + # "deberta-v2", # INF2 only # "distilbert", # accuracy off compared to pytorch: atol=1e-1 "electra", # "flaubert", # accuracy off compared to pytorch (not due to the padding) @@ -1094,11 +1163,9 @@ class NeuronModelForMultipleChoiceIntegrationTest(NeuronModelTestMixin): "bert", "camembert", # "convbert", # accuracy off compared to pytorch: atol=1e-2 - # "deberta", # INF2 only - # "deberta_v2", # INF2 only "distilbert", "electra", - # "flaubert", # accuracy off compared to pytorch (not due to the padding) + "flaubert", "mobilebert", "roberta", # "roformer", # accuracy off compared to pytorch: atol=1e-1 @@ -1146,17 +1213,22 @@ def test_compare_to_transformers_dyn_bs(self, model_arch): transformers_outputs = transformers_model(**pt_inputs) # Numeric validation + atol = neuron_model_dyn.neuron_config.ATOL_FOR_VALIDATION or self.ATOL_FOR_VALIDATION neuron_outputs_dyn = neuron_model_dyn(**pt_inputs) self.assertIn("logits", neuron_outputs_dyn) self.assertIsInstance(neuron_outputs_dyn.logits, torch.Tensor) self.assertTrue( - torch.allclose(neuron_outputs_dyn.logits, transformers_outputs.logits, atol=self.ATOL_FOR_VALIDATION) + torch.allclose( + neuron_outputs_dyn.logits, + transformers_outputs.logits, + atol=atol, + ) ) gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) - def test_compare_to_transformers_non_dyn_bas(self, model_arch): + def test_compare_to_transformers_non_dyn_bs(self, model_arch): model_args = { "test_name": model_arch + "_dyn_bs_false", "model_arch": model_arch, @@ -1191,11 +1263,19 @@ def test_compare_to_transformers_non_dyn_bas(self, model_arch): transformers_outputs = transformers_model(**pt_inputs) # Numeric validation + if is_neuron_available(): + atol = self.ATOL_FOR_VALIDATION + else: + atol = neuron_model_non_dyn.neuron_config.ATOL_FOR_VALIDATION or self.ATOL_FOR_VALIDATION neuron_outputs_non_dyn = neuron_model_non_dyn(**pt_inputs) self.assertIn("logits", neuron_outputs_non_dyn) self.assertIsInstance(neuron_outputs_non_dyn.logits, torch.Tensor) self.assertTrue( - torch.allclose(neuron_outputs_non_dyn.logits, transformers_outputs.logits, atol=self.ATOL_FOR_VALIDATION) + torch.allclose( + neuron_outputs_non_dyn.logits, + transformers_outputs.logits, + atol=atol, + ) ) gc.collect()