Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Add t5 support for export and inference #267

Merged
merged 38 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
64f12e3
init
JingyaHuang Oct 19, 2023
aa5a379
update wrappers
JingyaHuang Oct 20, 2023
6580875
encoder support
JingyaHuang Oct 23, 2023
e997f5f
decoder export
JingyaHuang Oct 24, 2023
7621e39
CLI support
JingyaHuang Oct 25, 2023
1eaa54a
validation
JingyaHuang Oct 25, 2023
2231afb
add seq2seq base model
JingyaHuang Oct 26, 2023
a7ce906
Merge branch 'main' into add-t5-export
JingyaHuang Oct 30, 2023
72ed695
modeling export and loading
JingyaHuang Oct 30, 2023
16ddeeb
fix style
JingyaHuang Nov 5, 2023
8c5d796
Merge branch 'main' into add-t5-export
JingyaHuang Nov 5, 2023
3efdbc8
finish base modeling funcs
JingyaHuang Nov 5, 2023
cdc885e
quick test inference
JingyaHuang Nov 6, 2023
2384e52
fix config loding
JingyaHuang Nov 7, 2023
ae9df1a
finish modeling, works
JingyaHuang Nov 9, 2023
a3784cf
add part of tests
JingyaHuang Nov 9, 2023
348b3b0
Merge branch 'main' into add-t5-export
JingyaHuang Nov 9, 2023
308c08e
tests done
JingyaHuang Nov 9, 2023
9d93f15
Merge branch 'main' into add-t5-export
JingyaHuang Nov 14, 2023
634a908
Merge branch 'main' into add-t5-export
JingyaHuang Nov 17, 2023
12e9311
apply some suggestions
JingyaHuang Nov 17, 2023
b8fb359
Merge branch 'main' into add-t5-export
JingyaHuang Nov 22, 2023
13445c9
fix style
JingyaHuang Nov 22, 2023
065ebaf
Merge branch 'main' into add-t5-export
JingyaHuang Nov 23, 2023
925f8f6
Merge branch 'main' into add-t5-export
JingyaHuang Nov 23, 2023
ded43a4
address part of comments
JingyaHuang Nov 23, 2023
994374b
apply some suggestions
JingyaHuang Nov 23, 2023
5c55ec1
add pad left support and log
JingyaHuang Nov 24, 2023
9396c7a
fix enable custom max length instead of real max length limit
JingyaHuang Nov 27, 2023
9676a65
reuse neuron gen mix
JingyaHuang Nov 28, 2023
e8d72c2
fix beam
JingyaHuang Nov 28, 2023
dd4b1c7
fix tests
JingyaHuang Nov 29, 2023
df7cde7
support optional outputs for decoder
JingyaHuang Nov 29, 2023
92cd6e5
enhance tests
JingyaHuang Dec 1, 2023
6f69d6d
fix style
JingyaHuang Dec 1, 2023
9f461f8
fix style
JingyaHuang Dec 1, 2023
d6a24b6
apply suggestions
JingyaHuang Dec 2, 2023
3b07ba1
fix tests
JingyaHuang Dec 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/tutorials/stable_diffusion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ To avoid Neuron device out of memory, it's suggested to finish all base inferenc
Latent Consistency Models (LCMs) were proposed in [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference by Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao](https://huggingface.co/papers/2310.04378). LCMs enable inference with fewer steps on any pre-trained LDMs, including Stable Diffusion and SDXL.

In `optimum-neuron`, you can:
- Use the class `NeuronLatentConsistencyModelPipeline` to compile and run inference of LCMs distilled from Stable Diffusion (SD) models,
- Use the class `NeuronLatentConsistencyModelPipeline` to compile and run inference of LCMs distilled from Stable Diffusion (SD) models.
- And continue to use the class `NeuronStableDiffusionXLPipeline` for LCMs distilled from SDXL models.

Here are examples to compile the LCMs of Stable Diffusion ( [SimianLuo/LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) ) and Stable Diffusion XL( [latent-consistency/lcm-sdxl](https://huggingface.co/latent-consistency/lcm-sdxl) ), and then run inference on AWS Inferentia 2 :
Expand Down
16 changes: 16 additions & 0 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ def parse_args_neuronx(parser: "ArgumentParser"):
type=int,
help=f"Sequence length {doc_input}",
)
input_group.add_argument(
"--num_beams",
type=int,
default=1,
help=f"Number of beams for beam search {doc_input}",
)
input_group.add_argument(
"--num_choices",
type=int,
Expand Down Expand Up @@ -135,6 +141,16 @@ def parse_args_neuronx(parser: "ArgumentParser"):
"UNet model ID on huggingface.co or path on disk to load model from. This will replace the unet in the original Stable Diffusion pipeline."
),
)
optional_group.add_argument(
"--output_hidden_states",
action="store_true",
help=("Whether or not for the traced model to return the hidden states of all layers."),
)
optional_group.add_argument(
"--output_attentions",
action="store_true",
help=("Whether or not for the traced model to return the attentions tensors of all attention layers."),
)


class NeuronxExportCommand(BaseOptimumCLICommand):
Expand Down
225 changes: 170 additions & 55 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoConfig
from transformers import AutoConfig, PretrainedConfig

from ...neuron.utils import (
DECODER_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_NAME,
DIFFUSION_MODEL_UNET_NAME,
DIFFUSION_MODEL_VAE_DECODER_NAME,
DIFFUSION_MODEL_VAE_ENCODER_NAME,
ENCODER_NAME,
NEURON_FILE_NAME,
is_neuron_available,
is_neuronx_available,
Expand All @@ -43,6 +45,7 @@
from .model_configs import * # noqa: F403
from .utils import (
build_stable_diffusion_components_mandatory_shapes,
get_encoder_decoder_models_for_export,
get_stable_diffusion_models_for_export,
replace_stable_diffusion_submodels,
)
Expand All @@ -64,8 +67,10 @@


if TYPE_CHECKING:
from transformers import PreTrainedModel

if is_diffusers_available():
from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline, StableDiffusionPipeline


logger = logging.get_logger()
Expand Down Expand Up @@ -103,7 +108,11 @@ def infer_task(task: str, model_name_or_path: str) -> str:

def normalize_input_shapes(task: str, args: argparse.Namespace) -> Dict[str, int]:
config = AutoConfig.from_pretrained(args.model)

model_type = config.model_type.replace("_", "-")
if config.is_encoder_decoder:
model_type = model_type + "-encoder"

neuron_config_constructor = TasksManager.get_exporter_config_constructor(
model_type=model_type, exporter="neuron", task=task
)
Expand All @@ -112,6 +121,18 @@ def normalize_input_shapes(task: str, args: argparse.Namespace) -> Dict[str, int
return input_shapes


def customize_optional_outputs(args: argparse.Namespace) -> Dict[str, bool]:
"""
Customize optional outputs of the traced model, eg. if `output_attentions=True`, the attentions tensors will be traced.
"""
possible_outputs = ["output_attentions", "output_hidden_states"]

customized_outputs = {}
for name in possible_outputs:
customized_outputs[name] = getattr(args, name, False)
return customized_outputs


def normalize_stable_diffusion_input_shapes(
args: argparse.Namespace,
) -> Dict[str, Dict[str, int]]:
Expand Down Expand Up @@ -173,6 +194,135 @@ def infer_stable_diffusion_shapes_from_diffusers(
return input_shapes


def _get_submodels_and_neuron_configs(
model: Union["PreTrainedModel", "DiffusionPipeline"],
input_shapes: Dict[str, int],
task: str,
output: Path,
dynamic_batch_size: bool = False,
model_name_or_path: Optional[Union[str, Path]] = None,
submodels: Optional[Dict[str, Union[Path, str]]] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
):
is_stable_diffusion = "stable-diffusion" in task
is_encoder_decoder = (
getattr(model.config, "is_encoder_decoder", False) if isinstance(model.config, PretrainedConfig) else False
)

if is_stable_diffusion:
# TODO: Enable optional outputs for Stable Diffusion
if output_attentions or output_hidden_states:
raise ValueError(
f"`output_attentions` and `output_hidden_states` are not supported by the {task} task yet."
)
models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_stable_diffusion(
model, input_shapes, task, output, dynamic_batch_size, submodels
)
elif is_encoder_decoder:
optional_outputs = {"output_attentions": output_attentions, "output_hidden_states": output_hidden_states}
models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_encoder_decoder(
model, input_shapes, task, output, dynamic_batch_size, model_name_or_path, **optional_outputs
)
else:
# TODO: Enable optional outputs for encoders
if output_attentions or output_hidden_states:
raise ValueError(
f"`output_attentions` and `output_hidden_states` are not supported by the {task} task yet."
)
neuron_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="neuron", task=task
)
neuron_config = neuron_config_constructor(model.config, dynamic_batch_size=dynamic_batch_size, **input_shapes)
model_name = model.name_or_path.split("/")[-1]
output_model_names = {model_name: "model.neuron"}
models_and_neuron_configs = {model_name: (model, neuron_config)}
maybe_save_preprocessors(model_name_or_path, output)
return models_and_neuron_configs, output_model_names


def _get_submodels_and_neuron_configs_for_stable_diffusion(
model: Union["PreTrainedModel", "DiffusionPipeline"],
input_shapes: Dict[str, int],
task: str,
output: Path,
dynamic_batch_size: bool = False,
submodels: Optional[Dict[str, Union[Path, str]]] = None,
):
check_compiler_compatibility_for_stable_diffusion()
model = replace_stable_diffusion_submodels(model, submodels)
if is_neuron_available():
raise RuntimeError(
"Stable diffusion export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead."
)
input_shapes = infer_stable_diffusion_shapes_from_diffusers(input_shapes, model)

# Saving the model config and preprocessor as this is needed sometimes.
model.scheduler.save_pretrained(output.joinpath("scheduler"))
if getattr(model, "tokenizer", None) is not None:
model.tokenizer.save_pretrained(output.joinpath("tokenizer"))
if getattr(model, "tokenizer_2", None) is not None:
model.tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))
if getattr(model, "feature_extractor", None) is not None:
model.feature_extractor.save_pretrained(output.joinpath("feature_extractor"))
model.save_config(output)

models_and_neuron_configs = get_stable_diffusion_models_for_export(
pipeline=model,
task=task,
dynamic_batch_size=dynamic_batch_size,
**input_shapes,
)
output_model_names = {
DIFFUSION_MODEL_UNET_NAME: os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME),
DIFFUSION_MODEL_VAE_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME),
DIFFUSION_MODEL_VAE_DECODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_DECODER_NAME, NEURON_FILE_NAME),
}
if getattr(model, "text_encoder", None) is not None:
output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_NAME] = os.path.join(
DIFFUSION_MODEL_TEXT_ENCODER_NAME, NEURON_FILE_NAME
)
if getattr(model, "text_encoder_2", None) is not None:
output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = os.path.join(
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME
)
del model

return models_and_neuron_configs, output_model_names


def _get_submodels_and_neuron_configs_for_encoder_decoder(
model: "PreTrainedModel",
input_shapes: Dict[str, int],
task: str,
output: Path,
dynamic_batch_size: bool = False,
model_name_or_path: Optional[Union[str, Path]] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
):
if is_neuron_available():
raise RuntimeError(
"Encoder-decoder models export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead."
)

models_and_neuron_configs = get_encoder_decoder_models_for_export(
model=model,
task=task,
dynamic_batch_size=dynamic_batch_size,
input_shapes=input_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
output_model_names = {
ENCODER_NAME: os.path.join(ENCODER_NAME, NEURON_FILE_NAME),
DECODER_NAME: os.path.join(DECODER_NAME, NEURON_FILE_NAME),
}
maybe_save_preprocessors(model_name_or_path, output)

return models_and_neuron_configs, output_model_names


def main_export(
model_name_or_path: str,
output: Union[str, Path],
Expand All @@ -188,14 +338,17 @@ def main_export(
local_files_only: bool = False,
use_auth_token: Optional[Union[bool, str]] = None,
do_validation: bool = True,
submodels: Dict[str, Union[Path, str]] = None,
submodels: Optional[Dict[str, Union[Path, str]]] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
**input_shapes,
):
output = Path(output)
if not output.parent.exists():
output.parent.mkdir(parents=True)

task = TasksManager.map_from_synonym(task)
is_stable_diffusion = "stable-diffusion" in task

model_kwargs = {
"task": task,
Expand All @@ -211,58 +364,17 @@ def main_export(
}
model = TasksManager.get_model_from_task(**model_kwargs)

is_stable_diffusion = "stable-diffusion" in task
if not is_stable_diffusion:
neuron_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="neuron", task=task
)
neuron_config = neuron_config_constructor(model.config, dynamic_batch_size=dynamic_batch_size, **input_shapes)
if atol is None:
atol = neuron_config.ATOL_FOR_VALIDATION
model_name = model.name_or_path.split("/")[-1]
output_model_names = {model_name: "model.neuron"}
models_and_neuron_configs = {model_name: (model, neuron_config)}
maybe_save_preprocessors(model, output.parent)

if is_stable_diffusion:
model = replace_stable_diffusion_submodels(model, submodels)
check_compiler_compatibility_for_stable_diffusion()
if is_neuron_available():
raise RuntimeError(
"Stable diffusion export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead."
)
input_shapes = infer_stable_diffusion_shapes_from_diffusers(input_shapes, model)

# Saving the model config and preprocessor as this is needed sometimes.
model.scheduler.save_pretrained(output.joinpath("scheduler"))
if hasattr(model, "tokenizer") and model.tokenizer is not None:
model.tokenizer.save_pretrained(output.joinpath("tokenizer"))
if hasattr(model, "tokenizer_2") and model.tokenizer_2 is not None:
model.tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))
if hasattr(model, "feature_extractor") and model.feature_extractor is not None:
model.feature_extractor.save_pretrained(output.joinpath("feature_extractor"))
model.save_config(output)

models_and_neuron_configs = get_stable_diffusion_models_for_export(
pipeline=model,
task=task,
dynamic_batch_size=dynamic_batch_size,
**input_shapes,
)
output_model_names = {
DIFFUSION_MODEL_UNET_NAME: os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME),
DIFFUSION_MODEL_VAE_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME),
DIFFUSION_MODEL_VAE_DECODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_DECODER_NAME, NEURON_FILE_NAME),
}
if hasattr(model, "text_encoder") and model.text_encoder is not None:
output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_NAME] = os.path.join(
DIFFUSION_MODEL_TEXT_ENCODER_NAME, NEURON_FILE_NAME
)
if hasattr(model, "text_encoder_2") and model.text_encoder_2 is not None:
output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = os.path.join(
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME
)
del model
models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs(
model=model,
input_shapes=input_shapes,
task=task,
output=output,
dynamic_batch_size=dynamic_batch_size,
model_name_or_path=model_name_or_path,
submodels=submodels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

_, neuron_outputs = export_models(
models_and_neuron_configs=models_and_neuron_configs,
Expand Down Expand Up @@ -329,6 +441,8 @@ def main():
input_shapes = normalize_input_shapes(task, args)
submodels = None

optional_outputs = customize_optional_outputs(args)

main_export(
model_name_or_path=args.model,
output=args.output,
Expand All @@ -340,6 +454,7 @@ def main():
trust_remote_code=args.trust_remote_code,
do_validation=not args.disable_validation,
submodels=submodels,
**optional_outputs,
**input_shapes,
)

Expand Down
8 changes: 7 additions & 1 deletion optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def __init__(
audio_sequence_length: Optional[int] = None,
point_batch_size: Optional[int] = None,
nb_points_per_image: Optional[int] = None,
num_beams: int = 1,
output_attentions: bool = False,
output_hidden_states: bool = False,
# TODO: add custom dtype after optimum 1.13 release
# int_dtype: str = "int64",
# float_dtype: str = "fp32",
Expand Down Expand Up @@ -147,13 +150,16 @@ def __init__(
"audio_sequence_length": audio_sequence_length,
"point_batch_size": point_batch_size,
"nb_points_per_image": nb_points_per_image,
"num_beams": num_beams,
}
input_shapes = {}
for name, value in axes_values.items():
if value is not None:
input_shapes[name] = value
setattr(self, name, value)
setattr(self, "input_shapes", input_shapes)
setattr(self, "output_attentions", output_attentions)
setattr(self, "output_hidden_states", output_hidden_states)
setattr(self, "compiler_type", compiler_type)
setattr(self, "compiler_version", compiler_version)

Expand Down Expand Up @@ -290,7 +296,7 @@ def flatten_inputs(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
flatten[name] = value
return flatten

def check_model_inputs_order(
def patch_model_for_export(
self,
model: "PreTrainedModel",
dummy_inputs: Optional[Dict[str, torch.Tensor]] = None,
Expand Down
Loading
Loading