Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for custom input shapes in exporters onnx #575

Merged
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0be6099
add support for custom input shapes in exporters onnx
fxmarty Dec 12, 2022
a3ceec9
add tests
fxmarty Dec 12, 2022
cff1a93
Merge branch 'master' into exporters-onnx-variable-inputs-shapes
fxmarty Dec 13, 2022
b16c70e
Merge branch 'master' into exporters-onnx-variable-inputs-shapes
fxmarty Dec 14, 2022
10d7b21
fix test
fxmarty Dec 14, 2022
a595236
fix bug
fxmarty Dec 14, 2022
3d792ca
Update tests/exporters/test_onnx_export.py
fxmarty Dec 16, 2022
e5a14dc
Merge branch 'master' into exporters-onnx-variable-inputs-shapes
fxmarty Dec 16, 2022
435fd0d
Merge branch 'exporters-onnx-variable-inputs-shapes' of https://githu…
fxmarty Dec 16, 2022
c063e9d
remove run slow
fxmarty Dec 16, 2022
671c09b
fix _get_models_to_test
fxmarty Dec 16, 2022
5e8118e
Merge branch 'master' into exporters-onnx-variable-inputs-shapes
fxmarty Dec 20, 2022
2a9cb29
fixup post merge
fxmarty Dec 20, 2022
00a973b
yet fixup post merge
fxmarty Dec 20, 2022
d338cd7
use f string
fxmarty Dec 20, 2022
b190cb7
authorize unused arguments in input generators
fxmarty Dec 20, 2022
ddb6c79
use hf-internal-testing models
fxmarty Dec 20, 2022
8024706
Merge branch 'master' into exporters-onnx-variable-inputs-shapes
fxmarty Dec 20, 2022
116ae21
remove redundant argument
fxmarty Dec 20, 2022
92fe49a
fixup merge
fxmarty Dec 20, 2022
7bb9742
remove redundant arg
fxmarty Dec 20, 2022
a36d945
fix messed merge
fxmarty Dec 21, 2022
63f1dca
change blenderbot
fxmarty Dec 21, 2022
7916ae8
workign deit
fxmarty Dec 21, 2022
5bd4d0d
valid gptj
fxmarty Dec 21, 2022
2f1effc
valid marian
fxmarty Dec 21, 2022
4f9e5f9
add whisper
fxmarty Dec 21, 2022
163cc00
fix test
fxmarty Dec 21, 2022
8775425
fix marian
fxmarty Dec 21, 2022
5f28904
Merge branch 'master' into exporters-onnx-variable-inputs-shapes
fxmarty Dec 21, 2022
1ed3997
Merge branch 'master' into exporters-onnx-variable-inputs-shapes
fxmarty Dec 21, 2022
81d693d
nit
fxmarty Dec 21, 2022
ecc36ba
fix bloom
fxmarty Dec 21, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_exporters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ jobs:
- name: Test with unittest
working-directory: tests
run: |
RUN_SLOW=1 pytest exporters --durations=0
pytest exporters -s --durations=0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure you want to output everything in the log?

Copy link
Contributor Author

@fxmarty fxmarty Dec 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, why not?

120 changes: 99 additions & 21 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
# limitations under the License.
"""Entry point to the optimum.exporters.onnx command line."""

from argparse import ArgumentParser
import argparse
from pathlib import Path

from transformers import AutoTokenizer

from ...utils import logging
from ...utils import DEFAULT_DUMMY_SHAPES, logging
from ...utils.save_utils import maybe_save_preprocessors
from ..tasks import TasksManager
from .base import OnnxConfigWithPast
Expand All @@ -32,20 +32,48 @@


def main():
parser = ArgumentParser("Hugging Face Optimum ONNX exporter")
parser.add_argument(
parser = argparse.ArgumentParser(
"Hugging Face Optimum ONNX exporter", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

required_group = parser.add_argument_group("Required arguments")
required_group.add_argument(
"-m", "--model", type=str, required=True, help="Model ID on huggingface.co or path on disk to load model from."
)
parser.add_argument(
required_group.add_argument(
"output", type=Path, help="Path indicating the directory where to store generated ONNX model."
)

optional_group = parser.add_argument_group("Optional arguments")
optional_group.add_argument(
"--task",
default="auto",
help="The type of task to export the model with.",
help=(
"The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among:"
f" {str(list(TasksManager._TASKS_TO_AUTOMODELS.keys()))}. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder."
),
)
optional_group.add_argument(
"--for-ort",
action="store_true",
help=(
"This exports models ready to be run with Optimum's ORTModel. Useful for encoder-decoder models for"
"conditional generation. If enabled the encoder and decoder of the model are exported separately."
),
)
optional_group.add_argument(
"--opset",
type=int,
default=None,
help="If specified, ONNX opset version to export the model with. Otherwise, the default opset will be used.",
)
parser.add_argument("--opset", type=int, default=None, help="ONNX opset version to export the model with.")
parser.add_argument(
"--atol", type=float, default=None, help="Absolute difference tolerance when validating the model."
optional_group.add_argument(
"--atol",
type=float,
default=None,
help="If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used.",
)
parser.add_argument(
optional_group.add_argument(
"--framework",
type=str,
choices=["pt", "tf"],
Expand All @@ -56,7 +84,7 @@ def main():
" or what is available in the environment."
),
)
parser.add_argument(
optional_group.add_argument(
"--pad_token_id",
type=int,
default=None,
Expand All @@ -65,16 +93,60 @@ def main():
" it."
),
)
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
parser.add_argument(
"--for-ort",
action="store_true",
help=(
"This exports models ready to be run with optimum.onnxruntime. Useful for encoder-decoder models for"
"conditional generation. If enabled the encoder and decoder of the model are exported separately."
),
optional_group.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")

input_group = parser.add_argument_group(
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input.)"
)
doc_input = " to use in the example input given to the ONNX export."
input_group.add_argument(
"--batch_size",
type=int,
default=DEFAULT_DUMMY_SHAPES["batch_size"],
help="Text tasks only. Batch size" + doc_input,
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
)
input_group.add_argument(
"--sequence_length",
type=int,
default=DEFAULT_DUMMY_SHAPES["sequence_length"],
help="Text tasks only. Sequence length " + doc_input,
)
input_group.add_argument(
"--num_choices",
type=int,
default=DEFAULT_DUMMY_SHAPES["num_choices"],
help="Text tasks only. Num choices " + doc_input,
)
input_group.add_argument(
"--width", type=int, default=DEFAULT_DUMMY_SHAPES["width"], help="Image tasks only. Width " + doc_input
)
input_group.add_argument(
"--height", type=int, default=DEFAULT_DUMMY_SHAPES["height"], help="Image tasks only. Height " + doc_input
)
input_group.add_argument(
"--num_channels",
type=int,
default=DEFAULT_DUMMY_SHAPES["num_channels"],
help="Image tasks only. Number of channels " + doc_input,
)
input_group.add_argument(
"--feature_size",
type=int,
default=DEFAULT_DUMMY_SHAPES["feature_size"],
help="Audio tasks only. Feature size " + doc_input,
)
input_group.add_argument(
"--nb_max_frames",
type=int,
default=DEFAULT_DUMMY_SHAPES["nb_max_frames"],
help="Audio tasks only. Maximum number of frames " + doc_input,
)
input_group.add_argument(
"--audio_sequence_length",
type=int,
default=DEFAULT_DUMMY_SHAPES["audio_sequence_length"],
help="Audio tasks only. Audio sequence length " + doc_input,
)
parser.add_argument("output", type=Path, help="Path indicating the directory where to store generated ONNX model.")

# Retrieve CLI arguments
args = parser.parse_args()
Expand All @@ -93,6 +165,11 @@ def main():
f"The task could not be automatically inferred. Please provide the argument --task with the task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)

# get input shapes
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
input_shapes = {}
for input_name in DEFAULT_DUMMY_SHAPES.keys():
input_shapes[input_name] = getattr(args, input_name)

# Allocate the model
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
model = TasksManager.get_model_from_task(task, args.model, framework=args.framework, cache_dir=args.cache_dir)
model_type = model.config.model_type.replace("_", "-")
Expand Down Expand Up @@ -145,9 +222,10 @@ def main():
opset=args.opset,
output_dir=args.output.parent,
fn_get_models_from_config=fn_get_models_from_config,
input_shapes=input_shapes,
)
else:
onnx_inputs, onnx_outputs = export(model, onnx_config, args.opset, args.output)
onnx_inputs, onnx_outputs = export(model, onnx_config, args.opset, args.output, input_shapes=input_shapes)
fxmarty marked this conversation as resolved.
Show resolved Hide resolved

# Saving the model config as this is needed sometimes.
model.config.save_pretrained(args.output.parent)
Expand Down
42 changes: 25 additions & 17 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def validate_models_outputs(
Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], OnnxConfig]],
],
output_names: Optional[List[str]] = None,
input_shapes: Dict = {},
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Validates the export of several models, by checking that the outputs from both the reference and the exported model match.
Expand All @@ -99,6 +100,9 @@ def validate_models_outputs(
output_names (`Optional[List[str]]`, defaults to `None`):
The names to use for the exported ONNX files. The order must be the same as the order of submodels in the ordered dict returned by `fn_get_models_from_config`.
If None, will use the keys from the output of `fn_get_models_from_config` as names.
input_shapes (`Dict`, defaults to `{}`):
If specified, allows to use specific shapes to validate the ONNX model on.

Raises:
ValueError: If the outputs shapes or values do not match between the reference and the exported model.
"""
Expand All @@ -122,11 +126,7 @@ def validate_models_outputs(
else output_dir.joinpath(model_name + ".onnx")
)
validate_model_outputs(
sub_onnx_config,
submodel,
onnx_model_path,
onnx_named_outputs[i],
atol,
sub_onnx_config, submodel, onnx_model_path, onnx_named_outputs[i], atol, input_shapes=input_shapes
)


Expand All @@ -136,6 +136,7 @@ def validate_model_outputs(
onnx_model: Path,
onnx_named_outputs: List[str],
atol: float,
input_shapes: Dict = {},
):
"""
Validates the export by checking that the outputs from both the reference and the exported model match.
Expand All @@ -151,6 +152,8 @@ def validate_model_outputs(
The names of the outputs to check.
atol (`float`):
The absolute tolerance in terms of outputs difference between the reference and the exported model.
input_shapes (`Dict`, defaults to `{}`):
If specified, allows to use specific shapes to validate the ONNX model on.

Raises:
ValueError: If the outputs shapes or values do not match between the reference and the exported model.
Expand All @@ -160,7 +163,7 @@ def validate_model_outputs(
logger.info("Validating ONNX model...")

framework = "pt" if is_torch_available() and issubclass(type(reference_model), PreTrainedModel) else "tf"
reference_model_inputs = config.generate_dummy_inputs(framework=framework)
reference_model_inputs = config.generate_dummy_inputs(framework=framework, **input_shapes)

# Create ONNX Runtime session
options = SessionOptions()
Expand Down Expand Up @@ -252,6 +255,7 @@ def export_pytorch(
opset: int,
output: Path,
device: str = "cpu",
input_shapes: Dict = dict(),
) -> Tuple[List[str], List[str]]:
"""
Exports a PyTorch model to an ONNX Intermediate Representation.
Expand All @@ -268,6 +272,8 @@ def export_pytorch(
device (`str`, *optional*, defaults to `cpu`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
input_shapes (`Dict`, defaults to `{}`):
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.

Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand All @@ -291,7 +297,7 @@ def export_pytorch(
setattr(model.config, override_config_key, override_config_value)

# Check that inputs match, and order them properly
dummy_inputs = config.generate_dummy_inputs(framework="pt")
dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes)
device = torch.device(device)
if device.type == "cuda" and torch.cuda.is_available():
model.to(device)
Expand Down Expand Up @@ -405,6 +411,7 @@ def export_models(
],
output_names: Optional[List[str]] = None,
device: str = "cpu",
input_shapes: Dict = dict(),
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Exports a Pytorch or TensorFlow encoder decoder model to an ONNX Intermediate Representation.
Expand All @@ -429,6 +436,8 @@ def export_models(
device (`str`, *optional*, defaults to `cpu`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
input_shapes (`Dict`, defaults to `{}`):
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.
Returns:
`Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named
inputs from the ONNX configuration.
Expand All @@ -448,15 +457,7 @@ def export_models(
if output_names is not None
else output_dir.joinpath(model_name + ".onnx")
)
outputs.append(
export(
submodel,
sub_onnx_config,
opset,
output_path,
device=device,
)
)
outputs.append(export(submodel, sub_onnx_config, opset, output_path, device=device, input_shapes=input_shapes))

outputs = list(map(list, zip(*outputs)))
return outputs
Expand All @@ -468,6 +469,7 @@ def export(
opset: int,
output: Path,
device: str = "cpu",
input_shapes: Dict = {},
) -> Tuple[List[str], List[str]]:
"""
Exports a Pytorch or TensorFlow model to an ONNX Intermediate Representation.
Expand All @@ -484,6 +486,8 @@ def export(
device (`str`, *optional*, defaults to `cpu`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
input_shapes (`Dict`, defaults to `{}`):
If specified, allows to use specific shapes for the example input provided to the ONNX exporter.

Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
Expand All @@ -508,11 +512,15 @@ def export(
f"Unsupported PyTorch version for this model. Minimum required is {config.MIN_TORCH_VERSION},"
f" got: {torch.__version__}"
)
return export_pytorch(model, config, opset, output, device=device)
return export_pytorch(model, config, opset, output, device=device, input_shapes=input_shapes)

elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
if device == "cuda":
raise RuntimeError("`tf2onnx` does not support export on CUDA device.")
if input_shapes != dict():
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
"`input_shapes` argument is not supported by the Tensorflow ONNX export and will be ignored."
)
return export_tensorflow(model, config, opset, output)

else:
Expand Down
17 changes: 9 additions & 8 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from packaging import version

from ...utils import (
DEFAULT_DUMMY_SHAPES,
DummyDecoderTextInputGenerator,
DummyPastKeyValuesGenerator,
DummySeq2SeqDecoderTextInputGenerator,
Expand Down Expand Up @@ -332,9 +333,9 @@ def __init__(
self,
task: str,
normalized_config: NormalizedSeq2SeqConfig,
batch_size: int = 2,
sequence_length: int = 16,
num_choices: int = 4,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
num_choices: int = DEFAULT_DUMMY_SHAPES["num_choices"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
random_num_choices_range: Optional[Tuple[int, int]] = None,
Expand Down Expand Up @@ -400,7 +401,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen
self.task, self._normalized_config, **kwargs
)
dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1](
self.task, self._normalized_config, batch_size=dummy_text_input_generator.batch_size, **kwargs
self.task, self._normalized_config, **kwargs
)
task = "default" if self.task != "causal-lm" else "causal-lm"
kwargs = {}
Expand Down Expand Up @@ -486,12 +487,12 @@ def outputs(self) -> Mapping[str, Mapping[int, str]]:
}
return common_outputs

def generate_dummy_inputs(self, framework: str = "pt"):
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
# This will handle the attention mask padding when Bart is used for causal-lm.
if self.task == "causal-lm":
self.PAD_ATTENTION_MASK_TO_MATCH_TOTAL_SEQUENCE_LENGTH = True

dummy_inputs = super().generate_dummy_inputs(framework=framework)
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)

# if self.use_past and self.task in ["default", "seq2seq-lm"]:
# attention_mask_length = dummy_inputs["decoder_attention_mask"].shape[1]
Expand Down Expand Up @@ -740,9 +741,9 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
# "attention_mask": dynamic_axis,
}

def generate_dummy_inputs(self, framework: str = "pt"):
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
self.is_generating_dummy_inputs = True
dummy_inputs = super().generate_dummy_inputs(framework=framework)
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
specialized_inputs_name = self.inputs_name
self.is_generating_dummy_inputs = True
dummy_inputs[self.inputs_name] = dummy_inputs.pop(specialized_inputs_name)
Expand Down
Loading