Skip to content

Commit

Permalink
Fix disable shape inference for optimization (#652)
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss authored Dec 29, 2022
1 parent 725abec commit 9ac1703
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions optimum/onnxruntime/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,30 @@ def optimize(
LOGGER.info("Optimizing model...")

for model_path in self.onnx_model_path:
optimizer = optimize_model(
model_path.as_posix(),
model_type,
self.normalized_config.num_attention_heads,
self.normalized_config.hidden_size,
opt_level=optimization_config.optimization_level,
optimization_options=optimization_options,
use_gpu=optimization_config.optimize_for_gpu,
only_onnxruntime=not optimization_config.enable_transformers_specific_optimizations,
)

if optimization_config.fp16:
# keep_io_types to keep inputs/outputs as float32
optimizer.convert_float_to_float16(keep_io_types=True)
try:
optimizer = optimize_model(
model_path.as_posix(),
model_type,
self.normalized_config.num_attention_heads,
self.normalized_config.hidden_size,
opt_level=optimization_config.optimization_level,
optimization_options=optimization_options,
use_gpu=optimization_config.optimize_for_gpu,
only_onnxruntime=not optimization_config.enable_transformers_specific_optimizations,
)

if optimization_config.fp16:
# keep_io_types to keep inputs/outputs as float32
optimizer.convert_float_to_float16(
use_symbolic_shape_infer=not optimization_config.disable_shape_inference, keep_io_types=True
)
except Exception as e:
if "Incomplete symbolic shape inference" in str(e):
err = RuntimeError(
f"{str(e)}. Try to set `disable_shape_inference=True` in your optimization configuration."
)
raise err from e
raise

suffix = f"_{file_suffix}" if file_suffix else ""
output_path = save_dir.joinpath(f"{model_path.stem}{suffix}").with_suffix(model_path.suffix)
Expand Down

0 comments on commit 9ac1703

Please sign in to comment.