Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SD3 and Flux support #2073

Merged
merged 32 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
07d9952
sd3 support
IlyasMoutawwakil Oct 21, 2024
08190c7
unsupported cli model types
IlyasMoutawwakil Oct 21, 2024
4518691
flux transformer support, unet export fixes, updated callback test, u…
IlyasMoutawwakil Oct 23, 2024
aa74f63
fixes
IlyasMoutawwakil Oct 23, 2024
b566392
move input generators
IlyasMoutawwakil Oct 24, 2024
88901ad
dummy diffusers
IlyasMoutawwakil Oct 24, 2024
11467ae
style
IlyasMoutawwakil Oct 28, 2024
6c0093a
Merge branch 'main' into diffusers-transformer-export
IlyasMoutawwakil Nov 1, 2024
e4c63f7
sd3 support
IlyasMoutawwakil Oct 21, 2024
c9e7003
unsupported cli model types
IlyasMoutawwakil Oct 21, 2024
0ba1e92
flux transformer support, unet export fixes, updated callback test, u…
IlyasMoutawwakil Oct 23, 2024
1ee6942
fixes
IlyasMoutawwakil Oct 23, 2024
f2c43a0
move input generators
IlyasMoutawwakil Oct 24, 2024
bbbb668
dummy diffusers
IlyasMoutawwakil Oct 24, 2024
1e1715e
style
IlyasMoutawwakil Oct 28, 2024
572ba46
Merge branch 'diffusers-transformer-export' of https://github.com/hug…
IlyasMoutawwakil Nov 1, 2024
f07ceef
distribute ort tests
IlyasMoutawwakil Nov 3, 2024
c1d4443
Merge branch 'main' into diffusers-transformer-export
IlyasMoutawwakil Nov 4, 2024
264f5f6
fix
IlyasMoutawwakil Nov 4, 2024
093e7c3
fix
IlyasMoutawwakil Nov 4, 2024
d22382c
fix
IlyasMoutawwakil Nov 4, 2024
a77549c
test num images
IlyasMoutawwakil Nov 4, 2024
69f344b
single process to reduce re-exports
IlyasMoutawwakil Nov 4, 2024
3400727
test
IlyasMoutawwakil Nov 15, 2024
8b12d7a
revert unnecessary changes
IlyasMoutawwakil Nov 18, 2024
b1b295d
T5Encoder inherits from TextEncoder
IlyasMoutawwakil Nov 18, 2024
7f03a0e
style
IlyasMoutawwakil Nov 18, 2024
395a4f7
fix typo in timestep
IlyasMoutawwakil Nov 18, 2024
a1827d0
style
IlyasMoutawwakil Nov 18, 2024
813002d
only test sd3 and flux on latest transformers
IlyasMoutawwakil Nov 18, 2024
2183fb7
conditional sd3 and flux modeling
IlyasMoutawwakil Nov 19, 2024
417db60
forgot sd3 inpaint
IlyasMoutawwakil Nov 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions .github/workflows/test_onnxruntime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,11 @@ jobs:
os: ubuntu-20.04

runs-on: ${{ matrix.os }}

steps:
- name: Free Disk Space (Ubuntu)
if: matrix.os == 'ubuntu-20.04'
uses: jlumbroso/free-disk-space@main
with:
tool-cache: false
swap-storage: false
large-packages: false

- name: Checkout code
uses: actions/checkout@v4
Expand All @@ -52,13 +49,11 @@ jobs:
run: pip install transformers==${{ matrix.transformers-version }}

- name: Test with pytest (in series)
working-directory: tests
run: |
pytest onnxruntime -m "run_in_series" --durations=0 -vvvv -s
pytest tests/onnxruntime -m "run_in_series" --durations=0 -vvvv -s

- name: Test with pytest (in parallel)
run: |
pytest tests/onnxruntime -m "not run_in_series" --durations=0 -vvvv -s -n auto
env:
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
working-directory: tests
run: |
pytest onnxruntime -m "not run_in_series" --durations=0 -vvvv -s -n auto
1 change: 1 addition & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def fix_dynamic_axes(
input_shapes = {}
dummy_inputs = self.generate_dummy_inputs(framework="np", **input_shapes)
dummy_inputs = self.generate_dummy_inputs_for_validation(dummy_inputs, onnx_input_names=onnx_input_names)
dummy_inputs = self.rename_ambiguous_inputs(dummy_inputs)

onnx_inputs = {}
for name, value in dummy_inputs.items():
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,10 @@ def onnx_export_from_model(
if tokenizer_2 is not None:
tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))

tokenizer_3 = getattr(model, "tokenizer_3", None)
if tokenizer_3 is not None:
tokenizer_3.save_pretrained(output.joinpath("tokenizer_3"))

model.save_config(output)

if float_dtype == "bf16":
Expand Down
123 changes: 109 additions & 14 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model specific ONNX configurations."""

import random
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
Expand All @@ -28,6 +29,8 @@
DummyCodegenDecoderTextInputGenerator,
DummyDecoderTextInputGenerator,
DummyEncodecInputGenerator,
DummyFluxTransformerTextInputGenerator,
DummyFluxTransformerVisionInputGenerator,
DummyInputGenerator,
DummyIntGenerator,
DummyPastKeyValuesGenerator,
Expand All @@ -38,6 +41,9 @@
DummySpeechT5InputGenerator,
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyTransformerTextInputGenerator,
DummyTransformerTimestepInputGenerator,
DummyTransformerVisionInputGenerator,
DummyVisionEmbeddingsGenerator,
DummyVisionEncoderDecoderPastKeyValuesGenerator,
DummyVisionInputGenerator,
Expand All @@ -53,6 +59,7 @@
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedVisionConfig,
check_if_diffusers_greater,
check_if_transformers_greater,
is_diffusers_available,
logging,
Expand Down Expand Up @@ -1037,22 +1044,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
"pooler_output": {0: "batch_size"},
}

if self._normalized_config.output_hidden_states:
for i in range(self._normalized_config.num_layers + 1):
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}

return common_outputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)

# TODO: fix should be by casting inputs during inference and not export
if framework == "pt":
import torch

dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
return dummy_inputs

def patch_model_for_export(
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
Expand All @@ -1062,7 +1060,7 @@ def patch_model_for_export(


class UNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
ATOL_FOR_VALIDATION = 1e-4
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
# operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
Expand All @@ -1085,17 +1083,19 @@ class UNetOnnxConfig(VisionOnnxConfig):
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"sample": {0: "batch_size", 2: "height", 3: "width"},
"timestep": {0: "steps"},
"timestep": {}, # a scalar with no dimension
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
}

# TODO : add text_image, image and image_embeds
# TODO : add addition_embed_type == text_image, image and image_embeds
# https://github.com/huggingface/diffusers/blob/9366c8f84bfe47099ff047272661786ebb54721d/src/diffusers/models/unets/unet_2d_condition.py#L671
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
common_inputs["text_embeds"] = {0: "batch_size"}
common_inputs["time_ids"] = {0: "batch_size"}

if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None:
common_inputs["timestep_cond"] = {0: "batch_size"}

return common_inputs

@property
Expand Down Expand Up @@ -1134,7 +1134,7 @@ def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]:


class VaeEncoderOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
ATOL_FOR_VALIDATION = 3e-4
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
# operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
Expand Down Expand Up @@ -1182,6 +1182,101 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
}


class T5EncoderOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 12 # int64 was supported since opset 12

@property
def inputs(self):
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
}

@property
def outputs(self):
return {
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
}


class SD3TransformerOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
# operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14

DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestepInputGenerator,
DummyTransformerVisionInputGenerator,
DummyTransformerTextInputGenerator,
)

NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
image_size="sample_size",
num_channels="in_channels",
vocab_size="attention_head_dim",
hidden_size="joint_attention_dim",
projection_size="pooled_projection_dim",
allow_new=True,
)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"hidden_states": {0: "batch_size", 2: "height", 3: "width"},
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
"pooled_projections": {0: "batch_size"},
"timestep": {0: "step"},
}

return common_inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"out_hidden_states": {0: "batch_size", 2: "height", 3: "width"},
}

@property
def torch_to_onnx_output_map(self) -> Dict[str, str]:
return {
"sample": "out_hidden_states",
}


class FluxTransformerOnnxConfig(SD3TransformerOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTransformerTimestepInputGenerator,
DummyFluxTransformerVisionInputGenerator,
DummyFluxTransformerTextInputGenerator,
)

@property
def inputs(self):
common_inputs = super().inputs
common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"}
common_inputs["txt_ids"] = (
{0: "sequence_length"} if check_if_diffusers_greater("0.31.0") else {0: "batch_size", 1: "sequence_length"}
)
common_inputs["img_ids"] = (
{0: "packed_height_width"}
if check_if_diffusers_greater("0.31.0")
else {0: "batch_size", 1: "packed_height_width"}
)

if getattr(self._normalized_config, "guidance_embeds", False):
common_inputs["guidance"] = {0: "batch_size"}

return common_inputs

@property
def outputs(self):
return {
"out_hidden_states": {0: "batch_size", 1: "packed_height_width"},
}


class GroupViTOnnxConfig(CLIPOnnxConfig):
pass

Expand Down
29 changes: 23 additions & 6 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,27 @@ class TasksManager:
}

_DIFFUSERS_SUPPORTED_MODEL_TYPE = {
"clip-text-model": supported_tasks_mapping(
"t5-encoder": supported_tasks_mapping(
"feature-extraction",
onnx="T5EncoderOnnxConfig",
),
"clip-text": supported_tasks_mapping(
"feature-extraction",
onnx="CLIPTextOnnxConfig",
),
"clip-text-with-projection": supported_tasks_mapping(
"feature-extraction",
onnx="CLIPTextWithProjectionOnnxConfig",
),
"unet": supported_tasks_mapping(
"flux-transformer-2d": supported_tasks_mapping(
"semantic-segmentation",
onnx="FluxTransformerOnnxConfig",
),
"sd3-transformer-2d": supported_tasks_mapping(
"semantic-segmentation",
onnx="SD3TransformerOnnxConfig",
),
"unet-2d-condition": supported_tasks_mapping(
"semantic-segmentation",
onnx="UNetOnnxConfig",
),
Expand Down Expand Up @@ -1177,12 +1189,17 @@ class TasksManager:
"transformers": _SUPPORTED_MODEL_TYPE,
}
_UNSUPPORTED_CLI_MODEL_TYPE = {
"unet",
# diffusers model types
"clip-text",
"clip-text-with-projection",
"flux-transformer-2d",
"sd3-transformer-2d",
"t5-encoder",
"unet-2d-condition",
"vae-encoder",
"vae-decoder",
"clip-text-model",
"clip-text-with-projection",
"trocr", # supported through the vision-encoder-decoder model type
# redundant model types
"trocr", # same as vision-encoder-decoder
}
_SUPPORTED_CLI_MODEL_TYPE = (
set(_SUPPORTED_MODEL_TYPE.keys())
Expand Down
Loading
Loading