From 34a345fa5eeb7d5ae7e63d0f8c61c29cd8d40754 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 4 Apr 2024 13:56:47 -0700 Subject: [PATCH 1/3] Migrate onnxrewriter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Squashed of the following steps: - #1328 - #1329 - #1330 - #1331 - #1332 - #1333 - #1343 - #1345 Co-authored-by: Shubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com> Co-authored-by: Justin Chu Co-authored-by: Xavier Dupré Co-authored-by: "G. Ramalingam" Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Co-authored-by: Ti-Tai Wang [ghstack-poisoned] --- .gitattributes | 2 + .github/workflows/main.yaml | 2 + .lintrunner.toml | 31 +- examples/pattern_rewriting.py | 191 +++ onnxscript/_legacy_ir/__init__.py | 338 +++++ onnxscript/_legacy_ir/irbuilder.py | 210 +++ onnxscript/_legacy_ir/irbuilder_test.py | 198 +++ onnxscript/_legacy_ir/protobuilder.py | 126 ++ onnxscript/_legacy_ir/protobuilder_test.py | 215 +++ onnxscript/_legacy_ir/visitor.py | 922 +++++++++++++ onnxscript/_legacy_ir/visitor_test.py | 38 + onnxscript/backend/onnx_export_test.py | 6 +- onnxscript/converter_test.py | 60 +- .../function_libs/torch_lib/graph_building.py | 2 +- onnxscript/ir/serde.py | 18 +- onnxscript/optimizer/__init__.py | 110 ++ onnxscript/optimizer/constant_folding.py | 280 ++++ onnxscript/optimizer/constant_folding_test.py | 444 +++++++ onnxscript/optimizer/copy_propagation.py | 81 ++ onnxscript/optimizer/copy_propagation_test.py | 49 + onnxscript/optimizer/evaluator.py | 434 ++++++ onnxscript/optimizer/fold_constants_v0.py | 248 ++++ onnxscript/optimizer/function_folding_test.py | 162 +++ onnxscript/optimizer/remove_unused.py | 127 ++ .../optimizer/remove_unused_function.py | 56 + onnxscript/optimizer/remove_unused_test.py | 173 +++ .../optimizer/simple_function_folding.py | 239 ++++ .../optimizer/simple_function_folding_test.py | 218 +++ onnxscript/rewriter/__init__.py | 39 + onnxscript/rewriter/broadcast_to_matmul.py | 175 +++ .../rewriter/broadcast_to_matmul_test.py | 283 ++++ onnxscript/rewriter/cast_constant_of_shape.py | 75 ++ .../rewriter/cast_constant_of_shape_test.py | 46 + onnxscript/rewriter/erfgelu.py | 30 + onnxscript/rewriter/function_rule.py | 242 ++++ onnxscript/rewriter/gemm_to_matmul_add.py | 21 + .../rewriter/gemm_to_matmul_add_test.py | 254 ++++ onnxscript/rewriter/generic_pattern.py | 1165 +++++++++++++++++ onnxscript/rewriter/generic_pattern_test.py | 501 +++++++ onnxscript/rewriter/no_op.py | 44 + onnxscript/rewriter/no_op_test.py | 180 +++ onnxscript/rewriter/onnxruntime/__init__.py | 59 + .../group_normalization_merge_silu.py | 58 + .../group_normalization_merge_silu_test.py | 125 ++ .../instance_to_group_normalization.py | 152 +++ .../instance_to_group_normalization_test.py | 435 ++++++ onnxscript/rewriter/onnxruntime/softmax.py | 64 + .../rewriter/onnxruntime/softmax_test.py | 92 ++ .../onnxruntime/transformers/__init__.py | 16 + .../onnxruntime/transformers/fastgelu.py | 31 + .../onnxruntime/transformers/fastgelu_test.py | 21 + .../onnxruntime/transformers/layernorm.py | 45 + .../transformers/layernorm_test.py | 21 + .../transformers/multihead_attention.py | 604 +++++++++ .../transformers/multihead_attention_test.py | 71 + onnxscript/rewriter/pattern.py | 1069 +++++++++++++++ onnxscript/rewriter/pattern_test.py | 305 +++++ .../{testing.py => testing/__init__.py} | 135 +- onnxscript/tests/common/testutils.py | 14 - onnxscript/type_annotation_test.py | 2 +- onnxscript/utils/__init__.py | 0 onnxscript/utils/evaluation_utils.py | 54 + onnxscript/utils/timing_utils.py | 33 + onnxscript/utils/utils.py | 82 ++ pyproject.toml | 49 +- requirements-dev.txt | 2 +- .../Speech2Text2ForCausalLM_dynamo.onnx | 3 + .../dynamo/test_data_set_0/input_0.pb | 3 + .../dynamo/test_data_set_0/input_1.pb | 3 + .../dynamo/test_data_set_0/output_0.pb | 3 + .../dynamo/test_data_set_0/output_1.pb | 3 + .../dynamo/test_data_set_0/output_10.pb | 3 + .../dynamo/test_data_set_0/output_11.pb | 3 + .../dynamo/test_data_set_0/output_12.pb | 3 + .../dynamo/test_data_set_0/output_13.pb | 3 + .../dynamo/test_data_set_0/output_2.pb | 3 + .../dynamo/test_data_set_0/output_3.pb | 3 + .../dynamo/test_data_set_0/output_4.pb | 3 + .../dynamo/test_data_set_0/output_5.pb | 3 + .../dynamo/test_data_set_0/output_6.pb | 3 + .../dynamo/test_data_set_0/output_7.pb | 3 + .../dynamo/test_data_set_0/output_8.pb | 3 + .../dynamo/test_data_set_0/output_9.pb | 3 + .../dynamo/mobilenetv2_100_dynamo.onnx | 3 + .../dynamo/test_data_set_0/input_0.pb | 3 + .../dynamo/test_data_set_0/output_0.pb | 3 + .../resnet18/dynamo/resnet18_dynamo.onnx | 3 + .../dynamo/test_data_set_0/input_0.pb | 3 + .../dynamo/test_data_set_0/output_0.pb | 3 + .../attn_llama2_4_34_0.onnx | 3 + .../test_data_set_0/input_0.pb | 3 + .../test_data_set_0/input_1.pb | 3 + .../test_data_set_0/input_2.pb | 3 + .../test_data_set_0/input_3.pb | 3 + .../test_data_set_0/input_4.pb | 3 + .../test_data_set_0/input_5.pb | 3 + .../test_data_set_0/input_6.pb | 3 + .../test_data_set_0/input_7.pb | 3 + .../test_data_set_0/input_8.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../test_data_set_0/output_1.pb | 3 + .../test_data_set_0/output_2.pb | 3 + .../attn_llama2_4_34_1.onnx | 3 + .../test_data_set_0/input_0.pb | 3 + .../test_data_set_0/input_1.pb | 3 + .../test_data_set_0/input_2.pb | 3 + .../test_data_set_0/input_3.pb | 3 + .../test_data_set_0/input_4.pb | 3 + .../test_data_set_0/input_5.pb | 3 + .../test_data_set_0/input_6.pb | 3 + .../test_data_set_0/input_7.pb | 3 + .../test_data_set_0/input_8.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../test_data_set_0/output_1.pb | 3 + .../test_data_set_0/output_2.pb | 3 + .../attn_llama2_4_36_0.onnx | 3 + .../test_data_set_0/input_0.pb | 3 + .../test_data_set_0/input_1.pb | 3 + .../test_data_set_0/input_2.pb | 3 + .../test_data_set_0/input_3.pb | 3 + .../test_data_set_0/input_4.pb | 3 + .../test_data_set_0/input_5.pb | 3 + .../test_data_set_0/input_6.pb | 3 + .../test_data_set_0/input_7.pb | 3 + .../test_data_set_0/input_8.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../test_data_set_0/output_1.pb | 3 + .../test_data_set_0/output_2.pb | 3 + .../attn_phi_1_5_0/attn_phi_1_5_0.onnx | 3 + .../attn_phi_1_5_0/test_data_set_0/input_0.pb | 3 + .../attn_phi_1_5_0/test_data_set_0/input_1.pb | 3 + .../test_data_set_0/input_10.pb | 3 + .../test_data_set_0/input_11.pb | 3 + .../test_data_set_0/input_12.pb | 3 + .../attn_phi_1_5_0/test_data_set_0/input_2.pb | 3 + .../attn_phi_1_5_0/test_data_set_0/input_3.pb | 3 + .../attn_phi_1_5_0/test_data_set_0/input_4.pb | 3 + .../attn_phi_1_5_0/test_data_set_0/input_5.pb | 3 + .../attn_phi_1_5_0/test_data_set_0/input_6.pb | 3 + .../attn_phi_1_5_0/test_data_set_0/input_7.pb | 3 + .../attn_phi_1_5_0/test_data_set_0/input_8.pb | 3 + .../attn_phi_1_5_0/test_data_set_0/input_9.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../test_data_set_0/output_1.pb | 3 + .../test_data_set_0/output_2.pb | 3 + .../attn_phi_1_5_1/attn_phi_1_5_1.onnx | 3 + .../attn_phi_1_5_1/test_data_set_0/input_0.pb | 3 + .../attn_phi_1_5_1/test_data_set_0/input_1.pb | 3 + .../test_data_set_0/input_10.pb | 3 + .../test_data_set_0/input_11.pb | 3 + .../test_data_set_0/input_12.pb | 3 + .../attn_phi_1_5_1/test_data_set_0/input_2.pb | 3 + .../attn_phi_1_5_1/test_data_set_0/input_3.pb | 3 + .../attn_phi_1_5_1/test_data_set_0/input_4.pb | 3 + .../attn_phi_1_5_1/test_data_set_0/input_5.pb | 3 + .../attn_phi_1_5_1/test_data_set_0/input_6.pb | 3 + .../attn_phi_1_5_1/test_data_set_0/input_7.pb | 3 + .../attn_phi_1_5_1/test_data_set_0/input_8.pb | 3 + .../attn_phi_1_5_1/test_data_set_0/input_9.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../test_data_set_0/output_1.pb | 3 + .../test_data_set_0/output_2.pb | 3 + .../attn_phi_1_5_2/attn_phi_1_5_2.onnx | 3 + .../attn_phi_1_5_2/test_data_set_0/input_0.pb | 3 + .../attn_phi_1_5_2/test_data_set_0/input_1.pb | 3 + .../test_data_set_0/input_10.pb | 3 + .../test_data_set_0/input_11.pb | 3 + .../test_data_set_0/input_12.pb | 3 + .../attn_phi_1_5_2/test_data_set_0/input_2.pb | 3 + .../attn_phi_1_5_2/test_data_set_0/input_3.pb | 3 + .../attn_phi_1_5_2/test_data_set_0/input_4.pb | 3 + .../attn_phi_1_5_2/test_data_set_0/input_5.pb | 3 + .../attn_phi_1_5_2/test_data_set_0/input_6.pb | 3 + .../attn_phi_1_5_2/test_data_set_0/input_7.pb | 3 + .../attn_phi_1_5_2/test_data_set_0/input_8.pb | 3 + .../attn_phi_1_5_2/test_data_set_0/input_9.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../test_data_set_0/output_1.pb | 3 + .../test_data_set_0/output_2.pb | 3 + .../attn_phi_1_5_3/attn_phi_1_5_3.onnx | 3 + .../attn_phi_1_5_3/test_data_set_0/input_0.pb | 3 + .../attn_phi_1_5_3/test_data_set_0/input_1.pb | 3 + .../test_data_set_0/input_10.pb | 3 + .../test_data_set_0/input_11.pb | 3 + .../test_data_set_0/input_12.pb | 3 + .../attn_phi_1_5_3/test_data_set_0/input_2.pb | 3 + .../attn_phi_1_5_3/test_data_set_0/input_3.pb | 3 + .../attn_phi_1_5_3/test_data_set_0/input_4.pb | 3 + .../attn_phi_1_5_3/test_data_set_0/input_5.pb | 3 + .../attn_phi_1_5_3/test_data_set_0/input_6.pb | 3 + .../attn_phi_1_5_3/test_data_set_0/input_7.pb | 3 + .../attn_phi_1_5_3/test_data_set_0/input_8.pb | 3 + .../attn_phi_1_5_3/test_data_set_0/input_9.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../test_data_set_0/output_1.pb | 3 + .../test_data_set_0/output_2.pb | 3 + .../attn_yi_4_37_0/attn_yi_4_37_0.onnx | 3 + .../attn_yi_4_37_0/test_data_set_0/input_0.pb | 3 + .../attn_yi_4_37_0/test_data_set_0/input_1.pb | 3 + .../attn_yi_4_37_0/test_data_set_0/input_2.pb | 3 + .../attn_yi_4_37_0/test_data_set_0/input_3.pb | 3 + .../attn_yi_4_37_0/test_data_set_0/input_4.pb | 3 + .../attn_yi_4_37_0/test_data_set_0/input_5.pb | 3 + .../attn_yi_4_37_0/test_data_set_0/input_6.pb | 3 + .../attn_yi_4_37_0/test_data_set_0/input_7.pb | 3 + .../attn_yi_4_37_0/test_data_set_0/input_8.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../test_data_set_0/output_1.pb | 3 + .../test_data_set_0/output_2.pb | 3 + .../gelu_phi_1_5_0/gelu_phi_1_5_0.onnx | 3 + .../gelu_phi_1_5_0/test_data_set_0/input_0.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../gelu_phi_1_5_1/gelu_phi_1_5_1.onnx | 3 + .../gelu_phi_1_5_1/test_data_set_0/input_0.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../gelu_phi_1_5_2/gelu_phi_1_5_2.onnx | 3 + .../gelu_phi_1_5_2/test_data_set_0/input_0.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../gelu_phi_1_5_3/gelu_phi_1_5_3.onnx | 3 + .../gelu_phi_1_5_3/test_data_set_0/input_0.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../ln_llama2_0/ln_llama2_0.onnx | 3 + .../ln_llama2_0/test_data_set_0/input_0.pb | 3 + .../ln_llama2_0/test_data_set_0/input_1.pb | 3 + .../ln_llama2_0/test_data_set_0/output_0.pb | 3 + .../ln_llama2_1/ln_llama2_1.onnx | 3 + .../ln_llama2_1/test_data_set_0/input_0.pb | 3 + .../ln_llama2_1/test_data_set_0/input_1.pb | 3 + .../ln_llama2_1/test_data_set_0/output_0.pb | 3 + .../ln_llama2_2/ln_llama2_2.onnx | 3 + .../ln_llama2_2/test_data_set_0/input_0.pb | 3 + .../ln_llama2_2/test_data_set_0/input_1.pb | 3 + .../ln_llama2_2/test_data_set_0/output_0.pb | 3 + .../ln_llama2_3/ln_llama2_3.onnx | 3 + .../ln_llama2_3/test_data_set_0/input_0.pb | 3 + .../ln_llama2_3/test_data_set_0/input_1.pb | 3 + .../ln_llama2_3/test_data_set_0/output_0.pb | 3 + .../sdpa_llama2_0/sdpa_llama2_0.onnx | 3 + .../sdpa_llama2_0/test_data_set_0/input_0.pb | 3 + .../sdpa_llama2_0/test_data_set_0/input_1.pb | 3 + .../sdpa_llama2_0/test_data_set_0/input_2.pb | 3 + .../sdpa_llama2_0/test_data_set_0/input_3.pb | 3 + .../sdpa_llama2_0/test_data_set_0/input_4.pb | 3 + .../sdpa_llama2_0/test_data_set_0/input_5.pb | 3 + .../sdpa_llama2_0/test_data_set_0/input_6.pb | 3 + .../sdpa_llama2_0/test_data_set_0/input_7.pb | 3 + .../sdpa_llama2_0/test_data_set_0/output_0.pb | 3 + .../sdpa_llama2_0/test_data_set_0/output_1.pb | 3 + .../sdpa_llama2_0/test_data_set_0/output_2.pb | 3 + .../sdpa_llama2_1/sdpa_llama2_1.onnx | 3 + .../sdpa_llama2_1/test_data_set_0/input_0.pb | 3 + .../sdpa_llama2_1/test_data_set_0/input_1.pb | 3 + .../sdpa_llama2_1/test_data_set_0/input_2.pb | 3 + .../sdpa_llama2_1/test_data_set_0/input_3.pb | 3 + .../sdpa_llama2_1/test_data_set_0/input_4.pb | 3 + .../sdpa_llama2_1/test_data_set_0/input_5.pb | 3 + .../sdpa_llama2_1/test_data_set_0/input_6.pb | 3 + .../sdpa_llama2_1/test_data_set_0/input_7.pb | 3 + .../sdpa_llama2_1/test_data_set_0/output_0.pb | 3 + .../sdpa_llama2_1/test_data_set_0/output_1.pb | 3 + .../sdpa_llama2_1/test_data_set_0/output_2.pb | 3 + .../sdpa_llama2_2/sdpa_llama2_2.onnx | 3 + .../sdpa_llama2_2/test_data_set_0/input_0.pb | 3 + .../sdpa_llama2_2/test_data_set_0/input_1.pb | 3 + .../sdpa_llama2_2/test_data_set_0/input_2.pb | 3 + .../sdpa_llama2_2/test_data_set_0/input_3.pb | 3 + .../sdpa_llama2_2/test_data_set_0/input_4.pb | 3 + .../sdpa_llama2_2/test_data_set_0/input_5.pb | 3 + .../sdpa_llama2_2/test_data_set_0/input_6.pb | 3 + .../sdpa_llama2_2/test_data_set_0/input_7.pb | 3 + .../sdpa_llama2_2/test_data_set_0/output_0.pb | 3 + .../sdpa_llama2_2/test_data_set_0/output_1.pb | 3 + .../sdpa_llama2_2/test_data_set_0/output_2.pb | 3 + .../sdpa_llama2_3/sdpa_llama2_3.onnx | 3 + .../sdpa_llama2_3/test_data_set_0/input_0.pb | 3 + .../sdpa_llama2_3/test_data_set_0/input_1.pb | 3 + .../sdpa_llama2_3/test_data_set_0/input_2.pb | 3 + .../sdpa_llama2_3/test_data_set_0/input_3.pb | 3 + .../sdpa_llama2_3/test_data_set_0/input_4.pb | 3 + .../sdpa_llama2_3/test_data_set_0/input_5.pb | 3 + .../sdpa_llama2_3/test_data_set_0/input_6.pb | 3 + .../sdpa_llama2_3/test_data_set_0/input_7.pb | 3 + .../sdpa_llama2_3/test_data_set_0/output_0.pb | 3 + .../sdpa_llama2_3/test_data_set_0/output_1.pb | 3 + .../sdpa_llama2_3/test_data_set_0/output_2.pb | 3 + .../sdpa_llama2_4_38_0.onnx | 3 + .../test_data_set_0/input_0.pb | 3 + .../test_data_set_0/input_1.pb | 3 + .../test_data_set_0/input_2.pb | 3 + .../test_data_set_0/input_3.pb | 3 + .../test_data_set_0/input_4.pb | 3 + .../test_data_set_0/input_5.pb | 3 + .../test_data_set_0/input_6.pb | 3 + .../test_data_set_0/input_7.pb | 3 + .../test_data_set_0/input_8.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../test_data_set_0/output_1.pb | 3 + .../test_data_set_0/output_2.pb | 3 + .../unittest_models/sdpa_yi_0/sdpa_yi_0.onnx | 3 + .../sdpa_yi_0/test_data_set_0/input_0.pb | 3 + .../sdpa_yi_0/test_data_set_0/input_1.pb | 3 + .../sdpa_yi_0/test_data_set_0/input_2.pb | 3 + .../sdpa_yi_0/test_data_set_0/input_3.pb | 3 + .../sdpa_yi_0/test_data_set_0/input_4.pb | 3 + .../sdpa_yi_0/test_data_set_0/input_5.pb | 3 + .../sdpa_yi_0/test_data_set_0/input_6.pb | 3 + .../sdpa_yi_0/test_data_set_0/input_7.pb | 3 + .../sdpa_yi_0/test_data_set_0/output_0.pb | 3 + .../sdpa_yi_0/test_data_set_0/output_1.pb | 3 + .../sdpa_yi_0/test_data_set_0/output_2.pb | 3 + .../unittest_models/sdpa_yi_1/sdpa_yi_1.onnx | 3 + .../sdpa_yi_1/test_data_set_0/input_0.pb | 3 + .../sdpa_yi_1/test_data_set_0/input_1.pb | 3 + .../sdpa_yi_1/test_data_set_0/input_2.pb | 3 + .../sdpa_yi_1/test_data_set_0/input_3.pb | 3 + .../sdpa_yi_1/test_data_set_0/input_4.pb | 3 + .../sdpa_yi_1/test_data_set_0/input_5.pb | 3 + .../sdpa_yi_1/test_data_set_0/input_6.pb | 3 + .../sdpa_yi_1/test_data_set_0/input_7.pb | 3 + .../sdpa_yi_1/test_data_set_0/output_0.pb | 3 + .../sdpa_yi_1/test_data_set_0/output_1.pb | 3 + .../sdpa_yi_1/test_data_set_0/output_2.pb | 3 + .../sdpa_yi_4_38_0/sdpa_yi_4_38_0.onnx | 3 + .../sdpa_yi_4_38_0/test_data_set_0/input_0.pb | 3 + .../sdpa_yi_4_38_0/test_data_set_0/input_1.pb | 3 + .../sdpa_yi_4_38_0/test_data_set_0/input_2.pb | 3 + .../sdpa_yi_4_38_0/test_data_set_0/input_3.pb | 3 + .../sdpa_yi_4_38_0/test_data_set_0/input_4.pb | 3 + .../sdpa_yi_4_38_0/test_data_set_0/input_5.pb | 3 + .../sdpa_yi_4_38_0/test_data_set_0/input_6.pb | 3 + .../sdpa_yi_4_38_0/test_data_set_0/input_7.pb | 3 + .../sdpa_yi_4_38_0/test_data_set_0/input_8.pb | 3 + .../test_data_set_0/output_0.pb | 3 + .../test_data_set_0/output_1.pb | 3 + .../test_data_set_0/output_2.pb | 3 + {onnxscript/tests => tests}/README.md | 0 {onnxscript/tests => tests}/__init__.py | 0 .../tests => tests}/common/__init__.py | 0 .../common/onnx_script_test_case.py | 0 tests/common/testutils.py | 117 ++ .../tests => tests}/eager_mode_test.py | 0 {onnxscript/tests => tests}/eager_test.py | 4 +- .../tests => tests}/external_tensor_test.py | 0 .../function_libs/torch_lib/README.md | 0 .../torch_lib/error_reproduction.py | 0 .../function_libs/torch_lib/extra_opinfo.py | 0 .../function_libs/torch_lib/ops_test.py | 2 +- .../torch_lib/ops_test_common.py | 2 +- .../function_libs/torch_lib/ops_test_data.py | 2 +- .../tests => tests}/functions/attr_test.py | 2 +- .../tests => tests}/functions/gemmgelu.py | 0 .../functions/gemmgelu_test.py | 4 +- .../tests => tests}/functions/if_test.py | 4 +- .../functions/onnxfns1A_test.py | 4 +- .../functions/onnxfns2_test.py | 4 +- .../tests => tests}/functions/onnxfns_test.py | 4 +- .../functions/ort_custom_ops.py | 0 {onnxscript/tests => tests}/if_test.py | 2 +- tests/ir/serde_test.py | 30 + {onnxscript/tests => tests}/loop_test.py | 2 +- .../tests => tests}/models/__init__.py | 0 {onnxscript/tests => tests}/models/attrref.py | 0 .../tests => tests}/models/cast_like.py | 0 .../tests => tests}/models/different_opset.py | 0 {onnxscript/tests => tests}/models/dropout.py | 0 .../tests => tests}/models/eager_op.py | 0 {onnxscript/tests => tests}/models/eg1.py | 0 {onnxscript/tests => tests}/models/getitem.py | 2 +- .../tests => tests}/models/graph_attr.py | 0 .../tests => tests}/models/identity.py | 0 .../tests => tests}/models/if_statement.py | 0 .../tests => tests}/models/loops_break.py | 0 .../tests => tests}/models/loops_while.py | 0 {onnxscript/tests => tests}/models/m1.py | 0 {onnxscript/tests => tests}/models/multi.py | 0 .../tests => tests}/models/onnxfns1.py | 0 .../tests => tests}/models/onnxfns1A.py | 0 .../tests => tests}/models/onnxfns2.py | 0 .../tests => tests}/models/opt_input.py | 0 .../tests => tests}/models/opt_output.py | 0 .../tests => tests}/models/renaming.py | 0 .../tests => tests}/models/sequences.py | 0 .../tests => tests}/models/signal_dft.py | 0 .../tests => tests}/models/subfunction.py | 0 .../tests => tests}/models/type_double.py | 0 .../tests => tests}/onnx_types_test.py | 0 {onnxscript/tests => tests}/operator_test.py | 0 tests/optimizer/test_models.py | 67 + 388 files changed, 12539 insertions(+), 94 deletions(-) create mode 100644 .gitattributes create mode 100644 examples/pattern_rewriting.py create mode 100644 onnxscript/_legacy_ir/__init__.py create mode 100644 onnxscript/_legacy_ir/irbuilder.py create mode 100644 onnxscript/_legacy_ir/irbuilder_test.py create mode 100644 onnxscript/_legacy_ir/protobuilder.py create mode 100644 onnxscript/_legacy_ir/protobuilder_test.py create mode 100644 onnxscript/_legacy_ir/visitor.py create mode 100644 onnxscript/_legacy_ir/visitor_test.py create mode 100644 onnxscript/optimizer/__init__.py create mode 100644 onnxscript/optimizer/constant_folding.py create mode 100644 onnxscript/optimizer/constant_folding_test.py create mode 100644 onnxscript/optimizer/copy_propagation.py create mode 100644 onnxscript/optimizer/copy_propagation_test.py create mode 100644 onnxscript/optimizer/evaluator.py create mode 100644 onnxscript/optimizer/fold_constants_v0.py create mode 100644 onnxscript/optimizer/function_folding_test.py create mode 100644 onnxscript/optimizer/remove_unused.py create mode 100644 onnxscript/optimizer/remove_unused_function.py create mode 100644 onnxscript/optimizer/remove_unused_test.py create mode 100644 onnxscript/optimizer/simple_function_folding.py create mode 100644 onnxscript/optimizer/simple_function_folding_test.py create mode 100644 onnxscript/rewriter/__init__.py create mode 100644 onnxscript/rewriter/broadcast_to_matmul.py create mode 100644 onnxscript/rewriter/broadcast_to_matmul_test.py create mode 100644 onnxscript/rewriter/cast_constant_of_shape.py create mode 100644 onnxscript/rewriter/cast_constant_of_shape_test.py create mode 100644 onnxscript/rewriter/erfgelu.py create mode 100644 onnxscript/rewriter/function_rule.py create mode 100644 onnxscript/rewriter/gemm_to_matmul_add.py create mode 100644 onnxscript/rewriter/gemm_to_matmul_add_test.py create mode 100644 onnxscript/rewriter/generic_pattern.py create mode 100644 onnxscript/rewriter/generic_pattern_test.py create mode 100644 onnxscript/rewriter/no_op.py create mode 100644 onnxscript/rewriter/no_op_test.py create mode 100644 onnxscript/rewriter/onnxruntime/__init__.py create mode 100644 onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py create mode 100644 onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py create mode 100644 onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py create mode 100644 onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py create mode 100644 onnxscript/rewriter/onnxruntime/softmax.py create mode 100644 onnxscript/rewriter/onnxruntime/softmax_test.py create mode 100644 onnxscript/rewriter/onnxruntime/transformers/__init__.py create mode 100644 onnxscript/rewriter/onnxruntime/transformers/fastgelu.py create mode 100644 onnxscript/rewriter/onnxruntime/transformers/fastgelu_test.py create mode 100644 onnxscript/rewriter/onnxruntime/transformers/layernorm.py create mode 100644 onnxscript/rewriter/onnxruntime/transformers/layernorm_test.py create mode 100644 onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py create mode 100644 onnxscript/rewriter/onnxruntime/transformers/multihead_attention_test.py create mode 100644 onnxscript/rewriter/pattern.py create mode 100644 onnxscript/rewriter/pattern_test.py rename onnxscript/{testing.py => testing/__init__.py} (66%) delete mode 100644 onnxscript/tests/common/testutils.py create mode 100644 onnxscript/utils/__init__.py create mode 100644 onnxscript/utils/evaluation_utils.py create mode 100644 onnxscript/utils/timing_utils.py create mode 100644 onnxscript/utils/utils.py create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/Speech2Text2ForCausalLM_dynamo.onnx create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/input_0.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/input_1.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_0.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_1.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_10.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_11.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_12.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_13.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_2.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_3.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_4.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_5.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_6.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_7.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_8.pb create mode 100644 testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_9.pb create mode 100644 testdata/e2e_models/mobilenetv2_100/dynamo/mobilenetv2_100_dynamo.onnx create mode 100644 testdata/e2e_models/mobilenetv2_100/dynamo/test_data_set_0/input_0.pb create mode 100644 testdata/e2e_models/mobilenetv2_100/dynamo/test_data_set_0/output_0.pb create mode 100644 testdata/e2e_models/resnet18/dynamo/resnet18_dynamo.onnx create mode 100644 testdata/e2e_models/resnet18/dynamo/test_data_set_0/input_0.pb create mode 100644 testdata/e2e_models/resnet18/dynamo/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_0/attn_llama2_4_34_0.onnx create mode 100644 testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_8.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_1/attn_llama2_4_34_1.onnx create mode 100644 testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_8.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/attn_llama2_4_36_0/attn_llama2_4_36_0.onnx create mode 100644 testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_8.pb create mode 100644 testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/attn_phi_1_5_0.onnx create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_10.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_11.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_12.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_8.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_9.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/attn_phi_1_5_1.onnx create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_10.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_11.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_12.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_8.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_9.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/attn_phi_1_5_2.onnx create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_10.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_11.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_12.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_8.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_9.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/attn_phi_1_5_3.onnx create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_10.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_11.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_12.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_8.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_9.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/attn_yi_4_37_0/attn_yi_4_37_0.onnx create mode 100644 testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_8.pb create mode 100644 testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/gelu_phi_1_5_0/gelu_phi_1_5_0.onnx create mode 100644 testdata/unittest_models/gelu_phi_1_5_0/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/gelu_phi_1_5_0/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/gelu_phi_1_5_1/gelu_phi_1_5_1.onnx create mode 100644 testdata/unittest_models/gelu_phi_1_5_1/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/gelu_phi_1_5_1/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/gelu_phi_1_5_2/gelu_phi_1_5_2.onnx create mode 100644 testdata/unittest_models/gelu_phi_1_5_2/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/gelu_phi_1_5_2/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/gelu_phi_1_5_3/gelu_phi_1_5_3.onnx create mode 100644 testdata/unittest_models/gelu_phi_1_5_3/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/gelu_phi_1_5_3/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/ln_llama2_0/ln_llama2_0.onnx create mode 100644 testdata/unittest_models/ln_llama2_0/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/ln_llama2_0/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/ln_llama2_0/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/ln_llama2_1/ln_llama2_1.onnx create mode 100644 testdata/unittest_models/ln_llama2_1/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/ln_llama2_1/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/ln_llama2_1/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/ln_llama2_2/ln_llama2_2.onnx create mode 100644 testdata/unittest_models/ln_llama2_2/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/ln_llama2_2/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/ln_llama2_2/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/ln_llama2_3/ln_llama2_3.onnx create mode 100644 testdata/unittest_models/ln_llama2_3/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/ln_llama2_3/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/ln_llama2_3/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/sdpa_llama2_0/sdpa_llama2_0.onnx create mode 100644 testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/sdpa_llama2_0/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/sdpa_llama2_0/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/sdpa_llama2_0/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/sdpa_llama2_1/sdpa_llama2_1.onnx create mode 100644 testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/sdpa_llama2_1/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/sdpa_llama2_1/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/sdpa_llama2_1/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/sdpa_llama2_2/sdpa_llama2_2.onnx create mode 100644 testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/sdpa_llama2_2/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/sdpa_llama2_2/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/sdpa_llama2_2/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/sdpa_llama2_3/sdpa_llama2_3.onnx create mode 100644 testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/sdpa_llama2_3/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/sdpa_llama2_3/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/sdpa_llama2_3/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/sdpa_llama2_4_38_0/sdpa_llama2_4_38_0.onnx create mode 100644 testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_8.pb create mode 100644 testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/sdpa_yi_0/sdpa_yi_0.onnx create mode 100644 testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/sdpa_yi_0/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/sdpa_yi_0/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/sdpa_yi_0/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/sdpa_yi_1/sdpa_yi_1.onnx create mode 100644 testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/sdpa_yi_1/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/sdpa_yi_1/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/sdpa_yi_1/test_data_set_0/output_2.pb create mode 100644 testdata/unittest_models/sdpa_yi_4_38_0/sdpa_yi_4_38_0.onnx create mode 100644 testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_0.pb create mode 100644 testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_1.pb create mode 100644 testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_2.pb create mode 100644 testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_3.pb create mode 100644 testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_4.pb create mode 100644 testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_5.pb create mode 100644 testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_6.pb create mode 100644 testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_7.pb create mode 100644 testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_8.pb create mode 100644 testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/output_0.pb create mode 100644 testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/output_1.pb create mode 100644 testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/output_2.pb rename {onnxscript/tests => tests}/README.md (100%) rename {onnxscript/tests => tests}/__init__.py (100%) rename {onnxscript/tests => tests}/common/__init__.py (100%) rename {onnxscript/tests => tests}/common/onnx_script_test_case.py (100%) create mode 100644 tests/common/testutils.py rename {onnxscript/tests => tests}/eager_mode_test.py (100%) rename {onnxscript/tests => tests}/eager_test.py (99%) rename {onnxscript/tests => tests}/external_tensor_test.py (100%) rename {onnxscript/tests => tests}/function_libs/torch_lib/README.md (100%) rename {onnxscript/tests => tests}/function_libs/torch_lib/error_reproduction.py (100%) rename {onnxscript/tests => tests}/function_libs/torch_lib/extra_opinfo.py (100%) rename {onnxscript/tests => tests}/function_libs/torch_lib/ops_test.py (99%) rename {onnxscript/tests => tests}/function_libs/torch_lib/ops_test_common.py (99%) rename {onnxscript/tests => tests}/function_libs/torch_lib/ops_test_data.py (99%) rename {onnxscript/tests => tests}/functions/attr_test.py (95%) rename {onnxscript/tests => tests}/functions/gemmgelu.py (100%) rename {onnxscript/tests => tests}/functions/gemmgelu_test.py (94%) rename {onnxscript/tests => tests}/functions/if_test.py (92%) rename {onnxscript/tests => tests}/functions/onnxfns1A_test.py (93%) rename {onnxscript/tests => tests}/functions/onnxfns2_test.py (95%) rename {onnxscript/tests => tests}/functions/onnxfns_test.py (95%) rename {onnxscript/tests => tests}/functions/ort_custom_ops.py (100%) rename {onnxscript/tests => tests}/if_test.py (97%) create mode 100644 tests/ir/serde_test.py rename {onnxscript/tests => tests}/loop_test.py (96%) rename {onnxscript/tests => tests}/models/__init__.py (100%) rename {onnxscript/tests => tests}/models/attrref.py (100%) rename {onnxscript/tests => tests}/models/cast_like.py (100%) rename {onnxscript/tests => tests}/models/different_opset.py (100%) rename {onnxscript/tests => tests}/models/dropout.py (100%) rename {onnxscript/tests => tests}/models/eager_op.py (100%) rename {onnxscript/tests => tests}/models/eg1.py (100%) rename {onnxscript/tests => tests}/models/getitem.py (98%) rename {onnxscript/tests => tests}/models/graph_attr.py (100%) rename {onnxscript/tests => tests}/models/identity.py (100%) rename {onnxscript/tests => tests}/models/if_statement.py (100%) rename {onnxscript/tests => tests}/models/loops_break.py (100%) rename {onnxscript/tests => tests}/models/loops_while.py (100%) rename {onnxscript/tests => tests}/models/m1.py (100%) rename {onnxscript/tests => tests}/models/multi.py (100%) rename {onnxscript/tests => tests}/models/onnxfns1.py (100%) rename {onnxscript/tests => tests}/models/onnxfns1A.py (100%) rename {onnxscript/tests => tests}/models/onnxfns2.py (100%) rename {onnxscript/tests => tests}/models/opt_input.py (100%) rename {onnxscript/tests => tests}/models/opt_output.py (100%) rename {onnxscript/tests => tests}/models/renaming.py (100%) rename {onnxscript/tests => tests}/models/sequences.py (100%) rename {onnxscript/tests => tests}/models/signal_dft.py (100%) rename {onnxscript/tests => tests}/models/subfunction.py (100%) rename {onnxscript/tests => tests}/models/type_double.py (100%) rename {onnxscript/tests => tests}/onnx_types_test.py (100%) rename {onnxscript/tests => tests}/operator_test.py (100%) create mode 100644 tests/optimizer/test_models.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..fc077c629 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +**/*.pb filter=lfs diff=lfs merge=lfs -text +**/*.onnx filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 558312015..042a1b262 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -67,6 +67,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install nox run: python -m pip install nox + - name: Pull Test Data + run: git lfs pull - name: Run tests run: nox -t ${{ matrix.nox-tag }} --forcecolor -- -v --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junit-xml pytest.xml env: diff --git a/.lintrunner.toml b/.lintrunner.toml index 4fd361f0d..e86109025 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -8,7 +8,7 @@ include_patterns = [ '**/*.pyi', ] exclude_patterns = [ - 'onnxscript/tests/models/**', + 'tests/models/**', ] command = [ 'python', @@ -43,9 +43,25 @@ exclude_patterns = [ 'onnxscript/evaluator_test.py', 'onnxscript/evaluator.py', 'onnxscript/onnx_types.py', - 'onnxscript/tests/**', # Skip linting test files for speed + 'tests/**', # Skip linting test files for speed 'onnxscript/**/*_test.py', # Skip linting test files for speed 'onnxscript/function_libs/torch_lib/ops/**', # Operators typing do not play well with mypy + 'onnxscript/optimizer/evaluator.py', # FIXME + 'onnxscript/optimizer/constant_folding.py', # FIXME + 'onnxscript/_legacy_ir/__init__.py', # FIXME + 'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME + 'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME + 'onnxscript/rewriter/function_rule.py', # FIXME + 'onnxscript/_legacy_ir/irbuilder.py', # FIXME + 'onnxscript/optimizer/fold_constants_v0.py', # FIXME + 'onnxscript/rewriter/pattern.py', # FIXME + 'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME + 'onnxscript/tools/function_unittest_producer.py', # FIXME + 'onnxscript/_legacy_ir/visitor.py', # FIXME + 'onnxscript/_legacy_ir/protobuilder.py', # FIXME + 'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME + 'onnxscript/ir/serde.py', # FIXME + 'onnxrewriter/rewriter/pattern/generic_pattern_test.py', # FIXME ] command = [ 'python', @@ -74,7 +90,7 @@ include_patterns = [ '**/*.py', ] exclude_patterns = [ - 'onnxscript/tests/onnx_backend_test_code/**', + 'tests/onnx_backend_test_code/**', ] command = [ 'python', @@ -105,9 +121,12 @@ exclude_patterns = [ 'docs/examples/**', 'docs/tutorial/examples/**', 'onnxscript/converter_test.py', - 'onnxscript/tests/functions/**', - 'onnxscript/tests/models/**', - 'onnxscript/tests/onnx_backend_test_code/**', + 'tests/functions/**', + 'tests/models/**', + 'tests/onnx_backend_test_code/**', + 'onnxscript/optimizer/**', # FIXME + 'onnxscript/rewriter/**', # FIXME + 'onnxscript/_legacy_ir/**', # FIXME ] command = [ 'python', diff --git a/examples/pattern_rewriting.py b/examples/pattern_rewriting.py new file mode 100644 index 000000000..5c5dd549b --- /dev/null +++ b/examples/pattern_rewriting.py @@ -0,0 +1,191 @@ +"""Onnx Pattern Rewriting. + +This script shows how to define a rewriting rule based on patterns. +The objective is to replace some nodes in an onnx model into another +sequence of nodes but more efficient. + +First a dummy model +=================== +""" + +import numpy as np +import onnx +import onnx.helper as oh +import onnx.numpy_helper as onh + +import onnxscript +import onnxscript._legacy_ir as oir +import onnxscript.rewriter.generic_pattern as org + + +def get_rotary_model(bad_model=False): + inputs = [ + oh.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]), + oh.make_tensor_value_info("pos_ids", onnx.TensorProto.FLOAT, shape=[]), + oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]), + ] + nodes = [ + oh.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]), + oh.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1), + oh.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]), + oh.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]), + oh.make_node( + "ConcatTrainingBad" if bad_model else "ConcatTraining", + ["_onx_transpose0", "_onx_transpose0"], + ["_onx_concattraining0", "_onx_concattraining1"], + domain="com.microsoft", + ), + oh.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]), + oh.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1), + oh.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]), + oh.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1), + ] + outputs = [ + oh.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []), + oh.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []), + ] + model = oh.make_model( + oh.make_graph( + nodes, + "experiment", + inputs, + outputs, + ), + opset_imports=[ + oh.make_opsetid("", 18), + oh.make_opsetid("com.microsoft", 18), + ], + ) + return model + + +model = get_rotary_model() +ir_model = oir.irbuilder.build_ir(model) + + +#################################### +# The rewriting pattern +# ===================== + +op = onnxscript.opset18 +msft_op = onnxscript.values.Opset("com.microsoft", 1) + + +def rotary_match_pattern(x, pos_ids, axis): + """The pattern to match.""" + unsqueeze = op.Unsqueeze(x, axis) + cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT) + + matmul = op.MatMul(pos_ids, cast) + transpose = op.Transpose(matmul) + output, length = msft_op.ConcatTraining(transpose, transpose) + + sin = op.Sin(output) + cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT) + cos = op.Cos(output) + cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT) + return cast1, cast2 + + +def validate_rotary_mapping(g, matched_nodes, added_nodes) -> bool: + """The validation post matching. + + Returns True to validate the replacement, + False not to apply it. + + :param g: model + :param matched_nodes: matched nodes + :param added_nodes: nodes replacing the matched nodes + """ + del g + del matched_nodes + del added_nodes + return True + + +def rotary_apply_pattern(x, pos_ids, axis): + """The replacement pattern.""" + cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16))) + sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16))) + part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache) + return part1, part2 + + +########################### +# The rule +# ======== +# +# The rule is easy to create. + + +rule = org.make_pattern_rule( + rotary_match_pattern, + rotary_apply_pattern, + validate_rotary_mapping, +) + +################################ +# ``validate_rotary_mapping`` always return True. +# This argument can be ignored in that case. + +rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern) + +########################## +# Let's apply it. +rule.apply_to_model(ir_model) + + +######################## +# And finally, we can generate the model. + +opt_onx = oir.protobuilder.build_model_proto(ir_model) + +######################## +# Let's see what it looks like. + +for node in opt_onx.graph.node: + print(f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}") + +############################# +# What if it fails? +# ================= + + +model = get_rotary_model(True) +ir_model = oir.irbuilder.build_ir(model) + +rule.apply_to_model(ir_model) +opt_onx = oir.protobuilder.build_model_proto(ir_model) + +print([n.op_type for n in opt_onx.graph.node]) + +################################ +# The match did not happen. +# Let's increase the verbosity. + +rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern, verbose=10) + +rule.apply_to_model(ir_model) + +###################################### +# The logs shows every time the algorithm rejected a pattern. +# We can see the following: +# +# :: +# +# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast +# --hint--: BACKWARD: different node types +# --pattern +# ConcatTraining(transpose, transpose) -> (output, length) +# -- model +# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1) +# iteration=1 +# --marked-- #2 +# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320] +# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472] +# len(stacked)=0:[] +# +# Line 673 in file `generic_pattern.py`, the match was rejected. +# It says while comparing two nodes in the backward direction, +# node types do not match. +# It also says that two nodes were actually matched. diff --git a/onnxscript/_legacy_ir/__init__.py b/onnxscript/_legacy_ir/__init__.py new file mode 100644 index 000000000..cc1b6af17 --- /dev/null +++ b/onnxscript/_legacy_ir/__init__.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +import dataclasses +from collections import deque +from typing import List, Tuple, Union + +import numpy as np +import onnx + + +class Unknown: + """A special value used to indicate that a value is not a statically known constant. + + We use this instead of None because None is a valid constant value (since ONNX + supports the Optional type). + """ + + instance = None + + def __init__(self) -> None: + if Unknown.instance is not None: + raise ValueError("Unknown.instance is already set") + Unknown.instance = self + + +# Singleton instance of Unknown +unknown = Unknown() +NotConstant = unknown + +# ConcreteValue: This type represents constant values that an ONNX variable can take. +# TODO: Extend this to a recursive type to handle lists of tensors, etc., support optionals, +# maps, etc. +# TODO (rama): The value is sometimes stored as a numpy array, and sometimes as an ONNX TensorProto. +# A uniform representation would be helpful, but we should avoid unnecessary conversions for +# large tensors. Should be cleaned up in the new IR. +ConcreteValue = Union[onnx.TensorProto, np.ndarray, Unknown, None] + +# SymbolicValue: This information is used to enable partial-evaluation and specialization +# of sequence operations, as well as elimination of redundant Identity ops. +# The symbolic value of a variable X can be: +# - a string with the value "Y", indicating that "X" is a copy of "Y" +# - a list of strings, indicating that "X" is a list of tensors, with their symbolic values +# Eg., the symbolic value ["A", "B", "C"] indicates that the value of X is equal to +# "SequenceConstruct(A, B, C)". +# TODO: Technically, SymbolicValue should be a recursive type to handle lists of lists of +# tensors, etc. However, we currently only handle lists of tensors. + +SymbolicValue = Union[str, List[str]] + +FunctionId = Tuple[str, str, str] + + +def get_function_id(function: onnx.FunctionProto) -> FunctionId: + return (function.domain, function.name, getattr(function, "overload", "")) + + +def get_function_id_from_node(node: onnx.NodeProto) -> FunctionId: + return (node.domain, node.op_type, getattr(node, "overload", "")) + + +@dataclasses.dataclass +class StaticValueInfo: + name: str + value: ConcreteValue = NotConstant + type: onnx.TypeProto | None = None + symbolic_value: SymbolicValue | None = None + + def is_copy(self) -> bool: + return isinstance(self.symbolic_value, str) + + def tensor_shape_proto(self) -> onnx.TensorShapeProto | None: + """Returns the shape of a tensor or None. + + A return value of None could mean that the type is unknown or that the type is not a tensor + or that the tensor shape (that is, even the rank) is unknown. + """ + type = self.type + if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + return type.tensor_type.shape + return None + + @property + def shape(self) -> list[str | int | None] | None: + """Returns the shape in a list. + + Str means that the shape is dynamic. + """ + type = self.type + if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + dims = [] + for dim in type.tensor_type.shape.dim: + if dim.HasField("dim_param"): + dims.append(dim.dim_param) + elif dim.HasField("dim_value"): + dims.append(dim.dim_value) + else: + dims.append(None) + return dims + if self.value_as_np_array is not None: + return list(self.value_as_np_array.shape) + return None + + @property + def element_type(self) -> int | None: + """Returns the element type of a tensor, or None if type is not known or is not a tensor.""" + type = self.type + if type and type.HasField("tensor_type"): + return type.tensor_type.elem_type + return None + + def identity_merge_from(self, other: StaticValueInfo) -> None: + """Merge the value of other into self. + + This models the effect of an identity (copy) operation. + This will update static-analysis information based on incoming value. + """ + if not isinstance(other, StaticValueInfo): + raise TypeError(f"Cannot merge {other} into {self}.") + if other.value is not NotConstant: + self.value = other.value + # TODO: merge and combine best shape information from both types. + if other.tensor_shape_proto() is not None and other.element_type is not None: + self.type = other.type + # We cannot copy symbolic value across different scopes. + + # WIP: Extensions towards new IR: Note that the default construction of StaticValueInfo + # does not fill in the following fields. These fields are filled in by the IRBuilder + # which constructs the IR from the ONNX model. + node: Node | None = None + uses: list[Node] = dataclasses.field(default_factory=list) + output_index: int | None = None + is_output: bool = False + + @property + def const_value(self) -> ConcreteValue: + return self.value + + @property + def value_as_np_array(self) -> np.ndarray | None: + if isinstance(self.value, np.ndarray): + return self.value + if isinstance(self.value, onnx.TensorProto): + return onnx.numpy_helper.to_array(self.value) + return None + + def def_node(self) -> Node | None: + return self.node + + def def_index(self) -> int: + return self.output_index + + def is_same_as(self, other: StaticValueInfo) -> bool: + """Returns true if this value represents the same IR object as the other value. + + This is *not* value-equality, but rather object-equality. + """ + return self is other + + def __str__(self) -> str: + shape = self.shape + if shape is not None: + shape = [str(dim) for dim in shape] + shape_str = f"[{', '.join(shape)}]" + else: + shape_str = "None" + return ( + f"StaticValueInfo({self.name}, shape:{shape_str}, dtype:{self.element_type}, " + f"{'has const value' if self.value is not unknown else 'no const value'}.)" + ) + + +Value = StaticValueInfo + + +class Model: + def __init__(self) -> None: + self.gen_var_counter: int = 0 + + def set( + self, + model_proto: onnx.ModelProto, + graph: Graph, + functions: list[Function], + version_map: dict[str, int], + ) -> None: + """TODO. This is a temporary patch.""" + self.original_model_proto = model_proto + self.graph = graph + self.functions = functions + self.version_map = version_map + + def make_new_name(self): + # Temporary hack. + self.gen_var_counter += 1 + return f"_gen_{self.gen_var_counter}" + + def __str__(self) -> str: + # TODO: Naive string representation for debugging. Need to improve this. + return "\n".join( + [ + f"ModelGraph: {self.graph}", + f"Functions: {self.functions}", + f"VersionMap: {self.version_map}", + ] + ) + + +class Graph: + def __init__(self, graph_proto: onnx.GraphProto): + self.original_graph_proto = graph_proto + self.nodes: deque[Node] = deque() + self.values: dict[str, Value] = {} + + @property + def name(self) -> str: + return self.original_graph_proto.name + + def __str__(self) -> str: + return "\n".join( + [ + "Graph", + f"Nodes: {[str(n) for n in self.nodes]}", + f"Values: {[str(v) for v in self.values]}", + ] + ) + + @property + def input_names(self) -> list[str]: + return [_.name for _ in self.original_graph_proto.input] + + @property + def output_names(self) -> list[str]: + return [_.name for _ in self.original_graph_proto.output] + + +class Function: + def __init__(self, function_proto: onnx.FunctionProto): + self.original_function_proto = function_proto + self.nodes = deque() + self.values = {} + + @property + def id(self) -> FunctionId: + return (self.domain, self.name, self.overload) + + @property + def domain(self) -> str: + return self.original_function_proto.domain + + @property + def name(self) -> str: + return self.original_function_proto.name + + @property + def overload(self) -> str: + return getattr(self.original_function_proto, "overload", "") + + def __str__(self) -> str: + return "\n".join( + [ + "Function", + f"Nodes: {[str(n) for n in self.nodes]}", + f"Values: {[str(v) for v in self.values]}", + ] + ) + + +class RefAttr: + def __init__(self, name: str, ref_attr_name: str, type) -> None: + self.name = name + self.ref_attr_name = ref_attr_name + self.type = type + + def to_proto(self) -> onnx.AttributeProto: + attr_proto = onnx.AttributeProto() + attr_proto.name = self.name + attr_proto.ref_attr_name = self.ref_attr_name + attr_proto.type = self.type + return attr_proto + + +class Node: + def __init__( + self, + node_proto: onnx.NodeProto, + populate_io: bool = False, + ) -> None: + self.original_node_proto = node_proto + self.domain: str = node_proto.domain + self.version: int | None = None + self.op_type: str = node_proto.op_type + if populate_io: + self.inputs: list[Value | None] = [Value(i) for i in node_proto.input] + self.outputs: list[Value | None] = [Value(i) for i in node_proto.output] + else: + self.inputs: list[Value | None] = [] + self.outputs: list[Value | None] = [] + self.attributes: dict[str, int | float | RefAttr | Graph | list[Graph]] = {} + + def __repr__(self) -> str: + return ( + f"{self.op_type}({','.join(self.original_node_proto.input)})" + f"->{','.join(self.original_node_proto.output)}" + ) + + @property + def name(self) -> str: + return self.original_node_proto.name + + @property + def input_names(self): + return self.original_node_proto.input + + @property + def output_names(self): + return self.original_node_proto.output + + @property + def attribute(self): + return self.original_node_proto.attribute + + def set_version_if_custom_op(self, version_map: dict[str, int]) -> None: + if self.domain != "" and self.domain in version_map: + self.version = version_map[self.domain] + + def get_attribute(self, name: str) -> int | float | None: + return self.attributes.get(name, None) + + def __str__(self) -> str: + return "\n".join( + [ + "Node", + f"OpType: {self.op_type}", + f"Inputs: {self.inputs}", + f"Outputs: {self.outputs}", + f"Attributes: {self.attributes}", + ] + ) diff --git a/onnxscript/_legacy_ir/irbuilder.py b/onnxscript/_legacy_ir/irbuilder.py new file mode 100644 index 000000000..5bee6083b --- /dev/null +++ b/onnxscript/_legacy_ir/irbuilder.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import warnings +from typing import Any + +import onnx + +import onnxscript._legacy_ir as ir +from onnxscript._legacy_ir import visitor +from onnxscript.utils import utils + +""" NOTE: IRBuilder and function visiting + +Current IRBuilder is designed to visit function by definition, instead of function by callsite. +This has the following implications during visiting: +- Prior to IR 10 / ONNX 1.16, value_info is not defined in function. They are experimentally defined under + main graph for models produced by PyTorch 2.2+ dynamo onnx exporter. Hence a workaround is required in `process_node` + to load function value info from a pre-processed `FunctionShapeEnv` object. + Post IR 10, using `process_value_info` method is enough to retrieve and process both function and graph + value_info. +- ref_attr_name is not resolved during visiting, because it requires the function callsite information. + +""" + + +class IRBuilder: + def __init__(self): + self._current_graphs: list[ir.Graph] = [] + # See NOTE: IRBuilder and function visiting + self._current_function: ir.Function | None = None + self._function_subgraphs: list[ir.Graph] = [] + self.functions: dict[ir.FuntionId, ir.Function] = {} + + def visit_model(self, model_proto: onnx.ModelProto) -> ir.Model: + self._function_shape_env = visitor.FunctionShapeEnv() + self._function_shape_env.load_from_model_proto(model_proto) + self._ir_version = model_proto.ir_version + self.version_map = {x.domain: x.version for x in model_proto.opset_import} + functions = [self.visit_function(function) for function in model_proto.functions] + self.functions = {function.id: function for function in functions} + graph = self.visit_graph(model_proto.graph) + model = ir.Model() + model.set(model_proto, graph, functions, self.version_map) + return model + + def visit_graph(self, graph: onnx.GraphProto) -> ir.Graph: + self.enter_graph(ir.Graph(graph)) + for input in graph.input: + self.process_graph_input(input) + for init in graph.initializer: + self.process_initializer(init) + for node in graph.node: + self.process_node(node) + for output in graph.output: + self.process_graph_output(output) + for value_info in graph.value_info: + self.process_value_info(value_info) + return self.exit_graph() + + def visit_function(self, function: onnx.FunctionProto) -> ir.Function: + self._current_function = ir.Function(function) + for input in function.input: + self.process_function_input(input) + for node in function.node: + self.process_node(node) + for output in function.output: + self.process_function_output(output) + for value_info in getattr(function, "value_info", []): + self.process_value_info(value_info) + function_ir = self._current_function + self._current_function = None + return function_ir + + @property + def current_graph_or_function(self) -> ir.Graph | ir.Function: + if self._function_subgraphs: + assert self._current_function is not None + return self._function_subgraphs[-1] + if self._current_function is not None: + return self._current_function + return self._current_graphs[-1] + + def enter_graph(self, graph: ir.Graph): + if self._current_function is not None: + self._function_subgraphs.append(graph) + else: + self._current_graphs.append(graph) + + def exit_graph(self) -> ir.Graph: + if self._current_function is not None: + return self._function_subgraphs.pop() + else: + return self._current_graphs.pop() + + def _lookup_from_graphs(self, name: str, graphs: list[ir.Graph]) -> ir.Value | None: + for graph in reversed(graphs): + value = graph.values.get(name, None) + if value is not None: + return value + return None + + def lookup(self, name: str) -> ir.Value | None: + if self._current_function is not None: + value = self._lookup_from_graphs(name, self._function_subgraphs) + if value is not None: + return value + return self._current_function.values.get(name, None) + return self._lookup_from_graphs(name, self._current_graphs) + + def bind(self, name: str, value: ir.Value): + self.current_graph_or_function.values[name] = value + + def process_graph_input(self, input: onnx.ValueInfoProto): + newvalue = ir.Value(name=input.name, type=input.type) + self.bind(input.name, newvalue) + + def process_initializer(self, init: onnx.TensorProto): + # TODO(titaiwang): Take care of the case where the initializer is already defined? + if init.name not in self.current_graph_or_function.values: + newvalue = ir.Value(name=init.name, value=init) + self.bind(init.name, newvalue) + + def process_node(self, node): + node_ir = ir.Node(node) + node_ir.set_version_if_custom_op(self.version_map) + self.current_graph_or_function.nodes.append(node_ir) + for name in node.input: + value = self.lookup(name) + node_ir.inputs.append(value) + if value is not None: + value.uses.append(node_ir) + else: + # TODO(titaiwang): Do something more than warnings? + warnings.warn(f"Use of undefined variable {name!r}.", stacklevel=1) + for index, output in enumerate(node.output): + newvalue = ir.Value(name=output, node=node_ir, output_index=index) + if self._current_function is not None: + ir_value = self._function_shape_env.lookup( + self._current_function.original_function_proto, output + ) + if ir_value is not None: + newvalue.identity_merge_from(ir_value) + node_ir.outputs.append(newvalue) + self.bind(output, newvalue) + for attr in node.attribute: + attr_val = self.process_attribute(attr) + node_ir.attributes[attr.name] = attr_val + # Set constant-value for Constant node: + if node.op_type == "Constant" and node.domain in {"", "ai.onnx"}: + node_ir.outputs[0].value = utils.get_constant_node_value(node, node.output[0]) + + def process_attribute(self, attr: onnx.AttributeProto) -> ir.Graph | list[ir.Graph] | Any: + if attr.HasField("g"): + return self.visit_graph(attr.g) + elif len(attr.graphs) > 0: + return [self.visit_graph(graph) for graph in attr.graphs] + elif attr.ref_attr_name: + return ir.RefAttr(attr.name, attr.ref_attr_name, attr.type) + else: + # This returns Any based on onnx.helper.get_attribute_value's return type. + return onnx.helper.get_attribute_value(attr) + + def process_graph_output(self, output: onnx.ValueInfoProto): + value = self.lookup(output.name) + if value is None: + # TODO(titaiwang): Should we remove the non-output value from the graph.values? + warnings.warn( + f"Graph contains no definition for output '{output.name}'.", + stacklevel=1, + ) + else: + value.type = output.type + value.is_output = True + + def process_function_input(self, input: str): + ir_value = self._function_shape_env.lookup( + self._current_function.original_function_proto, input + ) + if ir_value is None: + ir_value = ir.Value(name=input) + self.bind(input, ir_value) + + def process_function_output(self, output: str): + value = self.lookup(output) + if value is None: + print(f"WARNING: Function contains no definition for output '{output.name}'.") + else: + value.is_output = True + + def process_value_info(self, value_info: onnx.ValueInfoProto): + function_id, ir_value = self._function_shape_env.process_value_info(value_info) + existing_value = self.lookup(value_info.name) + if existing_value is not None: + existing_value.identity_merge_from(ir_value) + ir_value = existing_value + + if self._ir_version >= 10: + # ONNX >= 1.16 where value_info can be defined in function + self.bind(ir_value.name, ir_value) + elif function_id is not None: + # All value_infos are defined in main graph + # This needs to be handled while visiting function, so do nothing here. + pass + else: + self.bind(ir_value.name, ir_value) + + +def build_ir(model: onnx.ModelProto): + """Builds an IR from an ONNX model proto.""" + return IRBuilder().visit_model(model) diff --git a/onnxscript/_legacy_ir/irbuilder_test.py b/onnxscript/_legacy_ir/irbuilder_test.py new file mode 100644 index 000000000..531215258 --- /dev/null +++ b/onnxscript/_legacy_ir/irbuilder_test.py @@ -0,0 +1,198 @@ +import unittest + +import onnx.parser + +from onnxscript._legacy_ir import irbuilder + + +class IRBuilderTest(unittest.TestCase): + def test_irbuilder(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + three = Constant () + x_cube = Pow(x, three) + B = Constant () + x_cube_mul_B = Mul(x_cube, B) + sum = Add(x, x_cube_mul_B) + C = Constant () + C_times_sum = Mul(C, sum) + tanh = Tanh(C_times_sum) + one = Constant () + one_plus_tanh = Add(one, tanh) + half = Constant () + half_x = Mul(half, x) + z = Mul(one_plus_tanh, half_x) + } + """ + ) + irbuilder.build_ir(model) + + def test_shape_is_accessible_for_graph_value_with_value_info(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + + { + t = Add (x, y) + z = Add (t, x) + } + """ + ) + irmodel = irbuilder.build_ir(model) + self.assertEqual( + irmodel.graph.nodes[0].outputs[0].tensor_shape_proto(), + onnx.TensorShapeProto(dim=[onnx.TensorShapeProto.Dimension(dim_param="N")]), + ) + + def test_shape_is_accessible_for_function_value_with_experimental_value_info(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + + afunction (x, y) => (z) + { + o = MatMul (x, y) + shape = Constant () + z = Reshape (o, shape) + } + """ + ) + # Hack to put value_info in since parser does not support this experimental naming format + model.graph.value_info.append( + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/o", onnx.TensorProto.FLOAT, ["N", "K"] + ) + ) + irmodel = irbuilder.build_ir(model) + self.assertEqual( + irmodel.functions[0].nodes[0].outputs[0].tensor_shape_proto(), + onnx.TensorShapeProto( + dim=[ + onnx.TensorShapeProto.Dimension(dim_param="N"), + onnx.TensorShapeProto.Dimension(dim_param="K"), + ] + ), + ) + + def test_function_input_is_correctly_linked_with_subnodes_in_function_when_shape_is_missing( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[M] y) => (float[Z] z) + { + z = afunction (x, y) + } + + afunction (x, y) => (z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + irmodel = irbuilder.build_ir(model) + self.assertIsNotNone(irmodel.functions[0].nodes[0].inputs[0]) + self.assertIsNotNone(irmodel.functions[0].nodes[0].inputs[1]) + self.assertEqual( + irmodel.functions[0].nodes[0].inputs[0], irmodel.functions[0].values["x"] + ) + self.assertEqual( + irmodel.functions[0].nodes[0].inputs[1], irmodel.functions[0].values["y"] + ) + + def test_function_input_is_correctly_linked_with_subnodes_in_function_when_shape_is_present( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[M] y) => (float[Z] z) + { + z = afunction (x, y) + } + + afunction (x, y) => (z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + # Hack to put value_info in since parser does not support this experimental naming format + model.graph.value_info.extend( + [ + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/x", onnx.TensorProto.FLOAT, ["N"] + ), + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/y", onnx.TensorProto.FLOAT, ["M"] + ), + ] + ) + irmodel = irbuilder.build_ir(model) + self.assertIsNotNone(irmodel.functions[0].nodes[0].inputs[0]) + self.assertIsNotNone(irmodel.functions[0].nodes[0].inputs[1]) + self.assertEqual( + irmodel.functions[0].nodes[0].inputs[0], irmodel.functions[0].values["x"] + ) + self.assertEqual( + irmodel.functions[0].nodes[0].inputs[1], irmodel.functions[0].values["y"] + ) + self.assertEqual( + irmodel.functions[0].nodes[0].inputs[0].tensor_shape_proto(), + onnx.TensorShapeProto( + dim=[ + onnx.TensorShapeProto.Dimension(dim_param="N"), + ] + ), + ) + self.assertEqual( + irmodel.functions[0].nodes[0].inputs[1].tensor_shape_proto(), + onnx.TensorShapeProto( + dim=[ + onnx.TensorShapeProto.Dimension(dim_param="M"), + ] + ), + ) + + def test_out_of_context_value_reference_is_correct(self): + model = onnx.parser.parse_model( + """ + + agraph (float[16, 16] x, bool cond) => (float[16, 16] z) { + two = Constant () + z = If (cond) < + then_branch = then_graph () => (then_z) { + three = Constant () + temp = Add (two, three) + then_z = Mul (temp, x) + }, + else_branch = else_graph () => (else_z) { + four = Constant () + temp = Add (two, four) + else_z = Mul (temp, x) + } + > + } + """ + ) + irmodel = irbuilder.build_ir(model) + then_graph = irmodel.graph.nodes[1].attributes["then_branch"] + self.assertIsNotNone(then_graph.nodes[2].inputs[1]) + else_graph = irmodel.graph.nodes[1].attributes["else_branch"] + self.assertIsNotNone(else_graph.nodes[2].inputs[1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/_legacy_ir/protobuilder.py b/onnxscript/_legacy_ir/protobuilder.py new file mode 100644 index 000000000..bdaad92de --- /dev/null +++ b/onnxscript/_legacy_ir/protobuilder.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import onnx +import onnx.helper +from onnx.helper import make_attribute + +import onnxscript._legacy_ir as ir + + +class ModelProtoBuilder: + def __init__(self): + self.opset_imports: dict[str, onnx.OperatorSetIdProto] = {} + + def visit_ir_model(self, ir_model: ir.Model) -> onnx.ModelProto: + model_proto = onnx.ModelProto() + model_proto.ir_version = ir_model.original_model_proto.ir_version + # TODO (sbhokare) : Find a way of copying model properties without + # each property individually + # Copy over model properties + model_proto.doc_string = ir_model.original_model_proto.doc_string + model_proto.domain = ir_model.original_model_proto.domain + model_proto.metadata_props.extend(ir_model.original_model_proto.metadata_props) + model_proto.model_version = ir_model.original_model_proto.model_version + model_proto.producer_name = ir_model.original_model_proto.producer_name + model_proto.producer_version = ir_model.original_model_proto.producer_version + model_proto.training_info.extend(ir_model.original_model_proto.training_info) + + for domain, version in ir_model.version_map.items(): + operator_setid_proto = model_proto.opset_import.add() + operator_setid_proto.domain, operator_setid_proto.version = domain, version + self.opset_imports[domain] = operator_setid_proto + for function in ir_model.functions: + function_proto = model_proto.functions.add() + self.visit_ir_function(function, function_proto) + graph_proto = model_proto.graph + self.visit_ir_graph(ir_model.graph, graph_proto) + return model_proto + + def visit_ir_graph( + self, ir_graph: ir.Graph, graph_proto: onnx.GraphProto + ) -> onnx.GraphProto: + graph_proto.name = ir_graph.name + # Copy over graph properties + graph_proto.doc_string = ir_graph.original_graph_proto.doc_string + # graph_proto.metadata_props = ir_graph.original_graph_proto.metadata_props) + graph_proto.quantization_annotation.extend( + ir_graph.original_graph_proto.quantization_annotation + ) + + for node in ir_graph.nodes: + node_proto = graph_proto.node.add() + self.process_ir_node(node, node_proto) + for i in ir_graph.original_graph_proto.input: + graph_proto.input.append(i) + for o in ir_graph.original_graph_proto.output: + graph_proto.output.append(o) + for val in ir_graph.original_graph_proto.value_info: + graph_proto.value_info.append(val) + for i in ir_graph.original_graph_proto.initializer: # type: ignore[assignment] + graph_proto.initializer.append(i) # type: ignore[arg-type] + return graph_proto + + def visit_ir_function( + self, ir_function: ir.Function, function_proto: onnx.FunctionProto + ) -> onnx.FunctionProto: + function_proto.name = ir_function.name + function_proto.domain = ir_function.domain + # Copy over function properties + function_proto.doc_string = ir_function.original_function_proto.doc_string + # function_proto.metadata_props = ir_function.original_function_proto.metadata_props) + + for node in ir_function.nodes: + # TODO: deduplicate the opset import of function? + operator_setid_proto = function_proto.opset_import.add() + if node.domain in self.opset_imports: + operator_setid_proto.domain = self.opset_imports[node.domain].domain + operator_setid_proto.version = self.opset_imports[node.domain].version + else: + raise ValueError(f"Unknown domain {node.domain}") + node_proto = function_proto.node.add() + self.process_ir_node(node, node_proto) + # TODO (shubham) : Propagate shape-type info + for i in ir_function.original_function_proto.input: + function_proto.input.append(i) + for o in ir_function.original_function_proto.output: + function_proto.output.append(o) + for attr in ir_function.original_function_proto.attribute: + function_proto.attribute.append(attr) + for attr_proto in ir_function.original_function_proto.attribute_proto: + function_proto.attribute_proto.append(attr_proto) + for val in getattr(ir_function.original_function_proto, "value_info", []): + function_proto.value_info.append(val) + return function_proto + + def process_ir_node(self, ir_node: ir.Node, node_proto: onnx.NodeProto) -> onnx.NodeProto: + node_proto.op_type = ir_node.op_type + node_proto.domain = ir_node.domain + # Copy over node properties + node_proto.name = ir_node.original_node_proto.name + node_proto.doc_string = ir_node.original_node_proto.doc_string + # node_proto.metadata_props = ir_node.original_node_proto.metadata_props) + + for i in ir_node.inputs: + node_proto.input.append(i.name if i is not None else "") + for o in ir_node.outputs: + assert o is not None + node_proto.output.append(o.name) + for attr in ir_node.attributes.items(): + attr_proto = self.process_attribute(attr) + node_proto.attribute.append(attr_proto) + return node_proto + + def process_attribute(self, attr): + attr_name, attr_val = attr + if isinstance(attr_val, ir.RefAttr): + return attr_val.to_proto() + if isinstance(attr_val, ir.Graph): + graph_proto = onnx.GraphProto() + attr_val = self.visit_ir_graph(attr_val, graph_proto) + attr_proto = make_attribute(attr_name, attr_val) + return attr_proto + + +def build_model_proto(model: ir.Model) -> onnx.ModelProto: + """Builds an ONNX model proto from an IR.""" + return ModelProtoBuilder().visit_ir_model(model) diff --git a/onnxscript/_legacy_ir/protobuilder_test.py b/onnxscript/_legacy_ir/protobuilder_test.py new file mode 100644 index 000000000..d56ebb95c --- /dev/null +++ b/onnxscript/_legacy_ir/protobuilder_test.py @@ -0,0 +1,215 @@ +import unittest + +import numpy as np +import onnx.checker +import onnx.parser + +from onnxscript._legacy_ir import irbuilder, protobuilder +from onnxscript.rewriter import pattern +from onnxscript.rewriter.onnxruntime import instance_to_group_normalization + +op = pattern.onnxop + + +class ConcatSerializeTest(unittest.TestCase): + def rule(self) -> pattern.RewriteRule: + def concat_pattern(x, y, axis): + seq = op.SequenceConstruct(x, y) + return op.ConcatFromSequence(seq, axis=axis) + + def concat(x, y, axis): + return op.Concat(x, y, axis=axis) + + return pattern.RewriteRule(concat_pattern, concat) + + def test_concat_serialize(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[M] z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + # Tests related to IR + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 1) + # Tests related to serialization to ModelProto + model_proto = protobuilder.build_model_proto(ir) + onnx.checker.check_model(model_proto) + + def test_concat_in_function_serialize(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[M] y) => (float[Z] z) + { + z = pkg.custom.afunction (x, y) + } + + afunction (x, y) => (z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + # Tests related to IR + self.assertEqual(count, 1) + self.assertEqual(len(ir.functions), 1) + self.assertEqual(len(ir.functions[0].nodes), 1) + self.assertEqual(ir.functions[0].nodes[0].op_type, "Concat") + # Tests related to serialization to ModelProto + model_proto = protobuilder.build_model_proto(ir) + onnx.checker.check_model(model_proto) + + def test_concat_in_nested_function_serialize(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[M] y) => (float[Z] z) + { + z = pkg.custom.afunction (x, y) + } + + afunction (x, y) => (z) + { + z = pkg.custom.nestedfunction(x, y) + } + + nestedfunction (x, y) => (z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + # Tests related to IR + self.assertEqual(count, 1) + self.assertEqual(len(ir.functions), 2) + self.assertEqual(len(ir.functions[0].nodes), 1) + self.assertEqual(len(ir.functions[1].nodes), 1) + self.assertEqual(ir.functions[0].nodes[0].op_type, "nestedfunction") + self.assertEqual(ir.functions[1].nodes[0].op_type, "Concat") + # Tests related to serialization to ModelProto + model_proto = protobuilder.build_model_proto(ir) + onnx.checker.check_model(model_proto) + + +class ControlFlowSerializeTest(unittest.TestCase): + def test_conditional_serialize(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] y) + { + f = Constant () + t = Constant () + y1 = local.myfun (f, x) + y = local.myfun (t, y1) + } + + myfun (b, lx) => (ly) + { + ly = If (b) < + then_branch = g1 () => (float[N] z_then) + { + two = Constant () + z_then = Mul (lx, two) + }, + else_branch = g2 () => (float[N] z_else) + { + three = Constant () + z_else = Mul (lx, three) + } + > + } + """ + ) + ir = irbuilder.build_ir(model) + # Tests related to serialization to ModelProto + model_proto = protobuilder.build_model_proto(ir) + onnx.checker.check_model(model_proto) + + def test_function_attribute_serialize(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] y) + { + f = Constant () + t = Constant () + y1 = local.myfun (f, x) + y = local.myfun (t, y1) + } + + myfun (l, lx) => (ly) + { + ly = Mul (l, lx) + } + """ + ) + ir = irbuilder.build_ir(model) + model_proto = protobuilder.build_model_proto(ir) + onnx.checker.check_model(model_proto) + function_proto = model_proto.functions[0] + self.assertEqual(function_proto.attribute, ["a"]) + self.assertEqual(len(function_proto.attribute_proto), 1) + b_attr_proto = function_proto.attribute_proto[0] + self.assertEqual(b_attr_proto.name, "b") + self.assertEqual(b_attr_proto.type, onnx.AttributeProto.INT) + self.assertEqual(b_attr_proto.i, 1) + + def test_com_microsoft_opset_is_supported_in_protobuilder(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + image_reshape = Reshape (image, shape_a) + instance_norm = InstanceNormalization (image_reshape, scale, B) + shape_b = Constant() + instance_norm_reshape = Reshape (instance_norm, shape_b) + mul_output = Mul (instance_norm_reshape, weight) + output = Add (mul_output, bias) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight = np.random.rand(320, 1, 1).astype(np.float16) + bias = np.random.rand(320, 1, 1).astype(np.float16) + model.graph.initializer.extend( + [ + onnx.helper.make_tensor( + "scale", + onnx.TensorProto.FLOAT16, + [32], + np.ones(32, dtype=np.float16), + ), + onnx.helper.make_tensor( + "B", onnx.TensorProto.FLOAT16, [32], np.zeros(32, dtype=np.float16) + ), + onnx.helper.make_tensor( + "weight", onnx.TensorProto.FLOAT16, [320, 1, 1], weight + ), + onnx.helper.make_tensor("bias", onnx.TensorProto.FLOAT16, [320, 1, 1], bias), + ] + ) + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 1) + model_proto = protobuilder.build_model_proto(ir) + onnx.checker.check_model(model_proto) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/_legacy_ir/visitor.py b/onnxscript/_legacy_ir/visitor.py new file mode 100644 index 000000000..4895e60a2 --- /dev/null +++ b/onnxscript/_legacy_ir/visitor.py @@ -0,0 +1,922 @@ +from __future__ import annotations + +import dataclasses +import logging +from typing import Any, Sequence + +import numpy as np +import onnx + +import onnxscript._legacy_ir as ir +from onnxscript.utils.utils import ( + get_initializer_type, + is_control_flow_op, + normalize_domain, +) + +logger = logging.getLogger(__name__) + + +def _override_inferred_value_type_with_symbolic_value_type( + symbolic_value: ir.Value | None, + inferred_value: ir.Value | None, +) -> ir.Value | None: + if inferred_value is not None and symbolic_value is not None: + inferred_value.type = symbolic_value.type + if inferred_value is None: + inferred_value = symbolic_value + return inferred_value + + +def is_local_function_node( + node: onnx.NodeProto, functions: dict[ir.FunctionId, onnx.FunctionProto] +) -> bool: + return ir.get_function_id_from_node(node) in functions + + +class FunctionShapeEnv: + def __init__(self): + # Mapping from (domain, function_name, overload) to {value_name: ir_value} + self._function_values: dict[ir.FunctionId, dict[str, ir.Value]] = {} + + def load_from_model_proto(self, model_proto: onnx.ModelProto) -> None: + for value_info in model_proto.graph.value_info: + self.load_from_value_info(value_info) + + def save_to_model_proto(self, model_proto: onnx.ModelProto) -> None: + for ( + domain, + function_name, + overload, + ), named_ir_values in self._function_values.items(): + for ir_value in named_ir_values.values(): + if ( + value_info := self.save_to_value_info( + ir_value, domain, function_name, overload + ) + ) is not None: + model_proto.graph.value_info.append(value_info) + + def load_from_value_info(self, value_info: onnx.ValueInfoProto) -> None: + function_id, ir_value = self.process_value_info(value_info) + if function_id is not None: + logger.debug( + "Loads torch symbolic value info '%s'.", + value_info.name, + ) + self._function_values.setdefault(function_id, {})[ir_value.name] = ir_value + + def process_value_info( + self, value_info: onnx.ValueInfoProto + ) -> tuple[ir.FunctionId | None, ir.Value]: + name = value_info.name + if len(splits := name.split("/")) == 2: + # Experimental function value info format. + # To be deprecated after ONNX 1.16, where value_info is introduced in FunctionProto. + function_id, value_name = splits + splits = function_id.split("::") + domain, function_name = splits[0], splits[1] + # 'overload' is introduced in ONNX 1.16, consider it as empty string prior to that. + # The code is for future proof, in case overload is encoded in this format. + overload = "" + if len(splits) == 3: + overload = splits[2] + function_id = (domain, function_name, overload) + else: + # Standard main graph value info format. + function_id = None + value_name = name + return function_id, ir.Value(value_name, type=value_info.type) + + def save_to_value_info( + self, value: ir.Value, domain: str, function_name: str, overload: str + ) -> onnx.ValueInfoProto | None: + if overload != "": + raise NotImplementedError("Overload is not supported yet.") + function_id = f"{domain}::{function_name}" + + if value.type is not None: + return onnx.helper.make_value_info(f"{function_id}/{value.name}", value.type) + return None + + def lookup(self, function: onnx.FunctionProto, value_name: str) -> ir.Value | None: + """Lookup ir value of 'value_name' inside 'function'.""" + function_id = ir.get_function_id(function) + function_values = self._function_values.get(function_id) + if function_values is None or (ir_value := function_values.get(value_name)) is None: + logger.debug( + "Lookup Missed %s torch symbolic value info in function %s::%s.", + value_name, + function.domain, + function.name, + ) + return None + logger.debug( + "Lookup found %s torch symbolic value info in function %s::%s.", + value_name, + function.domain, + function.name, + ) + return ir_value + + def bind(self, value: ir.Value, domain: str, function_name: str, overload: str) -> None: + """Bind ir value 'value' to 'value_name' inside 'function'.""" + function_id = (domain, function_name, overload) + self._function_values.setdefault(function_id, {})[value.name] = value + + def get_ir_values(self, function: onnx.FunctionProto) -> dict[str, ir.Value]: + """Get all ir values inside 'function'.""" + function_id = ir.get_function_id(function) + return self._function_values.get(function_id, {}) + + +class SubScope: + values: dict[str, ir.Value] + ref_attributes: dict[str, onnx.AttributeProto] + owner: onnx.GraphProto | onnx.FunctionProto + + def __init__(self, owner: onnx.GraphProto | onnx.FunctionProto): + self.values = {} + self.ref_attributes = {} + self.owner = owner + + def lookup(self, name: str) -> ir.Value | None: + return self.values.get(name) + + def bind(self, name: str, value: ir.Value) -> None: + self.values[name] = value + + def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None: + return self.ref_attributes.get(ref_attr_name) + + def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None: + self.ref_attributes[ref_attr_name] = attr + + def readable_strs(self, indent: int = 0) -> list[str]: + indent_str = " " * indent + strs = [] + if isinstance(self.owner, onnx.GraphProto): + strs.append(f"Graph {self.owner.name}:") + else: + strs.append(f"Function {self.owner.name}:") + strs.append(" ir.Values:") + for name, value in self.values.items(): + strs.append(f" {name}: {value}") + strs.append(" RefAttributes:") + for name, attr in self.ref_attributes.items(): + strs.append(f" {name}: {attr}") + + return [f"{indent_str}{s}" for s in strs] + + def __str__(self) -> str: + return "\n".join(self.readable_strs()) + + +@dataclasses.dataclass +class Scope: + _sub_scopes: list[SubScope] = dataclasses.field(default_factory=list) + + def lookup(self, name: str) -> ir.Value | None: + """Lookup value by name from all SubScopes.""" + for sub_scope in reversed(self._sub_scopes): + if (result := sub_scope.lookup(name)) is not None: + return result + return None + + def bind(self, name: str, value: ir.Value) -> None: + """Bind value to name in the most recent SubScope.""" + if name == "": + raise ValueError("Cannot bind to empty name.") + if value is None: + raise ValueError(f"Cannot bind None to value {name}.") + self._sub_scopes[-1].bind(name, value) + + def lookup_or_create(self, name: str) -> ir.Value: + """Lookup value by name from all SubScopes. If not found, create a new one in most recent SubScope.""" + if name == "": + raise ValueError("Cannot lookup or create empty name.") + for sub_scope in reversed(self._sub_scopes): + if (result := sub_scope.lookup(name)) is not None: + return result + value = ir.Value(name=name) + self.bind(name, value) + return value + + def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None: + for sub_scope in reversed(self._sub_scopes): + if (result := sub_scope.lookup_ref_attribute(ref_attr_name)) is not None: + return result + return None + + def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None: + self._sub_scopes[-1].bind_ref_attribute(ref_attr_name, attr) + + def enter_sub_scope(self, owner: onnx.GraphProto) -> None: + self._sub_scopes.append(SubScope(owner)) + + def exit_sub_scope(self) -> SubScope: + return self._sub_scopes.pop() + + def current_function_scope(self) -> SubScope | None: + if len(self._sub_scopes) == 0: + return None + if isinstance(self._sub_scopes[0].owner, onnx.FunctionProto): + return self._sub_scopes[0] + return None + + def current_function(self) -> onnx.FunctionProto | None: + current_function_scope = self.current_function_scope() + if current_function_scope is not None: + return current_function_scope.owner + return None + + def current_graph(self) -> onnx.GraphProto | None: + for sub_scope in reversed(self._sub_scopes): + if isinstance(sub_scope.owner, onnx.GraphProto): + return sub_scope.owner + return None + + def readable_strs(self, indent: int = 0) -> list[str]: + indent_str = " " * indent + strs = [] + for i, sub_scope in enumerate(self._sub_scopes): + strs.append(f"SubScope {i}:") + strs.extend(sub_scope.readable_strs(indent=indent + 2)) + return [f"{indent_str}{s}" for s in strs] + + def __str__(self) -> str: + return "\n".join(self.readable_strs()) + + +@dataclasses.dataclass +class ScopeStack: + """Stack of scopes. + + Each Scope represents statically-nested SubScopes (where inner SubScopes can access names defined in outer SubScopes) + produced by subgraphs (occurring as attribute values), except for the first SubScope which could be produced by a function. + With a ScopeStack, there is no such possibility of referencing variables defined higher up in the stack by name. + Instead, it is meant to represent a sequence of (nested) function-calls. Each entry in the stack (except the outermost) + represents a call to a function. + + Thus, we would use a ScopeStack for a context-sensitive analysis (where we recursively process a called function). + For a context-insensitive analysis, we would only need a Scope (where we recursively process subgraphs). + + To debug, `print(scope_stack)` will print the scope structure as well as the info stored + in each scope. + """ + + _scopes: list[Scope] = dataclasses.field(default_factory=lambda: [Scope()]) + + def current_scope(self) -> Scope: + return self._scopes[-1] + + def lookup(self, name: str) -> ir.Value | None: + """Lookup value by name from the current Scope.""" + return self.current_scope().lookup(name) + + def bind(self, name: str, value: ir.Value) -> None: + """Bind value to name in the current Scope.""" + self.current_scope().bind(name, value) + + def lookup_or_create(self, name: str) -> ir.Value: + """Lookup value by name from the current Scope. If not found, create a new one.""" + return self.current_scope().lookup_or_create(name) + + def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None: + return self.current_scope().lookup_ref_attribute(ref_attr_name) + + def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None: + self.current_scope().bind_ref_attribute(ref_attr_name, attr) + + def enter_graph_scope(self, graph: onnx.GraphProto) -> None: + self.current_scope().enter_sub_scope(graph) + + def exit_graph_scope(self) -> SubScope: + sub_scope = self.current_scope().exit_sub_scope() + assert isinstance(sub_scope.owner, onnx.GraphProto), "Expected graph scope." + return sub_scope + + def enter_function_scope(self, function: onnx.FunctionProto) -> None: + self._scopes.append(Scope()) + self.current_scope().enter_sub_scope(function) + + def exit_function_scope(self) -> SubScope: + sub_scope = self.current_scope().exit_sub_scope() + assert isinstance(sub_scope.owner, onnx.FunctionProto), "Expected function scope." + self._scopes.pop() + return sub_scope + + def current_function(self) -> onnx.FunctionProto | None: + return self.current_scope().current_function() + + def current_graph(self) -> onnx.GraphProto | None: + return self.current_scope().current_graph() + + def __str__(self) -> str: + strs = ["ScopeStach:"] + for i, scope in enumerate(self._scopes): + strs.append(f" Scope {i}:") + strs.extend(scope.readable_strs(indent=2)) + return "\n".join(strs) + + +class ProtoVisitorCore: + def visit_model(self, model: onnx.ModelProto): + self.process_model(model) + for opset in model.opset_import: + self.process_opset_import(opset) + self.visit_graph(model.graph) + for function in model.functions: + self.visit_function(function) + + def process_model(self, model: onnx.ModelProto): + pass + + def process_opset_import(self, opset: onnx.OperatorSetIdProto): + pass + + def visit_graph(self, graph: onnx.GraphProto): + self.enter_scope(graph) + self.process_graph(graph) + for input in graph.input: + self.process_graph_input(input) + for init in graph.initializer: + self.process_initializer(init) + for value_info in graph.value_info: + self.process_value_info(value_info) + for node in graph.node: + self.visit_node(node) + for output in graph.output: + self.process_graph_output(output) + self.exit_scope(graph) + + def visit_function(self, function: onnx.FunctionProto): + self.enter_function_scope(function) + self.process_function(function) + for input in function.input: + self.process_function_input(input) + for node in function.node: + self.visit_node(node) + for output in function.output: + self.process_function_output(output) + self.exit_function_scope(function) + + def process_function_input(self, input: str): + pass + + def process_function_output(self, output: str): + pass + + def process_function(self, function: onnx.FunctionProto): + pass + + def enter_function_scope(self, function: onnx.FunctionProto): + pass + + def exit_function_scope(self, function: onnx.FunctionProto) -> SubScope: + pass + + def enter_scope(self, graph: onnx.GraphProto): + pass + + def process_graph(self, graph: onnx.GraphProto): + pass + + def exit_scope(self, graph: onnx.GraphProto) -> SubScope: + pass + + def process_graph_input(self, input: onnx.ValueInfoProto): + pass + + def process_initializer(self, init: onnx.TensorProto): + pass + + def process_value_info(self, value_info: onnx.ValueInfoProto): + pass + + def visit_node(self, node: onnx.NodeProto): + self.process_node(node) + for attr in node.attribute: + self.visit_attribute(attr) + + def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: + pass + + def process_graph_output(self, output: onnx.ValueInfoProto): + pass + + def visit_attribute(self, attr: onnx.AttributeProto): + self.process_attribute(attr) + if attr.HasField("g"): + self.visit_graph(attr.g) + elif len(attr.graphs) > 0: + for graph in attr.graphs: + self.visit_graph(graph) + + def process_attribute(self, attr: onnx.AttributeProto): + pass + + +class ProtoVisitor(ProtoVisitorCore): + def __init__( + self, external_data_folder: str = "", *, do_shape_inference: bool = False + ) -> None: + super().__init__() + self.scopes = ScopeStack() + self.function_shape_env = FunctionShapeEnv() + self.version_map = {} # Map from domain to version + self.do_shape_inference = do_shape_inference + self.external_data_folder = external_data_folder + self.modified = False + + def process_opset_import(self, opset: onnx.OperatorSetIdProto): + domain = normalize_domain(opset.domain) + self.version_map[domain] = opset.version + + def lookup_version(self, domain: str) -> int: + domain = normalize_domain(domain) + return self.version_map.get(domain, 1) # TODO: handle missing domain + + def lookup(self, name: str) -> ir.Value | None: + if name == "": + return None + if (result := self.scopes.lookup(name)) is None: + logger.debug("Lookup value %s unfound.", name) + raise ValueError( + f"Undefined variable {name}.\n" + f"Available variables: {self.scopes.current_scope()}" + ) + logger.debug("Lookup value %s. Shape %s", name, result.tensor_shape_proto()) + return result + + def bind(self, name: str, value: ir.Value) -> None: + logger.debug("Binding value %s. Shape %s", name, value.tensor_shape_proto()) + self.scopes.bind(name, value) + + def lookup_or_create(self, name: str) -> ir.Value: + return self.scopes.lookup_or_create(name) + + def has_input(self, node: onnx.NodeProto, index: int) -> bool: + return index < len(node.input) and node.input[index] != "" + + # TODO: Cleanup handling of undefined variables. May fail in some of methods below. + + def get_input(self, node: onnx.NodeProto, index: int) -> ir.Value | None: + if index < len(node.input): + return self.lookup(node.input[index]) + return None + + def input_type(self, node: onnx.NodeProto, index: int) -> onnx.TypeProto | None: + info = self.get_input(node, index) + return info.type if info is not None else None + + def input_element_type(self, node: onnx.NodeProto, index: int) -> int | None: + info = self.get_input(node, index) + return info.element_type if info is not None else None + + def input_shape(self, node: onnx.NodeProto, index: int) -> onnx.TensorShapeProto | None: + info = self.get_input(node, index) + return info.tensor_shape_proto() if info is not None else None + + def input_const_value(self, node: onnx.NodeProto, index: int) -> Any: + if not self.has_input(node, index): + return None # This is treated as a known constant value "None" + info = self.get_input(node, index) + return info.value + + def has_output(self, node: onnx.NodeProto, index: int) -> bool: + return index < len(node.output) and node.output[index] != "" + + def get_output(self, node: onnx.NodeProto, index: int) -> ir.Value | None: + if index < len(node.output): + return self.lookup(node.output[index]) + return None + + def get_input_value( + self, node: onnx.NodeProto, index: int, default: Any | None = None + ) -> Any | None: + info = self.get_input(node, index) + if info is not None: + return info.value + return default + + def get_input_type( + self, node: onnx.NodeProto, index: int, default: onnx.TypeProto | None = None + ) -> onnx.TypeProto | None: + info = self.get_input(node, index) + if info is not None: + return info.type + return default + + def enter_scope(self, graph: onnx.GraphProto): + logger.debug("enter_scope: graph %s", graph.name) + self.scopes.enter_graph_scope(graph) + + def exit_scope(self, graph: onnx.GraphProto) -> SubScope: + logger.debug("exit_scope: graph %s", graph.name) + return self.scopes.exit_graph_scope() + + def enter_function_scope(self, function: onnx.FunctionProto): + logger.debug("enter_function_scope: function %s", function.name) + self.scopes.enter_function_scope(function) + ir_values = self.function_shape_env.get_ir_values(function) + for name, ir_value in ir_values.items(): + inferred_ir_value = self.lookup_or_create(name) + updated_ir_value = _override_inferred_value_type_with_symbolic_value_type( + ir_value, inferred_ir_value + ) + self.bind(name, updated_ir_value) + + def exit_function_scope(self, function: onnx.FunctionProto) -> SubScope: + logger.debug("exit_function_scope: function %s", function.name) + # Sync ir value back to function_shape_env + function_scope = self.scopes.exit_function_scope() + for ir_value in function_scope.values.values(): + self.function_shape_env.bind(ir_value, *ir.get_function_id(function)) + return function_scope + + def process_initializer(self, init: onnx.TensorProto): + array = onnx.numpy_helper.to_array(init, self.external_data_folder) + self.bind( + init.name, + ir.Value(name=init.name, value=array, type=get_initializer_type(init)), + ) + + def process_graph_input(self, input: onnx.ValueInfoProto): + self.bind(input.name, ir.Value(name=input.name, type=input.type)) + + def process_value_info(self, value_info: onnx.ValueInfoProto): + logger.debug("process_value_info: %s", value_info) + value = self.lookup_or_create(value_info.name) + value.type = value_info.type + # Populate function shape environment + self.function_shape_env.load_from_value_info(value_info) + + def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: + output_types = {} + if self.do_shape_inference and not is_control_flow_op(node): + # Control-flow ops are more complicated. Not supported here yet. + # TODO: handle optional inputs + def get_constant_value(i: int) -> onnx.TensorProto | None: + value = self.input_const_value(node, i) + if isinstance(value, np.ndarray) and value.size < 20: + return onnx.numpy_helper.from_array(value, node.input[i]) + return None + + input_types = {x: self.input_type(node, i) for i, x in enumerate(node.input)} + input_data = {x: get_constant_value(i) for i, x in enumerate(node.input)} + input_data = {k: v for k, v in input_data.items() if v is not None} + if any(t is None for t in input_types.values()): + logger.debug( + "Skipping shape inference for node %s due to missing input type.", + node.name, + ) + else: + # TODO: pass in constant values, ir_version + try: + schema = onnx.defs.get_schema( + node.op_type, self.lookup_version(node.domain), node.domain + ) + output_types = onnx.shape_inference.infer_node_outputs( + schema, node, input_types, input_data + ) + except Exception as e: + logger.debug( + "Skipping shape inference for node %s due to exception: %s", + node.name, + e, + ) + + for output in node.output: + info = self.lookup_or_create(output) + if output in output_types: + # TODO: merge types + info.type = output_types[output] + + +class ProtoTransformer(ProtoVisitor): + # TODO(lowpri) Practically this is useless. + # Subgraph only exist in 'if' nodes. 'if' nodes only exist in torchlib functions. + # There is no pre-existing value_info in torchlib functions. + # def exit_scope(self, graph: onnx.GraphProto) -> SubScope: + # # Also sync updated ir values back to value_info in graph. + # sub_scope = super().exit_scope(graph) + + def visit_node(self, node: onnx.NodeProto) -> list[onnx.NodeProto] | None: + replacement = self.process_node(node) + logger.debug( + "visit_node: %s::%s %s replacement %s", + node.domain, + node.op_type, + node.name, + "found" if replacement is not None else "missed", + ) + if replacement is None: + # No change. Process attributes. + for attr in node.attribute: + self.visit_attribute(attr) + return None + else: + self.modified = True + # We recursively visit the replacement nodes. + result = [] + for newnode in replacement: + n = self.visit_node(newnode) + if n is not None: + result.extend(n) + else: + result.append(newnode) + return result + + def visit_graph(self, graph: onnx.GraphProto) -> dict[str, ir.Value]: + self.enter_scope(graph) + self.process_graph(graph) + for input in graph.input: + self.process_graph_input(input) + for init in graph.initializer: + self.process_initializer(init) + for value_info in graph.value_info: + self.process_value_info(value_info) + updates = [] + nodes = graph.node + for i, node in enumerate(nodes): + replacement = self.visit_node(node) + if replacement is not None: + updates.append((i, replacement)) + for i, replacement in reversed(updates): + old_node_name = nodes[i].name + del nodes[i] + for newnode in reversed(replacement): + logger.debug( + "Replacement node %s for %s. Size %s", + newnode.name, + old_node_name, + newnode.ByteSize(), + ) + nodes.insert(i, newnode) + for output in graph.output: + self.process_graph_output(output) + return self.exit_scope(graph) + + +class FunctionCallsiteAnalysis(ProtoVisitor): + """Collects the callsites of each function.""" + + def __init__(self): + super().__init__() + self.functions: dict[ir.FunctionId, onnx.FunctionProto] = {} + self.function_calls: dict[ir.FunctionId, list[onnx.NodeProto]] = {} + + def visit_function(self, function: onnx.FunctionProto): + # Do not visit function via model.functions. + # Only visit function at callsites. + # The purpose of this analysis is to collect the callsites of each function. + pass + + def visit_node(self, node: onnx.NodeProto) -> None: + if is_local_function_node(node, self.functions): + function_id = ir.get_function_id_from_node(node) + self.function_calls.setdefault(function_id, []).append(node) + for subnode in self.functions[function_id].node: + self.visit_node(subnode) + + def visit_model(self, model: onnx.ModelProto) -> None: + for function in model.functions: + self.functions[ir.get_function_id(function)] = function + + super().visit_model(model) + + +class FunctionRenamer: + _POSTFIX_FORMAT = "{name}|{postfix}_{count}" + + def __init__(self, postfix="folded"): + self._function_key_to_instance_count = {} + self._postfix = postfix + + def rename(self, function: onnx.FunctionProto) -> None: + domain = function.domain + name = function.name + key = (domain, name) + self._function_key_to_instance_count.setdefault(key, 0) + function.name = self._POSTFIX_FORMAT.format( + name=name, + postfix=self._postfix, + count=self._function_key_to_instance_count[key], + ) + self._function_key_to_instance_count[key] += 1 + + +class FunctionCallsiteProtoTransformer(ProtoTransformer): + """Unlike other base visitors, this is a special visitor that visits functions at their callsite. + + This allows transforming and constructing specialized functions based on callsite context. + """ + + _functions: dict[ir.FunctionId, onnx.FunctionProto] + _function_callsites: dict[ir.FunctionId, list[onnx.NodeProto]] + _new_functions: list[onnx.FunctionProto] + _function_renamer: FunctionRenamer + + def _gather_function_metadata(self, model: onnx.ModelProto): + analysis = FunctionCallsiteAnalysis() + analysis.visit_model(model) + self._functions = analysis.functions + self._function_callsites = analysis.function_calls + self._new_functions = [] + self._function_renamer = FunctionRenamer() + + def process_function_outputs(self, function: onnx.FunctionProto) -> bool: + """Process function outputs. + + This method is called when a function is visited at its callsite. + + Returns: + True if the function outputs are modified. + """ + del function # Unused + return False + + def process_function_node_outputs( + self, + node: onnx.NodeProto, + function_scope: SubScope, + ) -> None: + """Fetch value infos of function output to re-bind them for function node output.""" + function = function_scope.owner + output_values = [function_scope.lookup(output) for output in function.output] + for actual_name, formal_value in zip(node.output, output_values): + if formal_value is None: + raise RuntimeError( + "Missing output %s in function-call to %s", + actual_name, + node.op_type, + ) + actual_value = self.lookup_or_create(actual_name) + actual_value.identity_merge_from(formal_value) + if logger.level <= logging.INFO: + logger.info( + "Binding outputs for function %s. %s => %s", + function.name, + actual_value, + node.output, + ) + + def lookup_ref_attribute(self, ref_attr_name: str) -> onnx.AttributeProto | None: + return self.scopes.lookup_ref_attribute(ref_attr_name) + + def bind_ref_attribute(self, ref_attr_name: str, attr: onnx.AttributeProto) -> None: + self.scopes.bind_ref_attribute(ref_attr_name, attr) + + def visit_model(self, model: onnx.ModelProto): + self._gather_function_metadata(model) + + self.process_model(model) + for opset in model.opset_import: + self.process_opset_import(opset) + self.visit_graph(model.graph) + + for new_function in self._new_functions: + model.functions.append(new_function) + + self.function_shape_env.save_to_model_proto(model) + + def visit_node(self, node: onnx.NodeProto) -> list[onnx.NodeProto] | None: + if is_local_function_node(node, self._functions): + function_id = ir.get_function_id_from_node(node) + if function_id not in self._functions: + # Do not recursively visit new functions. + return None + replacement, _ = self.process_function_node(node) + else: + replacement = self.process_node(node) + logger.debug( + "visit_node: %s::%s %s replacement %s", + node.domain, + node.op_type, + node.name, + "found" if replacement is not None else "missed", + ) + if replacement is None: + # No change. Process attributes. + for attr in node.attribute: + self.visit_attribute(attr) + return None + else: + self.modified = True + # We recursively visit the replacement nodes. + result = [] + for newnode in replacement: + n = self.visit_node(newnode) + if n is not None: + result.extend(n) + else: + result.append(newnode) + return result + + def process_function_node( + self, node: onnx.NodeProto + ) -> tuple[list[onnx.NodeProto] | None, onnx.FunctionProto | None]: + function_id = ir.get_function_id_from_node(node) + function = self._functions[function_id] + + is_unique_callsite = len(self._function_callsites[function_id]) == 1 + if not is_unique_callsite: + mutable_function = onnx.FunctionProto() + mutable_function.CopyFrom(function) + else: + mutable_function = function + + logger.info("Visit function %s node %s", function_id, node.name) + actual_input_value_infos = [self.lookup(input) for input in node.input] + # Handle omitted inputs, these are considered optional inputs of the function. + actual_input_value_infos.extend( + [None] * (len(function.input) - len(actual_input_value_infos)) + ) + ref_attributes = { + attr_proto.name: self.lookup_ref_attribute(attr_proto.ref_attr_name) + for attr_proto in node.attribute + if attr_proto.ref_attr_name + } + + self.enter_function_scope(mutable_function) + if logger.level <= logging.INFO: + printable_actual_input_value_infos = [str(x) for x in actual_input_value_infos] + logger.info( + "Actual input value infos: %s", + printable_actual_input_value_infos, + ) + logger.info("Enter function scope: %s", self.scopes.current_scope()) + + logger.debug("Binding inputs for function %s", function.name) + for actual_input_value_info, formal_input in zip( + actual_input_value_infos, function.input + ): + formal_info = ir.Value(formal_input) + if actual_input_value_info is not None: + formal_info.identity_merge_from(actual_input_value_info) + self.bind(formal_input, formal_info) + + for attr_proto in function.attribute_proto: + # Default value of function attributes. + self.bind_ref_attribute(attr_proto.name, attr_proto) + + for attr_proto in node.attribute: + if attr_proto.ref_attr_name: + concrete_attribute = ref_attributes.get(attr_proto.name) + if concrete_attribute is None: + continue + self.bind_ref_attribute(attr_proto.name, concrete_attribute) + else: + self.bind_ref_attribute(attr_proto.name, attr_proto) + + # Visit inner function nodes. + node_updates: list[tuple[int, list[onnx.NodeProto]]] = [] + nodes = mutable_function.node + for i, inner_node in enumerate(nodes): + replacement = self.visit_node(inner_node) + if replacement is not None: + node_updates.append((i, replacement)) + for i, replacement in reversed(node_updates): + old_node_name = nodes[i].name + old_node_op_type = nodes[i].op_type + del nodes[i] + for newnode in reversed(replacement): + logger.debug( + "Replacement node inside function %s: %s for %s %s. Size %s", + node.name, + newnode.output, + old_node_name, + old_node_op_type, + newnode.ByteSize(), + ) + nodes.insert(i, newnode) + added_domains = set() + del mutable_function.opset_import[:] + for inner_node in nodes: + # Update opset_import if needed. + if inner_node.domain not in added_domains: + version = self.lookup_version(inner_node.domain) + mutable_function.opset_import.append( + onnx.OperatorSetIdProto(domain=inner_node.domain, version=version) + ) + added_domains.add(inner_node.domain) + + output_updates = self.process_function_outputs(mutable_function) + + is_new_function = not is_unique_callsite and (node_updates or output_updates) + if is_new_function: + self._new_functions.append(mutable_function) + self._function_renamer.rename(mutable_function) + node.op_type = mutable_function.name + + function_scope = self.exit_function_scope(mutable_function) + + self.process_function_node_outputs(node, function_scope) + + logger.info("Exit function scope: %s", function_scope) + logger.info("Exit function %s node %s", function_id, node.name) + + if is_new_function: + return [node], mutable_function + return None, None diff --git a/onnxscript/_legacy_ir/visitor_test.py b/onnxscript/_legacy_ir/visitor_test.py new file mode 100644 index 000000000..e4559472e --- /dev/null +++ b/onnxscript/_legacy_ir/visitor_test.py @@ -0,0 +1,38 @@ +import unittest + +import onnx + +from onnxscript._legacy_ir import visitor + + +class FunctionCallsiteProtoTransformerTest(unittest.TestCase): + def test_function_optional_input_is_recorded_by_shape_env(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + z = custom.function(x) + } + < + domain: "custom", + opset_import: ["" : 18] + > + function (x, optional_y, optional_z) => (return_val) + { + return_val = custom.custom_op (x, optional_y, optional_z) + } + """ + ) + + model_visitor = visitor.FunctionCallsiteProtoTransformer() + model_visitor.visit_model(model) + self.assertIsNotNone( + model_visitor.function_shape_env.lookup(model.functions[0], "optional_y") + ) + self.assertIsNotNone( + model_visitor.function_shape_env.lookup(model.functions[0], "optional_z") + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index fac474312..efcc8ae8a 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -20,7 +20,7 @@ import onnxscript.testing import onnxscript.values from onnxscript.backend import onnx_backend, onnx_export -from onnxscript.tests.models import type_double +from tests.models import type_double @dataclasses.dataclass @@ -112,7 +112,7 @@ def extract_functions(name: str, content: str, test_folder: pathlib.Path): init.touch(exist_ok=True) file = test_folder / f"{name}.py" file.write_text(content, encoding="utf-8") - import_name = f"onnxscript.tests.{test_folder.parts[-1]}.{name}" + import_name = f"tests.{test_folder.parts[-1]}.{name}" try: mod = importlib.import_module(import_name) except (SyntaxError, ImportError) as e: @@ -133,7 +133,7 @@ def exec_main(f, *inputs): class TestOnnxBackEnd(unittest.TestCase): - root_folder = pathlib.Path(__file__).parent.parent + root_folder = pathlib.Path(__file__).parent.parent.parent test_folder = root_folder / "tests" / "onnx_backend_test_code" temp_folder = root_folder / "tests" / "export" diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index bcbfdd625..0bfcf1d38 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -29,7 +29,7 @@ from onnxscript import BOOL, FLOAT, INT64, converter, graph, script, tensor from onnxscript.onnx_opset import opset11 as op11 from onnxscript.onnx_opset import opset15 as op -from onnxscript.tests.common import onnx_script_test_case, testutils +from tests.common import onnx_script_test_case, testutils TEST_INPUT_DIR = pathlib.Path(__file__).parent / "tests" / "models" TEST_OUTPUT_DIR = TEST_INPUT_DIR / "testoutputs" @@ -132,7 +132,7 @@ def validate_run(self, script_tests): self.check_run(val.function, val.input, val.output[0]) def test_eager_op(self): - from onnxscript.tests.models import eager_op + from tests.models import eager_op test_functions = self.validate_save(eager_op, check_ort=True) @@ -195,39 +195,39 @@ def cast_add(x, y): self.assertEqual(output_value_info.type.tensor_type.elem_type, onnx.TensorProto.FLOAT) def test_onnxfns1(self): - from onnxscript.tests.models import onnxfns1 + from tests.models import onnxfns1 self.validate(onnxfns1) def test_onnxfns1A(self): - from onnxscript.tests.models import onnxfns1A + from tests.models import onnxfns1A self.validate(onnxfns1A) def test_ort_custom_ops(self): - from onnxscript.tests.functions import ort_custom_ops + from tests.functions import ort_custom_ops self.validate(ort_custom_ops) def test_unary_op(self): - from onnxscript.tests.models import m1 + from tests.models import m1 self.validate_save(m1) def test_subfunction_check_model(self): - from onnxscript.tests.models import subfunction + from tests.models import subfunction model = subfunction.MyElu.function_ir.to_model_proto(producer_name="p2o") model = onnx.shape_inference.infer_shapes(model) onnx.checker.check_model(model) def test_subfunction(self): - from onnxscript.tests.models import subfunction + from tests.models import subfunction self.validate_save(subfunction, check_ort=True) def test_if_models(self): - from onnxscript.tests.models import if_statement + from tests.models import if_statement self.validate_save(if_statement) @@ -246,28 +246,28 @@ def sumprod(x: FLOAT["N"], N: INT64) -> (FLOAT["N"], FLOAT["N"]): # noqa: F821 self.assertEqual(proto.doc_string.strip(), "Combines ReduceSum, ReduceProd.") def test_signal(self): - from onnxscript.tests.models import signal_dft + from tests.models import signal_dft # shape_inference crashes on stft. self.validate_save(signal_dft, shape_inference=False) def test_multi(self): - from onnxscript.tests.models import multi + from tests.models import multi self.validate_save(multi, shape_inference=False) def test_dropout(self): - from onnxscript.tests.models import dropout + from tests.models import dropout self.validate_save(dropout, shape_inference=False) def test_attrref(self): - from onnxscript.tests.models import attrref + from tests.models import attrref self.validate_save(attrref, shape_inference=False) def test_renaming(self): - from onnxscript.tests.models import renaming + from tests.models import renaming self.validate_save(renaming, shape_inference=False) @@ -276,18 +276,18 @@ def test_renaming(self): reason="default_opset must be specified in script for functions that do not contain any use of an ONNX op", ) def test_opt_output(self): - from onnxscript.tests.models import opt_output + from tests.models import opt_output self.validate_save(opt_output, shape_inference=False) def test_opt_input(self): - from onnxscript.tests.models import opt_input + from tests.models import opt_input self.validate_save(opt_input, shape_inference=False) @unittest.skip("A function with attributes cannot be exported as a model.") def test_onnxfns2(self): - from onnxscript.tests.models import onnxfns2 + from tests.models import onnxfns2 self.validate_save(onnxfns2, shape_inference=False) @@ -301,7 +301,7 @@ def clipmax(x: FLOAT, max: FLOAT): self.validate_save(clipmax) def test_type_double(self): - from onnxscript.tests.models import type_double + from tests.models import type_double fcts = self.validate_save(type_double, check_ort=False) f = fcts["double_abs"] @@ -320,17 +320,17 @@ def test_type_double(self): self.validate_save(type_double, check_ort=True) def test_cast_like(self): - from onnxscript.tests.models import cast_like + from tests.models import cast_like self.validate_expansion(cast_like) def test_identity(self): - from onnxscript.tests.models import identity + from tests.models import identity self.validate_expansion(identity) def test_opset_import(self): - from onnxscript.tests.models import different_opset + from tests.models import different_opset fcts = self.validate_save(different_opset, shape_inference=False) s16 = str(fcts["shape_A"]) @@ -345,7 +345,7 @@ def test_opset_import(self): self.assertNotIn("version: 15", sdef) def test_sequences(self): - from onnxscript.tests.models import sequences + from tests.models import sequences test_functions = self.validate_save(sequences, check_ort=True) @@ -372,7 +372,7 @@ def test_sequences(self): np.testing.assert_almost_equal(eager_mode, result) def test_loops_break(self): - from onnxscript.tests.models import loops_break + from tests.models import loops_break test_functions = self.validate_save(loops_break, check_ort=True) self.assertIn("loop1", test_functions) @@ -392,7 +392,7 @@ def test_loops_break(self): self.assertEqual(y.tolist(), [0, 11, -22]) def test_loops_while(self): - from onnxscript.tests.models import loops_while + from tests.models import loops_while test_functions = self.validate_save(loops_while, check_ort=True) self.assertIn("loop1", test_functions) @@ -409,7 +409,7 @@ def test_loops_while(self): self.assertEqual(res.tolist(), [0, 10, -20]) def test_getitem(self): - from onnxscript.tests.models import getitem + from tests.models import getitem self.validate_save(getitem, check_ort=True, skip_check_ort=None) self.validate_run(getitem) @@ -459,28 +459,28 @@ def check_run(self, onnxfn, inputs, expected_output): np.testing.assert_equal(output, expected_output) def test_graph_attr_scan(self): - from onnxscript.tests.models.graph_attr import cumulative_sum + from tests.models.graph_attr import cumulative_sum inputs = [np.array([1, 2, 3, 4, 5], dtype=np.int64)] expected_output = np.array([1, 3, 6, 10, 15], dtype=np.int64) self.check_run(cumulative_sum, inputs, expected_output) def test_graph_attr_loop(self): - from onnxscript.tests.models.graph_attr import sum_to + from tests.models.graph_attr import sum_to inputs = [np.array(6, dtype=np.int64)] expected_output = np.array([0, 1, 3, 6, 10, 15], dtype=np.int64) self.check_run(sum_to, inputs, expected_output) def test_graph_attr_loop_error(self): - from onnxscript.tests.models.graph_attr import sum_to_error + from tests.models.graph_attr import sum_to_error input = np.array(6, dtype=np.int64) with self.assertRaisesRegex(TypeError, "@graph"): sum_to_error(input) def test_loop_outer_scope(self): - from onnxscript.tests.models.graph_attr import loop_add + from tests.models.graph_attr import loop_add input_x = np.array([1, 2, 3], dtype=np.int64) input_m = np.array(3, dtype=np.int64) @@ -504,7 +504,7 @@ def inner(): return op.DummyOp(body=inner) def test_attr(self): - from onnxscript.tests.functions import attr_test + from tests.functions import attr_test self.validate_run(attr_test) diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index de6de9323..cde621b64 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -830,7 +830,7 @@ def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto): new_value_info.pop(input.name, None) for output in onnx_model.graph.output: new_value_info.pop(output.name, None) - for tensor in onnx_model.graph.initializer: + for tensor in onnx_model.graph.initializer: # type: ignore[assignment] new_value_info.pop(tensor.name, None) existing_value_info.update(new_value_info) onnx_model.graph.value_info.extend(existing_value_info.values()) diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 45942e51a..dc29d826c 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -55,7 +55,7 @@ import logging import os import typing -from typing import Any, Mapping, Sequence +from typing import Any, List, Mapping, Sequence import numpy as np import onnx @@ -418,7 +418,7 @@ def _deserialize_graph( def deserialize_function(proto: onnx.FunctionProto) -> _core.Function: inputs = [_core.Input(name) for name in proto.input] values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc] - value_info = {info.name: info for info in proto.value_info} + value_info = {info.name: info for info in getattr(proto, "value_info", [])} # TODO(justinchuby): Handle unsorted nodes nodes = [_deserialize_node(node, [values], value_info=value_info) for node in proto.node] @@ -431,7 +431,9 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function: doc_string=_get_field(proto, "doc_string"), opset_imports=deserialize_opset_import(proto.opset_import), name=( - f"{proto.name}_{proto.domain}" + f"__{proto.overload}" if proto.overload else "" + f"{proto.name}_{proto.domain}" + f"__{proto.overload}" + if hasattr(proto, "overload") and proto.overload + else "" ), ) attributes = [_deserialize_attribute(attr, []) for attr in proto.attribute_proto] @@ -442,9 +444,9 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function: return _core.Function( domain=proto.domain, name=proto.name, - overload=proto.overload, + overload=getattr(proto, "overload", ""), graph=graph, - attributes=typing.cast(list[_core.Attr], attributes), + attributes=typing.cast(List[_core.Attr], attributes), ) @@ -639,7 +641,7 @@ def _deserialize_node( break if not found: raise ValueError( - f"Input '{name}' of node '{proto.name}({proto.domain}::{proto.op_type}:{proto.overload})' not found in any scope" + f"Input '{name}' of node '{proto.name}({proto.domain}::{proto.op_type}:{getattr(proto, 'overload', '')})' not found in any scope" f" (current depth: {len(scoped_values)})" ) node = _core.Node( @@ -647,7 +649,7 @@ def _deserialize_node( proto.op_type, node_inputs, [_deserialize_attribute(a, scoped_values) for a in proto.attribute], - overload=proto.overload, + overload=getattr(proto, "overload", ""), num_outputs=len(proto.output), name=_get_field(proto, "name"), ) @@ -664,7 +666,7 @@ def _deserialize_node( proto.op_type, ) scoped_values[-1][output] = value - for prop in proto.metadata_props: + for prop in getattr(proto, "metadata_props", []): node.metadata_props[prop.key] = prop.value return node diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py new file mode 100644 index 000000000..98c2038d9 --- /dev/null +++ b/onnxscript/optimizer/__init__.py @@ -0,0 +1,110 @@ +import logging +from typing import Any + +import onnx + +from onnxscript import rewriter +from onnxscript.optimizer.constant_folding import fold_constants +from onnxscript.optimizer.copy_propagation import ( + do_copy_propagation, + do_sequence_simplification, +) +from onnxscript.optimizer.remove_unused import remove_unused_nodes +from onnxscript.optimizer.remove_unused_function import remove_unused_functions +from onnxscript.optimizer.simple_function_folding import ( + inline_functions_with_unused_outputs, + inline_simple_functions, +) +from onnxscript.rewriter import ( + broadcast_to_matmul, + cast_constant_of_shape, + gemm_to_matmul_add, + no_op, +) + +logger = logging.getLogger(__name__) + + +def optimize( + model: onnx.ModelProto, + num_iterations: int = 2, + *, + onnx_shape_inference: bool = True, + stop_if_no_change: bool = True, + external_data_folder: str = "", + **kwargs: Any, +) -> onnx.ModelProto: + """Optimize the model. Perform optimizations and clean-ups such as constant folding, dead code elimination, etc. + + Args: + model (onnx.ModelProto): The model to optimize. + num_iterations (int, optional): Number of iterations to perform. + onnx_shape_inference (bool, optional): Whether to perform onnx shape inference on the model. + Set this to False to turn off onnx shape inference, and rely on model carried shapes and types. + This is useful for models produced by PyTorch 2.2+ dynamo onnx exporter, where the model carries + the symbolic shapes recorded from dynamo tracing. + stop_if_no_change (bool, optional): Whether to stop if no change is detected. + external_data_folder (str, optional): The folder to store external data. + **kwargs: Additional keyword arguments. For BC purposes. + """ + if kwargs.pop("function_aware_folding", None) is not None: + logger.warning( + "'function_aware_folding' is deprecated. 'optimize' now supports both fully inlined models and models with functions. " + "To achieve the same behavior as 'function_aware_folding=True' before, set 'onnx_shape_inference=False'. " + "This would turn off incremental onnx shape inference and rely on model carried shapes and types. " + "See 'onnx_shape_inference' for more details." + ) + for _ in range(num_iterations): + if onnx_shape_inference: + model = onnx.shape_inference.infer_shapes( + model, check_type=True, strict_mode=True, data_prop=True + ) + + inline_simple_functions(model) + modified = fold_constants( + model, external_data_folder, onnx_shape_inference=onnx_shape_inference + ) + + remove_unused_nodes(model) + inline_simple_functions(model) + remove_unused_functions(model) + inline_functions_with_unused_outputs(model) + # NOTE: This is general rewrite rules + model = rewriter.rewrite( + model, + pattern_rewrite_rules=[ + *no_op.rules.rules, # TODO: merge this rule into constant folding? + *broadcast_to_matmul.rules.rules, + gemm_to_matmul_add.rule, + *cast_constant_of_shape.rules.rules, + ], + ) + if stop_if_no_change and not modified: + logger.debug("Stopping after %d iterations.", _) + break + + for node in model.graph.node: + logger.debug("Node %s::%s name %s.", node.domain, node.op_type, node.name) + + for function in model.functions: + for node in function.node: + logger.debug( + "Function %s::%s node %s::%s name %s.", + function.domain, + function.name, + node.domain, + node.op_type, + node.name, + ) + + # do_sequence_simplification(model) + return model + + +__all__ = [ + "fold_constants", + "remove_unused_nodes", + "optimize", + "do_copy_propagation", + "do_sequence_simplification", +] diff --git a/onnxscript/optimizer/constant_folding.py b/onnxscript/optimizer/constant_folding.py new file mode 100644 index 000000000..9a51298c7 --- /dev/null +++ b/onnxscript/optimizer/constant_folding.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import logging +from typing import Any, Sequence + +import numpy as np +import onnx +import onnx.reference.ops + +import onnxscript._legacy_ir as ir +from onnxscript._legacy_ir import visitor +from onnxscript.optimizer import evaluator +from onnxscript.utils.utils import ( + is_control_flow_op, + is_onnx_domain, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = 1024 * 1024 + +# Ops excluded from constant-propagation: +# * Random ops, which are not deterministic (checked below) +# * Control flow ops (checked by presence of graph-attribute) + +non_deterministic_ops = frozenset( + { + "RandomUniform", + "RandomNormal", + "RandomUniformLike", + "RandomNormalLike", + "Multinomial", + } +) + +onnx_domain = frozenset({"", "onnx.ai"}) + + +def is_non_deterministic_op(node: onnx.NodeProto) -> bool: + return node.op_type in non_deterministic_ops and is_onnx_domain(node.domain) + + +def is_constant_op(node: onnx.NodeProto) -> bool: + return node.op_type in {"Constant", "ConstantOfShape"} and is_onnx_domain(node.domain) + + +class ConstantFolder(visitor.FunctionCallsiteProtoTransformer): + def __init__( + self, + registry: evaluator.PartialEvaluatorRegistry, + external_data_folder: str, + *, + do_shape_inference: bool, + ) -> None: + self.registry = registry + # TODO: make evaluator a parameter + self.evaluate = evaluator.reference_evaluator.evaluate + self._do_shape_inference = do_shape_inference + self._init() + super().__init__(external_data_folder, do_shape_inference=do_shape_inference) + + def _init(self) -> None: + self.counts = {} + self.sizes = {} + + def add_count(self, op: str, size: int = 1): + self.counts[op] = self.counts.get(op, 0) + 1 + self.sizes[op] = self.sizes.get(op, 0) + size + + def foldable_value(self, name: str, value): + """Checks if a runtime-constant can and should be folded into the graph. + + We fold constants only if they are tensors (not lists of tensors, for example) + and have size below desired limit. + """ + if value is ir.NotConstant: + return None + + if not isinstance(value, np.ndarray): + # ONNX does not have a way to represent non-tensor constants, eg. a sequence. + # So, a constant-value of type sequence is not folded, but it can be used + # to optimize subsequent operations when possible. + logger.warning( + "Skip storing constant folded value %s due to unsupported type %s.", + name, + type(value), + ) + return None + + if value.nbytes > _DEFAULT_CONSTANT_FOLD_SIZE_LIMIT: + logger.warning( + "Skip storing constant folded nvalue %s due to large size %s.", + name, + value.nbytes, + ) + return None + + return onnx.numpy_helper.from_array(value, name) + + def new_constant(self, name, value): + if isinstance(value, (int, float, np.ScalarType)): + value = np.array(value) + + info = self.lookup_or_create(name) + info.value = value + + tensor = self.foldable_value(name, value) + if tensor is None: + return None + + logger.debug( + "New constant for value %s dtype: %s shape: %s", + name, + value.dtype, + value.shape, + ) + info.type = onnx.helper.make_tensor_type_proto( + onnx.helper.np_dtype_to_tensor_dtype(value.dtype), value.shape + ) + node = onnx.helper.make_node("Constant", inputs=[], outputs=[name], value=tensor) + return [node] + + def convert_attributes(self, attributes: Sequence[onnx.AttributeProto]) -> dict[str, Any]: + if self.scopes.current_scope().current_function_scope(): + # Need to resolve ref_attr_name if inside a function. + attr_dict = {} + for attribute in attributes: + concrete_attribute = ( + self.lookup_ref_attribute(attribute.ref_attr_name) + if attribute.ref_attr_name + else attribute + ) + if concrete_attribute is None: + continue + attr_dict[attribute.name] = onnx.helper.get_attribute_value(concrete_attribute) + return attr_dict + return {attr.name: onnx.helper.get_attribute_value(attr) for attr in attributes} + + def replace_copy(self, node: onnx.NodeProto) -> None: + for i in range(len(node.input)): + input = self.get_input(node, i) + if input is not None and input.is_copy(): + old_value = self.lookup_or_create(input.name) + assert isinstance(input.symbolic_value, str) + new_value = self.lookup_or_create(input.symbolic_value) + # Merge meta info. It is important to do if the new value + # is created by evaluator, and thus carries zero meta info. + # Since this is a copy, the meta info should be the same. + new_value.identity_merge_from(old_value) + node.input[i] = input.symbolic_value + + def process_function_outputs(self, function: onnx.FunctionProto) -> bool: + # Resolve copy for function subgraph output. + # Avoid copy of function subgraph input, because it is illegal for a direct edge + # from function input to function output. + prohibited_value_set = set(function.input) + updated = False + for i, output_name in enumerate(function.output): + output = self.lookup(output_name) + if ( + output is not None + and output.is_copy() + and output.symbolic_value not in prohibited_value_set + ): + old_value = self.lookup_or_create(output.name) + assert isinstance(output.symbolic_value, str) + new_value = self.lookup_or_create(output.symbolic_value) + new_value.identity_merge_from(old_value) + function.output[i] = output.symbolic_value + updated = True + return updated + + def process_node(self, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: + self.replace_copy(node) + + super().process_node(node) + + inputs = [self.lookup(x) for x in node.input] + attrs = self.convert_attributes(node.attribute) + + domain = node.domain + op = node.op_type + version = self.lookup_version(domain) + + # if any(x is Undefined for x in inputs): + # return None + # Above check ensures that none of the optimizations below need to handle + # undefined inputs + + op_optimizers = self.registry.lookup_evaluators(domain, op, version) + for optimizer in op_optimizers: + assert optimizer + output = optimizer(self, node) + if output is None: + continue + if isinstance(output, list): + return output + else: + # Currently handles single output only + self.add_count(node.op_type, output.size) + return self.new_constant(node.output[0], output) + + if is_control_flow_op(node) or is_non_deterministic_op(node): + return None + + input_values = [x.value if x is not None else None for x in inputs] + if any(x is ir.NotConstant for x in input_values): + return None + + outputs = self.evaluate(domain, op, version, *input_values, **attrs) + # TODO: what if evaluated value is None? + if outputs is None: + return None + if len(node.output) == 1 and not isinstance(outputs, (tuple, list)): + replacement = self.new_constant(node.output[0], outputs) + if is_constant_op(node): + return None + self.add_count(op, outputs.size) + return replacement + else: + logger.warning("Skipping constant folding for op %s with multiple outputs.", op) + return None + + def process_function_node( + self, node: onnx.NodeProto + ) -> tuple[list[onnx.NodeProto] | None, onnx.FunctionProto | None]: + self.replace_copy(node) + + _, new_function = super().process_function_node(node) + + # Replace function node with Constant if all outputs are constants + ir_values = [self.lookup(output_name) for output_name in node.output] + tensors = [ + self.foldable_value(output_name, ir_value.value if ir_value is not None else None) + for output_name, ir_value in zip(node.output, ir_values) + ] + if all(tensor is not None for tensor in tensors): + replacements = [] + for output_name, tensor in zip(node.output, tensors): + newnode = onnx.helper.make_node( + "Constant", inputs=[], outputs=[output_name], value=tensor + ) + replacements.append(newnode) + logger.debug( + "Function node replacements: node %s %s (%s/%s)", + node.name, + [replacement.output for replacement in replacements], + len(replacements), + len(node.output), + ) + return replacements, new_function + return None, new_function + + def visit_model(self, model: onnx.ModelProto) -> None: + self._init() + + super().visit_model(model) + + +def fold_constants( + model: onnx.ModelProto, + external_data_folder: str = "", + *, + onnx_shape_inference: bool = False, +) -> bool: + """Returns true iff the model was modified.""" + folder = ConstantFolder( + evaluator.registry, + external_data_folder, + do_shape_inference=onnx_shape_inference, + ) + folder.visit_model(model) + for op in folder.counts: + logger.info( + "Constant-folded '%s' %s times, with %s size.", + op, + folder.counts[op], + folder.sizes[op], + ) + return folder.modified diff --git a/onnxscript/optimizer/constant_folding_test.py b/onnxscript/optimizer/constant_folding_test.py new file mode 100644 index 000000000..64a27e33d --- /dev/null +++ b/onnxscript/optimizer/constant_folding_test.py @@ -0,0 +1,444 @@ +import unittest + +import onnx +import pytest + +from onnxscript import optimizer + + +class FoldConstantsTest(unittest.TestCase): + def test_fold_add(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(x, four) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(optimized.graph.node[0].output[0], "four") + + def test_fold_cast_like(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + two_float = CastLike(two, x) + four = Add(two_float, two_float) + z = Mul(x, four) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(optimized.graph.node[0].output[0], "four") + + def test_fold_shape(self): + model = onnx.parser.parse_model( + """ + + agraph (float[16, 16] x) => (float[16, 16] z) { + shape = Shape(x) + rank = Size(shape) + two_float = CastLike(rank, x) + four = Add(two_float, two_float) + z = Mul(x, four) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(optimized.graph.node[0].output[0], "four") + + def test_fold_shape_slice(self): + model = onnx.parser.parse_model( + """ + + agraph (float[M, N, 16, 16] x) => (float[M, N, 16, 16] z) { + shape = Shape (x) + two = Size(shape) + two_float = CastLike(two, x) + four = Add(two_float, two_float) + z = Mul(x, four) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(optimized.graph.node[0].output[0], "four") + + def test_fold_if_cond(self): + model = onnx.parser.parse_model( + """ + + agraph (float[16, 16] x) => (float[16, 16] z) { + shape = Shape(x) + rank = Size(shape) + zero = Constant () + zero_cast = CastLike (zero, rank) + is_scalar = Equal(zero_cast, rank) + z = If (is_scalar) < + then_branch = then_graph () => (then_z) { then_z = Add (x, x) }, + else_branch = else_graph () => (else_z) { else_z = Mul (x, x) } + > + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 1) + self.assertEqual(optimized.graph.node[0].output[0], "z") + self.assertEqual(optimized.graph.node[0].op_type, "Mul") + + def test_fold_inside_if_branch(self): + model = onnx.parser.parse_model( + """ + + agraph (float[16, 16] x, bool cond) => (float[16, 16] z) { + two = Constant () + z = If (cond) < + then_branch = then_graph () => (then_z) { + three = Constant () + temp = Add (two, three) + then_z = Mul (temp, x) + }, + else_branch = else_graph () => (else_z) { + four = Constant () + temp = Add (two, four) + else_z = Mul (temp, x) + } + > + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 1) + then_graph = onnx.helper.get_node_attr_value(optimized.graph.node[0], "then_branch") + self.assertEqual(len(then_graph.node), 2) + else_graph = onnx.helper.get_node_attr_value(optimized.graph.node[0], "else_branch") + self.assertEqual(len(else_graph.node), 2) + + def test_fold_if_propagate(self): + model = onnx.parser.parse_model( + """ + + agraph (float[16, 16] x) => (float[16, 16] z) { + shape = Shape(x) + rank = Size(shape) + zero = Constant () + two = Constant () + zero_cast = CastLike (zero, rank) + is_scalar = Equal(zero_cast, rank) + m = If (is_scalar) < + then_branch = then_graph () => (then_z) { then_z = Add (x, x) }, + else_branch = else_graph () => (else_z) { else_z = Mul (two, two) } + > + m_square = Mul (m, m) + z = Mul (x, m_square) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + print(onnx.printer.to_text(optimized)) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(optimized.graph.node[0].output[0], "m_square") + self.assertEqual(optimized.graph.node[0].op_type, "Constant") + + def test_fold_redundant_cast(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + x_cast = CastLike(x, two) + z = Mul(x_cast, two) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 2) + + def test_fold_redundant_cast2(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + z = CastLike(x, two) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 1) + self.assertEqual(optimized.graph.node[0].op_type, "Identity") + self.assertEqual(optimized.graph.node[0].output[0], "z") + self.assertEqual(optimized.graph.node[0].input[0], "x") + + @pytest.mark.skip(reason="Feature removed to catch errors early") + def test_fold_undefined_vars(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + four = Add(two, two) + y = Shape(t1) + w = CastLike(x, t2) + w2 = CastLike(t3, t4) + w3 = Size(t5) + z = Sum (four, y, w, w2, w3) + } + """ + ) + # No optimizations expected. Just make sure it doesn't crash. + optimized = optimizer.optimize(model, num_iterations=1, onnx_shape_inference=False) + self.assertEqual(len(optimized.graph.node), 6) + + def test_shape_inference(self): + model = onnx.parser.parse_model( + """ + + agraph (int64[64] x) => (int64[N] z) { + one = Constant () + cond = Equal(one, one) + temp = If (cond) < + then_branch = then_graph () => (then_z) { + shape1 = Constant () + then_z = Reshape(x, shape1) + }, + else_branch = else_graph () => (else_z) { + shape2 = Constant () + else_z = Reshape(x, shape2) + }> + shape = Shape(temp) # shape = [8, 8] or [64], but [8, 8] after constant propagation + rank = Size(shape) # rank = 2 or 1, but 2 after constant propagation + C = Add (rank, rank) + z = Mul(x, C) + } + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + print(onnx.printer.to_text(optimized)) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(optimized.graph.node[0].output[0], "C") + + def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split( + self, + ): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,512] x) => ( return_val) { + int64_128 = Constant () + splits = SplitToSequence (x, int64_128) + int64_0 = Constant () + split_0 = SequenceAt (splits, int64_0) + int64_1 = Constant () + split_1 = SequenceAt (splits, int64_1) + int64_2 = Constant () + split_2 = SequenceAt (splits, int64_2) + int64_3 = Constant () + split_3 = SequenceAt (splits, int64_3) + return_val = Concat (split_0, split_1, split_2, split_3) +} + """ + ) + + # TODO: There is an unrelated limitation that `symbolic_value` is not + # utilized when the value is only referenced by graph output. + # E.g., the following test model will not have this optimization + # applied. + """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,512] x) => ( split_0, split_1, split_2, split_3) { + int64_128 = Constant () + splits = SplitToSequence (x, int64_128) + int64_0 = Constant () + split_0 = SequenceAt (splits, int64_0) + int64_1 = Constant () + split_1 = SequenceAt (splits, int64_1) + int64_2 = Constant () + split_2 = SequenceAt (splits, int64_2) + int64_3 = Constant () + split_3 = SequenceAt (splits, int64_3) +} + """ + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(len(optimized.graph.node[-2].output), 4) + self.assertEqual(optimized.graph.node[-2].op_type, "Split") + + def test_static_split_to_sequence_with_list_split_and_squence_at_is_folded_as_split( + self, + ): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,512] x) => ( return_val) { + const = Constant () + splits = SplitToSequence (x, const) + int64_0 = Constant () + split_0 = SequenceAt (splits, int64_0) + int64_1 = Constant () + split_1 = SequenceAt (splits, int64_1) + int64_2 = Constant () + split_2 = SequenceAt (splits, int64_2) + return_val = Concat (split_0, split_1, split_2) +} + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 3) + self.assertEqual(len(optimized.graph.node[-2].output), 3) + self.assertEqual(optimized.graph.node[-2].op_type, "Split") + + def test_static_split_to_sequence_with_list_split_no_keepdims_and_squence_at_is_folded_as_split_with_squeeze( + self, + ): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,3] x) => ( return_val) { + const = Constant () + splits = SplitToSequence (x, const) + int64_0 = Constant () + split_0 = SequenceAt (splits, int64_0) + int64_1 = Constant () + split_1 = SequenceAt (splits, int64_1) + int64_2 = Constant () + split_2 = SequenceAt (splits, int64_2) + return_val = Concat (split_0, split_1, split_2) +} + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 7) + self.assertEqual(len(optimized.graph.node[1].output), 3) + self.assertEqual(optimized.graph.node[1].op_type, "Split") + self.assertEqual(len([n for n in optimized.graph.node if n.op_type == "Squeeze"]), 3) + + def test_static_split_to_sequence_with_uneven_split(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18, "pkg.onnxscript.torch_lib.common" : 1], + producer_name: "pytorch", + producer_version: "2.2.0" +> +main_graph (float[3,5] l_tensor_x_) => (float[3,5] return_val) + < _val_2, float[3,5] l_tensor_x_, float[2,5] getitem, float[1,5] getitem_1> +{ + _val_1 = Constant () + _val_2 = pkg.onnxscript.torch_lib.aten_split (l_tensor_x_, _val_1) + _val_3 = Constant () + getitem = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_3) + _val_5 = Constant () + getitem_1 = pkg.onnxscript.torch_lib.aten_getitem (_val_2, _val_5) + return_val = Concat (getitem_1, getitem) +} +< + domain: "pkg.onnxscript.torch_lib", + opset_import: ["" : 18] +> +aten_split (self, split_size) => (return_val) +{ + return_val = SplitToSequence (self, split_size) +} +< + domain: "pkg.onnxscript.torch_lib", + opset_import: ["" : 18] +> +aten_getitem (self, i) => (return_val) +{ + return_val = SequenceAt (self, i) +} +< + domain: "pkg.onnxscript.torch_lib.common", + opset_import: ["" : 18] +> +Rank (input) => (return_val) +{ + tmp = Shape (input) + return_val = Size (tmp) +} +< + domain: "pkg.onnxscript.torch_lib.common", + opset_import: ["" : 18] +> +IsScalar (input) => (return_val) +{ + tmp = Shape (input) + tmp_0 = Size (tmp) + tmp_1 = Constant () + return_val = Equal (tmp_0, tmp_1) +} + """ + ) + optimized = optimizer.optimize(model, onnx_shape_inference=False) + + print(onnx.printer.to_text(optimized)) + self.assertEqual(len(optimized.graph.node), 2) + self.assertEqual(len(optimized.graph.node[0].output), 2) + self.assertEqual(optimized.graph.node[0].op_type, "Split") + + def test_split_to_sequence_and_concat_from_sequence_with_new_axis_0( + self, + ): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,3] x) => ( return_val) { + const = Constant () + splits = SplitToSequence (x, const) + return_val = ConcatFromSequence (splits) +} + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 3) + self.assertEqual(optimized.graph.node[2].op_type, "Concat") + onnx.checker.check_model(optimized) + + def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( + self, + ): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,3] x) => ( return_val) { + const = Constant () + splits = SplitToSequence (x, const) + return_val = ConcatFromSequence (splits) +} + """ + ) + optimized = optimizer.optimize(model, num_iterations=1) + self.assertEqual(len(optimized.graph.node), 7) + self.assertEqual(optimized.graph.node[6].op_type, "Concat") + onnx.checker.check_model(optimized) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/optimizer/copy_propagation.py b/onnxscript/optimizer/copy_propagation.py new file mode 100644 index 000000000..6a7d4143d --- /dev/null +++ b/onnxscript/optimizer/copy_propagation.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import Any + +import onnx + +import onnxscript.optimizer.remove_unused +from onnxscript._legacy_ir import visitor +from onnxscript.utils.utils import is_onnx_op + + +class CopyPropagator(visitor.ProtoVisitor): + def __init__(self): + super().__init__() + + def visit_node(self, node: onnx.NodeProto) -> None: + super().visit_node(node) + for i in range(len(node.input)): + input = self.get_input(node, i) + if input is not None and input.is_copy(): + node.input[i] = input.symbolic_value # type: ignore[assignment] + + if is_onnx_op(node, "Identity"): + input = self.get_input(node, 0) + output = self.get_output(node, 0) + if input is not None and output is not None: + output.symbolic_value = input.name + + +# TODO: "Z = Identity(x)" where Z is a graph-output cannot be handled by this optimization, +# and requires some extension. (Eg., we could rename graph-output to be Z or we can try to +# rename x to be Z.) + + +def get_node_attr_value(node: onnx.NodeProto, attr_name: str, default: Any) -> Any: + matching = [x for x in node.attribute if x.name == attr_name] + if len(matching) > 1: + raise ValueError(f"Node has multiple attributes with name {attr_name}") + if len(matching) < 1: + return default + return onnx.helper.get_attribute_value(matching[0]) + + +class SymbolicEvaluator(CopyPropagator): + def __init__(self): + super().__init__() + + def visit_node(self, node: onnx.NodeProto) -> None: + super().visit_node(node) + + if is_onnx_op(node, "SequenceConstruct"): + output = self.get_output(node, 0) + if output is not None: + output.symbolic_value = list(node.input) + + if is_onnx_op(node, "ConcatFromSequence"): + input = self.get_input(node, 0) + new_axis = get_node_attr_value(node, "new_axis", 0) + if input is not None and isinstance(input.symbolic_value, list) and new_axis == 0: + node.op_type = "Concat" + node.input[:] = input.symbolic_value + for i in range(len(node.attribute)): + if node.attribute[i].name == "new_axis": + del node.attribute[i] + break + + # TODO: handle SequenceEmpty, SequenceAt, etc. + + +def do_copy_propagation(model: onnx.ModelProto, *, remove_unused: bool = True) -> None: + transformer = CopyPropagator() + transformer.visit_model(model) + if remove_unused: + onnxscript.optimizer.remove_unused_nodes(model) + + +def do_sequence_simplification(model: onnx.ModelProto, *, remove_unused: bool = True) -> None: + transformer = SymbolicEvaluator() + transformer.visit_model(model) + if remove_unused: + onnxscript.optimizer.remove_unused_nodes(model) diff --git a/onnxscript/optimizer/copy_propagation_test.py b/onnxscript/optimizer/copy_propagation_test.py new file mode 100644 index 000000000..6b88b027a --- /dev/null +++ b/onnxscript/optimizer/copy_propagation_test.py @@ -0,0 +1,49 @@ +import unittest + +import onnx + +from onnxscript import optimizer + + +class RemoveUnusedTest(unittest.TestCase): + def test_simple_identity_removal(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + t = Identity(x) + t2 = Identity(t) + z = Identity(t2) + } + """ + ) + optimizer.do_copy_propagation(model) + self.assertEqual(len(model.graph.node), 1) + + def test_subgraph_identity_removal(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, bool cond) => (float[N] z) { + t = Identity(x) + t2 = Identity(t) + t3 = If (cond) < + then_branch = then_graph() => (t4) { + t5 = Identity(t2) + t4 = Identity(t5) + }, + else_branch = else__graph() => (t6) { + t7 = Identity(t) + t6 = Identity(t7) + } + > + z = Identity(t3) + } + """ + ) + optimizer.do_copy_propagation(model) + self.assertEqual(len(model.graph.node), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/optimizer/evaluator.py b/onnxscript/optimizer/evaluator.py new file mode 100644 index 000000000..bf3c5a882 --- /dev/null +++ b/onnxscript/optimizer/evaluator.py @@ -0,0 +1,434 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# ------------------------------------------------------------------------- + +from __future__ import annotations + +import dataclasses +import logging +import math +from typing import Any, Callable, Protocol, Sequence, Union + +import numpy as np +import onnx +import onnx.reference.ops + +import onnxscript._legacy_ir as ir +from onnxscript.utils.utils import ( + get_node_attr_value, +) + +logger = logging.getLogger(__name__) + +# "Standard" evaluators are used to perform constant-folding. +# The API below works only for non-control-flow ops (ops without any graph-attributes). +# This currently used ONNX's reference implementation. But we could also +# use ORT's implementation if we want to. + + +class ReferenceEvaluator: + def get_evaluator(self, domain: str, op: str, version: int) -> callable | None: + try: + op_impl_class = onnx.reference.ops.load_op(domain, op, version) + return op_impl_class.eval # noqa: TRY300 + except Exception: + return None + + def evaluate(self, domain: str, op: str, version: int, *args, **kwargs) -> Any: + logger.debug("Evaluating %s::%s", domain, op) + evaluator = self.get_evaluator(domain, op, version) + if evaluator is None: + return None + return evaluator(*args, **kwargs) + + +reference_evaluator = ReferenceEvaluator() + +# The "partial evaluators" below are non-standard evaluators. They are used to perform +# partial evaluation and/or static program analysis (abstract interpretation). + + +class IRContext(Protocol): + """A class that represents the context for partial evaluation. + + This is a placeholder, subject to simplification when a proper IR is defined. + """ + + def get_input(self, node: onnx.NodeProto, index: int) -> ir.Value | None: ... + + def get_output(self, node: onnx.NodeProto, index: int) -> ir.Value | None: ... + + def input_const_value(self, node: onnx.NodeProto, index: int) -> ir.ConcreteValue: ... + + def input_shape( + self, node: onnx.NodeProto, index: int + ) -> onnx.TensorShapeProto | None: ... + + def input_type(self, node: onnx.NodeProto, index: int) -> onnx.TypeProto | None: ... + + def input_element_type(self, node: onnx.NodeProto, index: int) -> int | None: ... + + def lookup_version(self, domain: str) -> int: ... + + def convert_attributes(self, attributes: Sequence[onnx.AttributeProto]) -> dict: ... + + def new_constant(self, name: str, value: Any) -> Sequence[onnx.NodeProto] | None: ... + + +# A partial-evaluator function takes an IRContext and a node, and returns a list of +# replacement nodes or None (if no replacement is needed). We return None instead +# of [input node] so the caller is aware that the node is not replaced. If the node +# is replaced, the caller will recursively visit the replacement nodes to process them. + +PartialEvaluatorFunction = Union[ + Callable[[IRContext, onnx.NodeProto], Sequence[onnx.NodeProto]], None +] + + +@dataclasses.dataclass +class PartialEvaluator: + """A class that represents a partial-evaluator for a particular op. + + It is applicable for a specific version range (min_version, max_version) of the op. + The min_version and max_version can be None, indicating that there is no version + constraint in that direction. + """ + + min_version: int | None + max_version: int | None + function: PartialEvaluatorFunction + + def valid_for(self, version: int) -> bool: + """Returns True if this evaluator is applicable for the given version.""" + return (self.min_version is None or version >= self.min_version) and ( + self.max_version is None or version <= self.max_version + ) + + +class PartialEvaluatorRegistry: + """A class that maintains a registry of evaluators for ops.""" + + def __init__(self): + self.op_evaluators: dict[tuple[str, str], list[PartialEvaluator]] = {} + + def lookup_evaluators(self, domain: str, opname: str, version: int): + evaluator_list = self.op_evaluators.get((domain, opname), []) + return [ + evaluator.function for evaluator in evaluator_list if evaluator.valid_for(version) + ] + + def register(self, opname: str, domain: str = "", version=None): + if (domain, opname) not in self.op_evaluators: + evaluator_list = [] + self.op_evaluators[(domain, opname)] = evaluator_list + else: + evaluator_list = self.op_evaluators[(domain, opname)] + if version is None: + min_version = None + max_version = None + elif isinstance(version, int): + min_version = version + max_version = version + elif isinstance(version, tuple): + min_version, max_version = version + + def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction: + evaluator_list.append(PartialEvaluator(min_version, max_version, function)) + return function + + return decorator + + +registry: PartialEvaluatorRegistry = PartialEvaluatorRegistry() + +register = registry.register + + +def get_bool_value(val) -> bool | None: + if isinstance(val, bool): + return val + if isinstance(val, np.bool_): + return bool(val) + if isinstance(val, np.ndarray) and val.size == 1 and val.dtype == bool: + return val.item(0) + return None + + +def get_size_info(type: onnx.TypeProto) -> np.ndarray | None: + if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + if all(d.HasField("dim_value") for d in type.tensor_type.shape.dim): + size = 1 + for d in type.tensor_type.shape.dim: + size *= d.dim_value + return np.array(size, dtype=np.int64) + return None + + +def get_dim_info(type: onnx.TypeProto, dim: int) -> int | None: + if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + rank = len(type.tensor_type.shape.dim) + dim = dim if dim >= 0 else dim + rank + if dim < 0 or dim >= rank: + return None + if type.tensor_type.shape.dim[dim].HasField("dim_value"): + return type.tensor_type.shape.dim[dim].dim_value + return None + + +@register("Cast") +def cast(context: IRContext, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: + if context.input_shape(node, 0) is not None: + output_value = context.get_output(node, 0) + output_value.type = onnx.TypeProto() + output_value.type.CopyFrom(context.input_type(node, 0)) + output_value.type.tensor_type.elem_type = node.attribute[0].i + return None + + +@register("CastLike") +def cast_like(context: IRContext, node: onnx.NodeProto): + source_element_type = context.input_element_type(node, 0) + target_element_type = context.input_element_type(node, 1) + + if target_element_type is None: + return None + if source_element_type == target_element_type: + node.op_type = "Identity" + del node.input[1] + return [node] + + node.op_type = "Cast" + del node.input[1] + del node.attribute[:] + node.attribute.append(onnx.helper.make_attribute("to", target_element_type)) + return [node] + + +@register("Shape") +def shape(context: IRContext, node: onnx.NodeProto): + shape = context.input_shape(node, 0) + if shape is None: + return None + start = get_node_attr_value(node, "start", 0) + end = get_node_attr_value(node, "end", None) + shape_slice = shape.dim[start:end] + if all(d.HasField("dim_value") for d in shape_slice): + return np.array([d.dim_value for d in shape_slice], dtype=np.int64) + return None + + +@register("Size") +def size(context: IRContext, node: onnx.NodeProto): + type = context.input_type(node, 0) + size = get_size_info(type) if type is not None else None + return size + + +@register("If") +def if_op(context: IRContext, node: onnx.NodeProto): + cond = context.input_const_value(node, 0) + if cond is ir.NotConstant: + # Visitor will recursively visit subgraphs to constant-fold them. + return None + cond = get_bool_value(cond) + if cond is not None: + # cond is a constant-value: inline the branch + branch = "then_branch" if cond else "else_branch" + graph = onnx.helper.get_node_attr_value(node, branch) + + formal_outs = list(graph.output) + actual_outs = node.output + renamings = { + formal.name: actual + for formal, actual in zip(formal_outs, actual_outs) + if actual != "" + } + # TODO: Extend renaming to intermediate values. + + def rename(name): + return renamings.get(name, name) + + for sub_node in graph.node: + # TODO: handle renaming inside subgraphs in nodes + sub_node.input[:] = [rename(name) for name in sub_node.input] + sub_node.output[:] = [rename(name) for name in sub_node.output] + # Avoid name collision. + sub_node.name = f"{node.name}_{sub_node.name}" + + # TODO: we should handle initializers as well! + return list(graph.node) + return None + + +@register("Identity") +def identity(context: IRContext, node: onnx.NodeProto): + input = context.get_input(node, 0) + output = context.get_output(node, 0) + if input is not None and output is not None: + output.symbolic_value = input.name + + +@register("SequenceConstruct") +def sequence_construct( + context: IRContext, node: onnx.NodeProto +) -> Sequence[onnx.NodeProto] | None: + output = context.get_output(node, 0) + if output is not None: + output.symbolic_value = list(node.input) + return None + + +@register("ConcatFromSequence") +def concat_from_sequence( + context: IRContext, node: onnx.NodeProto +) -> Sequence[onnx.NodeProto] | None: + input = context.get_input(node, 0) + attrs = context.convert_attributes(node.attribute) + new_axis = attrs.get("new_axis", 0) + if input is not None and isinstance(input.symbolic_value, list): + if new_axis == 0: + node.op_type = "Concat" + node.input[:] = input.symbolic_value + logger.debug("ConcatFromSequence => Concat: %s", node.input) + for i in range(len(node.attribute)): + if node.attribute[i].name == "new_axis": + del node.attribute[i] + return [node] + return [node] + if new_axis == 1: + # Unsqueeze the inputs with concat axis if new_axis is 1 + axis = attrs.get("axis", None) + assert axis is not None + output = context.get_output(node, 0) + axis_node = context.new_constant(f"{output.name}_axis", np.array([axis]))[0] + unsqueeze_nodes = [] + for node_input in input.symbolic_value: + unsqueeze_node = onnx.helper.make_node( + "Unsqueeze", + [node_input, axis_node.output[0]], + [f"{node_input}_unsqueeze"], + ) + unsqueeze_nodes.append(unsqueeze_node) + unsqueeze_outputs = [n.output[0] for n in unsqueeze_nodes] + unsqueeze_nodes = [axis_node, *unsqueeze_nodes] + + # Send unsqueezed outputs to Concat + node.input[:] = unsqueeze_outputs + node.op_type = "Concat" + logger.debug( + "ConcatFromSequence => UnSqueeze %s + Concat %s", + unsqueeze_outputs, + node.input, + ) + for i in range(len(node.attribute)): + if node.attribute[i].name == "new_axis": + del node.attribute[i] + return [*unsqueeze_nodes, node] + return None + + +@register("SplitToSequence") +def split_to_sequence( + context: IRContext, node: onnx.NodeProto +) -> Sequence[onnx.NodeProto] | None: + """Rewriting pattern. + + From + + splits = onnx::SplitToSequence(input, split, axis=axis) + + to + + split_0, split_1, ..., split_n = onnx::Split(input, split, axis=axis) + splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) + + or + + split_0, split_1, ..., split_n = onnx::Split(input, axis=axis, num_outputs=n+1) + splits = onnx::SequenceConstruct(split_0, split_1, ..., split_n) + + where number of output tensors in `splits` is statically known. + onnx::SequenceConstruct will be further optimized away if possible, by its own designated evaluator. + This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly. + """ + input = context.get_input(node, 0) + split = context.get_input(node, 1) + attrs = context.convert_attributes(node.attribute) + output = context.get_output(node, 0) + + if input is None or split is None or output is None: + return None + + axis = attrs.get("axis", 0) + if input.type is None: + return None + split_dimension_size = get_dim_info(input.type, axis) + if split_dimension_size is None: + return None + + split_value = split.value + if split_value is None or split_value is ir.NotConstant: + return None + assert isinstance(split_value, np.ndarray) + + if split_value.ndim == 0: + # split into chunks all of size 'split' if possible. + num_outputs = math.ceil(split_dimension_size / split_value.item()) + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_node = onnx.helper.make_node( + "Split", + [input.name], + split_outputs, + axis=axis, + num_outputs=num_outputs, + ) + else: + # split into 'size(split)' chunks + num_outputs = split_value.size + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_node = onnx.helper.make_node( + "Split", + [input.name, split.name], + split_outputs, + axis=axis, + ) + + keepdims = attrs.get("keepdims", 1) + squeeze_nodes = [] + if keepdims == 0: + # squeeze the split dimension if keepdims is 0 + axis_node = context.new_constant(f"{output.name}_axis", np.array([axis]))[0] + for i in range(num_outputs): + squeeze_node = onnx.helper.make_node( + "Squeeze", + [split_outputs[i], axis_node.output[0]], + [f"{split_outputs[i]}_squeeze"], + ) + squeeze_nodes.append(squeeze_node) + split_outputs = [n.output[0] for n in squeeze_nodes] + squeeze_nodes = [axis_node, *squeeze_nodes] + + node.op_type = "SequenceConstruct" + node.input[:] = split_outputs + del node.attribute[:] + logger.debug( + "SplitToSequence => Split %s + SequenceConstruct %s", + split_node.input, + node.input, + ) + return [split_node, *squeeze_nodes, node] + + +@register("SequenceAt") +def sequence_at(context: IRContext, node: onnx.NodeProto) -> Sequence[onnx.NodeProto] | None: + input = context.get_input(node, 0) + position = context.get_input(node, 1) + output = context.get_output(node, 0) + if input is not None and position is not None: + input_vals = input.symbolic_value + position_val = position.value + if isinstance(input_vals, list) and position_val is not None: + output.symbolic_value = input_vals[position_val] + logger.debug("SquenceAt %s => %s", input, output.symbolic_value) + return None diff --git a/onnxscript/optimizer/fold_constants_v0.py b/onnxscript/optimizer/fold_constants_v0.py new file mode 100644 index 000000000..556f824b8 --- /dev/null +++ b/onnxscript/optimizer/fold_constants_v0.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +from typing import Any, Sequence + +import numpy as np +import onnx +import onnx.reference.ops + +# Excluded ops include +# * Random ops, which are not deterministic +# * Control flow ops + +excluded_ops = frozenset( + { + "RandomUniform", + "RandomNormal", + "RandomUniformLike", + "RandomNormalLike", + "Multinomial", + "If", + "Loop", + "Scan", + "SequenceMap", + } +) + +onnx_domain = frozenset({"", "onnx.ai"}) + + +def get_evaluator(domain: str, op: str, version: int) -> callable | None: + if op in excluded_ops and domain in onnx_domain: + return None + try: + op_impl_class = onnx.reference.ops.load_op(domain, op, version) + except Exception: + return None + else: + return op_impl_class.eval + + +def convert_attributes(attributes: Sequence[onnx.AttributeProto]) -> dict[str, Any]: + return {attr.name: onnx.helper.get_attribute_value(attr) for attr in attributes} + + +def is_control_flow_op(node: onnx.NodeProto) -> bool: + return any(attr.HasField("g") or len(attr.graphs) > 0 for attr in node.attribute) + + +def is_constant_op(node: onnx.NodeProto) -> bool: + return node.op_type == "Constant" and node.domain == "" + + +def get_bool_value(val) -> bool | None: + if isinstance(val, bool): + return val + if isinstance(val, np.bool_): + return bool(val) + if isinstance(val, np.ndarray) and val.size == 1 and val.dtype == bool: + return val.item(0) + return None + + +def get_shape_info(type: onnx.TypeProto) -> tuple[int, ...] | None: + if type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + if all(d.HasField("dim_value") for d in type.tensor_type.shape.dim): + return np.array([d.dim_value for d in type.tensor_type.shape.dim], dtype=np.int64) + return None + + +def get_element_type(type: onnx.TypeProto) -> int | None: + if type.HasField("tensor_type"): + return type.tensor_type.elem_type + return None + + +class State: + def __init__(self, default_value) -> None: + self.scopes = [{}] + self.default_value = default_value + + def lookup(self, name: str) -> Any: + for scope in reversed(self.scopes): + if name in scope: + return scope[name] + return self.default_value + + def bind(self, name: str, value: Any) -> None: + self.scopes[-1][name] = value + + def enter_scope(self) -> None: + self.scopes.append({}) + + def exit_scope(self) -> None: + self.scopes.pop() + + +def is_onnx_op(node: onnx.NodeProto, op: str) -> bool: + return (node.op_type == op) and (node.domain in onnx_domain) + + +def matches(node: onnx.NodeProto, op: str, *arg_predicates) -> bool: + if node.op_type != op or node.domain != "": + return False + if len(node.input) < len(arg_predicates): + return False + return all(pred(input) for pred, input in zip(arg_predicates, node.input)) + + +def get_initializer_type(initializer: onnx.TensorProto) -> onnx.TypeProto: + type = onnx.TypeProto() + type.tensor_type.elem_type = initializer.data_type + dims = type.tensor_type.shape.dim + for dim in initializer.dims: + dims.add().dim_value = dim + return type + + +def fold_constants(model: onnx.ModelProto): + not_constant = object() + var_info = State(default_value=not_constant) + type_info = State(default_value=None) + counts = {} + sizes = {} + + def add_count(op: str, size: int = 1): + counts[op] = counts.get(op, 0) + 1 + sizes[op] = sizes.get(op, 0) + size + + def new_constant(name, value): + var_info.bind(name, value) + tensor = onnx.numpy_helper.from_array(value, name=name) + node = onnx.helper.make_node("Constant", inputs=[], outputs=[name], value=tensor) + return node + + def lookup_version(domain: str, op: str) -> int: + for opset in model.opset_import: + if opset.domain == domain: + return opset.version + return 1 # TODO + + def transform_node(node: onnx.NodeProto): + if is_onnx_op(node, "Transpose"): + return [node] + if is_onnx_op(node, "CastLike"): + value = var_info.lookup(node.input[0]) if len(node.input) > 0 else not_constant + if value is not_constant: + return [node] + type = type_info.lookup(node.input[1]) if len(node.input) > 1 else None + element_type = get_element_type(type) if type is not None else None + if element_type is None: + return [node] + evaluator = get_evaluator("", "Cast", lookup_version("", "Cast")) + if evaluator is None: + return [node] + cast_value = evaluator(value, to=element_type) + add_count("CastLike", cast_value.size) + return [new_constant(node.output[0], cast_value)] + if is_onnx_op(node, "Shape"): + type = type_info.lookup(node.input[0]) if len(node.input) > 0 else None + shape = get_shape_info(type) if type is not None else None + if shape is not None: + add_count("Shape", shape.size) + return [new_constant(node.output[0], shape)] + + if is_onnx_op(node, "If"): + cond = var_info.lookup(node.input[0]) if len(node.input) > 0 else None + cond = get_bool_value(cond) + if cond is not None: + # cond is a constant-value: inline the branch + branch = "then_branch" if cond else "else_branch" + graph = onnx.helper.get_node_attr_value(node, branch) + formal_outs = list(graph.output) + actual_outs = node.output + renamings = { + formal.name: actual + for formal, actual in zip(formal_outs, actual_outs) + if actual != "" + } + + def rename(name): + return renamings.get(name, name) + + for node in graph.node: + node.input[:] = [rename(name) for name in node.input] + node.output[:] = [rename(name) for name in node.output] + transform_graph(graph) + add_count("If") + return list(graph.node) + + if is_control_flow_op(node): + for attr in node.attribute: + if attr.HasField("g"): + transform_graph(attr.g) + elif len(attr.graphs) > 0: + for graph in attr.graphs: + transform_graph(graph) + return [node] + + domain = node.domain + op = node.op_type + version = lookup_version(domain, op) + inputs = [] + for x in node.input: + if x == "": + inputs.append(None) + else: + v = var_info.lookup(x) + if v is not_constant: + return [node] + inputs.append(v) + evaluator = get_evaluator(domain, op, version) + if evaluator is None: + return [node] + attrs = convert_attributes(node.attribute) + outputs = evaluator(*inputs, **attrs) + if len(node.output) == 1 and not isinstance(outputs, tuple): + replacement = new_constant(node.output[0], outputs) + if is_constant_op(node): + return [node] + add_count(op, outputs.size) + return [replacement] + else: + add_count(op) + return [new_constant(output, outputs[i]) for i, output in enumerate(node.output)] + + def transform_graph(graph: onnx.GraphProto): + var_info.enter_scope() + type_info.enter_scope() + for initializer in graph.initializer: + array = onnx.numpy_helper.to_array(initializer) + var_info.bind(initializer.name, array) + type_info.bind(initializer.name, get_initializer_type(initializer)) + for input in graph.input: + var_info.bind(input.name, not_constant) + type_info.bind(input.name, input.type) + for valueinfo in graph.value_info: + type_info.bind(valueinfo.name, valueinfo.type) + + replacement = [transform_node(node) for node in graph.node] + flattened = [node for nodes in replacement for node in nodes] + del graph.node[:] + graph.node.extend(flattened) + var_info.exit_scope() + type_info.exit_scope() + + transform_graph(model.graph) + for op in counts: + print(f"Constant-folded '{op}' {counts[op]} times, with {sizes[op]} size.") diff --git a/onnxscript/optimizer/function_folding_test.py b/onnxscript/optimizer/function_folding_test.py new file mode 100644 index 000000000..3074dc673 --- /dev/null +++ b/onnxscript/optimizer/function_folding_test.py @@ -0,0 +1,162 @@ +import unittest + +import onnx + +from onnxscript import optimizer + + +class FunctionFoldingTest(unittest.TestCase): + def test_identity(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x1, bool cond1) => (float[N] z1) { + z1 = local.fun1(x1, cond1) + } + + fun1 (x, cond) => (z) { + t = Identity(x) + t2 = Identity(t) + t3 = If (cond) < + then_branch = then_graph() => (t4) { + t5 = Identity(t2) + t4 = Identity(t5) + }, + else_branch = else__graph() => (t6) { + t7 = Identity(t) + t6 = Identity(t7) + } + > + t4 = Add(t3, t3) + z = Identity(t4) + } + """ + ) + optimized = optimizer.optimize( + model, + onnx_shape_inference=False, + num_iterations=1, + ) + self.assertEqual(len(optimized.functions), 0) + self.assertEqual(len(optimized.graph.node), 2) + + def test_sequence_concat(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x1) => (float[M] z1) { + z1 = local.fun1(x1) + } + + fun1 (x) => (z) { + t0 = Add (x, x) + t2 = Add (x, x) + t3 = SequenceConstruct (x, t0, t2, x) + z = ConcatFromSequence (t3) + } + """ + ) + optimized = optimizer.optimize( + model, + onnx_shape_inference=False, + num_iterations=1, + ) + function_node = optimized.functions[0].node + self.assertEqual(len(function_node), 3) + self.assertEqual(function_node[2].op_type, "Concat") + + def test_single_user_function_is_modified_inplace_after_folding(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x1) => (float[M] z1) { + z1 = local.fun1(x1) + } + + fun1 (x) => (z) { + t0 = Add (x, x) + t2 = Add (x, x) + t3 = SequenceConstruct (x, t0, t2, x) + z = ConcatFromSequence (t3) + } + """ + ) + optimized = optimizer.optimize( + model, + onnx_shape_inference=False, + num_iterations=1, + ) + self.assertEqual(optimized.functions[0].name, "fun1") + + def test_multi_users_function_is_not_modified_inplace_after_folding(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x1) => (float[M] z1, float[M] z2) { + z1 = local.fun1(x1) + z2 = local.fun1(x1) + } + + fun1 (x) => (z) { + t0 = Add (x, x) + t2 = Add (x, x) + t3 = SequenceConstruct (x, t0, t2, x) + z = ConcatFromSequence (t3) + } + """ + ) + optimized = optimizer.optimize( + model, + onnx_shape_inference=False, + num_iterations=1, + ) + self.assertEqual(len(optimized.functions), 2) + self.assertNotEqual(optimized.functions[0].name, "fun1") + self.assertNotEqual(optimized.functions[1].name, "fun1") + + def test_fold_nested_if_function_succeeds(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 9, + opset_import: ["this" : 1, "" : 21] +> +func (float[1,512] x, float[1,512] y) => ( out) { + out = this.foldable_func (x, y) +} +< + domain: "this", + opset_import: ["" : 18] +> +foldable_func (x, y) => (z_6) +{ + cond = Constant () + z_6 = If (cond) ( z_2) { + cond_0 = Not (cond) + z_2 = If (cond_0) ( z) { + z = Add (x, x) + }, else_branch: graph = elseGraph_5 () => ( z_1) { + z_1 = Identity (x) + }> + }, else_branch: graph = elseGraph_4 () => ( z_5) { + z_5 = If (cond) ( z_3) { + z_3 = Add (y, y) + }, else_branch: graph = elseGraph_10 () => ( z_4) { + z_4 = Add (x, y) + }> + }> +} + """ + ) + optimized = optimizer.optimize( + model, + onnx_shape_inference=False, + ) + + self.assertEqual(len(optimized.functions), 0) + self.assertEqual(len(optimized.graph.node), 1) + self.assertNotIn("If", {n.op_type for n in optimized.graph.node}) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/optimizer/remove_unused.py b/onnxscript/optimizer/remove_unused.py new file mode 100644 index 000000000..57357f3db --- /dev/null +++ b/onnxscript/optimizer/remove_unused.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import logging +from typing import Sequence + +import onnx +from google.protobuf.internal.containers import ( # type: ignore + RepeatedCompositeFieldContainer, +) + +logger = logging.getLogger(__name__) + + +def remove_unused_optional_outputs( + n: onnx.NodeProto, used: set, opset_import: Sequence[onnx.OperatorSetIdProto] +) -> None: + try: + if n.domain not in {"", "onnx.ai"}: + return + onnx_opset_version = 1 + for opset in opset_import: + if opset.domain == n.domain: + onnx_opset_version = opset.version + op_schema = onnx.defs.get_schema(n.op_type, onnx_opset_version, domain=n.domain) + except Exception: + return + # TODO: If current node is a BatchNormalization node, + # based on training_mode atrribute, number of optional outputs and + # how they are handled varies, handle both training_modes + if n.op_type == "BatchNormalization": + return + optional_info = [] + for o in op_schema.outputs: + # Current ops do not have optional outputs if they have variable number of outputs + if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic: + return + optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional) + # If no optional outputs in spec, skip delete operations + if len([o == 1 for o in optional_info]) == 0: + return + + for i, out in enumerate(n.output): + if out not in used and optional_info[i] is True: + n.output[i] = "" + # Only delete trailing unused optional outputs + for o in n.output[::-1]: # type: ignore[assignment] + if o == "": + n.output.pop() + else: + return + + +def compute_used_in_node(n: onnx.NodeProto) -> set[str]: + used = {n for n in n.input if n != ""} + for attr in n.attribute: + if attr.HasField("g"): + used |= compute_used_in_graph(attr.g) + elif len(attr.graphs) > 0: + for graph in attr.graphs: + used |= compute_used_in_graph(graph) + return used + + +def compute_used_in_graph(g: onnx.GraphProto) -> set[str]: + used = set() + for n in g.node: + used |= compute_used_in_node(n) + return used + + +def process_nodes( + nodes: RepeatedCompositeFieldContainer[onnx.NodeProto], + used: set, + opset_import: Sequence[onnx.OperatorSetIdProto], +) -> int: + count = 0 + i = len(nodes) - 1 + while i >= 0: + node = nodes[i] + remove_unused_optional_outputs(node, used, opset_import) + used_outputs = [x for x in node.output if x in used] + if not used_outputs: + del nodes[i] + count += 1 + i -= 1 + continue + for attr in node.attribute: + if attr.HasField("g"): + process_graph(attr.g, opset_import) + elif len(attr.graphs) > 0: + for graph in attr.graphs: + process_graph(graph, opset_import) + used |= compute_used_in_node(node) + i -= 1 + return count + + +def process_graph( + graph: onnx.GraphProto, opset_import: Sequence[onnx.OperatorSetIdProto] +) -> int: + used = {output.name for output in graph.output} + + count = process_nodes(graph.node, used, opset_import) + + for i in range(len(graph.initializer) - 1, -1, -1): + if graph.initializer[i].name not in used: + del graph.initializer[i] + count += 1 + + return count + + +def process_function( + function: onnx.FunctionProto, opset_import: Sequence[onnx.OperatorSetIdProto] +) -> int: + used = set(function.output) + + return process_nodes(function.node, used, opset_import) + + +def remove_unused_nodes(model: onnx.ModelProto) -> None: + """Removes unused nodes from the model.""" + count = process_graph(model.graph, model.opset_import) + for function in model.functions: + count += process_function(function, model.opset_import) + + logger.info("Removed %s unused nodes", count) diff --git a/onnxscript/optimizer/remove_unused_function.py b/onnxscript/optimizer/remove_unused_function.py new file mode 100644 index 000000000..573dfaa8b --- /dev/null +++ b/onnxscript/optimizer/remove_unused_function.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import logging + +import onnx +from google.protobuf.internal.containers import ( # type: ignore + RepeatedCompositeFieldContainer, +) + +logger = logging.getLogger(__name__) + + +class UnusedFunctionRemover: + def compute_used_in_node(self, n: onnx.NodeProto) -> set[tuple[str, str]]: + used = {(n.domain, n.op_type)} + for attr in n.attribute: + if attr.HasField("g"): + used |= self.process_graph(attr.g) + elif len(attr.graphs) > 0: + for graph in attr.graphs: + used |= self.process_graph(graph) + if (n.domain, n.op_type) in self._functions: + function = self._functions[(n.domain, n.op_type)] + used |= self.process_function(function) + return used + + def process_nodes( + self, nodes: RepeatedCompositeFieldContainer[onnx.NodeProto] + ) -> set[tuple[str, str]]: + used = set() + for node in nodes: + used |= self.compute_used_in_node(node) + return used + + def process_graph(self, graph: onnx.GraphProto) -> set[tuple[str, str]]: + return self.process_nodes(graph.node) + + def process_function(self, function: onnx.FunctionProto) -> set[tuple[str, str]]: + return self.process_nodes(function.node) + + def process_model(self, model: onnx.ModelProto) -> None: + self._functions = {(f.domain, f.name): f for f in model.functions} + used = self.process_graph(model.graph) + count = 0 + logger.debug("Used function protos: %s", used) + for i in range(len(model.functions) - 1, -1, -1): + if (model.functions[i].domain, model.functions[i].name) not in used: + del model.functions[i] + count += 1 + logger.info("Removed %s unused function protos", count) + logger.debug("Function protos left: %s", [f.name for f in model.functions]) + + +def remove_unused_functions(model: onnx.ModelProto) -> None: + """Removes unused function protos from the model.""" + UnusedFunctionRemover().process_model(model) diff --git a/onnxscript/optimizer/remove_unused_test.py b/onnxscript/optimizer/remove_unused_test.py new file mode 100644 index 000000000..350808def --- /dev/null +++ b/onnxscript/optimizer/remove_unused_test.py @@ -0,0 +1,173 @@ +import unittest + +import onnx + +from onnxscript import optimizer + + +class RemoveUnusedTest(unittest.TestCase): + def test_remove_unused_nodes(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(x, x) + } + """ + ) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "Mul") + + def test_remove_unused_initializers(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[N] z) + { + four = Add(two, two) + z = Mul(x, x) + } + """ + ) + self.assertEqual(len(model.graph.initializer), 1) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "Mul") + self.assertEqual(len(model.graph.initializer), 0) + + def test_partially_used_nodes(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[M] z) { + w1, w2, w3 = Split (x) + z = Mul(w3, w3) + } + """ + ) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 2) + self.assertEqual(model.graph.node[0].op_type, "Split") + + def test_remove_unused_optional_outputs_maxpool(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] z) { + z, indices = MaxPool (x) + } + """ + ) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "MaxPool") + self.assertEqual(len(model.graph.node[0].output), 2) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "MaxPool") + self.assertEqual(len(model.graph.node[0].output), 1) + + def test_remove_unused_optional_outputs_dropout_in_function(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] z) + { + z = pkg.custom.afunction (x) + } + + afunction (x) => (z) + { + z, indices = MaxPool (x) + } + """ + ) + self.assertEqual(len(model.functions), 1) + self.assertEqual(len(model.functions[0].node), 1) + self.assertEqual(model.functions[0].node[0].op_type, "MaxPool") + self.assertEqual(len(model.functions[0].node[0].output), 2) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.functions), 1) + self.assertEqual(len(model.functions[0].node), 1) + self.assertEqual(model.functions[0].node[0].op_type, "MaxPool") + self.assertEqual(len(model.functions[0].node[0].output), 1) + + def test_remove_used_optional_outputs_maxpool(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] y, float[1, 1, 5, 5] z) { + y, z = MaxPool (x) + } + """ + ) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "MaxPool") + self.assertEqual(len(model.graph.node[0].output), 2) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 1) + self.assertEqual(model.graph.node[0].op_type, "MaxPool") + self.assertEqual(len(model.graph.node[0].output), 2) + + def test_remove_multiple_unused_optional_outputs_layernorm(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z) { + scale = Constant () + B = Constant () + z, mean, InvStdDev = LayerNormalization(x, scale, B) + } + """ + ) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") + self.assertEqual(len(model.graph.node[2].output), 3) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") + self.assertEqual(len(model.graph.node[2].output), 1) + + def test_remove_trailing_unused_optional_outputs_layernorm(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z, float[1, 3, 5, 5] mean) { + scale = Constant () + B = Constant () + z, mean, InvStdDev = LayerNormalization(x, scale, B) + } + """ + ) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") + self.assertEqual(len(model.graph.node[2].output), 3) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") + self.assertEqual(len(model.graph.node[2].output), 2) + + def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z, float[1, 3, 5, 5] InvStdDev) { + scale = Constant () + B = Constant () + z, mean, InvStdDev = LayerNormalization(x, scale, B) + } + """ + ) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") + self.assertEqual(len(model.graph.node[2].output), 3) + optimizer.remove_unused_nodes(model) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") + self.assertEqual(len(model.graph.node[2].output), 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/optimizer/simple_function_folding.py b/onnxscript/optimizer/simple_function_folding.py new file mode 100644 index 000000000..b15a9c8a0 --- /dev/null +++ b/onnxscript/optimizer/simple_function_folding.py @@ -0,0 +1,239 @@ +"""Inlines the function if it only contains very few number of nodes.""" + +from __future__ import annotations + +import logging +from typing import Sequence + +import onnx + +import onnxscript._legacy_ir as ir +from onnxscript._legacy_ir import visitor +from onnxscript.optimizer import remove_unused + +logger = logging.getLogger(__name__) + + +class FunctionInliner(visitor.FunctionCallsiteProtoTransformer): + counts: dict[ir.FunctionId, int] + + def __init__(self, node_count: int) -> None: + super().__init__() + self._node_count = node_count + + def _gather_function_metadata(self, model: onnx.ModelProto) -> None: + super()._gather_function_metadata(model) + self._function_renamer._postfix = "inlined" + + def visit_model(self, model: onnx.ModelProto) -> None: + self.counts = {} + + super().visit_model(model) + + def should_inline_function(self, function: onnx.FunctionProto) -> bool: + return len(function.node) <= self._node_count + + def process_function_node( + self, node: onnx.NodeProto + ) -> tuple[list[onnx.NodeProto] | None, onnx.FunctionProto | None]: + # Recursively process sub nodes first. + function_id = (node.domain, node.op_type, getattr(node, "overload", "")) + function = self._functions[function_id] + replacement, new_function = super().process_function_node(node) + function = new_function if new_function else function + + if self.should_inline_function(function): + self.enter_function_scope(function) + sub_scope = self.exit_function_scope(function) + new_nodes = [] + + formal_outs = function.output + actual_outs = node.output + formal_ins = function.input + actual_ins = node.input + # TODO: Potential collision when actual is "". + # formal.name may collide with existing value names. + input_renamings = dict(zip(formal_ins, actual_ins)) + if len(actual_ins) < len(formal_ins): + input_renamings.update(dict.fromkeys(formal_ins[len(actual_ins) :], "")) + output_renamings = { + formal: actual + for formal, actual in zip(formal_outs, actual_outs) + if actual != "" + } + renamings = {**input_renamings, **output_renamings} + + logger.debug("renamings function %s: %s", function.name, renamings) + + def rename(name: str) -> str: + if name == "": + return name + new_name = renamings.get(name) + if new_name is None: + new_name = f"{node.name}_{name}" + logger.debug("renaming %s to %s", name, new_name) + if (ir_value := sub_scope.lookup(name)) is not None: + if ir_value.tensor_shape_proto() is not None and ir_value.type is not None: + ir_value.name = new_name + self.bind(new_name, ir_value) + return new_name + + ref_attrs = {attr.name: attr for attr in node.attribute} + # logger.debug("inlining simple function %s. Ref attrs: %s", function.name, ref_attrs) + + def fill_in_ref(attr: onnx.AttributeProto) -> onnx.AttributeProto: + if attr.ref_attr_name: + new_attr = onnx.AttributeProto() + new_attr.CopyFrom(ref_attrs[attr.ref_attr_name]) + new_attr.name = attr.name + return new_attr + return attr + + def update_graph_attribute( + attr: onnx.AttributeProto, + ) -> onnx.AttributeProto: + if attr.g: + new_attr = onnx.AttributeProto() + new_attr.CopyFrom(attr) + for node in new_attr.g.node: + node.input[:] = [rename(name) for name in node.input] + node.output[:] = [rename(name) for name in node.output] + new_attrs = [] + for attr in node.attribute: + new_attrs.append(update_attribute(attr)) + del node.attribute[:] + node.attribute.extend(new_attrs) + for vi_proto in new_attr.g.input: + vi_proto.name = rename(vi_proto.name) + for vi_proto in new_attr.g.output: + vi_proto.name = rename(vi_proto.name) + return new_attr + return attr + + def update_attribute(attr: onnx.AttributeProto) -> onnx.AttributeProto: + new_attr = fill_in_ref(attr) + new_attr = update_graph_attribute(new_attr) + return new_attr + + for sub_node in function.node: + # logger.debug("inlining simple function. old node: %s", sub_node) + new_node = onnx.NodeProto() + new_node.CopyFrom(sub_node) + new_node.input[:] = [rename(name) for name in new_node.input] + new_node.output[:] = [rename(name) for name in new_node.output] + del new_node.attribute[:] + for attr in sub_node.attribute: + new_node.attribute.append(update_attribute(attr)) + # Avoid name collision. + new_node.name = f"{node.name}_{new_node.name}" + # logger.debug("inlining simple function. new node: %s", new_node) + new_nodes.append(new_node) + + self.counts.setdefault(function_id, 0) + self.counts[function_id] += 1 + + return new_nodes, None + + return replacement, new_function + + +class SelectedFunctionInliner(FunctionInliner): + def __init__(self, functions_to_inline: Sequence[onnx.FunctionProto]): + super().__init__(node_count=0) # node_count unused. + self._functions_to_inline = functions_to_inline + + def should_inline_function(self, function: onnx.FunctionProto) -> bool: + return function in self._functions_to_inline + + +class FindFunctionWithUnusedOutputsVisitor(visitor.ProtoVisitor): + def __init__(self) -> None: + super().__init__() + self._function_with_unused_outputs: dict[ir.FunctionId, onnx.FunctionProto] = {} + self._functions: dict[ir.FunctionId, onnx.FunctionProto] = {} + self._used_nodes: list[onnx.NodeProto] = [] + + def _find_nodes_with_any_unused_output( + self, nodes: Sequence[onnx.NodeProto], used_values: set[str] + ) -> list[onnx.NodeProto]: + target_nodes = [] + for i in range(len(nodes) - 1, -1, -1): + node = nodes[i] + if any(x not in used_values for x in node.output): + # Any unused output means the node is a target node. + target_nodes.append(node) + if all(x not in used_values for x in node.output): + # All unused output means the node is not used at all. + # Hence do not update used_values with the node's inputs. + continue + used_values |= remove_unused.compute_used_in_node(node) + return target_nodes + + def visit_model(self, model: onnx.ModelProto) -> None: + used_values = {output.name for output in model.graph.output} + target_nodes = self._find_nodes_with_any_unused_output(model.graph.node, used_values) + + for function in model.functions: + self._functions[ + (function.domain, function.name, getattr(function, "overload", "")) + ] = function + used_values = set(function.output) + target_nodes.extend( + self._find_nodes_with_any_unused_output(function.node, used_values) + ) + + for node in target_nodes: + if visitor.is_local_function_node(node, self._functions): + function_id = (node.domain, node.op_type, getattr(node, "overload", "")) + self._function_with_unused_outputs[function_id] = self._functions[function_id] + + logger.info( + "Found %s function nodes that have unused outputs.", + len(self._function_with_unused_outputs), + ) + for key in self._function_with_unused_outputs: + logger.info("Function node with unused outputs: %s::%s", key[0], key[1]) + + @property + def function_with_unused_outputs(self) -> dict[ir.FunctionId, onnx.FunctionProto]: + return self._function_with_unused_outputs + + +def inline_simple_functions(model: onnx.ModelProto, node_count: int = 2) -> bool: + inliner = FunctionInliner(node_count) + inliner.visit_model(model) + logger.info( + "inlined %s simple functions based on node count threshold %s.", + len(inliner.counts), + node_count, + ) + for op in inliner.counts: + logger.info( + "Inlined simple function '%s::%s' %s times.", + op[0], + op[1], + inliner.counts[op], + ) + return inliner.modified + + +def inline_functions_with_unused_outputs(model: onnx.ModelProto) -> bool: + # TODO: Use onnx.inliner after 1.16. + # This visitor based inliner is used to ensure the function inner value info remains consistent. + visitor = FindFunctionWithUnusedOutputsVisitor() + visitor.visit_model(model) + # FIXME: Fix the type of the argument passed into SelectedFunctionInliner + inliner = SelectedFunctionInliner(visitor.function_with_unused_outputs.values()) # type: ignore[arg-type] + inliner.visit_model(model) + logger.info( + "inlined %s function nodes that have unused outputs.", + len(inliner.counts), + ) + for op in inliner.counts: + logger.info( + "Inlined function '%s::%s' %s times.", + op[0], + op[1], + inliner.counts[op], + ) + return inliner.modified diff --git a/onnxscript/optimizer/simple_function_folding_test.py b/onnxscript/optimizer/simple_function_folding_test.py new file mode 100644 index 000000000..df7feaec2 --- /dev/null +++ b/onnxscript/optimizer/simple_function_folding_test.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import unittest + +import onnx + +from onnxscript.optimizer import remove_unused_function, simple_function_folding + + +class SingleNodeFunctionFoldingTest(unittest.TestCase): + def test_fold_single_node_function(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["this" : 1, "" : 18] +> +func ( x, y) => ( return_val) { + tmp = this.foldable (x) + return_val = Add (tmp, y) +} +< + domain: "this", + opset_import: ["" : 18] +> +foldable (x) => (return_val) +{ + return_val = Identity (x) +} + """ + ) + + simple_function_folding.inline_simple_functions(model) + remove_unused_function.remove_unused_functions(model) + + self.assertEqual(len(model.functions), 0) + + def test_fold_single_node_function_ref_attr(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["this" : 1, "" : 18] +> +func ( x, y, z) => ( return_val) { + tmp = this.foldable (x, y) + return_val = Add (tmp, z) +} +< + domain: "this", + opset_import: ["" : 18] +> +foldable (x, y) => (return_val) +{ + return_val = Concat (x, y) +} + """ + ) + + simple_function_folding.inline_simple_functions(model) + remove_unused_function.remove_unused_functions(model) + + self.assertEqual(len(model.functions), 0) + self.assertFalse(model.graph.node[0].attribute[0].ref_attr_name) + self.assertEqual(model.graph.node[0].attribute[0].name, "axis") + + def test_fold_single_node_function_nested(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["this" : 1, "" : 18] +> +func ( x, y, z) => ( return_val) { + tmp = this.non_foldable (x, y) + return_val = Add (tmp, z) +} +< + domain: "this", + opset_import: ["" : 18] +> +foldable (x, y) => (return_val) +{ + return_val = Concat (x, y) +} +< + domain: "this", + opset_import: ["this" : 1,"" : 18] +> +non_foldable (x, y) => (return_val) +{ + tmp = this.foldable (x, y) + tmp_0 = this.foldable (x, y) + return_val = Add (tmp, tmp_0) +} + """ + ) + + simple_function_folding.inline_simple_functions(model) + remove_unused_function.remove_unused_functions(model) + + self.assertEqual(len(model.functions), 1) + self.assertEqual(model.functions[0].node[0].op_type, "Concat") + self.assertEqual(model.functions[0].node[1].op_type, "Concat") + + def test_fold_single_node_function_create_new_nodes_with_correct_attributes(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 9, + opset_import: ["this" : 1, "" : 21] +> +func (float[1,512] x) => ( a, b, c) { + a = this.prim_cast (x) + b = this.prim_cast (x) + c = this.prim_cast (x) +} +< + domain: "this", + opset_import: ["" : 18] +> +prim_cast (x) => (return_val) +{ + return_val = Cast (x) +} + """ + ) + simple_function_folding.inline_simple_functions(model) + remove_unused_function.remove_unused_functions(model) + self.assertEqual(len(model.functions), 0) + self.assertEqual(len(model.graph.node), 3) + self.assertEqual(model.graph.node[0].attribute[0].i, 10) + self.assertEqual(model.graph.node[1].attribute[0].i, 6) + self.assertEqual(model.graph.node[2].attribute[0].i, 7) + + def test_fold_nested_if_function_succeeds(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 9, + opset_import: ["this" : 1, "" : 21] +> +func (float[1,512] x, float[1,512] y) => ( out) { + out = this.foldable_func (x, y) +} +< + domain: "this", + opset_import: ["" : 18] +> +foldable_func (x, y) => (z_6) +{ + cond = Constant () + z_6 = If (cond) ( z_2) { + cond_0 = Not (cond) + z_2 = If (cond_0) ( z) { + z = Add (x, x) + }, else_branch: graph = elseGraph_5 () => ( z_1) { + z_1 = Identity (x) + }> + }, else_branch: graph = elseGraph_4 () => ( z_5) { + z_5 = If (cond) ( z_3) { + z_3 = Add (y, y) + }, else_branch: graph = elseGraph_10 () => ( z_4) { + z_4 = Add (x, y) + }> + }> +} + """ + ) + + simple_function_folding.inline_simple_functions(model) + remove_unused_function.remove_unused_functions(model) + + self.assertEqual(len(model.functions), 0) + self.assertEqual(len(model.graph.node), 2) + self.assertEqual(model.graph.node[1].op_type, "If") + + def test_fold_function_with_unused_output(self): + model = onnx.parser.parse_model( + """ +< + ir_version: 8, + opset_import: ["this" : 1, "" : 18] +> +func ( x, y, z) => ( return_val) { + tmp = this.non_foldable (x, y) + return_val = Add (tmp, z) +} +< + domain: "this", + opset_import: ["" : 18] +> +foldable (x, y) => (return_val, unused, unused1) +{ + return_val = Concat (x, y) + unused = Identity (x) + unused1 = Identity (y) +} +< + domain: "this", + opset_import: ["this" : 1,"" : 18] +> +non_foldable (x, y) => (return_val) +{ + tmp, unused, unused1 = this.foldable (x, y) + tmp_0, unused2, unused3 = this.foldable (x, y) + return_val = Add (tmp, tmp_0) +} + """ + ) + + simple_function_folding.inline_functions_with_unused_outputs(model) + remove_unused_function.remove_unused_functions(model) + self.assertEqual(len(model.functions), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py new file mode 100644 index 000000000..3fe036f22 --- /dev/null +++ b/onnxscript/rewriter/__init__.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Sequence + +__all__ = [ + # Modules + "irbuilder", + "protobuilder", + "function_rule", + "pattern", + # Functions + "rewrite", +] + +import onnx + +from onnxscript._legacy_ir import irbuilder, protobuilder +from onnxscript.rewriter import function_rule, pattern + +PatternRewriteRule = pattern.RewriteRule +FunctionRewriteRule = function_rule.FunctionRewriteRule + + +def rewrite( + model: onnx.ModelProto, + function_rewrite_rules: Sequence[type[FunctionRewriteRule]] = (), + pattern_rewrite_rules: Sequence[PatternRewriteRule] = (), +) -> onnx.ModelProto: + if function_rewrite_rules: + model_ir = irbuilder.build_ir(model) + for rule_cls in function_rewrite_rules: + rule_cls().apply_to_model(model_ir) + model = model_ir.original_model_proto + if pattern_rewrite_rules: + model_ir = irbuilder.build_ir(model) + count = pattern.RewriteRuleSet(pattern_rewrite_rules).apply_to_model(model_ir) + print(f"Applied {count} pattern rewrite rules.") + model = protobuilder.build_model_proto(model_ir) + return model diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/broadcast_to_matmul.py new file mode 100644 index 000000000..6c7ad4c07 --- /dev/null +++ b/onnxscript/rewriter/broadcast_to_matmul.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np + +import onnxscript._legacy_ir as ir +from onnxscript.rewriter import pattern + +op = pattern.onnxop +logger = logging.getLogger(__name__) + + +# condition to check if we need to replace the pattern +def check_if_need_reshape(match_bindings: dict[str, ir.Value | Any]) -> bool: + """If matmul broadcasting is enough, then we don't need the reshapes. + + To validate this, we need to check the following: + 1. Input shapes check: input_a and input_b should be broadcastable + 2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b) + + If the above are true, then we don't need the reshapes. + + Args: + match_bindings: The match binding dictionary from a MatchResult. + + Returns: + bool: True if we need to replace the pattern, False otherwise. + + """ + input_a_shape = match_bindings["input_a"].shape + input_b_shape = match_bindings["input_b"].shape + shape_c = match_bindings["shape_c"].value_as_np_array + if shape_c is None: + return False + if not isinstance(shape_c, np.ndarray): + logger.info("Unexpected shape_c value. Expected np.ndarray, got %s", type(shape_c)) + return False + if len(shape_c.shape) != 1: + logger.info( + "Unexpected final shape. The shape of 'shape' value is %s", + shape_c.shape, + ) + return False + shape_c = shape_c.tolist() + + # NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape + # information. So, we need to check if the shape is None and return False. + if input_a_shape is None or input_b_shape is None or shape_c is None: + logger.info("Shape information is not available for the inputs and outputs.") + return False + + dim_a = len(input_a_shape) + dim_b = len(input_b_shape) + + # 1. Check if input shapes are broadcastable + # 1.a. If the first input is 1-D, check whether + # the dim matches the last second dim of the second input. + mimic_matmul_broadcast_behavior = False + if dim_a < 2: + if input_a_shape[-1] != input_b_shape[-2]: + logger.info("Original shape is not MatMul compatible.") + return False + else: + input_a_shape = [1, *input_a_shape] + dim_a = len(input_a_shape) + mimic_matmul_broadcast_behavior = True + # 1.b. If the second input is 1-D, check whether + # the dim matches the last dim of the first input. + if dim_b < 2: + if input_b_shape[-1] != input_a_shape[-1]: + logger.info("Original shape is not MatMul compatible.") + return False + else: + input_b_shape = [*input_b_shape, 1] + dim_b = len(input_b_shape) + mimic_matmul_broadcast_behavior = True + # 1.c. If both inputs are at least 2-D, check whether + # the last dimension of the first input matches the second + # last dimension of the second input, and shape[:-2] are + # broadcastable. + input_a_shape_except_second_last_dim = input_a_shape[:-2] + [input_a_shape[-1]] + input_b_shape_except_last_dim = input_b_shape[:-1] + broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]] + for idx, (dim_from_a, dim_from_b) in enumerate( + zip( + reversed(input_a_shape_except_second_last_dim), + reversed(input_b_shape_except_last_dim), + ) + ): + if dim_from_a not in {1, dim_from_b}: + logger.info("Original shape is not broadcastable.") + return False + elif idx > 0: + broadcast_matmul_output_shape = [ + max(dim_from_a, dim_from_b), + *broadcast_matmul_output_shape, + ] + + # 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b) + # Prepend the broadcast_matmul_output_shape with the longer shape of input + if dim_a > dim_b: + longer_shape = input_a_shape + shorter_shape = input_b_shape + else: + longer_shape = input_b_shape + shorter_shape = input_a_shape + broadcast_matmul_output_shape = ( + longer_shape[: -len(shorter_shape)] + broadcast_matmul_output_shape + ) + if mimic_matmul_broadcast_behavior and dim_b == 2: + broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1] + if mimic_matmul_broadcast_behavior and dim_a == 2: + broadcast_matmul_output_shape.pop(-2) + if shape_c != broadcast_matmul_output_shape: + logger.info( + "Final output shape is not the same. Expected %s vs actual %s", + shape_c, + broadcast_matmul_output_shape, + ) + return False + + return True + + +def two_reshapes_matmul_reshape_pattern(input_a, input_b, shape_a, shape_b, shape_c): + # TODO: Modified from `value_ints` to `value` to match pattern in benchmark models. + # This implementation misses pattern of Constants with `value_ints` attribute. + # See more at https://github.com/microsoft/onnx-rewriter/issues/191. + # A better solution is to improve pattern matching and avoid depending on writing + # Constants in pattern. See https://github.com/microsoft/onnx-rewriter/issues/192. + reshape_a = op.Reshape(input_a, shape_a) + reshape_b = op.Reshape(input_b, shape_b) + matmul = op.MatMul(reshape_a, reshape_b) + return op.Reshape(matmul, shape_c) + + +def matmul_with_two_shape_inputs(input_a, input_b, shape_a, shape_b, shape_c): + del shape_a # Unused + del shape_b # Unused + del shape_c # Unused + return op.MatMul(input_a, input_b) + + +def one_reshape_matmul_reshape_pattern(input_a, input_b, shape_a, shape_c): + reshape_a = op.Reshape(input_a, shape_a) + matmul = op.MatMul(reshape_a, input_b) + return op.Reshape(matmul, shape_c) + + +def matmul_with_one_shape_input(input_a, input_b, shape_a, shape_c): + del shape_a # Unused + del shape_c # Unused + return op.MatMul(input_a, input_b) + + +# Register the rewrite rules +two_reshapes_matmul_reshape_rule = pattern.RewriteRule( + two_reshapes_matmul_reshape_pattern, + matmul_with_two_shape_inputs, + check_if_need_reshape, +) +one_reshape_matmul_reshape_rule = pattern.RewriteRule( + one_reshape_matmul_reshape_pattern, + matmul_with_one_shape_input, + # We can use the same check_if_need_reshape function for both the rules, + # as one_reshape_matmul_reshape_pattern is a subset of two_reshapes_matmul_reshape_pattern. + check_if_need_reshape, +) + +# NOTE: The order of the rules is important. Larger pattern should be checked first. +rules = pattern.RewriteRuleSet( + [two_reshapes_matmul_reshape_rule, one_reshape_matmul_reshape_rule] +) diff --git a/onnxscript/rewriter/broadcast_to_matmul_test.py b/onnxscript/rewriter/broadcast_to_matmul_test.py new file mode 100644 index 000000000..b462c46c1 --- /dev/null +++ b/onnxscript/rewriter/broadcast_to_matmul_test.py @@ -0,0 +1,283 @@ +import unittest + +import onnx.parser + +from onnxscript._legacy_ir import irbuilder +from onnxscript.rewriter import broadcast_to_matmul + + +class TwoReshapesMatMulReshapeTest(unittest.TestCase): + def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + + def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable_in_nested_function( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[1, 4, 512, 64] input_y) => (float[1, 4, 512, 64] output) + { + output = pkg.custom.afunction (input_x, input_y) + } + + afunction (input_x, input_y) => (output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + # Hack to put value_info in since parser does not support this experimental naming format + model.graph.value_info.append( + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/input_x", + onnx.TensorProto.FLOAT, + [1, 4, 512, 512], + ) + ) + model.graph.value_info.append( + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/input_y", onnx.TensorProto.FLOAT, [1, 4, 512, 64] + ) + ) + + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.functions), 1) + self.assertEqual(len(ir.functions[0].nodes), 4) + self.assertEqual(ir.functions[0].nodes[-1].op_type, "MatMul") + + def test_reshape_matmul_reshape_remain_when_input_last_dim_and_second_last_dim_not_matched( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[512, 512, 4] input_x, float[4, 64, 512] input_y) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 7) + + def test_reshape_matmul_reshape_remain_when_inputs_are_not_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 8, 512, 64] input_x, float[4, 4, 64, 512] input_y) => (float[2, 8, 512, 512] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 7) + + def test_reshape_matmul_reshape_replace_when_inputs_are_broadcastable_with_one_in_dims( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 8, 512, 64] input_x, float[1, 1, 2, 8, 64, 512] input_y) => (float[1, 1, 2, 8, 512, 512] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + + def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[4] input_x, float[2, 3, 4, 5] input_y) => (float[2, 3, 5] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + + def test_reshape_matmul_reshape_remain_when_first_input_is_one_dimension_and_not_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[8] input_x, float[2, 3, 4, 5] input_y) => (float[2, 3, 2, 5] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 7) + + def test_reshape_matmul_reshape_replace_when_second_input_is_one_dimension_and_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 3, 4, 5] input_x, float[5] input_y) => (float[2, 3, 4] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + + def test_reshape_matmul_reshape_remain_when_second_input_is_one_dimension_and_not_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 3, 4, 5] input_x, float[10] input_y) => (float[2, 3, 4, 2] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 7) + + def test_reshape_matmul_reshape_remain_when_output_is_not_matmul_broadcasted( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 3, 4, 5] input_x, float[5, 8] input_y) => (float[2, 4, 6, 4] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + shape_b = Constant() + reshape_y = Reshape (input_y, shape_b) + matmul = MatMul (reshape_x, reshape_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 7) + + +class OneReshapeMatMulReshapeTest(unittest.TestCase): + def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 512, 4096] input_x, float[4096, 4096] input_y) => (float[1, 512, 4096] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + matmul = MatMul (reshape_x, input_y) + shape_c = Constant() + output = Reshape (matmul, shape_c) + } + """ + ) + ir = irbuilder.build_ir(model) + count = broadcast_to_matmul.rules.apply_to_model(ir) + self.assertEqual(count, 1) + # The constant nodes are not removed. They should be removed by a subsequent DCE in optimizer. + self.assertEqual(len(ir.graph.nodes), 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/cast_constant_of_shape.py b/onnxscript/rewriter/cast_constant_of_shape.py new file mode 100644 index 000000000..1b4ac98e9 --- /dev/null +++ b/onnxscript/rewriter/cast_constant_of_shape.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import logging +from typing import Any, Sequence + +import numpy as np +import onnx + +import onnxscript._legacy_ir as ir +from onnxscript.rewriter import pattern + +op = pattern.onnxop +logger = logging.getLogger(__name__) + + +def cast_constant_of_shape( + shape: Sequence[int], + t: Any, + dtype: int, + match_bindings: dict[str, ir.Value | Any] | None = None, +) -> pattern.OpPattern: + constant = op.ConstantOfShape(shape, value=t) + return op.Cast(constant, to=dtype) + + +def fused_cast_constant_of_shape( + shape: Sequence[int], t: Any, dtype: int, match_bindings: dict[str, ir.Value | Any] +) -> pattern.OpPattern: + del dtype # unused + del t # unused + v_dtype = match_bindings["dtype"] + v_t = match_bindings["t"] + casted_val = onnx.numpy_helper.to_array(v_t).astype( # type: ignore[arg-type] + dtype=onnx.helper.tensor_dtype_to_np_dtype(v_dtype) # type: ignore[arg-type] + ) + return op.ConstantOfShape(shape, value=casted_val) + + +def cast_constant_of_shape_without_value( + shape: Sequence[int], + dtype: int, + match_bindings: dict[str, ir.Value | Any] | None = None, +) -> pattern.OpPattern: + del match_bindings # Unused + constant = op.ConstantOfShape(shape) + return op.Cast(constant, to=dtype) + + +def fused_cast_constant_of_shape_without_value( + shape: Sequence[int], dtype: int, match_bindings: dict[str, ir.Value | Any] +) -> pattern.OpPattern: + del dtype # Unused + v_dtype = match_bindings["dtype"] + val = np.zeros(1, dtype=onnx.helper.tensor_dtype_to_np_dtype(v_dtype)) # type: ignore + return op.ConstantOfShape(shape, value=val) + + +cast_constant_of_shape_rule = pattern.RewriteRule( + cast_constant_of_shape, + pattern.ReplacementPatternFunction(fused_cast_constant_of_shape, delay_run=True), +) + +cast_constant_of_shape_without_value_rule = pattern.RewriteRule( + cast_constant_of_shape_without_value, + pattern.ReplacementPatternFunction( + fused_cast_constant_of_shape_without_value, delay_run=True + ), +) + +rules = pattern.RewriteRuleSet( + [ + cast_constant_of_shape_rule, + cast_constant_of_shape_without_value_rule, + ] +) diff --git a/onnxscript/rewriter/cast_constant_of_shape_test.py b/onnxscript/rewriter/cast_constant_of_shape_test.py new file mode 100644 index 000000000..c459a40c4 --- /dev/null +++ b/onnxscript/rewriter/cast_constant_of_shape_test.py @@ -0,0 +1,46 @@ +import unittest + +import onnx.parser + +from onnxscript._legacy_ir import irbuilder +from onnxscript.rewriter import cast_constant_of_shape + + +class CastConstantOfShapeTest(unittest.TestCase): + def test_cast_after_constant_of_shape_is_fused(self): + model = onnx.parser.parse_model( + """ + + agraph (int64[2] input_x) => (float16[1, 4] output) + { + constant = ConstantOfShape (input_x) + output = Cast (constant) + } + """ + ) + ir = irbuilder.build_ir(model) + count = cast_constant_of_shape.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 1) + self.assertEqual(ir.graph.nodes[0].attributes["value"].data_type, 10) + + def test_cast_after_constant_of_shape_without_value_is_fused(self): + model = onnx.parser.parse_model( + """ + + agraph (int64[2] input_x) => (float16[1, 4] output) + { + constant = ConstantOfShape (input_x) + output = Cast (constant) + } + """ + ) + ir = irbuilder.build_ir(model) + count = cast_constant_of_shape.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 1) + self.assertEqual(ir.graph.nodes[0].attributes["value"].data_type, 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/erfgelu.py b/onnxscript/rewriter/erfgelu.py new file mode 100644 index 000000000..67f0d47e1 --- /dev/null +++ b/onnxscript/rewriter/erfgelu.py @@ -0,0 +1,30 @@ +import math + +from onnxscript.rewriter import pattern + +op = pattern.onnxop + + +# Pattern to match against +def erf_gelu_pattern(x): + # erf_gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + # half = pattern.Constant(0.5) + # sqrt2 = pattern.Constant(1.4142) + # x_div_sqrt2 = op.Div(x, sqrt2) + # erf = op.Erf(x_div_sqrt2) + # one = pattern.Constant(1.0) + # one_plus_erf = op.Add(erf, one) + # x_mul_one_plus_erf = op.Mul(x, one_plus_erf) + # return op.Mul(half, x_mul_one_plus_erf) + return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0)) + + +msft_op = pattern.msft_op + + +# Replacement +def gelu(x): + return msft_op.Gelu(x) + + +rule = pattern.RewriteRule(erf_gelu_pattern, gelu) diff --git a/onnxscript/rewriter/function_rule.py b/onnxscript/rewriter/function_rule.py new file mode 100644 index 000000000..526626d8f --- /dev/null +++ b/onnxscript/rewriter/function_rule.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import functools +import logging + +import onnx +from packaging import version + +import onnxscript +import onnxscript._legacy_ir as ir +from onnxscript._legacy_ir import visitor +from onnxscript.rewriter import pattern + +logger = logging.getLogger(__name__) + + +class FunctionRewriteError(RuntimeError): ... + + +@functools.lru_cache +def parse_domain(function_domain: str) -> tuple[str, version.Version | None]: + splits = function_domain.split(".") + if splits[0] != "pkg": + raise FunctionRewriteError( + f"Invalid domain: {function_domain}. Must start with 'pkg'." + ) + splits = splits[1:] + for i, s in enumerate(splits): + if s.isdigit(): + return ".".join(splits[:i]), version.parse(".".join(splits[i:])) + return ".".join(splits), None + + +MIN_VERSION = version.parse("0") +MAX_VERSION = version.parse("9999") + + +class VersionController: + def __init__(self): + # A dispatch table for rewrite implementation based on the function package version. + self.dispatch_table: dict[tuple[version.Version, version.Version], callable] = {} + + def register_version( + self, + min_version: version.Version | str | None = None, + max_version: version.Version | str | None = None, + ): + """Register a function implementation for a specific package version range [min_version, max_version). + + Args: + min_version: The minimum version of the package. Inclusive. + max_version: The maximum version of the package. Exclusive. + """ + # TODO: check for version overloap + + min_version = MIN_VERSION if min_version is None else min_version + max_version = MAX_VERSION if max_version is None else max_version + if isinstance(min_version, str): + min_version = version.parse(min_version) + if isinstance(max_version, str): + max_version = version.parse(max_version) + + def deco(func): + self.dispatch_table[(min_version, max_version)] = func + return func + + return deco + + def dispatch(self, version: version.Version | None) -> callable | None: + if version is None: + if len(self.dispatch_table) == 1: + return next(iter(self.dispatch_table.values())) + raise ValueError( + "No function package version specified, however there are multiple " + f"fusion rules based on package version: {self.dispatch_table.keys()}." + ) + for (min_version, max_version), func in self.dispatch_table.items(): + greater_than_min = min_version is None or min_version <= version + less_than_max = max_version is None or version < max_version + if greater_than_min and less_than_max: + return func + return None + + +class FunctionRewriteRule(pattern.RewriteRule): + FUNCTION_KEYWORD: str | tuple[str] + """The keyword to match the function name. If a tuple, any keyword will match.""" + + PACKAGE_NAME: str + """The package name to match. + + For example, 'transformers' to match for domain name 'pkg.transformers.4.36.2'. + """ + + _opset_imports: dict[str, int] + onnx_opset: onnxscript.values.Opset + _function_shape_env: visitor.FunctionShapeEnv + + def __init__(self, opset: onnxscript.values.Opset = onnxscript.opset18) -> None: + self.onnx_opset = opset + + def _match_function(self, function: onnx.FunctionProto, pkg_name: str) -> bool: + # TODO: Consolidate more checks from `compose_new_function` to here. + if pkg_name != self.PACKAGE_NAME: + logger.info( + "Rule %s did not match function %s::%s. Package name mismatch '%s' != '%s'.", + self.__class__.__name__, + function.domain, + function.name, + self.PACKAGE_NAME, + pkg_name, + ) + return False + + if isinstance(self.FUNCTION_KEYWORD, str): + return function.name.find(self.FUNCTION_KEYWORD) != -1 + elif isinstance(self.FUNCTION_KEYWORD, tuple): + return any(function.name.find(keyword) != -1 for keyword in self.FUNCTION_KEYWORD) + else: + raise ValueError( # noqa: TRY004 + f"Function keyword must be str or tuple, got {self.FUNCTION_KEYWORD}" + ) + + def _find_node_contains_key_in_name( + self, function: onnx.FunctionProto, keyword: str + ) -> onnx.NodeProto | None: + for node in function.node: + if node.name.find(keyword) != -1: + return node + return None + + def _find_node_by_type( + self, function: onnx.FunctionProto, domain: str, op_type: str + ) -> onnx.NodeProto | None: + # Repeat + for node in function.node: + if node.domain == domain and node.op_type == op_type: + return node + return None + + def _find_constant_node( + self, function: onnx.FunctionProto, value_name: str + ) -> onnx.NodeProto | None: + # Potentially repeat, utility function. + for node in function.node: + for output in node.output: + if output == value_name: + return node + return None + + def compose_new_function( + self, old_function: onnx.FunctionProto, pkg_version: version.Version | None + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + """Compose a new function from the old function. + + Returns: + A tuple of the new function and the opset imports. + + Raises: + FunctionRewriteError: If the rewrite fails. + """ + func = self._version_controller.dispatch(pkg_version) + if func is not None: + return func(self, old_function) + raise FunctionRewriteError( + f"No rewrite implementation for package version {pkg_version}." + ) + + def try_rewrite_function( + self, function: onnx.FunctionProto, model: onnx.ModelProto + ) -> bool: + try: + pkg_name, pkg_version = parse_domain(function.domain) + except FunctionRewriteError as e: + logger.warning("Could not parse domain: %s", e) + return False + + if pkg_version is None and not pkg_name.startswith("onnxscript"): + logger.warning( + "Could not parse version for domain of function %s::%s. " + "Usually this implies the model source is not from a package, but from arbitrary python files instead. " + "For example, models not defined in huggingface/transformers but loaded via 'trust_remote_code=True'.", + function.domain, + function.name, + ) + + if not self._match_function(function, pkg_name): + return False + logger.info( + "Rule %s matched function %s::%s", + self.__class__.__name__, + function.domain, + function.name, + ) + + try: + new_function, opset_imports = self.compose_new_function(function, pkg_version) + except FunctionRewriteError as e: + logger.warning("Could not rewrite function: %s", e) + return False + + nodes = new_function.node + + del function.input[:] + function.input.extend(new_function.input) + del function.output[:] + function.output.extend(new_function.output) + + del function.node[:] + function.node.extend(nodes) + for new_opset in opset_imports: + function.opset_import.append(new_opset) + if new_opset.domain not in self._opset_imports: + model.opset_import.append(new_opset) + + return True + + def try_rewrite(self, model: ir.Model, value) -> bool: + raise NotImplementedError( + "Use `try_rewrite_function` instead for function based rewrites." + ) + + def lookup(self, function: onnx.FunctionProto, value_name: str) -> ir.Value | None: + return self._function_shape_env.lookup(function, value_name) + + def apply_to_model(self, model: ir.Model, *, commute: bool = False) -> int: + del commute # unused + model_proto: onnx.ModelProto = model.original_model_proto + self._function_shape_env = visitor.FunctionShapeEnv() + self._function_shape_env.load_from_model_proto(model.original_model_proto) + self._opset_imports = {x.domain: x.version for x in model_proto.opset_import} + + rewrite_count = 0 + for function in model_proto.functions: + rewrite_count += self.try_rewrite_function(function, model_proto) + return rewrite_count + + def count_matches(self, model, *, commute: bool = False) -> int: + raise NotImplementedError() + + def commute(self) -> list[pattern.RewriteRule]: + raise NotImplementedError() diff --git a/onnxscript/rewriter/gemm_to_matmul_add.py b/onnxscript/rewriter/gemm_to_matmul_add.py new file mode 100644 index 000000000..ae44ffe27 --- /dev/null +++ b/onnxscript/rewriter/gemm_to_matmul_add.py @@ -0,0 +1,21 @@ +from onnxscript.rewriter import pattern +from onnxscript.rewriter.broadcast_to_matmul import check_if_need_reshape + +op = pattern.onnxop + + +# Pattern to match against +def reshape_gemm_reshape_pattern(input_a, input_b, input_c, shape_a, shape_c): + reshape_a = op.Reshape(input_a, shape_a) + # TODO: Temporary workaround to support benchmodels. + # Tracked by https://github.com/microsoft/onnx-rewriter/issues/197. + gemm = op.Gemm(reshape_a, input_b, input_c, alpha=1.0, beta=1.0) + return op.Reshape(gemm, shape_c) + + +def matmul_add(input_a, input_b, input_c, shape_a, shape_d): + matmul = op.MatMul(input_a, input_b) + return op.Add(matmul, input_c) + + +rule = pattern.RewriteRule(reshape_gemm_reshape_pattern, matmul_add, check_if_need_reshape) diff --git a/onnxscript/rewriter/gemm_to_matmul_add_test.py b/onnxscript/rewriter/gemm_to_matmul_add_test.py new file mode 100644 index 000000000..615d6311a --- /dev/null +++ b/onnxscript/rewriter/gemm_to_matmul_add_test.py @@ -0,0 +1,254 @@ +import unittest + +import onnx.parser + +from onnxscript._legacy_ir import irbuilder +from onnxscript.rewriter import gemm_to_matmul_add + + +class ReshapeGemmReshapeTest(unittest.TestCase): + def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[4, 512, 64] input_y, float[4, 512, 64] input_z) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + + def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable_in_nested_function( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[4, 512, 64] input_y, float[4, 512, 64] input_z) => (float[1, 4, 512, 64] output) + { + output = afunction (input_x, input_y, input_z) + } + + afunction (input_x, input_y, input_z) => (output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + # Hack to put value_info in since parser does not support this experimental naming format + model.graph.value_info.append( + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/input_x", + onnx.TensorProto.FLOAT, + [1, 4, 512, 512], + ) + ) + model.graph.value_info.append( + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/input_y", onnx.TensorProto.FLOAT, [4, 512, 64] + ) + ) + model.graph.value_info.append( + onnx.helper.make_tensor_value_info( + "pkg.custom::afunction/input_z", onnx.TensorProto.FLOAT, [1, 4, 512, 64] + ) + ) + + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.functions), 1) + self.assertEqual(len(ir.functions[0].nodes), 4) + self.assertEqual(ir.functions[0].nodes[2].op_type, "MatMul") + self.assertEqual(ir.functions[0].nodes[3].op_type, "Add") + + def test_reshape_gemm_reshape_remain_when_input_last_dim_and_second_last_dim_not_matched( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 4, 512, 512] input_x, float[4, 256, 64] input_y, float[4, 512, 64] input_z) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 5) + + def test_reshape_gemm_reshape_remain_when_inputs_are_not_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 2, 512, 512] input_x, float[4, 512, 64] input_y, float[4, 512, 64] input_z) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 5) + + def test_reshape_gemm_reshape_replace_when_inputs_are_broadcastable_with_one_in_dims( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[4, 512, 512] input_x, float[1, 4, 512, 64] input_y, float[1, 4, 512, 64] input_z) => (float[1, 4, 512, 64] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + self.assertEqual(ir.graph.nodes[2].op_type, "MatMul") + self.assertEqual(ir.graph.nodes[3].op_type, "Add") + + def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[4] input_x, float[2, 3, 4, 5] input_y, float[2, 3, 5] input_z) => (float[2, 3, 5] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + self.assertEqual(ir.graph.nodes[2].op_type, "MatMul") + self.assertEqual(ir.graph.nodes[3].op_type, "Add") + + def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_not_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[8] input_x, float[2, 3, 4, 5] input_y, float[2, 3, 5] input_z) => (float[2, 3, 5] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 5) + + def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 3, 5, 4] input_x, float[4] input_y, float[2, 3, 5] input_z) => (float[2, 3, 5] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 4) + self.assertEqual(ir.graph.nodes[2].op_type, "MatMul") + self.assertEqual(ir.graph.nodes[3].op_type, "Add") + + def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_not_broadcastable( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 3, 5, 4] input_x, float[10] input_y, float[2, 3, 5] input_z) => (float[2, 3, 5] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 5) + + def test_reshape_gemm_reshape_remain_when_output_is_not_matmul_broadcasted( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 3, 5, 4] input_x, float[5] input_y, float[2, 3, 5] input_z) => (float[2, 4, 6] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + ir = irbuilder.build_ir(model) + count = gemm_to_matmul_add.rule.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py new file mode 100644 index 000000000..5681beab6 --- /dev/null +++ b/onnxscript/rewriter/generic_pattern.py @@ -0,0 +1,1165 @@ +from __future__ import annotations + +import collections +import inspect +import os +import textwrap +import typing + +import onnx +import onnx.helper as oh + +import onnxscript._legacy_ir as oir +import onnxscript.rewriter.pattern as orp + + +def enumerate_subgraphs( + node: oir.Node, +) -> typing.Iterator[tuple[typing.Any, ...]]: + """Returns the subgraphs inside a graph.""" + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH and att.g: + this = node, att.name, att.g + yield this + + for no in att.g.node: + for tu in enumerate_subgraphs(no): + yield this + tu + + +class _GraphStructureAPI: + """Common accessors to predecessors and successors.""" + + def __init__(self): + self.predecessors_: dict[str, int] = {} + self.successors_: dict[str, list[int]] = {} + self.nodes_: dict[int, oir.Node] = {} + + def node_before(self, name: str) -> oir.Node | None: + """ + Returns the node producing this output. + + Returns None if it is an input or an initializer. + """ + if name not in self.predecessors_: + return None + predecessor = self.predecessors_[name] + return self.nodes_[predecessor] + + def next_nodes(self, name: str) -> list[oir.Node] | None: + """Returns the node consuming the given results.""" + if name not in self.successors_: + return [] + return [self.nodes_[i] for i in self.successors_[name]] + + +class BuilderWithGraphStructure(_GraphStructureAPI): + """Very concise graph builder. + + It wraps an ONNX graph + and builds successors and predecessors on top of it. + """ + + def __init__(self, bridge: ModelWithGraphStructure): + super().__init__() + self.bridge: ModelWithGraphStructure = bridge + self.input_names: list[str] = [] + self.output_names: list[str] = [] + self.nodes: list[oir.Node] = [] + + def _build(self) -> None: + self.predecessors_: dict[str, int] = {} + self.successors_: dict[str, list[int]] = {} + self.nodes_: dict[int, oir.Node] = {} + + self.outputs_ = set(self.output_names) + for node in self.nodes: + self.nodes_[id(node)] = node + + for k, v in self.nodes_.items(): + assert isinstance(v, oir.Node), f"Unexpected type {type(v)} for node {k}" + for o in v.output_names: + self.predecessors_[o] = k + for i in v.input_names: + if i not in self.successors_: + self.successors_[i] = [] + self.successors_[i].append(k) + + def make_input(self, name: str) -> None: + self.input_names.append(name) + + def make_output(self, name: str) -> None: + self.output_names.append(name) + + def __getattr__(self, name: str) -> typing.Any: + if name in self.__dict__: + return self.__dict__[name] + + # unknown name + assert ( + name[0].upper() == name[0] + ), f"A node type must starts with an upper letter but it is {name!r}" + return lambda *args, _name=name, **kwargs: self._make_node(_name, *args, **kwargs) + + def _make_node( + self, + op_type: str, + *args: str, + output_names: list[str] | int | None = None, + **kwargs: typing.Any, + ) -> str | tuple[str]: + if output_names is None: + # We assume there is only one outputs, we could also check into the schema. + output_names = 1 + return self.make_node(op_type, *args, output_names=output_names, **kwargs) + + def make_node_with_proto(self, node_proto: onnx.NodeProto) -> tuple[str] | str: + node = oir.Node(node_proto, True) + self.nodes.append(node) + assert node.output_names, f"No output in node {node}. This can't be true." + if len(node.output_names) == 1: + return node.output_names[0] + return tuple(node.output_names) + + def make_node( + self, + op_type: str, + *input_names: str, + output_names: int | list[str] | str | None = None, + domain: str = "", + name: str | None = None, + **kwargs: typing.Any, + ) -> str | tuple[str]: + node = oir.Node( + self.bridge.make_node( + op_type, input_names, output_names, domain=domain, name=name, **kwargs + ), + True, + ) + self.nodes.append(node) + assert node.output_names, f"No output in node {node}. This can't be true." + if len(node.output_names) == 1: + return node.output_names[0] + return tuple(node.output_names) + + +class ModelWithGraphStructure(oir.Model, _GraphStructureAPI): + """Implements all the necessary API it needs to work. + + Wraps a :class:`Model` and builds successors and predecessors on + top of it. + """ + + def __init__(self, model: oir.Model, verbose: int = 0): + oir.Model.__init__(self) + _GraphStructureAPI.__init__(self) + self.model = model + if hasattr(self.model, "graph"): + self.nodes = list(model.graph.nodes) + self.input_names = list(model.graph.input_names) + self.output_names = list(model.graph.output_names) + self._build() + else: + # empty graph + self._unique_names: set = set() + self._unique_node_names: set = set() + self.verbose = verbose + + def _build(self) -> None: + """Builds successor and predecessor.""" + self.nodes_ = {} + self.outputs_ = set(self.output_names) + self._unique_node_names = set() + for node in self.nodes: + self.nodes_[id(node)] = node + if node.name: + self._unique_node_names.add(node.name) + + self.predecessors_: dict = {} + self.successors_: dict = {} + # TODO: # initiliazer are missing + self._unique_names = set(self.input_names) | set(self.output_names) + for k, v in self.nodes_.items(): + assert isinstance(v, oir.Node), f"Unexpected type {type(v)} for node {k}" + for o in v.output_names: + self.predecessors_[o] = k + for i in v.input_names: + if i not in self.successors_: + self.successors_[i] = [] + self.successors_[i].append(k) + + for sub in enumerate_subgraphs(v): + g = sub[-1] + sub_knowns = set() + for n in g.input: + sub_knowns.add(n.name) + for n in g.initializer: + sub_knowns.add(n.name) + for n in g.sparse_initializer: + sub_knowns.add(n.name) + for n in g.node: + for i in n.input: + if i not in sub_knowns: + # an input coming from the parent + self._unique_names.add(i) + for i in n.output: + sub_knowns.add(i) + + def unique_name(self, prefix: str) -> str: + """Generates a unique result name. + + That excludes existing names as well. + """ + if prefix in self._unique_names: + i = 2 + sug = f"{prefix}2" + while sug in self._unique_names: + i += 1 + sug = f"{prefix}{i}" + self._unique_names.add(sug) + return sug + self._unique_names.add(prefix) + return prefix + + def unique_node_name(self, name: str | None) -> str: + """Creates a unique node name.""" + name = name or "" + if name in self._unique_node_names: + i = 2 + sug = f"{name}2" + while sug in self._unique_node_names: + i += 1 + sug = f"{name}{i}" + self._unique_node_names.add(sug) + return sug + self._unique_node_names.add(name) + return name + + def make_opset(self) -> BuilderWithGraphStructure: + return BuilderWithGraphStructure(self) + + @property + def opsets(self) -> dict: + """Property.""" + return self.model.version_map + + def make_node( + self, + op_type: str, + input_names: str | typing.Sequence[str] | None, + output_names: int | typing.Sequence[str] | str | None = 1, + domain: str = "", + attributes: list[onnx.AttributeProto] | None = None, + name: str | None = None, + **kwargs: typing.Any, + ) -> onnx.NodeProto: + """ + Creates a node without adding it to the graph. + + :param op_type: operator type + :param input_names: input names + :param output_names: outputs names, if one integer, creates n unique names, + if str, creates one unique names, if a list, use the name + :param domain: node domain + :param attributes: list of attributes + :param name: node name + :param kwargs: other attributes + :return: a node + """ + name = self.unique_node_name(name) + if isinstance(output_names, int): + if output_names == 1: + output_names = [self.unique_name(f"{op_type.lower()}")] + else: + output_names = [ + self.unique_name(f"{op_type.lower()}-{i}") for i in range(output_names) + ] + elif isinstance(output_names, str): + output_names = [self.unique_name(output_names)] + + proto = oh.make_node( + op_type, + ( + input_names + if isinstance(input_names, (list, tuple)) + else ([input_names] if isinstance(input_names, str) else None) + ), + output_names, + domain=domain, + name=name, + **kwargs, + ) + if attributes: + proto.attribute.extend(attributes) + return proto + + +class GenericRewriteRule(orp.RewriteRule): + """ + Defines a rewriting rule. + + :param pattern: a pattern defines by :class:`GenericPattern`. + """ + + def __init__(self, pattern: GenericPattern): + self.pattern = pattern + + def matches(self, node: oir.Node, model: oir.Model) -> orp.MatchResult: + del model + del node + raise RuntimeError(f"This pattern {self} is meant to replace not to only match.") + + def try_rewrite( + self, model: oir.Model, node: oir.Node + ) -> tuple[int, list[oir.Node], list[oir.Node]] | None: + """See :meth:`RewriteRule.try_rewrite`.""" + if isinstance(model, ModelWithGraphStructure): + bridge = model + else: + bridge = ModelWithGraphStructure(model) + deleted_nodes = [] + added_nodes = [] + marked = set() + matched = 0 + for matched_nodes in self.pattern.enumerate_matches(bridge, node): + assert all(isinstance(i, oir.Node) for i in matched_nodes) + conflict = False + for node in matched_nodes: + if id(node) in marked: + conflict = True + break + if conflict: + # Some nodes are already marked as rewritten. + continue + + # Let's build the new nodes + new_nodes = self.pattern.apply(bridge, *matched_nodes) + assert all( + isinstance(i, oir.Node) for i in new_nodes + ), f"Unexpected types {[type(n) for n in new_nodes]}" + + if not self.pattern.validate_mapping(bridge, matched_nodes, new_nodes): + continue + + # Everything is good. + marked |= set(map(id, matched_nodes)) + added_nodes.extend(new_nodes) + deleted_nodes.extend(matched_nodes) + matched += 1 + + if matched > 0: + return matched, deleted_nodes, added_nodes + return None + + def count_matches(self, model: oir.Model, *, commute: bool = False) -> int: + """See :meth:`RewriteRule.count_matches`.""" + raise NotImplementedError("Not supported yet.") + + def commute(self) -> list[orp.RewriteRule]: + """See :meth:`RewriteRule.commute`.""" + raise RuntimeError("Not supported (yet?). It could lead to many patterns.") + + def apply_to_model(self, model: oir.Model, *, commute: bool = False) -> int: + """See :meth:`RewriteRule.apply_to_model`.""" + return orp.RewriteRuleSet([self], commute=commute).apply_to_model(model) + + +class GenericPattern: + """ + Implements a pattern optimization for quick experimentation. + + Current limitation: + + * The current implementation does match on domain name (easy fix). + * It does not compares attributes either (easy fix as well). + """ + + def __init__(self, verbose: int = 0): + self.verbose = verbose + self._cache: dict = {} + + def validate_mapping( + self, g: oir.Model, deleted_nodes: list[oir.Node], added_nodes: list[oir.Node] + ) -> bool: + """Evaluates the consistency of the replacements.""" + raise NotImplementedError( + "This method could return True but it is better to let you know " + "that it exists. You need to overwrite it to return True." + ) + + def enumerate_matches( + self, g: ModelWithGraphStructure, node: oir.Node | None = None + ) -> typing.Iterator: + """Enumerates all the matches.""" + if node is None: + matched = [] + for node in g.nodes: + res = self.match(g, node) + if res: + matched.append(res) + yield res + else: + res = self.match(g, node) + if res: + yield res + + def none( + self, + node: oir.Node | None = None, + lineno: int | None = None, + msg: str = "", + ) -> None: + """Must be called every time a match fails to trace it. + + It may be useful which reason made a pattern matching fail. + Instead of returning None, method *match* can return the following + expression: + + :: + + return self.none(node, inspect.currentframe().f_lineno) + + By setting the verbosity (see next Section), the user may then know + which lines in the code returned None and which condition failed. + If logs are fully enabled, it shows informations about matched none + and the line deciding the matched failed. + For example, this tells the matching failed at line 601 in ``generic_pattern.py``. + It happens when propagating the match in the backward directions. + The unmatched types are Mul, MatMul and below, + it shows the matched nodes. The first one was Cast. + And the failure happened at iteration 5. + ``139774002356544-139774000632672`` is the pair of ids used in container ``marked``. + ``id(node)`` is used as a unique identifiers of the nodes. + + :: + + [RotaryEmbeddingPattern.match] NONE - line: 601:__main__, op_type=Cast + --hint--: BACKWARD: different node types + --pattern + Mul(pos_ids, cast) -> (mul) + -- model + MatMul(/_original_modu...Expand_output_0, /_original_modu...b/Cast_output_0) -> (/_original_modu...MatMul_output_0) + iteration=5 + --marked-- #6 + Cast(/_original_modu...mb/Cos_output_0) ~ Cast(cos) [139774002356544-139774000632672] + Cos(/_original_modu...ncat_1_output_0) ~ Cos(concattraining-transpose-0) [139774002356448-139774000632048] + ConcatTraining(/_original_modu...nspose_output_0,/_original_modu...nspose_output_0) ~ ConcatTraining(transpose,transpose) [139774002356352-139774000631712] + Transpose(/_original_modu...MatMul_output_0) ~ Transpose(mul) [139774002356256-139774000631184] + Sin(/_original_modu...ncat_1_output_0) ~ Sin(concattraining-transpose-0) [139774002358512-139774000631568] + Cast(/_original_modu...mb/Sin_output_0) ~ Cast(sin) [139774002358608-139774000632384] + len(stacked)=0:[] + + 'hints' are not added everywhere. More can easily be added with method ``_hint``. + """ + if node and self.verbose: + if self.verbose >= 10: + if hasattr(self, "_debug"): + msg2 = self._debug_print() + if msg2: + msg2 = f"\n{textwrap.indent(msg2, ' ')}" + else: + msg2 = "" + print( + f"[{self.__class__.__name__}.match] NONE - line: {lineno}:" + f"{os.path.split(self.__class__.__module__)[-1]}, " + f"op_type={node.op_type}{msg}{msg2}" + ) + + @classmethod + def match_pattern( + cls, + g: ModelWithGraphStructure, + *args: str, + **kwargs: typing.Any, + ) -> list[oir.Node] | None: + """Builds the pattern to match.""" + raise NotImplementedError( + f"Class {cls.__name__!r} must overwrite method match_pattern." + ) + + @classmethod + def _build_pattern( + cls, g: ModelWithGraphStructure, fct: typing.Callable + ) -> BuilderWithGraphStructure: + kwargs = {} + args = [] + + # There should be a better way. + sig = inspect.signature(fct) + for i, p in enumerate(sig.parameters.values()): + if i == 0: + continue + if p.default is not inspect._empty: + # an attribute + kwargs[p.name] = p.default + else: + args.append(p.name) + + assert len(kwargs) == 0, f"Attributes are not supported yet but kwargs={kwargs}" + + g2 = g.make_opset() + for name in args: + g2.make_input(name) + output = fct(g2, *args, **kwargs) + if isinstance(output, str): + g2.make_output(output) + else: + for name in output: + g2.make_output(name) + g2._build() + return g2 + + def _get_match_pattern(self, g: ModelWithGraphStructure) -> BuilderWithGraphStructure: + cache_key = 0, tuple(sorted(g.opsets.items())) + if cache_key in self._cache: + return self._cache[cache_key] + + pat = self._build_pattern(g, self.match_pattern) + self._cache[cache_key] = pat + return pat + + def _get_apply_pattern(self, g: ModelWithGraphStructure) -> BuilderWithGraphStructure: + cache_key = 1, tuple(sorted(g.opsets.items())) + if cache_key in self._cache: + return self._cache[cache_key] + + pat = self._build_pattern(g, self.apply_pattern) + self._cache[cache_key] = pat + return pat + + def display_pattern(self, g: ModelWithGraphStructure, fct: typing.Callable) -> str: + """Shows the pattern to match or to apply.""" + pat = self._build_pattern(g, fct) + rows = [] + rows.append( + f"{fct.__name__}({', '.join(pat.input_names)}) -> {', '.join(pat.output_names)}" + ) + for node in pat.nodes: + rows.append( + f"{node.op_type}({', '.join(node.input_names)}) -> " + f"{', '.join(node.output_names)}" + ) + return "\n".join(rows) + + def print_match(self, n1: oir.Node, n2: oir.Node) -> str: + s1 = f"{n1.op_type}({','.join(n1.input_names)})" + s2 = f"{n2.op_type}({','.join(n2.input_names)})" + return f"match {s1} with {s2} (pattern)" + + def _debug_print(self) -> str: + if not hasattr(self, "_debug"): + return "" + + def _s(s: str) -> str: + if len(s) <= 30: + return s + return f"{s[:15]}...{s[-15:]}" + + def _p(n: oir.Node, full: bool = False) -> str: + if isinstance(n, (oir.Node, onnx.NodeProto)): + if full: + return ( + f"{n.op_type}({', '.join(map(_s, n.input_names))}) " + f"-> ({', '.join(map(_s, n.output_names))})" + ) + return f"{n.op_type}({','.join(map(_s, n.input_names))})" + return str(n) + + rows = [] + for k, v in sorted(self._debug.items()): + if k == "stacked": + rows.append(f"len({k})={len(v)}:{v}") + continue + if k == "iteration": + rows.append(f"{k}={v}") + continue + if k == "marked": + rows.append(f"--marked-- #{len(v)}") + for i, tu in v.items(): + rows.append(f" {_p(tu[0])} ~ {_p(tu[1])} [{id(tu[0])}-{i}]") + continue + if k == "hint": + rows.append(f"--hint--: {v[0]}") + for i in v[1:]: + rows.append(" " + _p(i, full=True)) + continue + if k in {"node", "pattern", "pattern_node", "pattern_nodes"}: + continue + rows.append(f"-- not shown {k}") + + return "\n".join(rows) + + def _hint(self, *args: typing.Any) -> None: + """Add debugging information to help users.""" + self._debug["hint"] = args + + def _match_backward( + self, + g: ModelWithGraphStructure, + node: oir.Node, + pat: ModelWithGraphStructure, + marked: dict[int, tuple[oir.Node, oir.Node]], + stacked: list[int], + n: oir.Node, + pn: oir.Node, + ) -> int | None: + """ + Matches backward. + + :param g: graph + :param node: root node (the node the matched begain with, + used only for debugging) + :param pat: pattern + :param marked: nodes of the pattern marked as already matched + :param stacked: next node to look into + :param n: node coming from the graph + :param pn: node coming from the pattern + :return: number of matched nodes, None or False to indicate a failed match + """ + res = 0 + + # predecessors + if len(n.input_names) != len(pn.input_names): + # not the same number of inputs + self._hint( + "BACKWARD: not the same number of inputs", + "-- pattern", + pn, + "-- model", + n, + ) + return self.none(node, inspect.currentframe().f_lineno) + for i, pi in zip(n.input_names, pn.input_names): + ppred = pat.node_before(pi) + if ppred is None: + # ppred is None means the pattern ends here. + continue + pred = g.node_before(i) + if pred is None: + # No node in the graph. + return self.none(node, inspect.currentframe().f_lineno) + if pred.op_type != ppred.op_type: + self._hint( + "BACKWARD: different node types", + "--pattern", + ppred, + "-- model", + pred, + ) + return self.none(node, inspect.currentframe().f_lineno) + # matching backward + key = id(ppred) + if key not in marked: + if self.verbose >= 10: + print(f"[GenericPattern._match_backward] {self.print_match(pred, ppred)}") + marked[key] = pred, ppred + stacked.append(key) + res += 1 + if self.verbose > 5 and res > 0: + print(f"[GenericPattern._match_backward] add {res} nodes") + return res + + def _match_forward( + self, + g: ModelWithGraphStructure, + node: oir.Node, + pat: ModelWithGraphStructure, + marked: dict[int, tuple[oir.Node, oir.Node]], + stacked: list[int], + n: oir.Node, + pn: oir.Node, + ) -> int | None: + """ + Matches forward. + + :param g: graph + :param node: root node (the node the matched begain with, + used only for debugging) + :param pat: pattern + :param marked: nodes of the pattern marked as already matched + :param stacked: next node to look into + :param n: node coming from the graph + :param ns: node coming from the pattern + :return: number of matched nodes to continue, None or False to indicate a failed match + """ + res = 0 + + # successors + if len(n.output_names) != len(pn.output_names): + # not the same number of outputs + self._hint( + "FORWARD: not the same number of output_names", + "-- pattern", + pn, + "-- model", + n, + ) + return self.none(node, inspect.currentframe().f_lineno) + + for o, op in zip(n.output_names, pn.output_names): + ns = g.next_nodes(o) + pns = pat.next_nodes(op) + if len(pns) == 0: + # The pattern has no node forward, the matching stops. + continue + if len(ns) < len(pns): + # Not enough node in the graph to match the pattern, + # the result is known. + return self.none(node, inspect.currentframe().f_lineno) + + # Here comes the fun part, there is the same number of successors or more + # nodes in the graph to match with the pattern. + # And we have to handle the nodes already marked as found. + # Hopefully, there is only one option. + + if len(ns) == len(pns) == 1: + # Let's deal with the simple case + if ns[0].op_type != pns[0].op_type: + return self.none(node, inspect.currentframe().f_lineno) + + key = id(pns[0]) + if key not in marked: + if self.verbose >= 10: + print( + f"[GenericPattern._match_forward]{self.print_match(ns[0], pns[0])}" + ) + marked[key] = ns[0], pns[0] + stacked.append(key) + res += 1 + continue + + # Let's remove the nodes already marked. + p_marked = [_ for _ in pns if id(_) not in marked] + id_marked = [id(marked[id(_)][0]) for _ in pns if id(_) in marked] + assert len(id_marked) + len(p_marked) == len(pns), ( + f"Unexpected, id_marked={id_marked}, " + f"id_p_marked={set(map(id, p_marked))}, " + f"pns_ids={set(map(id, pns))}, " + f"ns_ids={set(map(id, ns))}, o={o!r}, op={op!r}, " + f"n.op_type={n.op_type!r}, " + f"n.output={n.output}, np.output={pn.output}, " + f"ns_types={ {_.op_type for _ in ns} }, " + f"pns_types={ {_.op_type for _ in pns} }" + ) + free = [_ for _ in ns if id(_) not in id_marked] + if len(p_marked) == 0: + # Everything is already marked. + continue + if len(free) < len(p_marked): + # Not enough successors to match the remaining patterns. + return self.none(node, inspect.currentframe().f_lineno) + if len(p_marked) == len(free) == 1: + # Only one option again. + if p_marked[0].op_type != free[0].op_type: + return self.none(node, inspect.currentframe().f_lineno) + + key = id(p_marked[0]) + if key not in marked: + if self.verbose >= 10: + print( + f"[GenericPattern._match_forward] {self.print_match(free[0], p_marked[0])}" + ) + marked[key] = free[0], p_marked[0] + stacked.append(key) + res += 1 + continue + + # And now another fun part, let's try to handle the case when + # there is only one option, matching on node type only returns one + # option. + expected_op_type = [_.op_type for _ in p_marked] + got_op_type = [_.op_type for _ in free] + + ec = collections.Counter(expected_op_type) + gc = collections.Counter(got_op_type) + if len(ec) != len(gc) or set(ec) != set(gc): + # unique operator types is different. + self._hint( + "FORWARD: unique operator types are different", + "-- pattern", + ec, + pn, + "-- model", + gc, + n, + "-- model-marked", + id_marked, + ) + return self.none(node, inspect.currentframe().f_lineno) + for k, v in ec.items(): + if gc[k] < v: + # Not enough types to match. + return self.none(node, inspect.currentframe().f_lineno) + + # At this stage, we know matching the types is possible. + # We first mark whatever is possible. + ptype_to_node = {_.op_type: _ for _ in p_marked} + gtype_to_node = {_.op_type: _ for _ in got_op_type} + missing = [] + for k, v in ec.items(): + if gc[k] == v == 1: + key = id(ptype_to_node[k]) + if key not in marked: + if self.verbose >= 10: + print( + f"[GenericPattern._match_forward] match " + f"{self.print_match(gtype_to_node[k], ptype_to_node[k])}" + ) + marked[key] = gtype_to_node[k], ptype_to_node[k] + stacked.append(key) + res += 1 + else: + missing.append(k) + + if not missing: + continue + + # At this stage, there are mutiple options for matching. We can: + # 1. make assumptions and continue + # 2. mark the node as incomplete matching, we could end up stuck anyway. + raise AssertionError( + f"There are more than one option, this will be implemented later, " + f"ec={ec}, gc={gc}" + ) + if self.verbose > 5 and res > 0: + print(f"[GenericPattern._match_forward] add {res} nodes") + return res + + def match( + self, + g: ModelWithGraphStructure, + node: oir.Node, + ) -> list[oir.Node] | None: + self._debug = {} + + pat = self._get_match_pattern(g) + + # Let's match the last node. + # Then we need to match successors and predecessors. + p_node = pat.nodes[-1] # the last one + if node.op_type != p_node.op_type: + # The last node does not have the same type. + return self.none() + + check_ids = {id(n) for n in pat.nodes} + if self.verbose > 5: + print( + f"[GenericPattern.match] starts with " + f"{node.op_type}({', '.join(node.input_names)})" + ) + if self.verbose >= 10: + print("[GenericPattern.match] match pattern") + print(textwrap.indent(self.display_pattern(g, self.match_pattern), " ")) + + marked = {id(p_node): (node, p_node)} + stacked = [id(p_node)] + iteration = 0 + + if self.verbose > 5: + self._debug = dict( + pattern=pat, + marked=marked, + stacked=stacked, + iteration=iteration, + node=node, + pattern_node=p_node, + pattern_nodes=pat.nodes, + ) + + max_iter = len(pat.nodes) * 2 + while stacked and iteration < max_iter: + assert all(id(b[1]) in check_ids for b in marked.values()), ( + f"At least one id is not part of the pattern ids={check_ids}, " + f"marked={ {id(b[1]) for b in marked.values()} }" + ) + + iteration += 1 + if self.verbose > 5: + print( + f"[GenericPattern.match] iteration={iteration} " + f"n_marked={len(marked)}, n_stacked={len(stacked)}, " + f"marked_types={collections.Counter(_[1].op_type for _ in marked.values())}" + ) + idn = stacked.pop() + n, pn = marked[idn] + + res = self._match_backward(g, node, pat, marked, stacked, n, pn) + if res is None: + if self.verbose > 5: + print("[GenericPattern.match] done. backward failed.") + return res + + assert all(id(b[1]) in check_ids for b in marked.values()), ( + f"At least one id is not part of the pattern ids={check_ids}, " + f"marked={ {id(b[1]) for b in marked.values()} }" + ) + + res = self._match_forward(g, node, pat, marked, stacked, n, pn) + if res is None: + if self.verbose > 5: + print("[GenericPattern.match] done. forward failed.") + return res + + assert all(id(b[1]) in check_ids for b in marked.values()), ( + f"At least one id is not part of the pattern ids={check_ids}, " + f"marked={ {id(b[1]) for b in marked.values()} }" + ) + + if self.verbose > 5: + self._debug["iteration"] = iteration + + if iteration >= max_iter and stacked: + self._hint("reached {iteration}>={max_iter} iterations") + return self.none(node, inspect.currentframe().f_lineno) + + if self.verbose > 5: + print(f"[GenericPattern.match] done. {len(marked)} marked nodes") + + # At this point, the pattern is matched but let's make sure. + assert len(marked) == len(pat.nodes), ( + f"Number of marked nodes is different, {len(marked)} marked nodes, " + f"and {len(pat.nodes)} nodes in the pattern, marked is {marked}" + ) + assert len(stacked) == 0, f"There are still {len(stacked)} nodes to explore." + + # We order the matched nodes in the same order than the pattern + # to let next functions to be able to build the matching again. + matched_nodes = [marked[id(n)][0] for i, n in enumerate(pat.nodes)] + return matched_nodes + + @classmethod + def apply_pattern( + cls, + g: ModelWithGraphStructure, + *args: typing.Any, + **kwargs: typing.Any, + ) -> list[oir.Node]: + """Applies the replacement.""" + raise NotImplementedError( + f"Class {cls.__name__!r} must overwrite method 'apply_pattern'." + ) + + def apply( + self, + g: ModelWithGraphStructure, + *nodes: typing.Sequence[oir.Node], + ) -> list[oir.Node]: + assert all(isinstance(n, oir.Node) for n in nodes) + pat = self._build_pattern(g, self.match_pattern) + assert len(nodes) == len(pat.nodes), ( + f"Mismatch matched nodes pattern has {len(pat.nodes)} != {len(nodes)} = " + f"the number of matched nodes" + ) + new_pat = self._build_pattern(g, self.apply_pattern) + assert len(new_pat.input_names) == len(pat.input_names), ( + f"Not the same number of inputs, matched inputs={len(new_pat.input_names)}, " + f"got {len(pat.input_names)} in the applied pattern." + ) + assert len(new_pat.output_names) == len(pat.output_names), ( + f"Not the same number of outputs, matched outputs={pat.output_names}, " + f"got {new_pat.output_names} in the applied pattern." + ) + assert all(isinstance(n, oir.Node) for n in pat.nodes) + + if g.verbose > 5: + print( + f"[GenericPattern.apply] replace {len(nodes)} nodes, " + f"applied {self.display_pattern(g, self.apply_pattern)}" + ) + + matched_pattern_to_applied_pattern = {} + for i, j in zip(pat.input_names, new_pat.input_names): + matched_pattern_to_applied_pattern[i] = j + for i, j in zip(pat.output_names, new_pat.output_names): + matched_pattern_to_applied_pattern[i] = j + + matched_pattern_to_graph_name: dict = {} + input_names = set(pat.input_names) + output_names = set(pat.output_names) + + matched_pairs = list(zip(nodes, pat.nodes)) + for gn, pn in matched_pairs: + assert ( + gn.op_type == pn.op_type + ), f"Unexpected type mismatch {gn.op_type!r} != {pn.op_type!r}" + assert len(gn.input_names) == len( + pn.input_names + ), f"Unexpected number of inputs for type {gn.op_type}" + for a, b in zip(gn.input_names, pn.input_names): + if b not in input_names or b == "": + # optional input or not an interesting input + continue + if b in matched_pattern_to_graph_name: + assert matched_pattern_to_graph_name[b] == a, ( + f"Ambiguities, pattern name {b!r} means " + f"{a!r} or {matched_pattern_to_graph_name[b]}" + ) + else: + matched_pattern_to_graph_name[b] = a + + assert len(gn.output_names) == len( + pn.output_names + ), f"Unexpected number of outputs for type {gn.op_type}" + for a, b in zip(gn.output_names, pn.output_names): + if b not in output_names or b == "": + # Only final outputs are interesting. + continue + assert a != "", f"{a!r} cannot be optional" + if b in matched_pattern_to_graph_name: + assert matched_pattern_to_graph_name[b] == a, ( + f"Ambiguities, pattern name {b!r} means " + f"{a!r} or {matched_pattern_to_graph_name[b]}" + ) + else: + matched_pattern_to_graph_name[b] = a + + # TODO: handle initializers here + # for name, init in pattern.initializers.items(): + # # We add them to the graph, they will be removed if unused. + # new_name = g.make_initializer(name, init) + # replacements[new_name] = name + + replacements = {} + for k, v in matched_pattern_to_graph_name.items(): + replacements[matched_pattern_to_applied_pattern[k]] = v + + # Creation of the new node. + new_nodes = [] + for node in new_pat.nodes: + new_inputs = [] + for i in node.input_names: + assert i in replacements, f"Unable to find {i!r} in {replacements}" + ni = replacements[i] + new_inputs.append(ni) + new_outputs = [] + for o in node.output_names: + if o in replacements: + new_outputs.append(replacements[o]) + else: + # We give it a new name. + n = g.unique_name(o) + replacements[o] = n + new_outputs.append(n) + new_node = g.make_node(node.op_type, new_inputs, new_outputs, domain=node.domain) + new_node.attribute.extend(node.attribute) + new_nodes.append(oir.Node(new_node, True)) + + if g.verbose > 5: + print(f"[GenericPattern.apply] done with {len(new_nodes)} nodes") + + return new_nodes + + def make_rule(self) -> orp.RewriteRule: + """Creates the corresponding rule for this pattern.""" + return GenericRewriteRule(self) + + +class OnnxGenericPattern(GenericPattern): + """An instance of GenericPattern taking onnx model. + + It defines the matching pattern and its replacement. + + :param match_proto: the onnx function defining the matching pattern + :param apply_proto: the onnx function defining the new pattern + :param validate_mapping: the function used to validate a pattern + :param verbose: in [0, 10], increase the verbosity to understand why a pattern + does not match + """ + + def __init__( + self, + match_proto: onnx.FunctionProto, + apply_proto: onnx.FunctionProto, + validate_mapping: typing.Callable, + verbose: int = 0, + ): + super().__init__(verbose=verbose) + self.match_proto = match_proto + self._validate_mapping = validate_mapping + self.apply_proto = apply_proto + self._cache = {} + + def validate_mapping( + self, g: oir.Model, deleted_nodes: list[oir.Node], added_nodes: list[oir.Node] + ) -> bool: + """Evaluates the consistency of the replacements.""" + return self._validate_mapping(g, deleted_nodes, added_nodes) + + def _build_pattern( + self, g: ModelWithGraphStructure, fct: typing.Callable + ) -> BuilderWithGraphStructure: + if fct == self.match_pattern: + key = id(g), "match" + if key in self._cache: + return self._cache[key] + onx = self.match_proto + elif fct == self.apply_pattern: + key = id(g), "apply" + if key in self._cache: + return self._cache[key] + onx = self.apply_proto + else: + raise AssertionError( + f"Function {fct} is not {self.match_pattern} or {self.apply_pattern}." + ) + + g2 = g.make_opset() + for name in onx.input: + g2.make_input(name) + for node in onx.node: + g2.make_node_with_proto(node) + for name in onx.output: + g2.make_output(name) + g2._build() + self._cache[key] = g2 + return g2 + + +def make_pattern_rule( + match_pattern: typing.Callable, + apply_pattern: typing.Callable, + validate_mapping: typing.Callable | None = None, + verbose: int = 0, + opsets: dict[str, "onnxscript.Opset"] | None = None, # noqa: F821 +) -> orp.RewriteRule: + """ + Creates a rewriting rule. + + :param match_pattern: a function interpreted by onnx-script + and converted into an onnx model, this model defines the + nodes to be replaced + :param apply_pattern: a function interpreted by onnx-script and + converted into an onnx model, this model defines the new nodes + replacing the matched nodes + :param validate_mapping: a function validating the matching once + it has happened, it is not valid, the pattern is not applied, + if not specified, the function always return True + :param opsets: opset to consider when converting the function into ONNX, + if not specified, it is opset 18 for the main opset, and opset 1 + for domain com.microsoft. + :return: the rewriting rule + """ + import onnxscript + + if opsets is None: + opsets = dict( + op=onnxscript.opset18, msft_op=onnxscript.values.Opset("com.microsoft", 1) + ) + + if verbose > 5: + print(f"[make_pattern_rule] Converting {match_pattern} into ONNX.") + match = onnxscript.script(**opsets)(match_pattern).to_function_proto() + if verbose > 5: + print("[make_pattern_rule] done.") + print(f"[make_pattern_rule] Converting {apply_pattern} into ONNX.") + apply = onnxscript.script(**opsets)(apply_pattern).to_function_proto() + if verbose > 5: + print("[make_pattern_rule] done.") + + pat = OnnxGenericPattern( + match, + apply, + validate_mapping or (lambda *_, **__: True), + verbose=verbose, + ) + return pat.make_rule() diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py new file mode 100644 index 000000000..f4dd2496c --- /dev/null +++ b/onnxscript/rewriter/generic_pattern_test.py @@ -0,0 +1,501 @@ +from __future__ import annotations + +import contextlib +import io +import os +import time +import unittest + +import numpy as np +import onnx +import onnx.helper as oh +import onnx.numpy_helper as onh +from numpy.testing import assert_almost_equal +from onnx.reference import ReferenceEvaluator +from onnx.reference.op_run import OpRun + +import onnxscript._legacy_ir as oir +import onnxscript._legacy_ir.protobuilder as oip +import onnxscript.rewriter.generic_pattern as org + +TFLOAT = onnx.TensorProto.FLOAT + + +class GenericPatternTest(unittest.TestCase): + def test_bridge_model(self): + model = onnx.parser.parse_model( + """ + + agraph (float[2, 3, 5, 4] input_x, float[5] input_y, float[2, 3, 5] input_z) => (float[2, 4, 6] output) + { + shape_a = Constant() + reshape_x = Reshape (input_x, shape_a) + gemm = Gemm (reshape_x, input_y, input_z) + shape_d = Constant() + output = Reshape (gemm, shape_d) + } + """ + ) + org.ModelWithGraphStructure(oir.irbuilder.build_ir(model)) + + def _range(self, *shape, bias: float | None = None): + n = np.prod(shape) + x = np.arange(n).astype(np.float32) / n + if bias: + x = x + bias + return x.reshape(tuple(shape)).astype(np.float32) + + def test_graph_pattern_builder(self): + class AddAddPattern(org.GenericPattern): + """Replaces Add + Add by AddAdd.""" + + @classmethod + def match_pattern(cls, op: org.BuilderWithGraphStructure, x, y, z): + """Builds the pattern to match.""" + tmp = op.Add(x, y) + return op.Add(tmp, z) + + @classmethod + def apply_pattern(cls, op: org.BuilderWithGraphStructure, x, y, z): + """Builds the pattern to match.""" + return op.AddAdd(x, y, z, domain="ZZZ") + + def validate_mapping( + self, + g: oir.Model, + deleted_nodes: list[oir.Node], + added_nodes: list[oir.Node], + ) -> bool: + assert g + assert len(deleted_nodes) == 2 + assert len(added_nodes) == 1 + return True + + class AddAdd(OpRun): + op_domain = "ZZZ" + + def _run(self, x, y, z): + return (x + y + z,) + + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("Add", ["x", "y"], ["gggg"]), + oh.make_node("Add", ["gggg", "z"], ["final"]), + ], + "dummy", + [ + oh.make_tensor_value_info("x", TFLOAT, [None, None]), + oh.make_tensor_value_info("y", TFLOAT, [None, None]), + oh.make_tensor_value_info("z", TFLOAT, [None, None]), + ], + [oh.make_tensor_value_info("final", TFLOAT, [None, None])], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + onnx.checker.check_model(model) + + ir_model = oir.irbuilder.build_ir(model) + + pattern = AddAddPattern(verbose=0) + rule = pattern.make_rule() + rule.apply_to_model(ir_model) + self.assertEqual( + ["AddAdd"], + [n.op_type for n in ir_model.graph.nodes], + ) + # TODO: do that in pattern.py. + ir_model.version_map["ZZZ"] = 1 + + builder = oip.ModelProtoBuilder() + opt_onx = builder.visit_ir_model(ir_model) + + self.assertEqual( + ["AddAdd"], + [n.op_type for n in opt_onx.graph.node], + ) + + feeds = { + "x": self._range(5, 6), + "y": self._range(5, 6), + "z": self._range(5, 6), + } + ref1 = ReferenceEvaluator(model) + expected = ref1.run(None, feeds) + + self.assertEqual(0, len(opt_onx.graph.initializer)) + opsets = {v.domain: v.version for v in opt_onx.opset_import} + self.assertIn("ZZZ", opsets) + self.assertEqual(opsets["ZZZ"], 1) + + ref2 = ReferenceEvaluator(opt_onx, new_ops=[AddAdd]) + got = ref2.run(None, feeds) + assert_almost_equal(expected[0], got[0]) + + def test_graph_pattern_builder_multi_outputs(self): + class AddAddAddAddPattern(org.GenericPattern): + """Replaces ConstantOfShape + ScatterND with ScatterNDOfShape (com.domain).""" + + @classmethod + def match_pattern(cls, op, x, y, w, z): + """Builds the pattern to match.""" + tmp = op.Add(x, y) + tmp2 = op.Add(tmp, w) + r1 = op.Add(tmp, z) + return tmp2, r1 + + @classmethod + def apply_pattern(cls, op, x, y, w, z): + """Builds the pattern to match.""" + return op.AddAddAddAdd(x, y, w, z, domain="ZZZ", output_names=2) + + def validate_mapping( + self, + g: oir.Model, + deleted_nodes: list[oir.Node], + added_nodes: list[oir.Node], + ) -> bool: + assert g + assert len(deleted_nodes) == 3 + assert len(added_nodes) == 1 + return True + + class AddAddAddAdd(OpRun): + op_domain = "ZZZ" + + def _run(self, x, y, w, z): + return (x + y + w, x + y + z) + + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("Add", ["x", "y"], ["gggg"]), + oh.make_node("Add", ["gggg", "w"], ["f1"]), + oh.make_node("Add", ["gggg", "z"], ["f2"]), + ], + "dummy", + [ + oh.make_tensor_value_info("x", TFLOAT, [None, None]), + oh.make_tensor_value_info("y", TFLOAT, [None, None]), + oh.make_tensor_value_info("z", TFLOAT, [None, None]), + oh.make_tensor_value_info("w", TFLOAT, [None, None]), + ], + [ + oh.make_tensor_value_info("f1", TFLOAT, [None, None]), + oh.make_tensor_value_info("f2", TFLOAT, [None, None]), + ], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + onnx.checker.check_model(model) + + ir_model = oir.irbuilder.build_ir(model) + + pattern = AddAddAddAddPattern(verbose=0) + rule = pattern.make_rule() + rule.apply_to_model(ir_model) + self.assertEqual( + ["AddAddAddAdd"], + [n.op_type for n in ir_model.graph.nodes], + ) + # TODO: do that in pattern.py. + ir_model.version_map["ZZZ"] = 1 + + builder = oip.ModelProtoBuilder() + opt_onx = builder.visit_ir_model(ir_model) + + self.assertEqual( + ["AddAddAddAdd"], + [n.op_type for n in opt_onx.graph.node], + ) + + feeds = { + "x": self._range(5, 6), + "y": self._range(5, 6), + "w": self._range(5, 6), + "z": self._range(5, 6), + } + ref1 = ReferenceEvaluator(model) + expected = ref1.run(None, feeds) + + self.assertEqual(0, len(opt_onx.graph.initializer)) + opsets = {v.domain: v.version for v in opt_onx.opset_import} + self.assertIn("ZZZ", opsets) + self.assertEqual(opsets["ZZZ"], 1) + + ref2 = ReferenceEvaluator(opt_onx, new_ops=[AddAddAddAdd]) + got = ref2.run(None, feeds) + assert_almost_equal(expected[0], got[0]) + + def check_with_ort(self, model: onnx.ModelProto, providers=None): + import onnxruntime + + if hasattr(onnxruntime, "rewrite"): + raise unittest.SkipTest( + "cannot check with onnxruntime because of a subfolder called onnxruntime." + ) + + if providers is None: + providers = ["CPUExecutionProvider"] + + if isinstance(model, onnx.ModelProto): + model = model.SerializeToString() + sess = onnxruntime.InferenceSession(model, providers=providers) + return sess + + def get_rotary_model(self): + inputs = [ + oh.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]), + oh.make_tensor_value_info("pos_ids", onnx.TensorProto.FLOAT, shape=[]), + oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]), + ] + nodes = [ + oh.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]), + oh.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1), + oh.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]), + oh.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]), + oh.make_node( + "ConcatTraining", + ["_onx_transpose0", "_onx_transpose0"], + ["_onx_concattraining0", "_onx_concattraining1"], + domain="com.microsoft", + ), + oh.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]), + oh.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1), + oh.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]), + oh.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1), + ] + outputs = [ + oh.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []), + oh.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []), + ] + model = oh.make_model( + oh.make_graph( + nodes, + "experiment", + inputs, + outputs, + ), + opset_imports=[ + oh.make_opsetid("", 18), + oh.make_opsetid("com.microsoft", 18), + ], + ) + return model + + def test_rotary_embedding(self): + # The test work on a model if it has the expected name. + # A dummy model is used if not present (not implemented yet). + + class RotaryEmbeddingPattern(org.GenericPattern): + """Fusion for Rotary.""" + + @classmethod + def match_pattern(cls, op, x, pos_ids, axis): + # original code: the code does verifies the constant yet + # unsqueeze = op.Unsqueeze(x, [1]) + + unsqueeze = op.Unsqueeze(x, axis) + cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT) + + matmul = op.MatMul(pos_ids, cast) + transpose = op.Transpose(matmul) + output, length = op.ConcatTraining( + transpose, + transpose, + domain="com.microsoft", + output_names=2, + ) + + sin = op.Sin(output) + cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT) + cos = op.Cos(output) + cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT) + return cast1, cast2 + + def validate_mapping(self, g, deleted_nodes, added_nodes) -> bool: + # If some pattern needs to be rejected. + return True + + @classmethod + def apply_pattern(cls, op, x, pos_ids, axis): + del axis + cos_cache = op.Constant( + value=onh.from_array(np.random.rand(256, 256).astype(np.float16)) + ) + sin_cache = op.Constant( + value=onh.from_array(np.random.rand(256, 256).astype(np.float16)) + ) + return op.RotaryEmbedding( + x, + pos_ids, + cos_cache, + sin_cache, + domain="com.microsoft", + output_names=2, + ) + + model = self.get_rotary_model() + + buffer = io.StringIO() + with contextlib.redirect_stdout(buffer): + # back to ir + ir_model = oir.irbuilder.build_ir(model) + + # starts matching + pattern = RotaryEmbeddingPattern(verbose=10) + rule = pattern.make_rule() + rule.apply_to_model(ir_model) + ir_model.version_map["com.microsoft"] = 1 + + builder = oip.ModelProtoBuilder() + opt_onx = builder.visit_ir_model(ir_model) + + expected = ["Constant", "Constant", "RotaryEmbedding"] + self.assertEqual(expected, [n.op_type for n in opt_onx.graph.node]) + out = buffer.getvalue() + self.assertIn("[GenericPattern.match", out) + + def test_rotary_embedding_onnxscript(self): + # The test work on a model if it has the expected name. + # A dummy model is used if not present (not implemented yet). + import onnxscript + + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + def rotary_match_pattern(x, pos_ids, axis): + unsqueeze = op.Unsqueeze(x, axis) + cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT) + + matmul = op.MatMul(pos_ids, cast) + transpose = op.Transpose(matmul) + output, length = msft_op.ConcatTraining(transpose, transpose) + + sin = op.Sin(output) + cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT) + cos = op.Cos(output) + cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT) + return cast1, cast2 + + def validate_rotary_mapping(g, deleted_nodes, added_nodes) -> bool: + # If some pattern needs to be rejected. + return True + + def rotary_apply_pattern(x, pos_ids, axis): + cos_cache = op.Constant( + value=onh.from_array(np.random.rand(256, 256).astype(np.float16)) + ) + sin_cache = op.Constant( + value=onh.from_array(np.random.rand(256, 256).astype(np.float16)) + ) + part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache) + return part1, part2 + + model = self.get_rotary_model() + + buffer = io.StringIO() + with contextlib.redirect_stdout(buffer): + # back to ir + ir_model = oir.irbuilder.build_ir(model) + + # starts matching + rule = org.make_pattern_rule( + rotary_match_pattern, + rotary_apply_pattern, + validate_rotary_mapping, + verbose=10, + ) + + rule.apply_to_model(ir_model) + ir_model.version_map["com.microsoft"] = 1 + + builder = oip.ModelProtoBuilder() + opt_onx = builder.visit_ir_model(ir_model) + + expected = ["Constant", "Constant", "RotaryEmbedding"] + self.assertEqual(expected, [n.op_type for n in opt_onx.graph.node]) + out = buffer.getvalue() + self.assertIn("[GenericPattern.match", out) + + def test_rotary_emb_file_onnxscript(self): + # The test work on a model if it has the expected name. + # A dummy model is used if not present (not implemented yet). + import onnxscript + + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + def rotary_match_pattern(x, pos_ids, axis): + unsqueeze = op.Unsqueeze(x, axis) + cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT) + + matmul = op.MatMul(pos_ids, cast) + transpose = op.Transpose(matmul) + output, length = msft_op.ConcatTraining(transpose, transpose) + + sin = op.Sin(output) + cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT) + cos = op.Cos(output) + cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT) + return cast1, cast2 + + def validate_rotary_mapping(g, deleted_nodes, added_nodes) -> bool: + # If some pattern needs to be rejected. + return True + + def rotary_apply_pattern(x, pos_ids, axis): + cos_cache = op.Constant( + value=onh.from_array(np.random.rand(256, 256).astype(np.float16)) + ) + sin_cache = op.Constant( + value=onh.from_array(np.random.rand(256, 256).astype(np.float16)) + ) + part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache) + return part1, part2 + + model = "gemma_optimized_pre_grad_training_2.onnx" + if not os.path.exists(model): + raise unittest.SkipTest(f"{model!r} is missing") + + begin = time.perf_counter() + onx = onnx.load(model) + ir_model = oir.irbuilder.build_ir(onx) + if __name__ == "__main__": + print(f"Loading done in {time.perf_counter() - begin}s") + + begin = time.perf_counter() + rule = org.make_pattern_rule( + rotary_match_pattern, + rotary_apply_pattern, + validate_rotary_mapping, + verbose=10, + ) + + rule.apply_to_model(ir_model) + + if __name__ == "__main__": + print(f"Matching done in {time.perf_counter() - begin}s") + + # TODO: do that in pattern.py. + ir_model.version_map["ZZZ"] = 1 + + begin = time.perf_counter() + builder = oip.ModelProtoBuilder() + opt_onx = builder.visit_ir_model(ir_model) + if __name__ == "__main__": + print(f"Building done in {time.perf_counter() - begin}s") + + begin = time.perf_counter() + buffer = opt_onx.SerializeToString() + with open(f"{model}.opt.onnx", "wb") as f: + f.write(buffer) + if __name__ == "__main__": + print(f"Saving done in {time.perf_counter() - begin}s") + self.check_with_ort(opt_onx) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/no_op.py new file mode 100644 index 000000000..0a149ad96 --- /dev/null +++ b/onnxscript/rewriter/no_op.py @@ -0,0 +1,44 @@ +from onnxscript.rewriter import pattern + +op = pattern.onnxop + +# TODO: Support 1-D constant tensors +# https://github.com/microsoft/onnx-rewriter/issues/186 + + +# Pattern to match against +def mul_by_1(x): + return x * 1 + + +def add_0(x): + return x + 0 + + +def sub_0(x): + return x - 0 + + +def div_by_1(x): + return x / 1 + + +# Replacement +def identity(x): + return op.Identity(x) + + +mul_by_1_rule = pattern.RewriteRule(mul_by_1, identity) +add_0_rule = pattern.RewriteRule(add_0, identity) +sub_0_rule = pattern.RewriteRule(sub_0, identity) +div_by_1_rule = pattern.RewriteRule(div_by_1, identity) +# TODO: Include Mul by 0, 0 by Mul, 0 by Div? Those would be 0s, but not no-ops + +rules = pattern.RewriteRuleSet( + [ + *mul_by_1_rule.commute(), + *add_0_rule.commute(), + sub_0_rule, + div_by_1_rule, + ] +) diff --git a/onnxscript/rewriter/no_op_test.py b/onnxscript/rewriter/no_op_test.py new file mode 100644 index 000000000..e38d8b7c6 --- /dev/null +++ b/onnxscript/rewriter/no_op_test.py @@ -0,0 +1,180 @@ +import unittest + +import onnx.parser +import parameterized + +from onnxscript._legacy_ir import irbuilder +from onnxscript.rewriter import no_op + + +class NoOpTest(unittest.TestCase): + def _check(self, model_text: str) -> None: + model = onnx.parser.parse_model(model_text) + ir = irbuilder.build_ir(model) + count = no_op.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(ir.graph.nodes[-1].op_type, "Identity") + + @parameterized.parameterized.expand( + [ + ("float one input", "float[M]", "value_float=1.0", "one, input"), + ("int one input", "int32[M]", "value_int=1", "one, input"), + ("float input one", "float[M]", "value_float=1.0", "input, one"), + ("int input one", "int32[M]", "value_int=1", "input, one"), + ] + ) + def test_mul_one_should_become_no_op(self, _, dtype, constant_value, input_order): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + {{ + one = Constant<{constant_value}>() + output = Mul({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float one input", "float[M]", "float one = {1.0}", "one, input"), + ("int one input", "int32[M]", "int32 one = {1}", "one, input"), + ("float input one", "float[M]", "float one = {1.0}", "input, one"), + ("int input one", "int32[M]", "int32 one = {1}", "input, one"), + ] + ) + def test_mul_one_should_become_no_op_initializer( + self, _, dtype, constant_value, input_order + ): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + <{constant_value}> + {{ + output = Mul({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float zero input", "float[M]", "value_float=0.0", "zero, input"), + ("int zero input", "int32[M]", "value_int=0", "zero, input"), + ("float input zero", "float[M]", "value_float=0.0", "input, zero"), + ("int input zero", "int32[M]", "value_int=0", "input, zero"), + ] + ) + def test_add_zero_should_become_no_op(self, _, dtype, constant_value, input_order): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + {{ + zero = Constant<{constant_value}>() + output = Add({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float input zero", "float[M]", "float zero = {0.0}", "input, zero"), + ("int input zero", "int32[M]", "int32 zero = {0}", "input, zero"), + ("float input zero", "float[M]", "float zero = {0.0}", "input, zero"), + ("int input zero", "int32[M]", "int32 zero = {0}", "input, zero"), + ] + ) + def test_add_zero_should_become_no_op_initializer( + self, _, dtype, constant_value, input_order + ): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + <{constant_value}> + {{ + output = Add({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float input zero", "float[M]", "value_float=0.0", "input, zero"), + ("int input zero", "int32[M]", "value_int=0", "input, zero"), + ] + ) + def test_sub_zero_should_become_no_op(self, _, dtype, constant_value, input_order): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + {{ + zero = Constant<{constant_value}>() + output = Sub({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float input zero", "float[M]", "float zero = {0.0}", "input, zero"), + ("int input zero", "int32[M]", "int32 zero = {0}", "input, zero"), + ] + ) + def test_sub_zero_should_become_no_op_initializer( + self, _, dtype, constant_value, input_order + ): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + <{constant_value}> + {{ + output = Sub({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float input one", "float[M]", "value_float=1.0", "input, one"), + ("int input one", "int32[M]", "value_int=1", "input, one"), + ] + ) + def test_div_one_should_become_no_op(self, _, dtype, constant_value, input_order): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + {{ + one = Constant<{constant_value}>() + output = Div({input_order}) + }} + """ + ) + + @parameterized.parameterized.expand( + [ + ("float input one", "float[M]", "float one = {1.0}", "input, one"), + ("int input one", "int32[M]", "int32 one = {1}", "input, one"), + ] + ) + def test_div_one_should_become_no_op_with_initializer( + self, _, dtype, constant_value, input_order + ): + self._check( + f""" + + agraph ({dtype} input) => ({dtype} output) + <{constant_value}> + {{ + output = Div({input_order}) + }} + """ + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/__init__.py b/onnxscript/rewriter/onnxruntime/__init__.py new file mode 100644 index 000000000..0e6eb613a --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/__init__.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import onnx + +from onnxscript._legacy_ir import irbuilder, protobuilder +from onnxscript.optimizer import remove_unused, remove_unused_function +from onnxscript.rewriter import function_rule, pattern +from onnxscript.rewriter.onnxruntime import ( + group_normalization_merge_silu, + instance_to_group_normalization, + softmax, + transformers, +) + +ORT_FUNCTION_REWRITE_RULES = [*transformers.TRANSFORMERS_FUNCTION_REWRITE_RULES] + +ORT_PATTERN_REWRITE_RULES = [ + *softmax.rules.rules, + *instance_to_group_normalization.rules.rules, + # NOTE: group normalization merge silu should be applied after instance to group normalization + *group_normalization_merge_silu.rules.rules, +] + + +def rewrite( + model: onnx.ModelProto, + function_rules: list[type[function_rule.FunctionRewriteRule]] | None = None, + pattern_rules: list[pattern.RewriteRule] | None = None, +) -> onnx.ModelProto: + """Rewrite the model using the given rules. + + Args: + model: The model to rewrite. + function_rules: The function rewrite rules to apply. If None, the default rules + for onnxruntime are used. + pattern_rules: The pattern rewrite rules to apply. If None, the default rules + for onnxruntime are used. + + Returns: + The rewritten model. + """ + function_rules = function_rules or ORT_FUNCTION_REWRITE_RULES + pattern_rules = pattern_rules or ORT_PATTERN_REWRITE_RULES + # TODO: Function rules first, or pattern rules first? + if function_rules: + model_ir = irbuilder.build_ir(model) + for rule_cls in function_rules: + rule_cls().apply_to_model(model_ir) + model = model_ir.original_model_proto + if pattern_rules: + model_ir = irbuilder.build_ir(model) + count = pattern.RewriteRuleSet(pattern_rules).apply_to_model(model_ir) + print(f"Applied {count} pattern rewrite rules.") + model = protobuilder.build_model_proto(model_ir) + # TODO: Does it make more sense we run DCE after each rewrite rule applied? + # If so, we need IR to support DCE. + remove_unused.remove_unused_nodes(model) + remove_unused_function.remove_unused_functions(model) + return model diff --git a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py new file mode 100644 index 000000000..a6dfb54eb --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import logging + +from onnxscript.rewriter import pattern + +op = pattern.onnxop +msft_op = pattern.msft_op +torch_module_op = pattern.torch_module_op + +logger = logging.getLogger(__name__) + + +def group_normalization_and_silu_submodule( + input, + weight, + bias, + epsilon, + groups, +): + group_norm = msft_op.GroupNorm( + input, + weight, + bias, + activation=0, + channels_last=1, + epsilon=epsilon, + groups=groups, + ) + transposed = op.Transpose(group_norm, perm=[0, 3, 1, 2]) + return torch_module_op.submodule("torch_nn_modules_activation_SiLU")(transposed) + + +def group_normalization_with_silu( + input, + weight, + bias, + epsilon, + groups, +): + group_norm = msft_op.GroupNorm( + input, + weight, + bias, + activation=1, + channels_last=1, + epsilon=epsilon, + groups=groups, + ) + return op.Transpose(group_norm, perm=[0, 3, 1, 2]) + + +group_normalization_merge_silu_submodule_rule = pattern.RewriteRule( + group_normalization_and_silu_submodule, + group_normalization_with_silu, +) + +rules = pattern.RewriteRuleSet([group_normalization_merge_silu_submodule_rule]) diff --git a/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py new file mode 100644 index 000000000..254e526d4 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/group_normalization_merge_silu_test.py @@ -0,0 +1,125 @@ +import unittest + +import numpy as np +import onnx.parser + +from onnxscript._legacy_ir import irbuilder +from onnxscript.rewriter.onnxruntime import ( + group_normalization_merge_silu, + instance_to_group_normalization, +) + + +class ReplaceInstanceNormWithGroupNormTest(unittest.TestCase): + def test_group_norm_with_silu_submodule_is_replaced_by_group_norm(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + group_norm = com.microsoft.GroupNorm (image, weight, bias) + transposed = Transpose (group_norm) + output = pkg.torch230a0git77ef9d4.torch_nn_modules_activation_SiLU_time_embedding_act_19 (transposed) + } + + torch_nn_modules_activation_SiLU_time_embedding_act_19 (transposed) => (output) + { + _to_copy_38 = Cast (transposed) + sigmoid_18 = Sigmoid (_to_copy_38) + mul_26 = Mul (_to_copy_38, sigmoid_18) + output = Cast (mul_26) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_value = np.random.rand(320, 1, 1).astype(np.float16) + model.graph.initializer.extend( + [ + onnx.helper.make_tensor( + "weight", + onnx.TensorProto.FLOAT16, + weight_value.shape, + weight_value, + ), + onnx.helper.make_tensor( + "bias", + onnx.TensorProto.FLOAT16, + bias_value.shape, + bias_value, + ), + ] + ) + + ir = irbuilder.build_ir(model) + count = group_normalization_merge_silu.rules.apply_to_model(ir) + self.assertEqual(count, 1) + # plus 2 in model constants + self.assertEqual(len(ir.graph.nodes), 2) + + def test_simulated_instance_norm_is_replaced_by_group_norm_silu(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + add_output = Add (mul_output, bias_full) + output = pkg.torch230a0git77ef9d4.torch_nn_modules_activation_SiLU_time_embedding_act_19 (add_output) + } + + torch_nn_modules_activation_SiLU_time_embedding_act_19 (add_output) => (output) + { + _to_copy_38 = Cast (add_output) + sigmoid_18 = Sigmoid (_to_copy_38) + mul_26 = Mul (_to_copy_38, sigmoid_18) + output = Cast (mul_26) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + + model.graph.initializer.extend( + [ + onnx.helper.make_tensor( + "weight_for_norm", + onnx.TensorProto.FLOAT16, + weight_for_norm_value.shape, + weight_for_norm_value, + ), + onnx.helper.make_tensor( + "bias_for_norm", + onnx.TensorProto.FLOAT16, + bias_for_norm_value.shape, + bias_for_norm_value, + ), + onnx.helper.make_tensor( + "weight_full", + onnx.TensorProto.FLOAT16, + weight_full_value.shape, + weight_full_value, + ), + onnx.helper.make_tensor( + "bias_full", + onnx.TensorProto.FLOAT16, + bias_full_value.shape, + bias_full_value, + ), + ] + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + count += group_normalization_merge_silu.rules.apply_to_model(ir) + self.assertEqual(count, 2) + # plus 2 in model constants + self.assertEqual(len(ir.graph.nodes), 10) diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py new file mode 100644 index 000000000..0f6e76685 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np +import onnx + +import onnxscript._legacy_ir as ir +from onnxscript.rewriter import pattern + +op = pattern.onnxop +msft_op = pattern.msft_op +torch_module_op = pattern.torch_module_op + +logger = logging.getLogger(__name__) + + +def _check_if_simulated_instance_norm_is_used_impl( + input_x, + adjusted_input_shape, + original_input_shape, + weight_for_norm, + bias_for_norm, + weight_full, + bias_full, + **kwargs, +) -> bool: + if not np.all(weight_for_norm.value_as_np_array == 1): + return False + if not np.all(bias_for_norm.value_as_np_array == 0): + return False + + input_rank_minus_one = len(input_x.shape) - 1 + weight_full_rank = len(weight_full.shape) + bias_full_rank = len(bias_full.shape) + if weight_full_rank != input_rank_minus_one or bias_full_rank != input_rank_minus_one: + return False + + input_rank = len(input_x.shape) + if input_rank != 4: + return False + + weight_full_shape = weight_full.shape + if not all(dim == 1 for dim in weight_full_shape[1:]): + return False + bias_full_shape = bias_full.shape + if not all(dim == 1 for dim in bias_full_shape[1:]): + return False + + adjusted_input_shape = adjusted_input_shape.value_as_np_array + g = weight_for_norm.shape[0] + if adjusted_input_shape is None or adjusted_input_shape.tolist() != [0, g, -1]: + return False + + # NOTE: Restrict the rule to only support constant shape + original_input_shape = original_input_shape.value_as_np_array + if original_input_shape is None or original_input_shape.tolist() != input_x.shape: + return False + + return True + + +def check_if_simulated_instance_norm_is_used( + match_bindings: dict[str, ir.Value | Any], +) -> bool: + """Check if the simulated instance normalization is used. + + In torchlib with opset18, onnx.GroupNorm is using wrong definition, so + we use InstanceNormalization to simulate GroupNormalization. We need to check if there are arguments created to simulation. + If there are, then we need to replace the pattern. If they are not used, then we don't need to replace the pattern. + + To validate this, we need to check the following: + 1. weight_for_norm are all 1 and bias_for_norm are all 0, as they are created for the simulation. + 2. weight_full and bias_full are unsqueezed to be easily broadcastable. + 3. input rank should be 4 + 4. weight_full and bias_full should have ones except first dim. + 5. adjusted_input_shape is a constant tensor of form [0, g, -1] + 6. original_input_shape is the same as input_x shape. + + Args: + match_bindings: The match binding dictionary from a MatchResult. + + Returns: + bool: True if the simulated instance normalization is used, False otherwise. + """ + return _check_if_simulated_instance_norm_is_used_impl(**match_bindings) + + +def instance_simulates_group_normalization_pattern( + input_x, + adjusted_input_shape, + original_input_shape, + weight_for_norm, + bias_for_norm, + weight_full, + bias_full, + epsilon, + match_bindings: dict[str, ir.Value | Any] | None = None, +): + adjusted_input = op.Reshape(input_x, adjusted_input_shape) + inst_norm = op.InstanceNormalization( + adjusted_input, weight_for_norm, bias_for_norm, epsilon=epsilon + ) + adjusted_inst_norm = op.Reshape(inst_norm, original_input_shape) + mul = op.Mul(adjusted_inst_norm, weight_full) + return op.Add(mul, bias_full) + + +def group_normalization( + input_x, + adjusted_input_shape, + original_input_shape, + weight_for_norm, + bias_for_norm, + weight_full, + bias_full, + epsilon, + match_bindings: dict[str, ir.Value | Any] | None = None, +): + # com.microsoft.GroupNorm only supports NHWC for now + nhwc_input = op.Transpose(input_x, perm=[0, 2, 3, 1]) + # com.microsoft.GroupNorm only supports gamma and beta as float type + weight_full = op.Cast(weight_full, to=onnx.TensorProto.FLOAT) + reshape_to_1d = op.Constant(value_ints=[-1]) + weight_full = op.Reshape(weight_full, reshape_to_1d) + bias_full = op.Cast(bias_full, to=onnx.TensorProto.FLOAT) + bias_full = op.Reshape(bias_full, reshape_to_1d) + # re-obtain attribute groups + groups = match_bindings["weight_for_norm"].shape[0] + output = msft_op.GroupNorm( + nhwc_input, + weight_full, + bias_full, + activation=0, + channels_last=1, + epsilon=epsilon, + groups=groups, + ) + return op.Transpose(output, perm=[0, 3, 1, 2]) + + +# Register the rewrite rules +instance_norm_to_group_norm_rule = pattern.RewriteRule( + instance_simulates_group_normalization_pattern, + pattern.ReplacementPatternFunction(group_normalization, delay_run=True), + check_if_simulated_instance_norm_is_used, +) + +# NOTE: instance_norm_to_group_norm_rule is subset of instance_norm_to_group_norm_with_silu_rule, +# so we need to run instance_norm_to_group_norm_with_silu_rule first. +rules = pattern.RewriteRuleSet([instance_norm_to_group_norm_rule]) diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py new file mode 100644 index 000000000..67ae0554f --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py @@ -0,0 +1,435 @@ +import unittest + +import numpy as np +import onnx.parser + +from onnxscript._legacy_ir import irbuilder +from onnxscript.rewriter.onnxruntime import instance_to_group_normalization + + +class ReplaceInstanceNormWithGroupNormTest(unittest.TestCase): + def _set_up_model_initializers( + self, + model, + weight_for_norm_value, + weight_for_norm_shape, + bias_for_norm_value, + bias_for_norm_shape, + weight_full_value, + weight_full_shape, + bias_full_value, + bias_full_shape, + ): + """Set up the model initializers for the test.""" + model.graph.initializer.extend( + [ + onnx.helper.make_tensor( + "weight_for_norm", + onnx.TensorProto.FLOAT16, + weight_for_norm_shape, + weight_for_norm_value, + ), + onnx.helper.make_tensor( + "bias_for_norm", + onnx.TensorProto.FLOAT16, + bias_for_norm_shape, + bias_for_norm_value, + ), + onnx.helper.make_tensor( + "weight_full", + onnx.TensorProto.FLOAT16, + weight_full_shape, + weight_full_value, + ), + onnx.helper.make_tensor( + "bias_full", + onnx.TensorProto.FLOAT16, + bias_full_shape, + bias_full_value, + ), + ] + ) + + def test_simulated_instance_norm_is_replaced_by_group_norm(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 1) + # plus 2 in model constants + self.assertEqual(len(ir.graph.nodes), 10) + + def test_instance_norm_with_non_one_weight_for_norm_should_remain(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.random.rand(32).astype(np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_non_zero_b_should_remain(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.random.rand(32).astype(np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_non_broadcasted_weight_full_should_remain(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_non_broadcasted_bias_full_should_remain(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_rank_not_4_should_remain(self): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_weight_full_having_multiple_not_one_dim_should_remain( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 2, 3).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 2, 3], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_bias_full_having_multiple_not_one_dim_should_remain( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 2, 3).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 2, 3], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_not_0_g_negative_1_shape_of_adjusted_input_shape_should_remain( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + def test_instance_norm_with_non_equal_of_image_shape_and_original_input_shape_should_remain( + self, + ): + model = onnx.parser.parse_model( + """ + + agraph (float[1, 320, 128, 128] image) => (float[1, 4, 512, 64] output) + { + adjusted_input_shape = Constant() + image_reshape = Reshape (image, adjusted_input_shape) + instance_norm = InstanceNormalization (image_reshape, weight_for_norm, bias_for_norm) + original_input_shape = Constant() + instance_norm_reshape = Reshape (instance_norm, original_input_shape) + mul_output = Mul (instance_norm_reshape, weight_full) + output = Add (mul_output, bias_full) + } + """ + ) + # Use inserted initializers to avoid manually coding the large constants + weight_full_value = np.random.rand(320, 1, 1).astype(np.float16) + bias_full_value = np.random.rand(320, 1, 1).astype(np.float16) + weight_for_norm_value = np.ones(32, dtype=np.float16) + bias_for_norm_value = np.zeros(32, dtype=np.float16) + self._set_up_model_initializers( + model, + weight_for_norm_value, + [32], + bias_for_norm_value, + [32], + weight_full_value, + [320, 1, 1], + bias_full_value, + [320, 1, 1], + ) + + ir = irbuilder.build_ir(model) + count = instance_to_group_normalization.rules.apply_to_model(ir) + self.assertEqual(count, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/softmax.py b/onnxscript/rewriter/onnxruntime/softmax.py new file mode 100644 index 000000000..1a70d12e2 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/softmax.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import logging +from typing import Any + +import onnx + +import onnxscript._legacy_ir as ir +from onnxscript.rewriter import pattern + +op = pattern.onnxop +logger = logging.getLogger(__name__) + + +def softmax_with_fp32_upcast(input, axis): + upcast = op.Cast(input, to=onnx.TensorProto.FLOAT) + softmax = op.Softmax(upcast, axis=axis) # pylint: disable=redefined-outer-name + return op.Cast(softmax, to=onnx.TensorProto.FLOAT16) + + +def softmax(input, axis): + return op.Softmax(input, axis=axis) + + +def softmax_with_fp32_upcast_without_axis(input): + upcast = op.Cast(input, to=onnx.TensorProto.FLOAT) + softmax = op.Softmax(upcast) # pylint: disable=redefined-outer-name + return op.Cast(softmax, to=onnx.TensorProto.FLOAT16) + + +def softmax_without_axis(input): + return op.Softmax(input) + + +def check_if_fp16_input(match_bindings: dict[str, ir.Value | Any]) -> bool: + input_val = match_bindings.get("input") + if input_val is None: + logger.warning( + "Cannot perform softmax upcast removal: " + "cannot retrieve match_bindings for 'input' for dtype validation." + ) + return False + return input_val.element_type == onnx.TensorProto.FLOAT16 + + +# pylint: disable=pointless-string-statement +""" +This is an onnxruntime specific pattern. Softmax upcast is a common +pattern observed in transformers models to prevent overflow. However +this is not required since onnxruntime implementation already takes +overflow into account. Hence it is safe to remove the surrounding casts +to free up memory as well as saving performance. +""" +# pylint: enable=pointless-string-statement +rules = pattern.RewriteRuleSet( + [ + pattern.RewriteRule(softmax_with_fp32_upcast, softmax, check_if_fp16_input), + pattern.RewriteRule( + softmax_with_fp32_upcast_without_axis, + softmax_without_axis, + check_if_fp16_input, + ), + ] +) diff --git a/onnxscript/rewriter/onnxruntime/softmax_test.py b/onnxscript/rewriter/onnxruntime/softmax_test.py new file mode 100644 index 000000000..507c38c14 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/softmax_test.py @@ -0,0 +1,92 @@ +import unittest + +import onnx.parser +import parameterized + +from onnxscript._legacy_ir import irbuilder +from onnxscript.rewriter.onnxruntime import softmax + + +class SoftmaxUpcastRemovalTest(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ("Softmax",), + ("Softmax",), + ] + ) + def test_softmax_upcast_to_fp32_is_removed_when_input_and_final_output_is_fp16( + self, softmax_op_str + ): + model = onnx.parser.parse_model( + f""" + + agraph (float16[N] x) => (float16[N] z) + {{ + x_fp32 = Cast(x) + z_fp32 = {softmax_op_str}(x_fp32) + z = Cast(z_fp32) + }} + """ + ) + ir = irbuilder.build_ir(model) + count = softmax.rules.apply_to_model(ir) + self.assertEqual(count, 1) + self.assertNotIn("Cast", {node.op_type for node in ir.graph.nodes}) + + @parameterized.parameterized.expand( + [ + ("Softmax",), + ("Softmax",), + ] + ) + def test_softmax_upcast_to_fp32_is_not_removed_when_input_is_not_fp16( + self, softmax_op_str + ): + model = onnx.parser.parse_model( + f""" + + agraph (int32[N] x) => (float16[N] z) + {{ + x_fp32 = Cast(x) + z_fp32 = {softmax_op_str}(x_fp32) + z = Cast(z_fp32) + }} + """ + ) + ir = irbuilder.build_ir(model) + count = softmax.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual( + len([node.op_type for node in ir.graph.nodes if node.op_type == "Cast"]), 2 + ) + + @parameterized.parameterized.expand( + [ + ("Softmax",), + ("Softmax",), + ] + ) + def test_softmax_upcast_to_fp32_is_not_removed_when_final_output_is_not_fp16( + self, softmax_op_str + ): + model = onnx.parser.parse_model( + f""" + + agraph (float16[N] x) => (double[N] z) + {{ + x_fp32 = Cast(x) + z_fp32 = {softmax_op_str}(x_fp32) + z = Cast(z_fp32) + }} + """ + ) + ir = irbuilder.build_ir(model) + count = softmax.rules.apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual( + len([node.op_type for node in ir.graph.nodes if node.op_type == "Cast"]), 2 + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/transformers/__init__.py b/onnxscript/rewriter/onnxruntime/transformers/__init__.py new file mode 100644 index 000000000..53eeeef9a --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/transformers/__init__.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from onnxscript.rewriter import function_rule +from onnxscript.rewriter.onnxruntime.transformers import ( + fastgelu, + layernorm, + multihead_attention, +) + +TRANSFORMERS_FUNCTION_REWRITE_RULES: list[type[function_rule.FunctionRewriteRule]] = [ + multihead_attention.GQALlama2RewriteRule, + multihead_attention.GQALlamaSdpa2RewriteRule, + multihead_attention.AttnPhi15RewriteRule, + layernorm.LNRewriteRule, + fastgelu.GeluRewriteRule, +] diff --git a/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py b/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py new file mode 100644 index 000000000..faef84062 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import logging + +import onnx + +import onnxscript +from onnxscript.rewriter import function_rule + +logger = logging.getLogger(__name__) + + +class GeluRewriteRule(function_rule.FunctionRewriteRule): + FUNCTION_KEYWORD = "GELUActivation" + PACKAGE_NAME = "transformers" + _version_controller = function_rule.VersionController() + + @_version_controller.register_version() + def _fusion( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, list[onnx.OperatorSetIdProto]]: + del function # Unused + op = self.onnx_opset + msft_opset = onnxscript.values.Opset("com.microsoft", 1) + + def gelu(input): + return msft_opset.FastGelu(input) + + return onnxscript.script(default_opset=op)(gelu).to_function_proto(), ( + onnx.helper.make_operatorsetid("com.microsoft", 1), + ) diff --git a/onnxscript/rewriter/onnxruntime/transformers/fastgelu_test.py b/onnxscript/rewriter/onnxruntime/transformers/fastgelu_test.py new file mode 100644 index 000000000..db26adf28 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/transformers/fastgelu_test.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import unittest + +import numpy as np + +from tests.common import testutils + + +class FastGeluParityTest(unittest.TestCase): + def setUp(self): + np.random.seed(0) + + def test_gelu_phi_1_5(self): + testutils.test_onnxruntime_rewrite( + "gelu_phi_1_5", 4, {("com.microsoft", "FastGelu", "")} + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/transformers/layernorm.py b/onnxscript/rewriter/onnxruntime/transformers/layernorm.py new file mode 100644 index 000000000..0779bf2af --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/transformers/layernorm.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import logging + +import onnx +from onnx import numpy_helper + +import onnxscript +from onnxscript.rewriter import function_rule + +logger = logging.getLogger(__name__) + + +class LNRewriteRule(function_rule.FunctionRewriteRule): + FUNCTION_KEYWORD = "layernorm" + PACKAGE_NAME = "transformers" + _version_controller = function_rule.VersionController() + + @_version_controller.register_version() + def _fusion( # type: ignore[misc] + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, list[onnx.OperatorSetIdProto]]: + # TODO(bowbao): Might be more desirable to annotate as attribute in nn.Module + aten_add_node = self._find_node_by_type(function, "", "Add") + if aten_add_node is None: + raise function_rule.FunctionRewriteError("Could not find Add node") + + eps_node = self._find_constant_node(function, aten_add_node.input[1]) + if eps_node is None: + raise function_rule.FunctionRewriteError("Could not find eps node") + + eps = numpy_helper.to_array(eps_node.attribute[0].t).item() + logger.info("eps: %s", eps) + + # TODO(ORT): SimplifiedLayerNormalization in ort is defined under onnx domain. + # https://github.com/microsoft/onnxruntime/issues/7573 + # msft_op = onnxscript.values.Opset("com.microsoft", 1) + op = self.onnx_opset + + def ln(input, weight): + return op.SimplifiedLayerNormalization( + input, weight, axis=-1, epsilon=eps, stash_type=1 + ) + + return onnxscript.script(default_opset=op)(ln).to_function_proto(), [] diff --git a/onnxscript/rewriter/onnxruntime/transformers/layernorm_test.py b/onnxscript/rewriter/onnxruntime/transformers/layernorm_test.py new file mode 100644 index 000000000..f4f494aa1 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/transformers/layernorm_test.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import unittest + +import numpy as np + +from tests.common import testutils + + +class LNParityTest(unittest.TestCase): + def setUp(self): + np.random.seed(0) + + def test_ln_llama2(self): + testutils.test_onnxruntime_rewrite( + "ln_llama2", 4, {("", "SimplifiedLayerNormalization", "")} + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py new file mode 100644 index 000000000..40a821af2 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py @@ -0,0 +1,604 @@ +r"""POC experimenting function aware pattern re-write. + +In this case we don't want to spell-out the entire source pattern. +Instead, we want to replace an entire function call a new subgraph. + +Source function: LlamaAttention +inputs (positional args, the names in function definition are unfortunately arbitrary and don't provide value): + - hidden_states + - position_id + - attention_mask + - q_proj.weight + - k_proj.weight + - v_proj.weight + - cos_cached + - sin_cached + - o_proj.weight +outputs (similarly, positional) + - present_value + - present_key + - attn_output (o_proj) + +The rewriting algorithm is as follows: + +The final new function graph should look like this: + + function_proj_q function_proj_k + | | + | | +com.microsoft::RotaryEmbedding com.microsoft::RotaryEmbedding function_proj_v + \ / / + \ / / + \ / / + \--------------- / -----------------------/ + com.microsoft::MultiHeadAttention + | | | + attn_output (present_key) (present_value) + | + function_proj_o + | + (output) + +So all we need, is to locate 'function_proj_q', 'function_proj_k', 'function_proj_v', 'function_proj_o'. +Construct the 4 nodes with new contrib op nodes, and properly name their inputs/outputs. + +""" + +from __future__ import annotations + +import abc +import dataclasses +import logging + +import onnx +from onnx import helper as onnx_helper + +import onnxscript +from onnxscript.rewriter import function_rule + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class AttnSizeConfig: + num_attention_heads: int + num_key_value_heads: int + head_size: int + hidden_size: int + + +class AttentionRewriteRule(function_rule.FunctionRewriteRule, abc.ABC): + def infer_attn_size_config(self, function: onnx.FunctionProto) -> AttnSizeConfig: + if len(function.output) != 3: + raise function_rule.FunctionRewriteError( + f"Unexpected number of outputs. Expected 3, got {len(function.output)}." + ) + present_value, _, attn_output = function.output + if ( + present_value_ir := self.lookup(function, present_value) + ) is None or present_value_ir.shape is None: + raise function_rule.FunctionRewriteError("Failed to find shape for present_value.") + if ( + attn_output_ir := self.lookup(function, attn_output) + ) is None or attn_output_ir.shape is None: + raise function_rule.FunctionRewriteError("Failed to find shape for attn_output.") + head_size = present_value_ir.shape[3] + num_key_value_heads = present_value_ir.shape[1] + hidden_size = attn_output_ir.shape[2] + num_attention_heads = hidden_size // head_size + return AttnSizeConfig( + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_size=head_size, + hidden_size=hidden_size, + ) + + +class MHALlama2RewriteRule(AttentionRewriteRule): + FUNCTION_KEYWORD = "LlamaAttention" + PACKAGE_NAME = "transformers" + _version_controller = function_rule.VersionController() + + def __init__(self) -> None: + super().__init__() + + @_version_controller.register_version(min_version="4.33", max_version="4.36") + def _fusion_with_4d_cache( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + if len(function.input) != 9: + raise function_rule.FunctionRewriteError( + f"Unexpected number of inputs. Expected 9, got {len(function.input)}." + ) + + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + # Workaround onnxscript error by specifying the output shape here. + cos_sin_gather_size = [attn_size_config.head_size // 2] + expand_shape = [1, attn_size_config.num_attention_heads, 1, 1] + + def mha( + hidden_states, + position_id, + attention_mask, + q_proj_weight, + k_proj_weight, + v_proj_weight, + cos_cached, + sin_cached, + o_proj_weight, + ): + q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) + k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) + v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) + + # TODO(onnxscript) + # ValueError: ERROR: Unsupported expression type . + # at: Function 'mha', line 16 + # cos = op.Slice(op.Squeeze(cos_cached, [0, 1]), [0], [cos_sin_gather_size], [1]) + # NOTE: Depending on transformers version, the shape of cos/sin is different. + # In later version, the shape is [seq_len, head_size], so the Squeeze is not needed. + # In this version, the shape is [1, 1, seq_len, head_size], hence the below Squeeze. + cos = op.Slice(op.Squeeze(cos_cached, [0, 1]), [0], cos_sin_gather_size, [1]) + sin = op.Slice(op.Squeeze(sin_cached, [0, 1]), [0], cos_sin_gather_size, [1]) + + q_rope = msft_op.RotaryEmbedding(q, position_id, cos, sin, interleaved=False) + k_rope = msft_op.RotaryEmbedding(k, position_id, cos, sin, interleaved=False) + + # TODO(onnxscript) + # ValueError: ERROR: Unsupported expression type . + # expanded_mask = op.Expand(attention_mask, [1, self.num_heads, 1, 1]) + expanded_mask = op.Expand(attention_mask, expand_shape) + + mha_output, present_key, present_value = msft_op.MultiHeadAttention( + q_rope, + k_rope, + v, + None, + None, + expanded_mask, + num_heads=attn_size_config.num_attention_heads, + ) + attn_output = op.MatMul(mha_output, op.Transpose(o_proj_weight, [1, 0])) + return present_value, present_key, attn_output + + return onnxscript.script(default_opset=onnxscript.opset18)(mha).to_function_proto(), ( + onnx.helper.make_operatorsetid("com.microsoft", 1), + ) + + @_version_controller.register_version(min_version="4.36", max_version="4.38") + def _fusion_with_2d_cache( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + if len(function.input) != 9: + raise function_rule.FunctionRewriteError( + f"Unexpected number of inputs. Expected 9, got {len(function.input)}." + ) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + # Workaround onnxscript error by specifying the output shape here. + cos_sin_gather_size = [attn_size_config.head_size // 2] + expand_shape = [1, attn_size_config.num_attention_heads, 1, 1] + + def mha( + hidden_states, + position_id, + attention_mask, + q_proj_weight, + k_proj_weight, + v_proj_weight, + cos_cached, + sin_cached, + o_proj_weight, + ): + q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) + k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) + v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) + + cos = op.Slice(cos_cached, [0], cos_sin_gather_size, [1]) + sin = op.Slice(sin_cached, [0], cos_sin_gather_size, [1]) + + q_rope = msft_op.RotaryEmbedding(q, position_id, cos, sin, interleaved=False) + k_rope = msft_op.RotaryEmbedding(k, position_id, cos, sin, interleaved=False) + + # TODO(onnxscript) + # ValueError: ERROR: Unsupported expression type . + # expanded_mask = op.Expand(attention_mask, [1, self.num_heads, 1, 1]) + expanded_mask = op.Expand(attention_mask, expand_shape) + + mha_output, present_key, present_value = msft_op.MultiHeadAttention( + q_rope, + k_rope, + v, + None, + None, + expanded_mask, + num_heads=attn_size_config.num_attention_heads, + ) + attn_output = op.MatMul(mha_output, op.Transpose(o_proj_weight, [1, 0])) + return present_value, present_key, attn_output + + return onnxscript.script(default_opset=onnxscript.opset18)(mha).to_function_proto(), ( + onnx.helper.make_operatorsetid("com.microsoft", 1), + ) + + +class GQALlama2RewriteRule(AttentionRewriteRule): + FUNCTION_KEYWORD = "LlamaAttention" + PACKAGE_NAME = "transformers" + _version_controller = function_rule.VersionController() + + def __init__(self) -> None: + super().__init__() + + @_version_controller.register_version(min_version="4.33", max_version="4.36") + def _fusion_with_4d_cache( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + if len(function.input) != 9: + raise function_rule.FunctionRewriteError( + f"Unexpected number of inputs. Expected 9, got {len(function.input)}." + ) + + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + # Workaround onnxscript error by specifying the output shape here. + cos_sin_gather_size = [attn_size_config.head_size // 2] + + def gqa( + hidden_states, + position_id, + attention_mask, + q_proj_weight, + k_proj_weight, + v_proj_weight, + cos_cached, + sin_cached, + o_proj_weight, + ): + q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) + k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) + v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) + + # NOTE: Depending on transformers version, the shape of cos/sin is different. + # In later version, the shape is [seq_len, head_size], so the Squeeze is not needed. + # In this version, the shape is [1, 1, seq_len, head_size], hence the below Squeeze. + cos = op.Slice(op.Squeeze(cos_cached, [0, 1]), [0], cos_sin_gather_size, [1]) + sin = op.Slice(op.Squeeze(sin_cached, [0, 1]), [0], cos_sin_gather_size, [1]) + + q_rope = msft_op.RotaryEmbedding(q, position_id, cos, sin, interleaved=False) + k_rope = msft_op.RotaryEmbedding(k, position_id, cos, sin, interleaved=False) + + batch_size = op.Slice(op.Shape(hidden_states), [0], [1], [0]) + sequence_length = op.Slice(op.Shape(hidden_states), [1], [2], [0]) + past_seq_lengths = op.ConstantOfShape( + batch_size, + value=onnx_helper.make_tensor( + "past_seq_lengths", onnx.TensorProto.INT32, [1], [0] + ), + ) + total_seq_lengths = op.Cast(sequence_length, to=onnx.TensorProto.INT32) + + gqa_output, present_key, present_value = msft_op.GroupQueryAttention( + q_rope, + k_rope, + v, + None, + None, + past_seq_lengths, + total_seq_lengths, + kv_num_heads=attn_size_config.num_key_value_heads, + num_heads=attn_size_config.num_attention_heads, + ) + attn_output = op.MatMul(gqa_output, op.Transpose(o_proj_weight, [1, 0])) + return present_value, present_key, attn_output + + return onnxscript.script(default_opset=onnxscript.opset18)(gqa).to_function_proto(), ( + onnx.helper.make_operatorsetid("com.microsoft", 1), + ) + + @_version_controller.register_version(min_version="4.36", max_version="4.38") + def _fusion_with_2d_cache( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + if len(function.input) != 9: + raise function_rule.FunctionRewriteError( + f"Unexpected number of inputs. Expected 9, got {len(function.input)}." + ) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + # Workaround onnxscript error by specifying the output shape here. + cos_sin_gather_size = [attn_size_config.head_size // 2] + + def gqa( + hidden_states, + position_id, + attention_mask, + q_proj_weight, + k_proj_weight, + v_proj_weight, + cos_cached, + sin_cached, + o_proj_weight, + ): + q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) + k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) + v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) + + cos = op.Slice(cos_cached, [0], cos_sin_gather_size, [1]) + sin = op.Slice(sin_cached, [0], cos_sin_gather_size, [1]) + + q_rope = msft_op.RotaryEmbedding(q, position_id, cos, sin, interleaved=False) + k_rope = msft_op.RotaryEmbedding(k, position_id, cos, sin, interleaved=False) + + batch_size = op.Slice(op.Shape(hidden_states), [0], [1], [0]) + sequence_length = op.Slice(op.Shape(hidden_states), [1], [2], [0]) + past_seq_lengths = op.ConstantOfShape( + batch_size, + value=onnx_helper.make_tensor( + "past_seq_lengths", onnx.TensorProto.INT32, [1], [0] + ), + ) + total_seq_lengths = op.Cast(sequence_length, to=onnx.TensorProto.INT32) + + gqa_output, present_key, present_value = msft_op.GroupQueryAttention( + q_rope, + k_rope, + v, + None, + None, + past_seq_lengths, + total_seq_lengths, + kv_num_heads=attn_size_config.num_key_value_heads, + num_heads=attn_size_config.num_attention_heads, + ) + attn_output = op.MatMul(gqa_output, op.Transpose(o_proj_weight, [1, 0])) + return present_value, present_key, attn_output + + return onnxscript.script(default_opset=onnxscript.opset18)(gqa).to_function_proto(), ( + onnx.helper.make_operatorsetid("com.microsoft", 1), + ) + + +class GQALlamaSdpa2RewriteRule(AttentionRewriteRule): + # TODO: There are a lot of duplicated code with `MHALlama2RewriteRule`. + # The pitfall is that the source function signature is slightly different. + # One has `attention_mask` as input while the other does not. + # Possibly designing a function template system could help reduce the boilerplate. + FUNCTION_KEYWORD = "LlamaSdpaAttention" + PACKAGE_NAME = "transformers" + _version_controller = function_rule.VersionController() + + def __init__(self) -> None: + super().__init__() + + @_version_controller.register_version(min_version="4.36", max_version="4.38") + def _fusion( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + cos_sin_gather_size = [attn_size_config.head_size // 2] + + def gqa( + hidden_states, + position_id, + q_proj_weight, + k_proj_weight, + v_proj_weight, + cos_cached, + sin_cached, + o_proj_weight, + ): + q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) + k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) + v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) + + cos = op.Slice(cos_cached, [0], cos_sin_gather_size, [1]) + sin = op.Slice(sin_cached, [0], cos_sin_gather_size, [1]) + + q_rope = msft_op.RotaryEmbedding(q, position_id, cos, sin, interleaved=False) + k_rope = msft_op.RotaryEmbedding(k, position_id, cos, sin, interleaved=False) + + batch_size = op.Slice(op.Shape(hidden_states), [0], [1], [0]) + sequence_length = op.Slice(op.Shape(hidden_states), [1], [2], [0]) + past_seq_lengths = op.ConstantOfShape( + batch_size, + value=onnx_helper.make_tensor( + "past_seq_lengths", onnx.TensorProto.INT32, [1], [0] + ), + ) + total_seq_lengths = op.Cast(sequence_length, to=onnx.TensorProto.INT32) + + gqa_output, present_key, present_value = msft_op.GroupQueryAttention( + q_rope, + k_rope, + v, + None, + None, + past_seq_lengths, + total_seq_lengths, + kv_num_heads=attn_size_config.num_key_value_heads, + num_heads=attn_size_config.num_attention_heads, + ) + attn_output = op.MatMul(gqa_output, op.Transpose(o_proj_weight, [1, 0])) + return present_value, present_key, attn_output + + return onnxscript.script(default_opset=onnxscript.opset18)( + gqa, + ).to_function_proto(), (onnx.helper.make_operatorsetid("com.microsoft", 1),) + + @_version_controller.register_version(min_version="4.38") + def _fusion_without_cos_sin_cache( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_op = onnxscript.values.Opset("com.microsoft", 1) + + cos_sin_gather_size = [attn_size_config.head_size // 2] + + def gqa( + hidden_states, + position_id, + causal_mask, + cache_position, + q_proj_weight, + k_proj_weight, + v_proj_weight, + inv_freq, + o_proj_weight, + ): + q = op.MatMul(hidden_states, op.Transpose(q_proj_weight, [1, 0])) + k = op.MatMul(hidden_states, op.Transpose(k_proj_weight, [1, 0])) + v = op.MatMul(hidden_states, op.Transpose(v_proj_weight, [1, 0])) + + # In 4.38 and later, cos/sin are not cached, but computed on the fly. + # This can be further optimized by constant folding for scenarios where + # the position_id is known at compile time. + seq_len = op.Slice(op.Shape(hidden_states), [1], [2], [0]) + seq_len_scalar = op.Squeeze(seq_len, [0]) + t = op.Unsqueeze( + op.Cast(op.Range(0, seq_len_scalar, 1), to=onnx.TensorProto.FLOAT), [1] + ) + inv_freq = op.Cast(op.Unsqueeze(inv_freq, [0]), to=onnx.TensorProto.FLOAT) + freqs = op.MatMul(t, inv_freq) + + emb = op.Concat(freqs, freqs, axis=-1) + cos = op.CastLike(op.Cos(emb), hidden_states) + sin = op.CastLike(op.Sin(emb), hidden_states) + cos = op.Slice(cos, [0], cos_sin_gather_size, [1]) + sin = op.Slice(sin, [0], cos_sin_gather_size, [1]) + + q_rope = msft_op.RotaryEmbedding(q, position_id, cos, sin, interleaved=False) + k_rope = msft_op.RotaryEmbedding(k, position_id, cos, sin, interleaved=False) + + batch_size = op.Slice(op.Shape(hidden_states), [0], [1], [0]) + sequence_length = op.Slice(op.Shape(hidden_states), [1], [2], [0]) + past_seq_lengths = op.ConstantOfShape( + batch_size, + value=onnx_helper.make_tensor( + "past_seq_lengths", onnx.TensorProto.INT32, [1], [0] + ), + ) + total_seq_lengths = op.Cast(sequence_length, to=onnx.TensorProto.INT32) + + gqa_output, present_key, present_value = msft_op.GroupQueryAttention( + q_rope, + k_rope, + v, + None, + None, + past_seq_lengths, + total_seq_lengths, + kv_num_heads=attn_size_config.num_key_value_heads, + num_heads=attn_size_config.num_attention_heads, + ) + attn_output = op.MatMul(gqa_output, op.Transpose(o_proj_weight, [1, 0])) + return present_value, present_key, attn_output + + return onnxscript.script(default_opset=onnxscript.opset18)( + gqa, + ).to_function_proto(), (onnx.helper.make_operatorsetid("com.microsoft", 1),) + + +class AttnPhi15RewriteRule(AttentionRewriteRule): + FUNCTION_KEYWORD = "PhiAttention" + PACKAGE_NAME = "transformers_modules" + _version_controller = function_rule.VersionController() + + def __init__(self) -> None: + super().__init__() + + @_version_controller.register_version() + def _fusion( + self, function: onnx.FunctionProto + ) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]: + # Infer size configurations from the function. + attn_size_config = self.infer_attn_size_config(function) + + # Code new pattern with onnxscript. + op = onnxscript.opset18 + msft_opset = onnxscript.values.Opset("com.microsoft", 1) + + def phi_attention( + hidden_states, + position_id, + attention_mask, + q_proj_weight, + q_proj_bias, + k_proj_weight, + k_proj_bias, + v_proj_weight, + v_proj_bias, + cos_cached, + sin_cached, + dense_weight, + dense_bias, + ): + qkv_weight = op.Transpose( + op.Concat(q_proj_weight, k_proj_weight, v_proj_weight, axis=0), + perm=[1, 0], + ) + qkv_bias = op.Concat(q_proj_bias, k_proj_bias, v_proj_bias, axis=0) + + # [batch_size, sequence_length] + attention_mask_shape = op.Slice(op.Shape(hidden_states), [0], [2], [0]) + + # Create 2d mask to mimic 4d causal mask. + attention_mask = op.ConstantOfShape( + attention_mask_shape, + value=onnx_helper.make_tensor("mask_value", onnx.TensorProto.INT32, [1], [1]), + ) + attn_output, present = msft_opset.Attention( + hidden_states, + qkv_weight, + qkv_bias, + attention_mask, + unidirectional=1, + do_rotary=1, + # Attention.rotary_embedding_dim only supports 32, 64 or 128 + rotary_embedding_dim=attn_size_config.head_size // 2 // 32 * 32, + num_heads=attn_size_config.num_attention_heads, + ) + present_key = op.Gather(present, 0) + present_value = op.Gather(present, 1) + output = op.Add( + op.MatMul(attn_output, op.Transpose(dense_weight, [1, 0])), dense_bias + ) + + return present_value, present_key, output + + return onnxscript.script(default_opset=onnxscript.opset18)( + phi_attention + ).to_function_proto(), (onnx.helper.make_operatorsetid("com.microsoft", 1),) diff --git a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention_test.py b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention_test.py new file mode 100644 index 000000000..26c8c12f2 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention_test.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import unittest + +import numpy as np + +from tests.common import testutils + + +class MHAParityTest(unittest.TestCase): + def setUp(self): + np.random.seed(0) + + @testutils.skip_if_no_cuda("GQA Kernel unsupported on CPU.") + def test_attn_llama2_4_34(self): + testutils.test_onnxruntime_rewrite( + "attn_llama2_4_34", 2, {("com.microsoft", "GroupQueryAttention", "")} + ) + + @testutils.skip_if_no_cuda("GQA Kernel unsupported on CPU.") + def test_attn_llama2_4_36(self): + testutils.test_onnxruntime_rewrite( + "attn_llama2_4_36", 1, {("com.microsoft", "GroupQueryAttention", "")} + ) + + @testutils.skip_if_no_cuda("GQA Kernel unsupported on CPU.") + def test_attn_yi_4_37(self): + testutils.test_onnxruntime_rewrite( + "attn_yi_4_37", 1, {("com.microsoft", "GroupQueryAttention", "")} + ) + + @testutils.skip_if_no_cuda("GQA Kernel unsupported on CPU.") + def test_sdpa_llama2_4_36(self): + # TODO: Clean-up naming logic of test models. + # Package version was not considered. + testutils.test_onnxruntime_rewrite( + "sdpa_llama2", 4, {("com.microsoft", "GroupQueryAttention", "")} + ) + + @unittest.skip("TODO: Fails parity check") + def test_sdpa_llama2_4_38(self): + testutils.test_onnxruntime_rewrite( + "sdpa_llama2_4_38", 1, {("com.microsoft", "GroupQueryAttention", "")} + ) + + @testutils.skip_if_no_cuda("GQA Kernel unsupported on CPU.") + def test_sdpa_yi_4_36(self): + testutils.test_onnxruntime_rewrite( + "sdpa_yi", 2, {("com.microsoft", "GroupQueryAttention", "")} + ) + + @unittest.skip("TODO: Fails parity check") + def test_sdpa_yi_4_38(self): + testutils.test_onnxruntime_rewrite( + "sdpa_yi_4_38", 1, {("com.microsoft", "GroupQueryAttention", "")} + ) + + +class AttnParityTest(unittest.TestCase): + def setUp(self): + np.random.seed(0) + + @testutils.skip_if_no_cuda("CPU has parity issue.") + def test_attn_phi_1_5(self): + testutils.test_onnxruntime_rewrite( + "attn_phi_1_5", 4, {("com.microsoft", "Attention", "")} + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py new file mode 100644 index 000000000..e89057d47 --- /dev/null +++ b/onnxscript/rewriter/pattern.py @@ -0,0 +1,1069 @@ +from __future__ import annotations + +import inspect +import itertools +import math +from typing import Any, Callable, Sequence + +import numpy as np +import onnx +import onnx.numpy_helper +import onnx.printer + +import onnxscript._legacy_ir as ir +from onnxscript._legacy_ir import irbuilder + +# Overview of the pattern module: The classes below are used to define both +# patterns (that we search for) and replacements for rewrite rules. +# The matches() method of a pattern is used to check if an IR component +# matches the pattern. +# The to_ir() method of a pattern is used to create a new IR component +# TODO: Ensure that all matches() methods have same type signature (where +# appropriate) and that all to_ir() methods have same type signature (where +# appropriate). + + +class ConstantPattern: + def __init__(self, value: int | str | list) -> None: + self._value = value + + @property + def value(self) -> int | str | list: + return self._value + + def matches(self, value: int | str | list) -> bool: + return value == self.value + + def to_ir(self, model, bindings=None) -> int | str | list: + return self.value + + +class PrefixPattern: + """This pattern is used to simplify submodule opset pattern matching.""" + + def __init__(self, value: str) -> None: + self._value = value + + @property + def value(self) -> str: + return self._value + + def matches(self, value: str) -> bool: + return value.startswith(self.value) + + def to_ir(self, model, bindings=None) -> str: + raise NotImplementedError("PrefixPattern should not be converted to IR") + + +class FloatConstantPattern: + def __init__(self, value: float, rel_tol: float = 1e-5, abs_tol: float = 1e-8) -> None: + self._value = value + self._rel_tol = rel_tol + self._abs_tol = abs_tol + + @property + def value(self): + return self._value + + def matches(self, value: float): + return math.isclose(value, self.value, rel_tol=self._rel_tol, abs_tol=self._abs_tol) + + def to_ir(self, model, bindings=None) -> float: + return self.value + + +class TensorConstantPattern: + def __init__( + self, value: np.ndarray, rel_tol: float = 1e-3, abs_tol: float = 1e-3 + ) -> None: + self._value = value + self._rel_tol = rel_tol + self._abs_tol = abs_tol + + @property + def value(self): + return self._value + + def matches(self, value: np.ndarray): + return ( + value.dtype == self._value.dtype + and value.shape == self._value.shape + and np.allclose( + value, + self._value, + rtol=self._rel_tol, + atol=self._abs_tol, + ) + ) + + def to_ir(self, model, bindings=None) -> onnx.TensorProto: + return onnx.helper.make_tensor( + "", + onnx.helper.np_dtype_to_tensor_dtype(self.value.dtype), + self.value.shape, + self.value, + ) + + +def _make_constant_pattern( + value: float | int | list | np.ndarray, +) -> ConstantPattern | FloatConstantPattern | TensorConstantPattern: + """Convert an attrbute value to a ConstantPattern.""" + if isinstance(value, float): + return FloatConstantPattern(value) + if isinstance(value, (int, list)): + return ConstantPattern(value) + if isinstance(value, np.ndarray): + return TensorConstantPattern(value) + raise TypeError(f"Cannot convert {type(value)} to ConstantPattern") + + +class AnyPattern: + def matches(self, value) -> bool: + return True + + +class AttrPattern: + def __init__(self, value: Var | int | float | list | np.ndarray) -> None: + if isinstance(value, Var): + self.value_pattern = value + elif isinstance(value, (int, float, list, np.ndarray)): + self.value_pattern = _make_constant_pattern(value) + else: + raise TypeError(f"Cannot convert {type(value)} to AttrPattern") + + def matches(self, attr_val: int | float | list, model: ir.Model) -> MatchResult: + if isinstance(self.value_pattern, Var): + return self.value_pattern.matches(attr_val, model) + return self.value_pattern.matches(attr_val) + + def to_ir(self, model: ir.Model, rewrite_cache: RewriteCache, bindings=None) -> ir.Val: + if isinstance(self.value_pattern, Var): + val, nodes = self.value_pattern.to_ir( + model, bindings, 1, rewrite_cache + ) # TODO: handle multiple outputs + return val + # constant pattern + return self.value_pattern.to_ir(model, bindings) + + +class OpsetPattern: + """Represents an opset pattern. + + It is used primarily to create a NodePattern (via OpPattern). + Example usage: + :: + + z = op.Matmul(x, y) + + Here, `op` is an instance of OpsetPattern and `op.Matmul` is an instance + of OpPattern, and `op.Matmul(x, y)` is an instance of NodePattern. + + An opset pattern is also matched against the actual opset used in the + input model. Typically, we match against an ONNX opset (ignoring the + version), but we can match against a specific version of the opset too. + However, it is preferable that version-dependences are handled at the + level of a rewrite rule, rather than at the level of a pattern. + + """ + + def __init__( + self, + domain_pattern: ConstantPattern | PrefixPattern, + version_pattern: ConstantPattern | AnyPattern, + ) -> None: + self.domain_pattern = domain_pattern + self.version_pattern = version_pattern + + @classmethod + def singleton(cls, domain: str, version: int) -> OpsetPattern: + return cls(ConstantPattern(domain), ConstantPattern(version)) + + @classmethod + def domain(cls, domain: str) -> OpsetPattern: + return cls(ConstantPattern(domain), AnyPattern()) + + @classmethod + def domain_prefix(cls, domain: str) -> OpsetPattern: + return cls(PrefixPattern(domain), AnyPattern()) + + def matches(self, opset): + domain, version = opset + return self.domain_pattern.matches(domain) and self.version_pattern.matches(version) + + def to_ir(self, model, bindings=None) -> str: + domain = self.domain_pattern.to_ir(model, bindings) + # TODO: Should we ban other custom domains? + if domain not in model.version_map: + model.version_map[self.domain_pattern.value] = self.version_pattern.value + return domain + + def __getattr__(self, name: str) -> Any: + return OpPattern(self, ConstantPattern(name)) + + def submodule(self, name: str) -> Any: + """This method is used to match against submodule ops with prefix.""" + return OpPattern(self, PrefixPattern(name)) + + +opset17 = OpsetPattern.singleton("", 17) + +onnxop = OpsetPattern.domain("") + +msft_op = OpsetPattern.singleton("com.microsoft", 1) + +torch_module_op = OpsetPattern.domain_prefix("pkg.torch") + + +class OpPattern: + """A utility class to build a NodePattern. + + It is used primarily to create a NodePattern. + Example usage: + :: + + z = op.Matmul(x, y) + + Here, `op` is an instance of OpsetPattern and `op.Matmul` is an instance + of OpPattern, and `op.Matmul(x, y)` is an instance of NodePattern. + + """ + + def __init__( + self, + opset_pattern: OpsetPattern, + op_name_pattern: ConstantPattern | PrefixPattern, + ) -> None: + self.opset_pattern = opset_pattern + self.op_name_pattern = op_name_pattern + + def __call__(self, *args, **kwargs): + if "_num_outputs" in kwargs: + num_outputs = kwargs["_num_outputs"] + del kwargs["_num_outputs"] + else: + num_outputs = 1 + attributes = {name: AttrPattern(value) for (name, value) in kwargs.items()} + node_pattern = NodePattern(self.opset_pattern, self.op_name_pattern, args, attributes) + if num_outputs == 1: + return NodeOutputPattern(node_pattern, 0) + else: + return [NodeOutputPattern(node_pattern, i) for i in range(num_outputs)] + + +def _to_value_pattern(x: ValuePattern | int | float) -> ValuePattern: + """Promotes an input-value used to construct a NodePattern to a ValuePattern. + + Example usage: + :: + x = op.MatMul(a, b) + z = op.Add(x, 0) + + In this example, `a, `b`, and `x` are ValuePatterns used to construct a NodePattern. + `0` is a constant (int) value, and is automatically promoted to a ValuePattern. + + Note that this is a shorthand for creating a Constant pattern. The user can more + explicitly write this as: + :: + z = op.Add(x, op.Constant(0)) + """ + if isinstance(x, ValuePattern): + return x + if isinstance(x, (int, float, list)): + return Constant(x) + # TODO(titaiwang): Could this be wrapped Constant? + raise TypeError(f"Cannot convert {type(x)} to ValuePattern") + + +class MatchResult: + """Represents the result of a match operation. + + A match can either succeed or fail. + If it succeeds, it returns a list of IR values that matched the pattern + and a set of bindings for the variables in the pattern. + + Example: + :: + def pattern(x, shape1, shape2): + t1 = op.Reshape(x, shape1) + t2 = op.Reshape(t1, shape2) + return t2 + The above pattern matches a sequence of two Reshape ops. + The matched_values will contain the values representing the (output of) + the two Reshape ops, and the bindings will contain the values that + are bound to the variables `x`, `shape1`, and `shape2`. + """ + + def __init__( + self, matched_values=None, bindings: dict[str, ir.Value | Any] | None = None + ) -> None: + assert matched_values is None or isinstance(matched_values, list) + self.success: bool = matched_values is not None + # For a successful match, matched_values is a list of values that matched the pattern. + # These include the internal nodes of the pattern that were matched, but not + # the leaves (sub-trees) that match against the variables in the pattern. + # These represent the values that will be replaced by the replacement pattern. + self.matched_values: Sequence[Any] | None = matched_values + # For a successful match, bindings is a dictionary of mapping pattern-variable-names + # to values. + self.bindings: dict[str, Any] = bindings if bindings is not None else {} + + def __bool__(self): + return self.success + + @classmethod + def FAIL(cls): + return cls(None) + + @property + def values(self) -> Sequence[Any] | None: + return self.matched_values + + def fail(self): + self.success = False + self.matched_values = None + self.bindings = {} + + def extend(self, other: MatchResult | bool, model): + del model # Unused + if not self.success: + return + if not other: + self.fail() + return + if isinstance(other, bool): + return + for var, val in other.bindings.items(): + if var in self.bindings: + # TODO: handle attribute var bindings + if not self.bindings[var].is_same_as(val): + self.fail() + return + else: + self.bindings[var] = val + self.matched_values.extend(other.matched_values) + + +class ValuePattern: + """Base class for all patterns that match against IR values. + + This is used primarily to provide operator overloadings for arithmetic + operations, so that we can write patterns like `x + 1` and `1 + x`. + """ + + def __init__(self) -> None: + pass + + def __add__(self, other): + return onnxop.Add(self, other) + + def __radd__(self, other): + return onnxop.Add(other, self) + + def __sub__(self, other): + return onnxop.Sub(self, other) + + def __rsub__(self, other): + return onnxop.Sub(other, self) + + def __mul__(self, other): + return onnxop.Mul(self, other) + + def __rmul__(self, other): + return onnxop.Mul(other, self) + + def __truediv__(self, other): + return onnxop.Div(self, other) + + def __rtruediv__(self, other): + return onnxop.Div(other, self) + + def __pow__(self, other): + return onnxop.Pow(self, other) + + +# NOTE(bowbao): Based on reading code, this is (nearly) the only place where `model` is used +# for (nearly) all the functions that passes `model` around. It seems the goal is to be able +# create unique value names. +def _make_node( + model: ir.Model, + domain: str, + op: str, + input, + attributes, + num_outputs: int, +) -> tuple[list[ir.Value], ir.Node]: + inputnames = [x.name for x in input] + outputs = [model.make_new_name() for i in range(num_outputs)] + node = onnx.helper.make_node(op, inputnames, outputs, domain=domain, **attributes) + newnode = ir.Node(node) + newnode.set_version_if_custom_op(model.version_map) + newvalues = [ir.Value(name=v, node=newnode, output_index=i) for i, v in enumerate(outputs)] + newnode.inputs = input + newnode.outputs = newvalues + newnode.attributes = attributes # TODO + return newvalues, newnode + + +class NodePattern: + """Represents a pattern that matches against a Node. + + This differs from a NodeOutputPattern in that it matches against a node (which + may produce 1 or more outputs), whereas a NodeOutputPattern matches against + a specific output of a node. + """ + + def __init__( + self, + domain: OpsetPattern, + op: ConstantPattern, + inputs: Sequence[int | float | ValuePattern], + attributes: dict[str, AttrPattern], + ): + self.domain = domain + self.op = op + self.inputs = [_to_value_pattern(x) for x in inputs] + self.attributes = attributes + self.bound_value = None + + def matches(self, value: ir.Value, model: ir.Model): + if self.bound_value is not None: + # DAG-matching, not Tree-matching. + if self.bound_value.is_same_as(value): + return MatchResult([]) + else: + return MatchResult.FAIL() + node = value.def_node() + if node is None: + # Eg., value could be an input parameter, which will not match a value + # computed by the op in this pattern. + return MatchResult.FAIL() + return self.matches_node(node, model) + + def matches_node(self, node: ir.Node, model: ir.Model) -> MatchResult: + """Examine if the IR node matches the self pattern.""" + if not self.domain.matches((node.domain, node.version)): + return MatchResult.FAIL() + if not self.op.matches(node.op_type): + return MatchResult.FAIL() + match = MatchResult([]) + # TODO: We should add filtered logging starting from here to emit why + # matching failed. This should cut a lot of noises compared to logging everything, + # because at least the starting node op_type is already matched. + for arg_value, previous_node_output_pattern in zip(node.inputs, self.inputs): + # previous_node_output_pattern could be a Var, if it's the original arg. + sub_match = previous_node_output_pattern.matches(arg_value, model) + match.extend(sub_match, model) + if not match: # If sub-match failed, + return match + # Sub-graphs not handled yet. + for name, attr_pattern in self.attributes.items(): + attr_value = node.get_attribute(name) + if attr_value is None: + return MatchResult.FAIL() + sub_match = attr_pattern.matches(attr_value, model) + if not sub_match: + return MatchResult.FAIL() + match.extend(sub_match, model) + for name in node.attributes: + # TODO: Support matching default values for attributes. + if name not in self.attributes: + return MatchResult.FAIL() + match.values.append(node) + return match + + def to_ir( + self, + model: ir.Model, + bindings: dict[str, ir.Value | Any], + num_outputs: int, + rewrite_cache: RewriteCache, + ) -> tuple[list[ir.Value], list[ir.Node]]: + domain = self.domain.to_ir(model) + op = self.op.to_ir(model) + inputs = [] + nodes = [] + for val_pattern in self.inputs: + if ( + value_and_node := rewrite_cache.get_node_output_pattern(val_pattern) + ) is not None: + val, n = value_and_node + else: + val, n = val_pattern.to_ir(model, bindings, 1, rewrite_cache) + rewrite_cache.set_node_output_pattern_with_ir(val_pattern, val, n) + nodes.extend(n) + # If one of the inputs was a the output of a previous node, + # unpack the new output ir value that is created for that node + if isinstance(val, list): + # TODO: Move implementation of output_index to NodeOutputPatter.to_ir + inputs.append(val[val_pattern.output_index]) + else: + inputs.append(val) + attributes = { + name: attr_pattern.to_ir(model, rewrite_cache, bindings) + for (name, attr_pattern) in self.attributes.items() + } + newvals, newnode = _make_node(model, domain, op, inputs, attributes, num_outputs) + nodes.append(newnode) + return newvals, nodes + + def commute(self) -> list[ValuePattern]: + list_of_lists = [pattern.commute() for pattern in self.inputs] + + def enumerate_inputs(inputs, index): + if index >= len(inputs): + yield [] + else: + for pattern in inputs[index]: + for rest in enumerate_inputs(inputs, index + 1): + yield [pattern, *rest] + + inputs = list(enumerate_inputs(list_of_lists, 0)) + if self.domain.matches(("", None)) and ( + self.op.matches("Add") or self.op.matches("Mul") + ): + # TODO: handle cases where number of inputs is not 2. + swapped = [[x[1], x[0]] for x in inputs] + inputs.extend(swapped) + return [NodePattern(self.domain, self.op, input, self.attributes) for input in inputs] + + +class NodeOutputPattern(ValuePattern): + """Represents a pattern that matches against a specific output of a Node. + + This is the primary pattern used to match against computed values, that + is values computed using a specific op. + """ + + def __init__(self, node_pattern: NodePattern, output_index: int) -> None: + self.node_pattern = node_pattern + self.output_index = output_index + + def matches(self, value: ir.Value, model: ir.Model): + """Match the StaticValueInfo from IR with the `matches_node()` in node pattern.""" + node = value.def_node() + if node is None: + return MatchResult.FAIL() + if value.def_index() != self.output_index: + return MatchResult.FAIL() + return self.node_pattern.matches_node(node, model) + + def to_ir( + self, + model: ir.Model, + bindings: dict[str, ir.Value | Any], + num_outputs: int, + rewrite_cache: RewriteCache, + ) -> tuple[list[ir.Value], list[ir.Node]]: + assert self.output_index == 0, "TODO: handle multiple outputs" + return self.node_pattern.to_ir(model, bindings, num_outputs, rewrite_cache) + + +class Var(ValuePattern): + """Represents a pattern variable.""" + + def __init__(self, name: str) -> None: + self.pattern_var_name = name + self.bound_value = None + + def __repr__(self) -> str: + return f"Var({self.pattern_var_name!r})" + + def matches(self, value: ir.Value, model: ir.Model): + return MatchResult([], {self.pattern_var_name: value}) + + def to_ir( + self, + model: ir.Model, + bindings: dict[str, ir.Value | Any], + num_outputs: int, + rewrite_cache: RewriteCache, + ) -> tuple[ir.Value, list[None]]: + del model # Unused + del num_outputs # Unused + del rewrite_cache # Unused + return bindings[self.pattern_var_name], [] + + def commute(self) -> list[ValuePattern]: + return [self] + + +class Constant(ValuePattern): + """Represents a pattern that matches against a scalar constant value.""" + + def __init__( + self, value: int | float, rel_tol: float = 1e-5, abs_tol: float = 1e-8 + ) -> None: + self.value = value + self.rel_tol = rel_tol + self.abs_tol = abs_tol + + def match_scalar(self, scalar_value, return_value: list[ir.Node]): + if math.isclose(scalar_value, self.value, rel_tol=self.rel_tol, abs_tol=self.abs_tol): + return MatchResult(return_value) + else: + return MatchResult.FAIL() + + def matches(self, value: ir.Value, model: ir.Model): + del model # Unused + constant_value = value.value_as_np_array + if isinstance(constant_value, np.ndarray): + # TODO (rama): allow users to specify shape requirement, if desired. + if constant_value.size != 1: + return MatchResult.FAIL() + + return_value = [] + # Note: If the value is produced by a Constant node, we could include + # the Constant node in the return_value list. However, we don't do that. + # Instead, we will rely on DCE to remove the constant node if it is not + # used elsewhere. + + return self.match_scalar(constant_value.item(), return_value) + return MatchResult.FAIL() + + def commute(self) -> list[ValuePattern]: + return [self] + + +def _handle_pattern_return_value( + node_output_pattern: NodeOutputPattern | list[NodeOutputPattern], +) -> tuple[NodePattern, int]: + """This checks and cleans up the return value of a pattern-construction function. + + A pattern-construction function will return values as below: + :: + def pattern(x, shape1, shape2): + ... + return op.SomeOp(...) + However, `SomeOp` may represent an ONNX op that produces multiple outputs. + This function validates that the return values represent the outputs of + a single NodePattern. It returns the node_pattern and the number of outputs. + + This follows an important restriction of the pattern-matcher algorithm: it + only matches against subgraphs that end in a single terminal node. If we + permit two terminal nodes, then we would have to match against all possible + pairs of nodes in the graph, which produces an extra quadratic factor in the + complexity of the pattern-matching algorithm. In general, the complexity becomes + exponential in the number of terminal nodes. + + Args: + node_output_pattern: NodeOutputPattern | list[NodeOutputPattern] + + Returns: + tuple[NodePattern, int]: The last node_pattern, num_outputs + """ + if isinstance(node_output_pattern, NodeOutputPattern): + node_pattern = node_output_pattern.node_pattern + num_outputs = 1 + elif isinstance(node_output_pattern, (list, tuple)): + node_pattern = node_output_pattern[0].node_pattern + num_outputs = len(node_output_pattern) + for i, p in enumerate(node_output_pattern): + assert isinstance(p, NodeOutputPattern) + assert p.node_pattern is node_pattern + assert p.output_index == i + else: + raise TypeError(f"Invalid type {type(node_output_pattern)} for pattern") + return node_pattern, num_outputs + + +# Currently, the replacement graph function is the same as the pattern function. +# This may change in the future. +_handle_replacement_return_value = _handle_pattern_return_value + + +def _valid_to_replace(matched_nodes: Sequence[ir.Node]) -> bool: + """Check that values computed by the matched_nodes, except for the last one, are used only by the matched_nodes.""" + # * Must check that all values matched by pattern are used only by pattern, + # except for the value that is replaced. + # * Must ensure that replacement subgraph does not use any of the deleted + # (intermediate) values. (Not necessary for now. Guaranteed.) + deleted_nodes = matched_nodes[:-1] + for n in deleted_nodes: + for v in n.outputs: + if v.is_output: + # value is an output-value of the graph/function. + return False + for use in v.uses: + if use not in matched_nodes: + return False + return True + + +class TargetPatternFunction: + """The targeted pattern that will be replaced by the replacement pattern. + + Attributes: + function (Callable): The pattern function that will be matched against the IR. + """ + + def __init__(self, function: Callable) -> None: + self._function = function + + @property + def function(self) -> Callable: + return self._function + + def get_pattern(self, *variables: Sequence[Var]) -> tuple[NodePattern, int]: + node_output_pattern = self._function(*variables) + return _handle_pattern_return_value(node_output_pattern) + + +class ReplacementPatternFunction: + """The replacement pattern that will replace the targeted pattern. + + Attributes: + function (Callable): The replacement function that will be used to replace the matched pattern. + delay_run (bool): If True, the replacement function will not be run until the matched pattern is found. + This is useful when we want to extract certain metavalue from the matched pattern and use it in the + replacement pattern. + """ + + def __init__(self, function, *, delay_run: bool = False): + self._function = function + self._delay_run = delay_run + + @property + def function(self) -> Callable: + return self._function + + @property + def delay_run(self) -> bool: + return self._delay_run + + # TODO: How do we merge it with to_ir function? + def get_pattern( + self, + *vars: Sequence[Var], + match_bindings: dict[str, ir.Value | Any] | None = None, + ) -> tuple[NodePattern | None, int | None]: + if self._delay_run: + if match_bindings is None: + return None, None + node_output_pattern = self._function(*vars, match_bindings) + else: + node_output_pattern = self._function(*vars) + return _handle_pattern_return_value(node_output_pattern) + + +class RewriteCache: + def __init__(self): + self._node_output_pattern_to_ir: dict[NodeOutputPattern, tuple[ir.Value, ir.Node]] = ( + dict() + ) + + def get_node_output_pattern( + self, node_output_pattern: NodeOutputPattern + ) -> tuple[ir.Value, ir.Node] | None: + return self._node_output_pattern_to_ir.get(node_output_pattern, None) + + def set_node_output_pattern_with_ir( + self, node_output_pattern: NodeOutputPattern, value: ir.Value, node: ir.Node + ) -> bool: + self._node_output_pattern_to_ir[node_output_pattern] = (value, node) + + +class RewriteRule: + def __init__( + self, + target_pattern: TargetPatternFunction | Callable | None = None, + replacement_pattern: ReplacementPatternFunction | Callable | None = None, + condition_function: Callable | None = None, + ) -> None: + """Create a rewrite rule. + + Args: + target_pattern: The pattern function that will be + matched against the IR. + replacement_pattern: The replacement function that + will be used to replace the matched pattern. + condition_function: The condition function that + will be used to check if the pattern matches the IR with ir.Values + constraints in consideration. + + """ + if target_pattern is None: + # NOTE: commute() generated rules will have target_pattern as None + # ReplacementPatternFunction is still needed in try_rewrite + assert replacement_pattern is None + assert condition_function is None + self._replacement_pattern = ReplacementPatternFunction(replacement_pattern) + return + elif replacement_pattern is None: + raise ValueError( + "replacement_pattern must be provided if target_pattern is provided" + ) + # TODO: Do we want to tolerate Callable inputs? + if callable(target_pattern): + target_pattern = TargetPatternFunction(target_pattern) + if callable(replacement_pattern): + replacement_pattern = ReplacementPatternFunction(replacement_pattern) + + self._target_pattern = target_pattern + self._replacement_pattern = replacement_pattern + self._condition_function = condition_function + + _pattern_vars = inspect.signature(self._target_pattern.function).parameters + _replacement_vars = inspect.signature(self._replacement_pattern.function).parameters + # TODO: accept _replacement_vars being subset of _pattern_vars? + assert len(_pattern_vars) == len(_replacement_vars) + + self._vars = [Var(v) for v in _pattern_vars] + # Get the last node pattern and number of outputs from the pattern function + self._target_node_pattern, self._target_num_outputs = self._target_pattern.get_pattern( + *self._vars + ) + # NOTE: Return Nones if the replacement pattern is delayed running + self._replace_node_pattern, _replacement_num_outputs = replacement_pattern.get_pattern( + *self._vars + ) + if _replacement_num_outputs is not None: + assert self._target_num_outputs == _replacement_num_outputs + + def matches(self, node: ir.Node, model: ir.Model) -> MatchResult: + """Check if the node from IR matches the pattern.""" + if len(node.outputs) != self._target_num_outputs: + return MatchResult.FAIL() + match = self._target_node_pattern.matches_node(node, model) + if ( + self._condition_function is not None + and match + and not self._condition_function(match.bindings) + ): + return MatchResult.FAIL() + return match + + def try_rewrite( + self, model: ir.Model, node: ir.Node + ) -> tuple[list[ir.Node], list[ir.Node]] | None: + """If the node matches the pattern, then replace the node with the replacement pattern.""" + match = self.matches(node, model) + if match: + if _valid_to_replace(match.values): + # NOTE: delayed running as the replacement pattern needs bindings + if self._replacement_pattern.delay_run: + # bindings will be consumed by the replacement function + self._replace_node_pattern, _replacement_num_outputs = ( + self._replacement_pattern.get_pattern( + *self._vars[:-1], match_bindings=match.bindings + ) + ) + assert self._target_num_outputs == _replacement_num_outputs + rewrite_cache = RewriteCache() + _, _to_insert = self._replace_node_pattern.to_ir( + model, match.bindings, self._target_num_outputs, rewrite_cache + ) + + return (match.values, _to_insert) + return None + + def apply_to_model(self, model: ir.Model, *, commute: bool = False): + # TODO(titaiwang): Why do we need RewriteRuleSet? + return RewriteRuleSet([self], commute=commute).apply_to_model(model) + + def count_matches(self, model: ir.Model, *, commute: bool = False): + return RewriteRuleSet([self], commute=commute).count_matches(model) + + def commute(self) -> list[RewriteRule]: + def replace_pattern(new_pattern): + """Return a shallow copy of self with node_pattern replaced by new_pattern.""" + rule = RewriteRule() + rule._condition_function = self._condition_function + rule._target_node_pattern = new_pattern + rule._target_num_outputs = self._target_num_outputs + rule._replace_node_pattern = self._replace_node_pattern + return rule + + return [replace_pattern(p) for p in self._target_node_pattern.commute()] + + +def _apply_deltas( + graph_or_function: ir.Graph | ir.Function, + deltas: list[tuple[int, tuple[list[ir.Node], list[ir.Node]]]], +): + """Applies deltas. + + This code is valid is the considered pattern has only one output. + In case of multi output replacements, there is not need to rename + the outputs. + + In case of multi-output design, the nodes may not be necessary inserted + all at the same position. To be convinced, you can take a pattern + producing two outputs, but the second one needs the first one and + another input appeared after the first outputs. What could be + the right place to inserted all of the node. + + The current implementation insert all the nodes at the same position + but checks there is not inconsistency. In that case, it fails. + We could reorder (long) or do more clever changes. + The reordering would probably happen not very often. + """ + nodes = graph_or_function.nodes + existing_ids = {id(n): (i, n) for i, n in enumerate(nodes)} + to_delete = set() + to_insert = {} + path_2 = False + + for i, delta in reversed(deltas): + if len(delta) == 3: + # multi-outut strategy + n_matches, deleted_nodes, inserted_nodes = delta + for d in deleted_nodes: + assert id(d) in existing_ids + to_delete.add(id(d)) + + # the position to insert must be chosen. + # we'll try position i + assert i not in to_insert # conflicts should avoid that case + to_insert[i] = inserted_nodes + + else: + deleted_nodes, inserted_nodes = delta + # Replace deleted nodes with inserted nodes. + # However, we merge the last deleted node and last inserted node + # to avoid replacing the values produced by the last deleted node + # in all places where they are used. So, we reuse the output + # values from the last deleted node and replace the node itself + # TODO: simplify this + last_deleted = deleted_nodes[-1] + last_inserted = inserted_nodes[-1] + + assert len(last_deleted.outputs) == len(last_inserted.outputs) + del last_inserted.outputs[:] + for v in last_deleted.outputs: + v.node = last_inserted + last_inserted.outputs.append(v) + + del nodes[i] + + for new_node in reversed(inserted_nodes): + nodes.insert(i, new_node) + # bind the outputs to the graph + for output_name, value in zip(new_node.output_names, new_node.outputs): + graph_or_function.values[output_name] = value + path_2 = True + + assert not to_delete or not path_2, ( + "Two different rules were applied. It will solved later. " + "Right now, the functions assumes all the changes come from one " + "rule." + ) + + if path_2: + for _, delta in deltas: + deleted_nodes, inserted_nodes = delta + inserted_input_output = [] + for nd in inserted_nodes: + inserted_input_output += nd.inputs + nd.outputs + for old_node in deleted_nodes[0:-1]: + # Delete intermediary outputs from graph that are not used as + # outputs of the graph + for output in old_node.outputs: + if not output.is_output and output not in inserted_input_output: + graph_or_function.values.pop(output.name) + nodes.remove(old_node) + + for i in to_delete: + position = existing_ids[i][0] + nodes[position] = None + + for position, insert in sorted(to_insert.items(), reverse=True): + for v in reversed(insert): + nodes.insert(position, v) + + position_to_delete = [] + for i, n in enumerate(nodes): + if n is None: + position_to_delete.append(i) + + for p in reversed(position_to_delete): + del nodes[p] + + +class RewriteRuleSet: + def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: + if commute: + rules = list(itertools.chain.from_iterable([rule.commute() for rule in rules])) + self.rules = rules + + def _apply_to_graph_or_function( + self, + model: ir.Model, + graph_or_function: ir.Graph | ir.Function, + ) -> int: + count = 0 + marked = set() + bridge = None + # NOTE: Rules should be prioritized in the order they are added to the RewriteRuleSet. + # And the graph is applied in order. + for rule in self.rules: + deltas = [] + for i, node in enumerate(graph_or_function.nodes): + if hasattr(rule, "pattern"): + from onnxscript.rewriter.generic_pattern import ( + GenericRewriteRule, + ModelWithGraphStructure, + ) + + assert isinstance( + rule, GenericRewriteRule + ), f"Unexpected type {type(rule)}" + # The successors and the predecessors do not change + # until the deltas are applied. We cache the structure + # to avoid building them again. + if bridge is None: + bridge = ModelWithGraphStructure(model) + delta = rule.try_rewrite(bridge, node) + else: + delta = rule.try_rewrite(model, node) + if delta is None: + continue + + matched_nodes, _ = delta[-2:] + + conflict = False + for n in matched_nodes: + if id(n) in marked: + # The same node cannot be matched twice with different patterns. + conflict = True + break + + if conflict: + # Some nodes are already marked as rewritten. + continue + + marked |= set(map(id, matched_nodes)) + + deltas.append((i, delta)) + count += 1 + + _apply_deltas(graph_or_function, deltas) + return count + + def apply_to_model(self, model: ir.Model) -> int: + assert isinstance(model, ir.Model) + count = self._apply_to_graph_or_function(model, model.graph) + for function in model.functions: + count += self._apply_to_graph_or_function(model, function) + return count + + def _count_matches_in_graph_or_function( + self, model: ir.Model, graph_or_funciton: ir.Graph | ir.Function + ) -> int: + count = 0 + for node in graph_or_funciton.nodes: + for rule in self.rules: + if rule.matches(node, model): + count += 1 + break + return count + + def count_matches(self, model: onnx.ModelProto | ir.Model): + if isinstance(model, onnx.ModelProto): + model = irbuilder.build_ir(model) + else: + assert isinstance(model, ir.Model) + count = self._count_matches_in_graph_or_function(model, model.graph) + for function in model.functions: + count += self._count_matches_in_graph_or_function(model, function) + return count diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py new file mode 100644 index 000000000..c6319adc9 --- /dev/null +++ b/onnxscript/rewriter/pattern_test.py @@ -0,0 +1,305 @@ +import logging +import unittest + +import numpy as np +import onnx.parser + +from onnxscript._legacy_ir import irbuilder, protobuilder +from onnxscript.rewriter import cast_constant_of_shape, pattern + +logger = logging.getLogger(__name__) +op = pattern.onnxop +msft_op = pattern.msft_op + + +class ReciprocalMulTest(unittest.TestCase): + def rule(self) -> pattern.RewriteRule: + def reciprocal_mul_pattern(x, y): + return (1 / x) * y + + def div(x, y): + return y / x + + return pattern.RewriteRule(reciprocal_mul_pattern, div) + + def test_single_match(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + c1 = Constant() + t1 = Div(c1, x) + z1 = Mul(t1, y) + z = Identity(z1) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 3) + + def test_failed_match(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + c1 = Constant() + t1 = Div(c1, x) + z1 = Mul(t1, y) + z = Identity(z1) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + self.assertEqual(count, 0) + self.assertEqual(len(ir.graph.nodes), 4) + + def test_multiple_matches(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + # {c1, t1, z1} is a valid match + # {c2, t2, z2} is a valid match + # {c3, t3, z3} is a match, but cannot be replaced since t3 has other-uses. + c1 = Constant() + c2 = Constant() + t2 = Div(c2, y) + t1 = Div(c1, x) + z1 = Mul(t1, y) + z2 = Mul(t2, z1) + + c3 = Constant() + t3 = Div(c3, x) + z3 = Mul(t3, y) + reuse_t3 = Div(t3, x) + z = Add(z2, reuse_t3) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + self.assertEqual(count, 2) + self.assertEqual(len(ir.graph.nodes), 9) + + +class FastGeluTest(unittest.TestCase): + def rule(self) -> pattern.RewriteRule: + def fast_gelu_pattern1(x): + b = 0.044715 + c = 0.79788 + tanh = op.Tanh(c * (x + (x**3) * b)) + return (1.0 + tanh) * (0.5 * x) + + def fast_gelu(x): + return msft_op.FastGelu(x) + + return pattern.RewriteRule(fast_gelu_pattern1, fast_gelu) + + def long_form_rule(self) -> pattern.RewriteRule: + def fast_gelu_pattern1_long(x): + three = pattern.Constant(3) + x_cube = op.Pow(x, three) + b = pattern.Constant(0.044715) + x_cube_mul_b = op.Mul(x_cube, b) # support OR op.Mul(B, x_cube) + sum_ = op.Add(x, x_cube_mul_b) + c = pattern.Constant(0.79788) + c_times_sum = op.Mul(c, sum_) + tanh = op.Tanh(c_times_sum) + one = pattern.Constant(1.0) + one_plus_tanh = op.Add(one, tanh) + half = pattern.Constant(0.5) + half_x = op.Mul(half, x) + return op.Mul(one_plus_tanh, half_x) + + def fast_gelu(x): + return msft_op.FastGelu(x) + + return pattern.RewriteRule(fast_gelu_pattern1_long, fast_gelu) + + def _check(self, rule): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[N] z) + { + three = Constant () + x_cube = Pow(x, three) + B = Constant () + x_cube_mul_B = Mul(x_cube, B) + sum = Add(x, x_cube_mul_B) + C = Constant () + C_times_sum = Mul(C, sum) + tanh = Tanh(C_times_sum) + one = Constant () + one_plus_tanh = Add(one, tanh) + half = Constant () + half_x = Mul(half, x) + z = Mul(one_plus_tanh, half_x) + } + """ + ) + ir = irbuilder.build_ir(model) + count = rule.apply_to_model(ir) + self.assertEqual(count, 1) + # 5 Constant nodes and 1 FastGelu node + self.assertEqual(len(ir.graph.nodes), 6) + + def test_short_rule(self): + self._check(self.rule()) + + def test_long_rule(self): + self._check(self.long_form_rule()) + + +class ConcatTest(unittest.TestCase): + def rule(self) -> pattern.RewriteRule: + def concat_pattern(x, y, axis): + seq = op.SequenceConstruct(x, y) + return op.ConcatFromSequence(seq, axis=axis) + + def concat(x, y, axis): + return op.Concat(x, y, axis=axis) + + return pattern.RewriteRule(concat_pattern, concat) + + def test_concat(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[N] y) => (float[M] z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.graph.nodes), 1) + + def test_concat_in_function(self): + model = onnx.parser.parse_model( + """ + + agraph (float[N] x, float[M] y) => (float[Z] z) + { + z = afunction (x, y) + } + + afunction (x, y) => (z) + { + t = SequenceConstruct (x, y) + z = ConcatFromSequence (t) + } + """ + ) + ir = irbuilder.build_ir(model) + count = self.rule().apply_to_model(ir) + self.assertEqual(count, 1) + self.assertEqual(len(ir.functions), 1) + self.assertEqual(len(ir.functions[0].nodes), 1) + self.assertEqual(ir.functions[0].nodes[0].op_type, "Concat") + + +class RewriteRuleTest(unittest.TestCase): + def test_commute(self): + def add_0(x): + return x + 0 + + def identity(x): + return op.Identity(x) + + add_0_rule = pattern.RewriteRule(add_0, identity) + + model = onnx.parser.parse_model( + """ + + agraph (float[N] x) => (float[M] z) + { + zero = Constant () + z = Add (zero, x) + } + """ + ) + ir = irbuilder.build_ir(model) + count = pattern.RewriteRuleSet([add_0_rule], commute=True).apply_to_model(ir) + optimized_model = protobuilder.build_model_proto(ir) + self.assertEqual(count, 1) + nodes = optimized_model.graph.node + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[1].op_type, "Identity") + + def test_const_value(self): + def reshape(x, newshape): + return op.Reshape(x, newshape) + + def identity(x, newshape): + del newshape # Unused + return op.Identity(x) + + def _check_for_redundant_reshape(x, newshape): + oldshape = x.shape + if not isinstance(oldshape, list): + return False + newshape = newshape.value_as_np_array + if not isinstance(newshape, np.ndarray): + return False + newshape = newshape.tolist() + + if len(oldshape) != len(newshape): + return False + return all(not (d1 != d2 and d2 != -1) for d1, d2 in zip(oldshape, newshape)) # pylint: disable=consider-using-in + + def check_for_redundant_reshape(bindings): + return _check_for_redundant_reshape(**bindings) + + rule = pattern.RewriteRule(reshape, identity, check_for_redundant_reshape) + + model = onnx.parser.parse_model( + """ + + agraph (float[10, 20, 30] x) => (float[10, 20, 30] z) + { + shape = Constant () + z = Reshape (x, shape) + } + """ + ) + ir = irbuilder.build_ir(model) + count = pattern.RewriteRuleSet([rule]).apply_to_model(ir) + optimized_model = protobuilder.build_model_proto(ir) + self.assertEqual(count, 1) + nodes = optimized_model.graph.node + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[1].op_type, "Identity") + + def test_delayed_run_provides_correct_bindings_for_multiple_matches(self): + model = onnx.parser.parse_model( + """ + + agraph (int64[2] input_x) => (float16[1, 4] output, float[1, 4] output2) + { + constant = ConstantOfShape (input_x) + output = Cast (constant) + constant2 = ConstantOfShape (input_x) + output2 = Cast (constant2) + } + """ + ) + ir = irbuilder.build_ir(model) + count = cast_constant_of_shape.rules.apply_to_model(ir) + self.assertEqual(count, 2) + self.assertEqual(len(ir.graph.nodes), 2) + self.assertEqual(ir.graph.nodes[0].attributes["value"].data_type, 10) + self.assertEqual(ir.graph.nodes[1].attributes["value"].data_type, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/testing.py b/onnxscript/testing/__init__.py similarity index 66% rename from onnxscript/testing.py rename to onnxscript/testing/__init__.py index 8c927b091..e62a44d9a 100644 --- a/onnxscript/testing.py +++ b/onnxscript/testing/__init__.py @@ -1,16 +1,21 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Public utilities for testing onnxscript.""" +from __future__ import annotations +__all__ = [ + "assert_isomorphic", + "assert_isomorphic_graph", + "assert_isomorphic_function", + "assert_onnx_proto_equal", +] + +import difflib +from typing import Any, Collection, Sequence + +import google.protobuf.message import onnx from onnx import parser import onnxscript -__all__ = ["assert_isomorphic", "assert_isomorphic_graph", "assert_isomorphic_function"] - def assert_isomorphic(graph_or_function_1, graph_or_function_2): """Assert two graphs or functions are isomorphic.""" @@ -340,3 +345,119 @@ def _to_function_or_graph(obj): if isinstance(obj, onnxscript.OnnxFunction): return obj.to_function_proto() raise TypeError(f"Cannot convert {type(obj)} to FunctionProto or GraphProto") + + +def _opset_import_key(opset_import: onnx.OperatorSetIdProto) -> tuple[str, int]: + return (opset_import.domain, opset_import.version) + + +def _value_info_key(value_info: onnx.ValueInfoProto) -> str: + return value_info.name + + +def _function_key(function: onnx.FunctionProto) -> tuple[str, str, str]: + return (function.domain, function.name, getattr(function, "overload", "")) + + +def _find_duplicates(with_duplicates: Collection[Any]) -> list[Any]: + """Return a list of duplicated elements in a collection.""" + seen = set() + duplicates = [] + for x in with_duplicates: + if x in seen: + duplicates.append(x) + seen.add(x) + return duplicates + + +def assert_onnx_proto_equal( + a: google.protobuf.message.Message | Any, b: google.protobuf.message.Message | Any +) -> None: + """Assert that two ONNX protos are equal. + + Equality is defined as having the same fields with the same values. When + a field takes the default value, it is considered equal to the field + not being set. + + Sequential fields with name `opset_import`, `value_info`, and `functions` are + compared disregarding the order of their elements. + + Args: + a: The first ONNX proto. + b: The second ONNX proto. + """ + assert type(a) == type(b), f"Type not equal: {type(a)} != {type(b)}" # pylint: disable=unidiomatic-typecheck + + a_fields = {field.name: value for field, value in a.ListFields()} + b_fields = {field.name: value for field, value in b.ListFields()} + all_fields = sorted(set(a_fields.keys()) | set(b_fields.keys())) + for field in all_fields: + # Obtain the default value if the field is not set. This way we can compare the two fields. + a_value = getattr(a, field) + b_value = getattr(b, field) + if ( + isinstance(a_value, Sequence) + and isinstance(b_value, Sequence) + and not isinstance(a_value, (str, bytes)) + and not isinstance(b_value, (str, bytes)) + ): + # Check length first + a_keys: list[Any] = [] + b_keys: list[Any] = [] + if field == "opset_import": + a_value = sorted(a_value, key=_opset_import_key) + b_value = sorted(b_value, key=_opset_import_key) + a_keys = [_opset_import_key(opset_import) for opset_import in a_value] + b_keys = [_opset_import_key(opset_import) for opset_import in b_value] + elif field == "value_info": + a_value = sorted(a_value, key=_value_info_key) + b_value = sorted(b_value, key=_value_info_key) + a_keys = [_value_info_key(value_info) for value_info in a_value] + b_keys = [_value_info_key(value_info) for value_info in b_value] + elif field == "functions": + a_value = sorted(a_value, key=_function_key) + b_value = sorted(b_value, key=_function_key) + a_keys = [_function_key(functions) for functions in a_value] + b_keys = [_function_key(functions) for functions in b_value] + + if a_keys != b_keys: + keys_only_in_a = set(a_keys) - set(b_keys) + keys_only_in_b = set(b_keys) - set(a_keys) + error_message = ( + f"Field {field} not equal: keys_only_in_a={keys_only_in_a}, keys_only_in_b={keys_only_in_b}. " + f"Field type: {type(a_value)}. " + f"Duplicated a_keys: {_find_duplicates(a_keys)}, duplicated b_keys: {_find_duplicates(b_keys)}" + ) + raise AssertionError(error_message) + if len(a_value) != len(b_value): + error_message = ( + f"Field {field} not equal: len(a)={len(a_value)}, len(b)={len(b_value)} " + f"Field type: {type(a_value)}" + ) + raise AssertionError(error_message) + # Check every element + for i in range(len(a_value)): # pylint: disable=consider-using-enumerate + a_value_i = a_value[i] + b_value_i = b_value[i] + if isinstance(a_value_i, google.protobuf.message.Message) and isinstance( + b_value_i, google.protobuf.message.Message + ): + try: + assert_onnx_proto_equal(a_value_i, b_value_i) + except AssertionError as e: + error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}, a_value_i: {a_value_i}, b_value_i: {b_value_i}" + raise AssertionError(error_message) from e + elif a_value_i != b_value_i: + error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}" + for line in difflib.ndiff( + str(a_value_i).splitlines(), str(b_value_i).splitlines() + ): + error_message += "\n" + line + raise AssertionError(error_message) + elif isinstance(a_value, google.protobuf.message.Message) and isinstance( + b_value, google.protobuf.message.Message + ): + assert_onnx_proto_equal(a_value, b_value) + elif a_value != b_value: + error_message = f"Field {field} not equal. field_a: {a_value}, field_b: {b_value}" + raise AssertionError(error_message) diff --git a/onnxscript/tests/common/testutils.py b/onnxscript/tests/common/testutils.py deleted file mode 100644 index 9ce4a0c37..000000000 --- a/onnxscript/tests/common/testutils.py +++ /dev/null @@ -1,14 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import unittest - - -class TestBase(unittest.TestCase): - """The base class for testing ONNX Script functions for internal use.""" - - def validate(self, fn): - """Validate script function translation.""" - return fn.to_function_proto() diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index a4b64485e..18728ae76 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -12,7 +12,7 @@ import onnxscript.testing from onnxscript import FLOAT, INT64, script, type_annotation from onnxscript.onnx_opset import opset15 as op -from onnxscript.tests.common import testutils +from tests.common import testutils class TypeAnnotationTest(testutils.TestBase): diff --git a/onnxscript/utils/__init__.py b/onnxscript/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/onnxscript/utils/evaluation_utils.py b/onnxscript/utils/evaluation_utils.py new file mode 100644 index 000000000..eb93b79cb --- /dev/null +++ b/onnxscript/utils/evaluation_utils.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import pathlib + +import numpy as np +import onnx +from onnx import helper as onnx_helper + + +def load_test_data( + qual_model_dir: str, input_names: list[str] +) -> tuple[dict[str, np.ndarray], list[np.ndarray]]: + test_data_dir = pathlib.Path(qual_model_dir) / "test_data_set_0" + inputs = {} + expected_outputs = [] + for test_data in test_data_dir.glob("input_*.pb"): + idx = int(test_data.stem[len("input_") :]) + input_name = input_names[idx] + input_data = onnx.TensorProto() + with open(test_data, "rb") as f: + input_data.ParseFromString(f.read()) + inputs[input_name] = onnx.numpy_helper.to_array(input_data) + + output_file_paths = list(test_data_dir.glob("output_*.pb")) + expected_outputs = [None] * len(output_file_paths) + for test_data in test_data_dir.glob("output_*.pb"): + idx = int(test_data.stem[len("output_") :]) + output_data = onnx.TensorProto() + with open(test_data, "rb") as f: + output_data.ParseFromString(f.read()) + expected_outputs[idx] = onnx.numpy_helper.to_array(output_data) # type: ignore[call-overload] + + assert all(name in inputs for name in input_names), "Some inputs are missing." + assert not any(output is None for output in expected_outputs), "Some outputs are missing." + + return inputs, expected_outputs # type: ignore[return-value] + + +def generate_random_input(model: onnx.ModelProto) -> dict[str, np.ndarray]: + """Generate random input for the model. + + NOTE: This is unused. There is parity issue with randomly generated data. Need investigation. + """ + inputs = {} + for _, input in enumerate(model.graph.input): + shape = [d.dim_value for d in input.type.tensor_type.shape.dim] + np_dtype = onnx_helper.tensor_dtype_to_np_dtype(input.type.tensor_type.elem_type) + if np_dtype is None: + raise ValueError(f"Unsupported dtype: {input.type.tensor_type.elem_type}") + if np_dtype in (np.float16, np.float32, np.float64): + inputs[input.name] = np.random.rand(*shape).astype(np_dtype) - 0.5 + else: + inputs[input.name] = np.random.randint(3, 100, size=shape, dtype=np_dtype) + return inputs diff --git a/onnxscript/utils/timing_utils.py b/onnxscript/utils/timing_utils.py new file mode 100644 index 000000000..6805a7e19 --- /dev/null +++ b/onnxscript/utils/timing_utils.py @@ -0,0 +1,33 @@ +import time + +import onnx + +from onnxscript import optimizer + +# from onnxscript.rewriter.rules import all_rules + + +def timeit(f, message): + def timed(*args, **kw): + ts = time.time() + result = f(*args, **kw) + te = time.time() + print(f"{message} time: {te-ts}") + return result + + return timed + + +load = timeit(onnx.load, "Load") + +save = timeit(onnx.save, "Save") + +infer = timeit(onnx.shape_inference.infer_shapes, "Infer") + +fold_constants = timeit(optimizer.fold_constants, "Fold Constants") + +remove_unused = timeit(optimizer.remove_unused_nodes, "Remove Unused") + +optimize = timeit(optimizer.optimize, "Optimize") + +# rewrite = timeit(all_rules.apply_to_model, "Rewrite") diff --git a/onnxscript/utils/utils.py b/onnxscript/utils/utils.py new file mode 100644 index 000000000..26ef525b1 --- /dev/null +++ b/onnxscript/utils/utils.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from typing import Any + +import onnx + + +def normalize_domain(d: str) -> str: + return "" if d == "ai.onnx" else d + + +def is_onnx_domain(d: str) -> bool: + return normalize_domain(d) == "" + + +def is_onnx_op(node: onnx.NodeProto, op_type: str) -> bool: + return is_onnx_domain(node.domain) and node.op_type == op_type + + +def is_control_flow_op(node: onnx.NodeProto) -> bool: + return any(attr.HasField("g") or len(attr.graphs) > 0 for attr in node.attribute) + + +def get_node_attr_value(node: onnx.NodeProto, attr_name: str, default: Any) -> Any: + matching = [x for x in node.attribute if x.name == attr_name] + if len(matching) > 1: + raise ValueError(f"Node has multiple attributes with name {attr_name}") + if len(matching) < 1: + return default + return onnx.helper.get_attribute_value(matching[0]) + + +def get_initializer_type(initializer: onnx.TensorProto) -> onnx.TypeProto: + type = onnx.TypeProto() + type.tensor_type.elem_type = initializer.data_type + dims = type.tensor_type.shape.dim + for dim in initializer.dims: + dims.add().dim_value = dim + return type + + +def get_constant_node_value(node: onnx.NodeProto, name: str) -> onnx.TensorProto | None: + if ( + node.op_type != "Constant" + or node.domain not in {"", "ai.onnx"} + or len(node.attribute) != 1 + ): + return None + attr = node.attribute[0] + if attr.ref_attr_name: + return None + attr_name = attr.name + value = onnx.helper.get_attribute_value(attr) + + if isinstance(value, onnx.TensorProto): + # Two names exist in this case: we use tensorproto as is (with original name) + return value + shape: list[int] + if attr_name == "value_int": + dtype = onnx.TensorProto.INT64 + shape = [] + value = [value] + elif attr_name == "value_float": + dtype = onnx.TensorProto.FLOAT + shape = [] + value = [value] + elif attr_name == "value_string": + dtype = onnx.TensorProto.STRING + shape = [] + value = [value] + elif attr_name == "value_ints": + dtype = onnx.TensorProto.INT64 + shape = [len(value)] + elif attr_name == "value_floats": + dtype = onnx.TensorProto.FLOAT + shape = [len(value)] + elif attr_name == "value_strings": + dtype = onnx.TensorProto.STRING + shape = [len(value)] + else: + return None # sparse tensors not handled + return onnx.helper.make_tensor(name, dtype, shape, value) diff --git a/pyproject.toml b/pyproject.toml index 2880f386f..57b607caa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,21 +67,55 @@ disallow_untyped_defs = true [[tool.mypy.overrides]] module = [ "setup", - "onnxscript.tests.models.*", - "onnxscript.tests.onnx_backend_test_code.*", + "tests.models.*", + "tests.onnx_backend_test_code.*", ] ignore_errors = true +[[tool.mypy.overrides]] +module = [ + "onnxrewriter.rewriter.generic_pattern_test.*", +] +check_untyped_defs = false +disable_error_code = 'override,import-untyped,no-untyped-def,assignment' +disallow_incomplete_defs = true +disallow_untyped_defs = true +disallow_untyped_decorators = true +show_column_numbers = true +strict_optional = true +warn_incomplete_stub = true +warn_no_return = true +warn_redundant_casts = true +warn_unused_configs = true +warn_unused_ignores = false + +[[tool.mypy.overrides]] +module = [ + "onnxrewriter.rewriter.generic_pattern.*", +] +check_untyped_defs = false +disable_error_code = 'override,import-untyped,no-untyped-def,assignment,union-attr,func-returns-value,annotation-unchecked,arg-type,index,name-defined,attr-defined' +disallow_incomplete_defs = true +disallow_untyped_defs = true +disallow_untyped_decorators = true +show_column_numbers = true +strict_optional = true +warn_incomplete_stub = true +warn_no_return = true +warn_redundant_casts = true +warn_unused_configs = true +warn_unused_ignores = false + [tool.black] target-version = ["py38", "py39", "py310", "py311"] # Black's extend-exclude needs to be a regex string -extend-exclude = "/onnxscript/tests/models|/onnxscript/tests/onnx_backend_test_code" +extend-exclude = "/tests/models|/tests/onnx_backend_test_code" line-length = 95 [tool.isort] profile = "black" extend_skip_glob = [ - "onnxscript/tests/onnx_backend_test_code/*.py", + "tests/onnx_backend_test_code/*.py", ] [tool.pylint.messages_control] @@ -157,8 +191,15 @@ ignore = [ line-length = 95 ignore-init-module-imports = true +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"pathlib".msg = "Using pathlib can impact performance. Use os.path instead" + [tool.ruff.per-file-ignores] "__init__.py" = ["TID252"] # Allow relative imports in init files +"**/{examples,tests,docs,tools,utils}/*" = ["TID251"] # pathlib is allowed in supporting code +"**/*_test.py" = ["TID251"] # pathlib is allowed in tests +"**/generic_pattern.py" = ["FBT003", "UP037"] # inline ignoring fails +"**/generic_pattern_test.py" = ["ARG001", "ARG002", "PLR2004"] [tool.ruff.flake8-tidy-imports] # Disallow all relative imports. diff --git a/requirements-dev.txt b/requirements-dev.txt index 4d4ecdcd0..0a11bf9bf 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,7 @@ numpy onnx-weekly>=1.17.0.dev20240325 onnxruntime>=1.17.0 typing_extensions -rich +rich>=13.7.1 # Docs site furo diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/Speech2Text2ForCausalLM_dynamo.onnx b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/Speech2Text2ForCausalLM_dynamo.onnx new file mode 100644 index 000000000..e0d380b46 --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/Speech2Text2ForCausalLM_dynamo.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06d78f841f26ec59cea1d15dd2c2a086cb907d6644ef8dac15e6d366935413e8 +size 43087292 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/input_0.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/input_0.pb new file mode 100644 index 000000000..2cad62451 --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e3134f09a4c5b06cbdf29158ccf758554f8844403c616c2905bfa26a99a8d0a0 +size 1034 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/input_1.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/input_1.pb new file mode 100644 index 000000000..d603c1f03 --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:951a7cddaed4a979aadd89c63f90517920f1b2ba5ad008393ef502b58b88535b +size 1034 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_0.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_0.pb new file mode 100644 index 000000000..aa6e4e687 --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5f4d3b73ebce58ef8baa8b74f2aabff8b31337baf133fbd4db7be028d824b39 +size 8 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_1.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_1.pb new file mode 100644 index 000000000..5de4c752c --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27c22e64073f98615cf492bf813a6c0ae44261094dbe368f47eeb91a14f8a7f3 +size 5120015 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_10.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_10.pb new file mode 100644 index 000000000..74714ea41 --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_10.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d599ff9f092e837ee9825af7b0f5f0239c43bbde485834fc971e4a10e5d6315 +size 131087 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_11.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_11.pb new file mode 100644 index 000000000..3e5317f3e --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_11.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:400d856a814258cbf95320dfcb36ecc8a5b47f8295fe034944ba51a14bc036c1 +size 131087 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_12.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_12.pb new file mode 100644 index 000000000..b069d7b77 --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_12.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c00b77782b8793a2dc8b3af42578c3ff3467248e536ee1cb7f85859da085b8b +size 131087 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_13.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_13.pb new file mode 100644 index 000000000..8196efefe --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_13.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58ba0f2a898467cd5143c46a3389b7a7d572ebd4a7264c2264833ec69d14c803 +size 131087 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_2.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_2.pb new file mode 100644 index 000000000..a88b4121f --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:015c8d9c96efbb22043be290f9a73ea1a74dcb3cf33d2e8fe6c13b7f6cb3f92c +size 131087 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_3.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_3.pb new file mode 100644 index 000000000..369d9aed5 --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b8fedb07987d6d782b3fa8d4dc07b1c91eafbb4059bc78cbf3b78a2f6759d4e +size 131087 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_4.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_4.pb new file mode 100644 index 000000000..2ee93f5c3 --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1024114644e9c56161245fe9bfcc92a4b5788f6edd19278d50d6234b3b314cfa +size 131087 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_5.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_5.pb new file mode 100644 index 000000000..dbca83f91 --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63b19f4cf067fd128e37d5acde172f292ff4cd25aabaeeaf94e2dbff982d0583 +size 131087 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_6.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_6.pb new file mode 100644 index 000000000..538c73f5b --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff2e579758329f7c0b44d2dcaca4b8afda5bdd9b169409ba20a83bb2aed495d5 +size 131087 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_7.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_7.pb new file mode 100644 index 000000000..2001498a9 --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12c7bbb02f1da162fd425fa59e82e8f783e79e0f0f849b42a7853dbf582c017b +size 131087 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_8.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_8.pb new file mode 100644 index 000000000..18cd9a3f4 --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_8.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18665f3e0925a5d16cdb953cd6b3739e077e5713fbabdfcf987cdedde8f8c661 +size 131087 diff --git a/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_9.pb b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_9.pb new file mode 100644 index 000000000..8bc909f47 --- /dev/null +++ b/testdata/e2e_models/Speech2Text2ForCausalLM/dynamo/test_data_set_0/output_9.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d078a9178afaaefb2a87aaed53558b7b560b02f37d747c9bd82d61cb67eb5ab5 +size 131087 diff --git a/testdata/e2e_models/mobilenetv2_100/dynamo/mobilenetv2_100_dynamo.onnx b/testdata/e2e_models/mobilenetv2_100/dynamo/mobilenetv2_100_dynamo.onnx new file mode 100644 index 000000000..2eede96c9 --- /dev/null +++ b/testdata/e2e_models/mobilenetv2_100/dynamo/mobilenetv2_100_dynamo.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a336102b11d8439daa2c1a164a851f34414529a5610a046943fd869b1b44336f +size 14665355 diff --git a/testdata/e2e_models/mobilenetv2_100/dynamo/test_data_set_0/input_0.pb b/testdata/e2e_models/mobilenetv2_100/dynamo/test_data_set_0/input_0.pb new file mode 100644 index 000000000..dbc92487a --- /dev/null +++ b/testdata/e2e_models/mobilenetv2_100/dynamo/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9b234da9285eb41d80ba90c483e725c7fbe8fbf3323eb513a1c57468e89d5b1 +size 602128 diff --git a/testdata/e2e_models/mobilenetv2_100/dynamo/test_data_set_0/output_0.pb b/testdata/e2e_models/mobilenetv2_100/dynamo/test_data_set_0/output_0.pb new file mode 100644 index 000000000..d0df1d52f --- /dev/null +++ b/testdata/e2e_models/mobilenetv2_100/dynamo/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a50d91a03784ec6a6f650b58b62f6ca3bdef9483f94cbb62911b86991680f3a +size 4010 diff --git a/testdata/e2e_models/resnet18/dynamo/resnet18_dynamo.onnx b/testdata/e2e_models/resnet18/dynamo/resnet18_dynamo.onnx new file mode 100644 index 000000000..61122be18 --- /dev/null +++ b/testdata/e2e_models/resnet18/dynamo/resnet18_dynamo.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31fbebb580ff85ed8eefa7fb95d4e2cbda41fe267afeaae2d4f4177264d1f4e7 +size 46918368 diff --git a/testdata/e2e_models/resnet18/dynamo/test_data_set_0/input_0.pb b/testdata/e2e_models/resnet18/dynamo/test_data_set_0/input_0.pb new file mode 100644 index 000000000..c587cd85b --- /dev/null +++ b/testdata/e2e_models/resnet18/dynamo/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:144678f435f995960b1debb8b7c9125b2f68f65d5b175dee9c355be59949b199 +size 602128 diff --git a/testdata/e2e_models/resnet18/dynamo/test_data_set_0/output_0.pb b/testdata/e2e_models/resnet18/dynamo/test_data_set_0/output_0.pb new file mode 100644 index 000000000..3e12dc65a --- /dev/null +++ b/testdata/e2e_models/resnet18/dynamo/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4936df2381104ffe50ad98d42fadba95fbd295308a17766f4929db4f59957e29 +size 4010 diff --git a/testdata/unittest_models/attn_llama2_4_34_0/attn_llama2_4_34_0.onnx b/testdata/unittest_models/attn_llama2_4_34_0/attn_llama2_4_34_0.onnx new file mode 100644 index 000000000..39ca15812 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_0/attn_llama2_4_34_0.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:22c54ffbd19ec835f054ce4da66c34605654cdaeaab7e4cddb7afdf5daaa9e77 +size 30233 diff --git a/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_0.pb b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_0.pb new file mode 100644 index 000000000..2ff266b26 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:707e2139fa3b9db89e8667ad7b1829fbf1c05c678f064dbbdc855e77cee50b32 +size 4194319 diff --git a/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_1.pb b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_2.pb b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_2.pb new file mode 100644 index 000000000..ed46c3425 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d77af46feb251c72b7df238f10c3f2e3bd9baf84813dd0efd415105948e2adb8 +size 524304 diff --git a/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_3.pb b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_3.pb new file mode 100644 index 000000000..3faff2d0e --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d34d627c716ad6e7cf2a302c29605655d94cd7bbceba7b02741612acdafbe27 +size 33554445 diff --git a/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_4.pb b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_4.pb new file mode 100644 index 000000000..050a63c70 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6fd6f8eb0ae294d2885e027e7c842377b48a95d4abfe0f4b4a0a191d9c846207 +size 33554445 diff --git a/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_5.pb b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_5.pb new file mode 100644 index 000000000..dab47834b --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83ac8df43b3731643c69f628c9842a0b4b9c1a97d544a15906743a3071eca3f9 +size 33554445 diff --git a/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_6.pb b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_6.pb new file mode 100644 index 000000000..2c611eb78 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5a18f5d0f1f88bc8d5a403763a11354846870bf8dee98d4fd5aed3eeb3c919a +size 1048592 diff --git a/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_7.pb b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_7.pb new file mode 100644 index 000000000..cf4644605 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:23b0dfc60d794b1291cf52f9b65b9190485e3c07976947467c1bdba6cc283443 +size 1048592 diff --git a/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_8.pb b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_8.pb new file mode 100644 index 000000000..44398410c --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/input_8.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:983c5ebdb5e94954b10bb9dc90d9f21dc1b0d0618a6563db48b7dadcf5697f57 +size 33554445 diff --git a/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/output_0.pb b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/output_0.pb new file mode 100644 index 000000000..652744d71 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d5cc227ef6865239083c47b73ad5b051fb8b348f61d6b0b56cdf7eafc56618ee +size 4194321 diff --git a/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/output_1.pb b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/output_1.pb new file mode 100644 index 000000000..5ff5523e9 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74999e531bab9143795421abf912611699e976c95458c98f0c11236b3a6fd935 +size 4194321 diff --git a/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/output_2.pb b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/output_2.pb new file mode 100644 index 000000000..b232c6440 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_0/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7fe871ab7d24854fb3e216dbbb9e150204594685c1babdeb3debedd380ea605c +size 4194319 diff --git a/testdata/unittest_models/attn_llama2_4_34_1/attn_llama2_4_34_1.onnx b/testdata/unittest_models/attn_llama2_4_34_1/attn_llama2_4_34_1.onnx new file mode 100644 index 000000000..e0b2cedd6 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_1/attn_llama2_4_34_1.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ed91e90b6f53c400de9576f9e3cd148575d836e4e0326e1775acf9e7462b8cc +size 30349 diff --git a/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_0.pb b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_0.pb new file mode 100644 index 000000000..3c60034a6 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fdace10343a44a6c0d71d2bfb3a0f9c27df1f787120eb40e62d5f10c967a187f +size 4194319 diff --git a/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_1.pb b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_2.pb b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_2.pb new file mode 100644 index 000000000..ed46c3425 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d77af46feb251c72b7df238f10c3f2e3bd9baf84813dd0efd415105948e2adb8 +size 524304 diff --git a/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_3.pb b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_3.pb new file mode 100644 index 000000000..fa7de69d7 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6516581d07d3c0128caf311423f81ac04a11ed791126d68cfedc60bf0b5116bf +size 33554445 diff --git a/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_4.pb b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_4.pb new file mode 100644 index 000000000..224633ff7 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:276c7a6a379955acf2c6bf1fbb69e7b1914ff420801e280d359f5f264415e981 +size 33554445 diff --git a/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_5.pb b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_5.pb new file mode 100644 index 000000000..27e12e914 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca7d1ae2782defbe040a7129d40c1594783f282e4ffb63500269176d76a52743 +size 33554445 diff --git a/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_6.pb b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_6.pb new file mode 100644 index 000000000..2c611eb78 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5a18f5d0f1f88bc8d5a403763a11354846870bf8dee98d4fd5aed3eeb3c919a +size 1048592 diff --git a/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_7.pb b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_7.pb new file mode 100644 index 000000000..cf4644605 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:23b0dfc60d794b1291cf52f9b65b9190485e3c07976947467c1bdba6cc283443 +size 1048592 diff --git a/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_8.pb b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_8.pb new file mode 100644 index 000000000..ade7db3c2 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/input_8.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4aca46405c62018d694d58f3b65eaad75bad998f3452a1c11a3310062e637f4 +size 33554445 diff --git a/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/output_0.pb b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/output_0.pb new file mode 100644 index 000000000..1b063d6ef --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbc69d1383d85f23721912806104a8e74157bd11113da70c2c737f1eabccb77b +size 4194321 diff --git a/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/output_1.pb b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/output_1.pb new file mode 100644 index 000000000..d98197af0 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70706688fd0d87108dedd615e7b32c4009c0821d9022701df3c3dd52c0c7a0a1 +size 4194321 diff --git a/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/output_2.pb b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/output_2.pb new file mode 100644 index 000000000..a0a0f4fe3 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_34_1/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c766421e7c87ed4d3be345aab6b1c32bcb229a8d122b683279cef9b123069beb +size 4194319 diff --git a/testdata/unittest_models/attn_llama2_4_36_0/attn_llama2_4_36_0.onnx b/testdata/unittest_models/attn_llama2_4_36_0/attn_llama2_4_36_0.onnx new file mode 100644 index 000000000..f270c5332 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_36_0/attn_llama2_4_36_0.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:791944ad35bae89819a0a3448c8cb8a733a0a9bdeefff085af5cb3028bebdf16 +size 800895 diff --git a/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_0.pb b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_0.pb new file mode 100644 index 000000000..e60e700f6 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dcf6c8416670332de11f3fe6efc6e41197ba7c175394c0fc38a98710bfc4ff43 +size 4194319 diff --git a/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_1.pb b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_2.pb b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_2.pb new file mode 100644 index 000000000..ed46c3425 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d77af46feb251c72b7df238f10c3f2e3bd9baf84813dd0efd415105948e2adb8 +size 524304 diff --git a/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_3.pb b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_3.pb new file mode 100644 index 000000000..42b063627 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a994e919c415c5b8acdca27fb87fcd06d7036d646fc0a770929d09024e6d96ae +size 33554445 diff --git a/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_4.pb b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_4.pb new file mode 100644 index 000000000..8c209fce7 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e2fa19eab256f8a81b4757e37f99c49aaba1333a55a2c96fabc989d4de2d0f3 +size 33554445 diff --git a/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_5.pb b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_5.pb new file mode 100644 index 000000000..4a0bd6e9e --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db733c3917275c5341cba7bc2a64b5a2dc1abb1c4c01ae5852031eb94e0e2f0c +size 33554445 diff --git a/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_6.pb b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_6.pb new file mode 100644 index 000000000..f29b63640 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a44bf5864390eef1f1b54f464b619a7a283041dd6ce5da65150af5c708051f9 +size 524300 diff --git a/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_7.pb b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_7.pb new file mode 100644 index 000000000..634e449a0 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d15975f1d254f650321ccdd14c3157d9eac1ca2cf3567e6c159bcdd278488f82 +size 524300 diff --git a/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_8.pb b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_8.pb new file mode 100644 index 000000000..a286694f7 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/input_8.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aad340617afb9b743c37a1c6fd08ec2e509c4b22642a81bb60aa29d9f27d5c1c +size 33554445 diff --git a/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/output_0.pb b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/output_0.pb new file mode 100644 index 000000000..cb32480a7 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e85913fc823ad2058c3c339a2aad4160e1fcbcab5a9881c46ca390256c1f17be +size 4194321 diff --git a/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/output_1.pb b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/output_1.pb new file mode 100644 index 000000000..a3dd4c830 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:416bb2a6eb20bc2312862c6ff475b23a9975935f4de0f520e7e5dcff604a37ac +size 4194321 diff --git a/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/output_2.pb b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/output_2.pb new file mode 100644 index 000000000..3e3911a81 --- /dev/null +++ b/testdata/unittest_models/attn_llama2_4_36_0/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c04d8c70045f317d2e4603ea3737a2e82bdb4335a7e13bea37c09ee2ba8e9388 +size 4194319 diff --git a/testdata/unittest_models/attn_phi_1_5_0/attn_phi_1_5_0.onnx b/testdata/unittest_models/attn_phi_1_5_0/attn_phi_1_5_0.onnx new file mode 100644 index 000000000..de73964d6 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/attn_phi_1_5_0.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6620dce87374e515078354c52059ee5f4fb80e6f3f755dca3736d70097171f5a +size 1131586 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_0.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_0.pb new file mode 100644 index 000000000..47be63d03 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5941ca97ca1ca39a863353b19409ddca9bfda4da0e1f471e82e449207f5afcf7 +size 2097167 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_1.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_10.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_10.pb new file mode 100644 index 000000000..ca5368e7a --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_10.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c500808f88f0c089b75163ca6677921f617e391736f13365471b3dc6fe199e84 +size 131083 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_11.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_11.pb new file mode 100644 index 000000000..6f91ec534 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_11.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e44bb9d0e8473fcc87d68760c92df021cae64b82fc0f73507e38ed8de2096aa +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_12.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_12.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_12.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_2.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_2.pb new file mode 100644 index 000000000..ed46c3425 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d77af46feb251c72b7df238f10c3f2e3bd9baf84813dd0efd415105948e2adb8 +size 524304 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_3.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_3.pb new file mode 100644 index 000000000..110ff2868 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83c167b37147b1773b7aaa680be07baae7c3dd7adbf7c0eb2e33e15deecbe97e +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_4.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_4.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_5.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_5.pb new file mode 100644 index 000000000..9b095158b --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1abc4e3deaa0e5c79c2e0501e82dd8d42af85e79c9ba9d12fee672e8b506f785 +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_6.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_6.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_7.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_7.pb new file mode 100644 index 000000000..474ff7950 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77b2a9c58e03d0cea4a130e2c93af40ca5522e809b505a788116c42f10a0bb6e +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_8.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_8.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_8.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_9.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_9.pb new file mode 100644 index 000000000..6bef11c83 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/input_9.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df68178de6799d7bd70879efd224c66eab433d71cd7396bd66f046ed5902bf6c +size 131083 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/output_0.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/output_0.pb new file mode 100644 index 000000000..1f194c2fb --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5ea134d4b49c784a9b44cce0e6396515cb0cde3e5c5e44f3d93e7ee8aa81146 +size 2097168 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/output_1.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/output_1.pb new file mode 100644 index 000000000..4ac6d56bd --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73f8e4c8da4389e456f9685732550ac52575e2789799837c36106b7564f376a9 +size 2097168 diff --git a/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/output_2.pb b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/output_2.pb new file mode 100644 index 000000000..50d3f20a7 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_0/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e08d3f65d45c98eaede04a47f5e1fec8917d19f8c84ad8ae720f29a7ca1bc26a +size 2097167 diff --git a/testdata/unittest_models/attn_phi_1_5_1/attn_phi_1_5_1.onnx b/testdata/unittest_models/attn_phi_1_5_1/attn_phi_1_5_1.onnx new file mode 100644 index 000000000..99a93a38c --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/attn_phi_1_5_1.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3e631386757f0ecc674dc955c2e04922f82bbf9d86edb9f5043eba9b801ad1d +size 1131684 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_0.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_0.pb new file mode 100644 index 000000000..88c6457dc --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9658a3943b84436266dffac8e1c3db88d3789da3f9311984c9160f12874e7c1c +size 2097167 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_1.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_10.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_10.pb new file mode 100644 index 000000000..ca5368e7a --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_10.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c500808f88f0c089b75163ca6677921f617e391736f13365471b3dc6fe199e84 +size 131083 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_11.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_11.pb new file mode 100644 index 000000000..59bbfd554 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_11.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b381fb61a0f133b247d2b48fc9ae6d4539b19b73104b64784ba6c85527cf8fc9 +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_12.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_12.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_12.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_2.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_2.pb new file mode 100644 index 000000000..ed46c3425 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d77af46feb251c72b7df238f10c3f2e3bd9baf84813dd0efd415105948e2adb8 +size 524304 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_3.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_3.pb new file mode 100644 index 000000000..1a2326f6a --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:590b7408923d0def097d718cf3d9128a81c7f2fd17058730752d6a891bf6ced6 +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_4.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_4.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_5.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_5.pb new file mode 100644 index 000000000..262876bf1 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6262fda6a541a87aa383723ba045bef516dd5fe7a52619bdd59e9fe2ef4c9c1f +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_6.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_6.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_7.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_7.pb new file mode 100644 index 000000000..da8df1b6c --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d984f4273cd80fc37ff15e4bb08de8562059fdfdeffbf74c419015c729af6a1c +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_8.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_8.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_8.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_9.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_9.pb new file mode 100644 index 000000000..6bef11c83 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/input_9.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df68178de6799d7bd70879efd224c66eab433d71cd7396bd66f046ed5902bf6c +size 131083 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/output_0.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/output_0.pb new file mode 100644 index 000000000..b1224983b --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43f5c6c09e8e2db1e20152aa0ae22aed6ab516a99dabb33f1029b83bb42706fc +size 2097168 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/output_1.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/output_1.pb new file mode 100644 index 000000000..7d5e19338 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff1c4e67ccb447e6031ceba071b505992aae7c45165b420a93cd981904b5b1ce +size 2097168 diff --git a/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/output_2.pb b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/output_2.pb new file mode 100644 index 000000000..8f741c919 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_1/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b031bddd7df63733371c82e48278c1f42d5051759a081981c9b86a1c74accabd +size 2097167 diff --git a/testdata/unittest_models/attn_phi_1_5_2/attn_phi_1_5_2.onnx b/testdata/unittest_models/attn_phi_1_5_2/attn_phi_1_5_2.onnx new file mode 100644 index 000000000..809547867 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/attn_phi_1_5_2.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2abfbaa33ecfad724c84ace6853a7791d9224e61115b044e7d3e68bd09721e1d +size 1131742 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_0.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_0.pb new file mode 100644 index 000000000..276e97ecf --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3236b6abb79011d00aae3042eb9b17df2682011ae37746b8919d9c0cbf869b3 +size 2097167 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_1.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_10.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_10.pb new file mode 100644 index 000000000..ca5368e7a --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_10.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c500808f88f0c089b75163ca6677921f617e391736f13365471b3dc6fe199e84 +size 131083 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_11.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_11.pb new file mode 100644 index 000000000..493c1c7fb --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_11.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:41d90523a8dcccc88d0443c0e7bd5d4cb8b85b4548910dfea63e5c9bc0cef70b +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_12.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_12.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_12.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_2.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_2.pb new file mode 100644 index 000000000..ed46c3425 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d77af46feb251c72b7df238f10c3f2e3bd9baf84813dd0efd415105948e2adb8 +size 524304 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_3.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_3.pb new file mode 100644 index 000000000..5e6c7fc64 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:38176fdc53d14638c91e46de47f3f3472bc3145e1782a1cd8c0b5decc1afe3d3 +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_4.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_4.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_5.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_5.pb new file mode 100644 index 000000000..c8f49f0f5 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a6ce1210a9fc318fd9578c3ac94632114bf7fb67eb701620b60b125fc7b9231 +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_6.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_6.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_7.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_7.pb new file mode 100644 index 000000000..48d348808 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:82747fcde0ec8dcbc77dcad5b8917502b37ea85b3cef04cec3fbec1ffed12813 +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_8.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_8.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_8.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_9.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_9.pb new file mode 100644 index 000000000..6bef11c83 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/input_9.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df68178de6799d7bd70879efd224c66eab433d71cd7396bd66f046ed5902bf6c +size 131083 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/output_0.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/output_0.pb new file mode 100644 index 000000000..97e9bdaf2 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae3b921f6da0f0afd4e5780d2179576c6567b8dabae0909161efee85bb5cd1f3 +size 2097168 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/output_1.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/output_1.pb new file mode 100644 index 000000000..e0eeef072 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb1c41ef18b7bcba3eb2e2f4ffb52bed556b0c90287ea07c8e5d6fcc059282ca +size 2097168 diff --git a/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/output_2.pb b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/output_2.pb new file mode 100644 index 000000000..b8b61b0bd --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_2/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6299375755089f487d590d4e989a644920b6f184b038fd94f5c7c4ee372af003 +size 2097167 diff --git a/testdata/unittest_models/attn_phi_1_5_3/attn_phi_1_5_3.onnx b/testdata/unittest_models/attn_phi_1_5_3/attn_phi_1_5_3.onnx new file mode 100644 index 000000000..eeaae0164 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/attn_phi_1_5_3.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd0873efdc58c3d2eb549c86ea68ec909585aeec4fd18b32c8281ae217f71095 +size 1131758 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_0.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_0.pb new file mode 100644 index 000000000..a5bcbf498 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a68df946202c3c8d04414d13fd344f57fbf5e9b7190e3918b66672e88c5be127 +size 2097167 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_1.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_10.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_10.pb new file mode 100644 index 000000000..ca5368e7a --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_10.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c500808f88f0c089b75163ca6677921f617e391736f13365471b3dc6fe199e84 +size 131083 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_11.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_11.pb new file mode 100644 index 000000000..e82afd3de --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_11.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d06afa33b56aaa12620a9881a9665fe75509f3efe7cb5ff1ae7987ffb7e6ccc +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_12.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_12.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_12.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_2.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_2.pb new file mode 100644 index 000000000..ed46c3425 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d77af46feb251c72b7df238f10c3f2e3bd9baf84813dd0efd415105948e2adb8 +size 524304 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_3.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_3.pb new file mode 100644 index 000000000..cdc9e0816 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:220b8f93f6f0337545121af549580672514477098edb64d6a63253c82c807cb9 +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_4.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_4.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_5.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_5.pb new file mode 100644 index 000000000..d07734ea9 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60fca48da78a1feafb31d4b3c3bbedec56a5b0d99c47f35c2c080338e80e780f +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_6.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_6.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_7.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_7.pb new file mode 100644 index 000000000..8b373bd3a --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84e261119d5c0ef87d12263774805937b4ec9848e78cc08c042dae142764cbfe +size 8388621 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_8.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_8.pb new file mode 100644 index 000000000..0c69c5832 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_8.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:393e3d65cf8454705fdc60b67a114edf6f2608c5f16e7f92292cff4b9c7d623d +size 4104 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_9.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_9.pb new file mode 100644 index 000000000..6bef11c83 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/input_9.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df68178de6799d7bd70879efd224c66eab433d71cd7396bd66f046ed5902bf6c +size 131083 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/output_0.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/output_0.pb new file mode 100644 index 000000000..1b5f00f58 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b05aac6e6f39644f01328190b4cda3be2f735125b13b929a1cba9ab9bef86a99 +size 2097168 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/output_1.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/output_1.pb new file mode 100644 index 000000000..1ecd5e5bb --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4cb6e3668905e8136ae1a4523295f15eb4b2db9b9c4844e0766cfa62b46d762 +size 2097168 diff --git a/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/output_2.pb b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/output_2.pb new file mode 100644 index 000000000..839f6fa85 --- /dev/null +++ b/testdata/unittest_models/attn_phi_1_5_3/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:19c82d7dac70715375b0055e8e480304ce344217d76734b3719f942c06525ca1 +size 2097167 diff --git a/testdata/unittest_models/attn_yi_4_37_0/attn_yi_4_37_0.onnx b/testdata/unittest_models/attn_yi_4_37_0/attn_yi_4_37_0.onnx new file mode 100644 index 000000000..baa2c1047 --- /dev/null +++ b/testdata/unittest_models/attn_yi_4_37_0/attn_yi_4_37_0.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:22df78044fab14e44844aeb62cbf81325ea638e7570d6e12c1a6da9d32b592d1 +size 806387 diff --git a/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_0.pb b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_0.pb new file mode 100644 index 000000000..e064b76fe --- /dev/null +++ b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e82a1c4ff1754d1fe9e470b56c14e91ba6837713c840d8a51b73a0c797f3823 +size 4194319 diff --git a/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_1.pb b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_2.pb b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_2.pb new file mode 100644 index 000000000..ed46c3425 --- /dev/null +++ b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d77af46feb251c72b7df238f10c3f2e3bd9baf84813dd0efd415105948e2adb8 +size 524304 diff --git a/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_3.pb b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_3.pb new file mode 100644 index 000000000..5296fd63b --- /dev/null +++ b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe0ecdda5d8aa3579ba6a1d3a70e9527a0c9482eb3f28c06771fe9fff6780731 +size 33554445 diff --git a/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_4.pb b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_4.pb new file mode 100644 index 000000000..97d69e8d8 --- /dev/null +++ b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b35459131cf1408e9a34589ce199cd7505c77ca1b6c36f387abddf025af2c4f6 +size 4194317 diff --git a/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_5.pb b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_5.pb new file mode 100644 index 000000000..b1cb0c1d6 --- /dev/null +++ b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4cc6c3459eaf88b7f4666ab404d9a834e76484d3436246ccfcf0d2fa9886951 +size 4194317 diff --git a/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_6.pb b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_6.pb new file mode 100644 index 000000000..318eb241e --- /dev/null +++ b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f1e30a65367a08ab0718ad3516bc12c2f033ff8899fc119c19b6165ea48c2be +size 1048588 diff --git a/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_7.pb b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_7.pb new file mode 100644 index 000000000..844cf7887 --- /dev/null +++ b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e17201583d346161ca6872298552ab05e0c3d0e1cecbc67adba0848a240cdfc4 +size 1048588 diff --git a/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_8.pb b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_8.pb new file mode 100644 index 000000000..8f48f067d --- /dev/null +++ b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/input_8.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:323a3642f548314af50de9c5e88e25de7aeb435ed10cd44086bbc55f44d50ebe +size 33554445 diff --git a/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/output_0.pb b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/output_0.pb new file mode 100644 index 000000000..b6409905c --- /dev/null +++ b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a12d32a2fc58002ec4d17364d4e4584ae3fc9605fd916fa0a5ed5981ae20560e +size 524304 diff --git a/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/output_1.pb b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/output_1.pb new file mode 100644 index 000000000..ac643f22b --- /dev/null +++ b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71e19eb3a6b3380b73d5bb7e20305ac26e956adf6008b0a4538ebdde7e2724a2 +size 524304 diff --git a/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/output_2.pb b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/output_2.pb new file mode 100644 index 000000000..53deb7d04 --- /dev/null +++ b/testdata/unittest_models/attn_yi_4_37_0/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9aaca1e11d01b536c0f6246f1850cfc07ecf6c02c319253ccfe9b267bdbce86 +size 4194319 diff --git a/testdata/unittest_models/gelu_phi_1_5_0/gelu_phi_1_5_0.onnx b/testdata/unittest_models/gelu_phi_1_5_0/gelu_phi_1_5_0.onnx new file mode 100644 index 000000000..88cd18af0 --- /dev/null +++ b/testdata/unittest_models/gelu_phi_1_5_0/gelu_phi_1_5_0.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73ad4152d9d07634ae1a0b4972ded7a8ae2610fa88279b4383d2ca57fbfe7ad2 +size 3122 diff --git a/testdata/unittest_models/gelu_phi_1_5_0/test_data_set_0/input_0.pb b/testdata/unittest_models/gelu_phi_1_5_0/test_data_set_0/input_0.pb new file mode 100644 index 000000000..18df85a17 --- /dev/null +++ b/testdata/unittest_models/gelu_phi_1_5_0/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1eb70478bc7ff108daa03523ba309fd7ba6029160a67360b4e5bd8b3015e41fb +size 8388623 diff --git a/testdata/unittest_models/gelu_phi_1_5_0/test_data_set_0/output_0.pb b/testdata/unittest_models/gelu_phi_1_5_0/test_data_set_0/output_0.pb new file mode 100644 index 000000000..197345b87 --- /dev/null +++ b/testdata/unittest_models/gelu_phi_1_5_0/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9ab6fa6a35446500dff970a00ee60dee0791664c09e33ee9c19eaa246983ac1 +size 8388623 diff --git a/testdata/unittest_models/gelu_phi_1_5_1/gelu_phi_1_5_1.onnx b/testdata/unittest_models/gelu_phi_1_5_1/gelu_phi_1_5_1.onnx new file mode 100644 index 000000000..646ab988a --- /dev/null +++ b/testdata/unittest_models/gelu_phi_1_5_1/gelu_phi_1_5_1.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d4cf937ac9813a7d936cd12666c222f5a256c604ad74b4aa8047e2756434256 +size 3147 diff --git a/testdata/unittest_models/gelu_phi_1_5_1/test_data_set_0/input_0.pb b/testdata/unittest_models/gelu_phi_1_5_1/test_data_set_0/input_0.pb new file mode 100644 index 000000000..ffb30d52f --- /dev/null +++ b/testdata/unittest_models/gelu_phi_1_5_1/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7af424dbc83aee3fbd3e181b1b17d2d5dbc96a969687aef01830820b4e451242 +size 8388623 diff --git a/testdata/unittest_models/gelu_phi_1_5_1/test_data_set_0/output_0.pb b/testdata/unittest_models/gelu_phi_1_5_1/test_data_set_0/output_0.pb new file mode 100644 index 000000000..d68fa2a12 --- /dev/null +++ b/testdata/unittest_models/gelu_phi_1_5_1/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f9692832fc46ed0f14fc0986be75579e55abde433a831b256f9bd1be6b20a06f +size 8388623 diff --git a/testdata/unittest_models/gelu_phi_1_5_2/gelu_phi_1_5_2.onnx b/testdata/unittest_models/gelu_phi_1_5_2/gelu_phi_1_5_2.onnx new file mode 100644 index 000000000..816e914c9 --- /dev/null +++ b/testdata/unittest_models/gelu_phi_1_5_2/gelu_phi_1_5_2.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a48fa9ebb495c4bec7ed175bdd1d8c639d727528ad2eb0de4f4b53128306a36 +size 3159 diff --git a/testdata/unittest_models/gelu_phi_1_5_2/test_data_set_0/input_0.pb b/testdata/unittest_models/gelu_phi_1_5_2/test_data_set_0/input_0.pb new file mode 100644 index 000000000..3fa6f316a --- /dev/null +++ b/testdata/unittest_models/gelu_phi_1_5_2/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9f0098eeac6bdefa568d9f392111c60fe2ccd6e5613c244912d3cd175b18a61 +size 8388623 diff --git a/testdata/unittest_models/gelu_phi_1_5_2/test_data_set_0/output_0.pb b/testdata/unittest_models/gelu_phi_1_5_2/test_data_set_0/output_0.pb new file mode 100644 index 000000000..49ac37ad6 --- /dev/null +++ b/testdata/unittest_models/gelu_phi_1_5_2/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0611a4cf6b215e467fd1021749d7b9f9c6bcfcdc0fe071fdd0a9a0d8166a66c6 +size 8388623 diff --git a/testdata/unittest_models/gelu_phi_1_5_3/gelu_phi_1_5_3.onnx b/testdata/unittest_models/gelu_phi_1_5_3/gelu_phi_1_5_3.onnx new file mode 100644 index 000000000..c359da5cf --- /dev/null +++ b/testdata/unittest_models/gelu_phi_1_5_3/gelu_phi_1_5_3.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de65cdeb463aa1844ff4bc0da8ed5051be1e7eac42af5250f324384c28f4f801 +size 3163 diff --git a/testdata/unittest_models/gelu_phi_1_5_3/test_data_set_0/input_0.pb b/testdata/unittest_models/gelu_phi_1_5_3/test_data_set_0/input_0.pb new file mode 100644 index 000000000..989d4df71 --- /dev/null +++ b/testdata/unittest_models/gelu_phi_1_5_3/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:11233c9d4a4f43e93b5d046346c992da00a49a7db7ced1179ce70b44df5f3664 +size 8388623 diff --git a/testdata/unittest_models/gelu_phi_1_5_3/test_data_set_0/output_0.pb b/testdata/unittest_models/gelu_phi_1_5_3/test_data_set_0/output_0.pb new file mode 100644 index 000000000..ef4f1c822 --- /dev/null +++ b/testdata/unittest_models/gelu_phi_1_5_3/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a4ee9b1cef233ad5bd6ed49d01cd369fee05024e41662dcf2f20b5273784d98 +size 8388623 diff --git a/testdata/unittest_models/ln_llama2_0/ln_llama2_0.onnx b/testdata/unittest_models/ln_llama2_0/ln_llama2_0.onnx new file mode 100644 index 000000000..da0b966bd --- /dev/null +++ b/testdata/unittest_models/ln_llama2_0/ln_llama2_0.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7bb94d323af9d7c16e5b1e16e1775c906ecac60b10387057a8f34c8f07c31c88 +size 3618 diff --git a/testdata/unittest_models/ln_llama2_0/test_data_set_0/input_0.pb b/testdata/unittest_models/ln_llama2_0/test_data_set_0/input_0.pb new file mode 100644 index 000000000..74b225508 --- /dev/null +++ b/testdata/unittest_models/ln_llama2_0/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2baee9fd173155738cfd0629f851e14d134d677f28d87e856905f1370391f0e3 +size 4194319 diff --git a/testdata/unittest_models/ln_llama2_0/test_data_set_0/input_1.pb b/testdata/unittest_models/ln_llama2_0/test_data_set_0/input_1.pb new file mode 100644 index 000000000..8ce5ce277 --- /dev/null +++ b/testdata/unittest_models/ln_llama2_0/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bca889d73daeeecbcbb56a6c58f65a7538715b65de5f55ee505c96cd1fff121c +size 8200 diff --git a/testdata/unittest_models/ln_llama2_0/test_data_set_0/output_0.pb b/testdata/unittest_models/ln_llama2_0/test_data_set_0/output_0.pb new file mode 100644 index 000000000..2ff266b26 --- /dev/null +++ b/testdata/unittest_models/ln_llama2_0/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:707e2139fa3b9db89e8667ad7b1829fbf1c05c678f064dbbdc855e77cee50b32 +size 4194319 diff --git a/testdata/unittest_models/ln_llama2_1/ln_llama2_1.onnx b/testdata/unittest_models/ln_llama2_1/ln_llama2_1.onnx new file mode 100644 index 000000000..ca0c7da93 --- /dev/null +++ b/testdata/unittest_models/ln_llama2_1/ln_llama2_1.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0ef31dabf73e8517de32e8abb0523558a147ae8f35c914a09e75713a814e7d77 +size 3675 diff --git a/testdata/unittest_models/ln_llama2_1/test_data_set_0/input_0.pb b/testdata/unittest_models/ln_llama2_1/test_data_set_0/input_0.pb new file mode 100644 index 000000000..0effc28f3 --- /dev/null +++ b/testdata/unittest_models/ln_llama2_1/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f7588cb2aee5caf1f67ecf93e1d8742405ca4094bd931785d8b277c389879a7 +size 4194319 diff --git a/testdata/unittest_models/ln_llama2_1/test_data_set_0/input_1.pb b/testdata/unittest_models/ln_llama2_1/test_data_set_0/input_1.pb new file mode 100644 index 000000000..8ce5ce277 --- /dev/null +++ b/testdata/unittest_models/ln_llama2_1/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bca889d73daeeecbcbb56a6c58f65a7538715b65de5f55ee505c96cd1fff121c +size 8200 diff --git a/testdata/unittest_models/ln_llama2_1/test_data_set_0/output_0.pb b/testdata/unittest_models/ln_llama2_1/test_data_set_0/output_0.pb new file mode 100644 index 000000000..e690fce90 --- /dev/null +++ b/testdata/unittest_models/ln_llama2_1/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62e46bc9154924a08494a7968e62f169db7fd1c8f26ba87e0d446712dfae1ea4 +size 4194319 diff --git a/testdata/unittest_models/ln_llama2_2/ln_llama2_2.onnx b/testdata/unittest_models/ln_llama2_2/ln_llama2_2.onnx new file mode 100644 index 000000000..f1bf549dc --- /dev/null +++ b/testdata/unittest_models/ln_llama2_2/ln_llama2_2.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60664ba124291396bf97dd4fcaac75614b9d0f04826fde2f483337d0704527f2 +size 3630 diff --git a/testdata/unittest_models/ln_llama2_2/test_data_set_0/input_0.pb b/testdata/unittest_models/ln_llama2_2/test_data_set_0/input_0.pb new file mode 100644 index 000000000..aa0e2a6bb --- /dev/null +++ b/testdata/unittest_models/ln_llama2_2/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62d5182ff70cdeb1e82bccb6fcfe1d6a54151586c35645d62e0c5211fd48d125 +size 4194319 diff --git a/testdata/unittest_models/ln_llama2_2/test_data_set_0/input_1.pb b/testdata/unittest_models/ln_llama2_2/test_data_set_0/input_1.pb new file mode 100644 index 000000000..8ce5ce277 --- /dev/null +++ b/testdata/unittest_models/ln_llama2_2/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bca889d73daeeecbcbb56a6c58f65a7538715b65de5f55ee505c96cd1fff121c +size 8200 diff --git a/testdata/unittest_models/ln_llama2_2/test_data_set_0/output_0.pb b/testdata/unittest_models/ln_llama2_2/test_data_set_0/output_0.pb new file mode 100644 index 000000000..3c60034a6 --- /dev/null +++ b/testdata/unittest_models/ln_llama2_2/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fdace10343a44a6c0d71d2bfb3a0f9c27df1f787120eb40e62d5f10c967a187f +size 4194319 diff --git a/testdata/unittest_models/ln_llama2_3/ln_llama2_3.onnx b/testdata/unittest_models/ln_llama2_3/ln_llama2_3.onnx new file mode 100644 index 000000000..01b3b6d9f --- /dev/null +++ b/testdata/unittest_models/ln_llama2_3/ln_llama2_3.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:afa920c4388c2ca721e796b155e3b334a03471a49c6a0a7098827863241e1851 +size 3695 diff --git a/testdata/unittest_models/ln_llama2_3/test_data_set_0/input_0.pb b/testdata/unittest_models/ln_llama2_3/test_data_set_0/input_0.pb new file mode 100644 index 000000000..923ff8b77 --- /dev/null +++ b/testdata/unittest_models/ln_llama2_3/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7804812f1235a04652e679dbbcc700fad2639a546ec6fe01e72c781bab69a1c6 +size 4194319 diff --git a/testdata/unittest_models/ln_llama2_3/test_data_set_0/input_1.pb b/testdata/unittest_models/ln_llama2_3/test_data_set_0/input_1.pb new file mode 100644 index 000000000..8ce5ce277 --- /dev/null +++ b/testdata/unittest_models/ln_llama2_3/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bca889d73daeeecbcbb56a6c58f65a7538715b65de5f55ee505c96cd1fff121c +size 8200 diff --git a/testdata/unittest_models/ln_llama2_3/test_data_set_0/output_0.pb b/testdata/unittest_models/ln_llama2_3/test_data_set_0/output_0.pb new file mode 100644 index 000000000..7fe45cd17 --- /dev/null +++ b/testdata/unittest_models/ln_llama2_3/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e674beb11054441c229ddb7a027c64392411df0cdf4257738483ed79a4110aa +size 4194319 diff --git a/testdata/unittest_models/sdpa_llama2_0/sdpa_llama2_0.onnx b/testdata/unittest_models/sdpa_llama2_0/sdpa_llama2_0.onnx new file mode 100644 index 000000000..fb18e3900 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_0/sdpa_llama2_0.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a16731ed446341ee45d728881d0e1605464a3f050469a8eef47a0b70615df462 +size 798936 diff --git a/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_0.pb b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_0.pb new file mode 100644 index 000000000..efa4744e8 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:778b4cdb9a8da368ed72d53e069a8ddf9a0830d51e21d6d3633824ea5d3a1e59 +size 4194319 diff --git a/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_1.pb b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_2.pb b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_2.pb new file mode 100644 index 000000000..4e34e0e7a --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dcbb1bfe97bb59144e109dc32f00dbc7c6782691da11e841b150775d73c0e4e9 +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_3.pb b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_3.pb new file mode 100644 index 000000000..350bb29f9 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c21e9d621732c18df7a97d6f0e09c040ad25e59921265c5b5f2865166dccbd0d +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_4.pb b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_4.pb new file mode 100644 index 000000000..dbe4dc5cc --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:373639b07d8966f9aaf2ad40bd21ecf1b501fbce44982324a66b3571fd87c04e +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_5.pb b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_5.pb new file mode 100644 index 000000000..f29b63640 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a44bf5864390eef1f1b54f464b619a7a283041dd6ce5da65150af5c708051f9 +size 524300 diff --git a/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_6.pb b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_6.pb new file mode 100644 index 000000000..634e449a0 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d15975f1d254f650321ccdd14c3157d9eac1ca2cf3567e6c159bcdd278488f82 +size 524300 diff --git a/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_7.pb b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_7.pb new file mode 100644 index 000000000..066bfbc57 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b38092781e2fdbd15794b253e331285aed57305e1639f8fb0da83fedf6f1110a +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/output_0.pb b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/output_0.pb new file mode 100644 index 000000000..7b36ba4c8 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b6872fcb8348cce2cd37199979c9ef5fa046660f8c5f30ac68167c130b02283 +size 4194321 diff --git a/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/output_1.pb b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/output_1.pb new file mode 100644 index 000000000..db181af5a --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4ecc3c6db5513df9e1d36cc96d671e1e1606fc26f10f3a05b0229d13f962306 +size 4194321 diff --git a/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/output_2.pb b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/output_2.pb new file mode 100644 index 000000000..0f8a16d0c --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_0/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:847df9acbc6d8ae0542420c32fb0b7eaee3c4f460b7b0f850a405586e9bac479 +size 4194319 diff --git a/testdata/unittest_models/sdpa_llama2_1/sdpa_llama2_1.onnx b/testdata/unittest_models/sdpa_llama2_1/sdpa_llama2_1.onnx new file mode 100644 index 000000000..780862193 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_1/sdpa_llama2_1.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:530265630aea82f5d40955b0351abe39abcf81ac88946d92b5be91c28da71ef0 +size 799028 diff --git a/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_0.pb b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_0.pb new file mode 100644 index 000000000..89d2d84bf --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f438ee4996790f71091a1d411b27691a80eedc365b2b019a0cfedb05b9adc2f +size 4194319 diff --git a/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_1.pb b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_2.pb b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_2.pb new file mode 100644 index 000000000..058613bb2 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:93dd9424a7dcffdeb1dc9323215afa04394d885a96f02168336b1499a5a2ce7b +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_3.pb b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_3.pb new file mode 100644 index 000000000..42fbd294d --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21d7d26f218433bf1ce7d00dae722a55ded01e2f2215b00111c87912a1f25493 +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_4.pb b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_4.pb new file mode 100644 index 000000000..9a4bd955f --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34f5cb05c1e0b2f205c97ee0bc2fb6ca8660dd6f6d2e2ccc425d09043fdb9b08 +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_5.pb b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_5.pb new file mode 100644 index 000000000..f29b63640 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a44bf5864390eef1f1b54f464b619a7a283041dd6ce5da65150af5c708051f9 +size 524300 diff --git a/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_6.pb b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_6.pb new file mode 100644 index 000000000..634e449a0 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d15975f1d254f650321ccdd14c3157d9eac1ca2cf3567e6c159bcdd278488f82 +size 524300 diff --git a/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_7.pb b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_7.pb new file mode 100644 index 000000000..4dbbace8d --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86b6d05cdabf287de5bd1a0100b943cc0a3279ae1e68b692c27071feeb3a7ded +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/output_0.pb b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/output_0.pb new file mode 100644 index 000000000..bc231f072 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16c69fdb0853b67c0b8db326391ee74b70708feecd1150e628abf25b4a080b42 +size 4194321 diff --git a/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/output_1.pb b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/output_1.pb new file mode 100644 index 000000000..b7d55aa7b --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a2a7d6d62d0b7cbd671f9003b66b65debe5d3100bb43a4342ac87fc9de0cf27f +size 4194321 diff --git a/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/output_2.pb b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/output_2.pb new file mode 100644 index 000000000..11f2afb84 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_1/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:512c99995fb05e846f2f437e1d02f4594f53536986bfeb660e88bb97994ec5cc +size 4194319 diff --git a/testdata/unittest_models/sdpa_llama2_2/sdpa_llama2_2.onnx b/testdata/unittest_models/sdpa_llama2_2/sdpa_llama2_2.onnx new file mode 100644 index 000000000..abe04edbb --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_2/sdpa_llama2_2.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13fb16901400a2ab02aefbd0478f49cf71a0047140b7c5b91ad1ccfdf11871f8 +size 799058 diff --git a/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_0.pb b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_0.pb new file mode 100644 index 000000000..d15a270e6 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc8d9eea571158d4c0926be5ecbb9b749d6fdd4a9f48d9e5fb04ebb4d343f07c +size 4194319 diff --git a/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_1.pb b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_2.pb b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_2.pb new file mode 100644 index 000000000..4bbd9afd9 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d75cc8150f705405f5a98b6ecbf49de0ca0a732124e60a56e4cea9b8c0aa73e +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_3.pb b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_3.pb new file mode 100644 index 000000000..eae67d62f --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:33c01d9c6ef6427e8a680a810dc17944efe9aec5bccd0a67ba38abb2213f2d21 +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_4.pb b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_4.pb new file mode 100644 index 000000000..be489c280 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c48cafb2ba842be9bb90f206056b4b78812cebfa3f66986dca0a697aff4ac8dd +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_5.pb b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_5.pb new file mode 100644 index 000000000..f29b63640 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a44bf5864390eef1f1b54f464b619a7a283041dd6ce5da65150af5c708051f9 +size 524300 diff --git a/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_6.pb b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_6.pb new file mode 100644 index 000000000..634e449a0 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d15975f1d254f650321ccdd14c3157d9eac1ca2cf3567e6c159bcdd278488f82 +size 524300 diff --git a/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_7.pb b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_7.pb new file mode 100644 index 000000000..6d2cc784c --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:101c20fe6e507456307fb25c583672fd163444df7468856823523d09427fc1b9 +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/output_0.pb b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/output_0.pb new file mode 100644 index 000000000..2bcbd560b --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97fa244b848212ffb598698dc6c6d865bcb80234946c804e751fd3e8739520af +size 4194321 diff --git a/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/output_1.pb b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/output_1.pb new file mode 100644 index 000000000..d58322566 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:683587dcabb2862944bea11c6b55e950b7e5c0c7732daf8e3fe8b05a7e015453 +size 4194321 diff --git a/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/output_2.pb b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/output_2.pb new file mode 100644 index 000000000..705b3495c --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_2/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5972eb073f74f7599e9c0fdc348d2196fe5f448787d923346ad3bdd55752ed16 +size 4194319 diff --git a/testdata/unittest_models/sdpa_llama2_3/sdpa_llama2_3.onnx b/testdata/unittest_models/sdpa_llama2_3/sdpa_llama2_3.onnx new file mode 100644 index 000000000..ac29d2cbf --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_3/sdpa_llama2_3.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a5866fa751bae100e7742fd1d2dbd0530e962ba9de0555d41bbf645967bbc3a3 +size 799066 diff --git a/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_0.pb b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_0.pb new file mode 100644 index 000000000..8dfb22350 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37b6432a636da58197893e78a34b7db33770f165ee4534b3c6ff1023d7747d2a +size 4194319 diff --git a/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_1.pb b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_2.pb b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_2.pb new file mode 100644 index 000000000..d6c3b6c3e --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aded299cc07e16816e2641e7245845296233948b432394894ee89c8f0e4a5173 +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_3.pb b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_3.pb new file mode 100644 index 000000000..c726ed1c9 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c57896892bda7fc78d66afd6f2af2f229912ca02e662192fe0d4629e094f405 +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_4.pb b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_4.pb new file mode 100644 index 000000000..f62cc0dfb --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:efa03ee5d35e6c46164b0ecad3045c0f09a282e55ce71d281e3866babdc57b30 +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_5.pb b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_5.pb new file mode 100644 index 000000000..f29b63640 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a44bf5864390eef1f1b54f464b619a7a283041dd6ce5da65150af5c708051f9 +size 524300 diff --git a/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_6.pb b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_6.pb new file mode 100644 index 000000000..634e449a0 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d15975f1d254f650321ccdd14c3157d9eac1ca2cf3567e6c159bcdd278488f82 +size 524300 diff --git a/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_7.pb b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_7.pb new file mode 100644 index 000000000..e5e15dd0b --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8a47b688521256d24da5f1c8da624ddd439cc401755ea6e8b84ede14fe9b2e33 +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/output_0.pb b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/output_0.pb new file mode 100644 index 000000000..7289670b5 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfebdb51021c7fa776577e6629adea74d46a0e9a7ca76604e7433c054457ed55 +size 4194321 diff --git a/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/output_1.pb b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/output_1.pb new file mode 100644 index 000000000..8c0a20af1 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9d1313b293d4c8fea2622f3ce4350bc420394af9e9c7fd700b8476341f391d92 +size 4194321 diff --git a/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/output_2.pb b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/output_2.pb new file mode 100644 index 000000000..c049f2c0e --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_3/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:51408621625025e6fb828f9712a9da6f56d9d6e50d2660d38c168abbca4a83d8 +size 4194319 diff --git a/testdata/unittest_models/sdpa_llama2_4_38_0/sdpa_llama2_4_38_0.onnx b/testdata/unittest_models/sdpa_llama2_4_38_0/sdpa_llama2_4_38_0.onnx new file mode 100644 index 000000000..2ee6f5979 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_4_38_0/sdpa_llama2_4_38_0.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:32a29a23df6eb1dda62e073d80d0b42aa4d9506ac0d732a5857870f8282dc926 +size 274363 diff --git a/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_0.pb b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_0.pb new file mode 100644 index 000000000..e60e700f6 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dcf6c8416670332de11f3fe6efc6e41197ba7c175394c0fc38a98710bfc4ff43 +size 4194319 diff --git a/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_1.pb b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_2.pb b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_2.pb new file mode 100644 index 000000000..2e24e6df7 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:259cd77fb1aa09aa4b06224754bdcbcecceda18dddb98dbb09301c79c0c06409 +size 8388625 diff --git a/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_3.pb b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_3.pb new file mode 100644 index 000000000..331a38944 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20c03c1962748bb752bcf3e44c719f0c4723ce2f0e2d4654566dd858b3fe90e0 +size 4104 diff --git a/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_4.pb b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_4.pb new file mode 100644 index 000000000..42b063627 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a994e919c415c5b8acdca27fb87fcd06d7036d646fc0a770929d09024e6d96ae +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_5.pb b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_5.pb new file mode 100644 index 000000000..8c209fce7 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e2fa19eab256f8a81b4757e37f99c49aaba1333a55a2c96fabc989d4de2d0f3 +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_6.pb b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_6.pb new file mode 100644 index 000000000..4a0bd6e9e --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db733c3917275c5341cba7bc2a64b5a2dc1abb1c4c01ae5852031eb94e0e2f0c +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_7.pb b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_7.pb new file mode 100644 index 000000000..a98036afc --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c60c6ca7c4386b6b76700b348fb75d0d42b23463fff73eaa4a106d23d801dde +size 135 diff --git a/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_8.pb b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_8.pb new file mode 100644 index 000000000..a286694f7 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/input_8.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aad340617afb9b743c37a1c6fd08ec2e509c4b22642a81bb60aa29d9f27d5c1c +size 33554445 diff --git a/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/output_0.pb b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/output_0.pb new file mode 100644 index 000000000..cb32480a7 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e85913fc823ad2058c3c339a2aad4160e1fcbcab5a9881c46ca390256c1f17be +size 4194321 diff --git a/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/output_1.pb b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/output_1.pb new file mode 100644 index 000000000..e930a1e3e --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae24404fde2f02499993f7ba93fa23d5b8fc801791a0be85a46251a6722e66b9 +size 4194321 diff --git a/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/output_2.pb b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/output_2.pb new file mode 100644 index 000000000..e21f75468 --- /dev/null +++ b/testdata/unittest_models/sdpa_llama2_4_38_0/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:759f063994330e61e4aa715dc0fb734e7e416c0c822f84b2d695ec73e218b8a1 +size 4194319 diff --git a/testdata/unittest_models/sdpa_yi_0/sdpa_yi_0.onnx b/testdata/unittest_models/sdpa_yi_0/sdpa_yi_0.onnx new file mode 100644 index 000000000..f994e1cf4 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_0/sdpa_yi_0.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea8d7fc76d15329678f93c32d4c4b8c0b1fef583104369daa27d239f3e7f2021 +size 802629 diff --git a/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_0.pb b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_0.pb new file mode 100644 index 000000000..e064b76fe --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e82a1c4ff1754d1fe9e470b56c14e91ba6837713c840d8a51b73a0c797f3823 +size 4194319 diff --git a/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_1.pb b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_2.pb b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_2.pb new file mode 100644 index 000000000..5296fd63b --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe0ecdda5d8aa3579ba6a1d3a70e9527a0c9482eb3f28c06771fe9fff6780731 +size 33554445 diff --git a/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_3.pb b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_3.pb new file mode 100644 index 000000000..97d69e8d8 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b35459131cf1408e9a34589ce199cd7505c77ca1b6c36f387abddf025af2c4f6 +size 4194317 diff --git a/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_4.pb b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_4.pb new file mode 100644 index 000000000..b1cb0c1d6 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4cc6c3459eaf88b7f4666ab404d9a834e76484d3436246ccfcf0d2fa9886951 +size 4194317 diff --git a/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_5.pb b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_5.pb new file mode 100644 index 000000000..318eb241e --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f1e30a65367a08ab0718ad3516bc12c2f033ff8899fc119c19b6165ea48c2be +size 1048588 diff --git a/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_6.pb b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_6.pb new file mode 100644 index 000000000..844cf7887 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e17201583d346161ca6872298552ab05e0c3d0e1cecbc67adba0848a240cdfc4 +size 1048588 diff --git a/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_7.pb b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_7.pb new file mode 100644 index 000000000..8f48f067d --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:323a3642f548314af50de9c5e88e25de7aeb435ed10cd44086bbc55f44d50ebe +size 33554445 diff --git a/testdata/unittest_models/sdpa_yi_0/test_data_set_0/output_0.pb b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/output_0.pb new file mode 100644 index 000000000..b6409905c --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a12d32a2fc58002ec4d17364d4e4584ae3fc9605fd916fa0a5ed5981ae20560e +size 524304 diff --git a/testdata/unittest_models/sdpa_yi_0/test_data_set_0/output_1.pb b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/output_1.pb new file mode 100644 index 000000000..ac643f22b --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71e19eb3a6b3380b73d5bb7e20305ac26e956adf6008b0a4538ebdde7e2724a2 +size 524304 diff --git a/testdata/unittest_models/sdpa_yi_0/test_data_set_0/output_2.pb b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/output_2.pb new file mode 100644 index 000000000..35c5bf2a2 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_0/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0af23d5fe016b5a835af49472c3bc85f1967600bdc37440836332c200467c966 +size 4194319 diff --git a/testdata/unittest_models/sdpa_yi_1/sdpa_yi_1.onnx b/testdata/unittest_models/sdpa_yi_1/sdpa_yi_1.onnx new file mode 100644 index 000000000..309719b25 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_1/sdpa_yi_1.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef0002880be094d3b8401a9154f06ca4a592fa96412550fd80ed905609c565d0 +size 802731 diff --git a/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_0.pb b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_0.pb new file mode 100644 index 000000000..af4339690 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ff146382a59cae41d59167e77d9392b03331945426f34ad34c47a775d661631c +size 4194319 diff --git a/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_1.pb b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_2.pb b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_2.pb new file mode 100644 index 000000000..aa691cb76 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a707fc005c43a316ef7654e5a04fc8c682dbf5559bbb7e8314d9582047f68e51 +size 33554445 diff --git a/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_3.pb b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_3.pb new file mode 100644 index 000000000..48173b436 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:35a747a52ddaa8c95dda2880b79a2947c78c44889b3f7559bf41cf3ae8e06c0a +size 4194317 diff --git a/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_4.pb b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_4.pb new file mode 100644 index 000000000..f181835ea --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f568141851f62cc7887af55a9e63cc4500965f9b21763a60241b9e39f7fcd5c +size 4194317 diff --git a/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_5.pb b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_5.pb new file mode 100644 index 000000000..318eb241e --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f1e30a65367a08ab0718ad3516bc12c2f033ff8899fc119c19b6165ea48c2be +size 1048588 diff --git a/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_6.pb b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_6.pb new file mode 100644 index 000000000..844cf7887 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e17201583d346161ca6872298552ab05e0c3d0e1cecbc67adba0848a240cdfc4 +size 1048588 diff --git a/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_7.pb b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_7.pb new file mode 100644 index 000000000..a62e3ca94 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4bb415c586a70f56152ae09f6e5395c17efd82c6f16b2e86d9d4edb4dd8e5acc +size 33554445 diff --git a/testdata/unittest_models/sdpa_yi_1/test_data_set_0/output_0.pb b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/output_0.pb new file mode 100644 index 000000000..01cb04e29 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b7099a734e5be6876d05506c4062cfa94a6a4782eedb3a5e240f7cb6816a775 +size 524304 diff --git a/testdata/unittest_models/sdpa_yi_1/test_data_set_0/output_1.pb b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/output_1.pb new file mode 100644 index 000000000..f7de77ef7 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d150a541d3723b5973eaadf2116221bff263f0744a271389de271f5fdc21f16 +size 524304 diff --git a/testdata/unittest_models/sdpa_yi_1/test_data_set_0/output_2.pb b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/output_2.pb new file mode 100644 index 000000000..bc01353a1 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_1/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8054c4d9bd772f419dc2ca0c649c441f4eb435df4d84503fde7df648e2a06cb1 +size 4194319 diff --git a/testdata/unittest_models/sdpa_yi_4_38_0/sdpa_yi_4_38_0.onnx b/testdata/unittest_models/sdpa_yi_4_38_0/sdpa_yi_4_38_0.onnx new file mode 100644 index 000000000..6a7b420c0 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_4_38_0/sdpa_yi_4_38_0.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be2a8476c605bad08d6e774d5af77fb64548fc08d4e002b66d91cb00498ceecd +size 278064 diff --git a/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_0.pb b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_0.pb new file mode 100644 index 000000000..e064b76fe --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e82a1c4ff1754d1fe9e470b56c14e91ba6837713c840d8a51b73a0c797f3823 +size 4194319 diff --git a/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_1.pb b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_1.pb new file mode 100644 index 000000000..281b43611 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04286eebdadde98ff59ca4ebcbcb6aad797aa995728cd6b76f5948a8a7586fa9 +size 4106 diff --git a/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_2.pb b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_2.pb new file mode 100644 index 000000000..51a909426 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:199ec6dcdba5dd17b7b36b221b0808db110eab945081f814a6f185e82b51950c +size 33554449 diff --git a/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_3.pb b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_3.pb new file mode 100644 index 000000000..331a38944 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_3.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20c03c1962748bb752bcf3e44c719f0c4723ce2f0e2d4654566dd858b3fe90e0 +size 4104 diff --git a/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_4.pb b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_4.pb new file mode 100644 index 000000000..5296fd63b --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_4.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe0ecdda5d8aa3579ba6a1d3a70e9527a0c9482eb3f28c06771fe9fff6780731 +size 33554445 diff --git a/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_5.pb b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_5.pb new file mode 100644 index 000000000..97d69e8d8 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_5.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b35459131cf1408e9a34589ce199cd7505c77ca1b6c36f387abddf025af2c4f6 +size 4194317 diff --git a/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_6.pb b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_6.pb new file mode 100644 index 000000000..b1cb0c1d6 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_6.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4cc6c3459eaf88b7f4666ab404d9a834e76484d3436246ccfcf0d2fa9886951 +size 4194317 diff --git a/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_7.pb b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_7.pb new file mode 100644 index 000000000..9c7bca13e --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_7.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e23d23d0b4793b8b4b676086d18d8eddde4a2396a55b65c3bae93b66177a0e0c +size 135 diff --git a/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_8.pb b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_8.pb new file mode 100644 index 000000000..8f48f067d --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/input_8.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:323a3642f548314af50de9c5e88e25de7aeb435ed10cd44086bbc55f44d50ebe +size 33554445 diff --git a/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/output_0.pb b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/output_0.pb new file mode 100644 index 000000000..b6409905c --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a12d32a2fc58002ec4d17364d4e4584ae3fc9605fd916fa0a5ed5981ae20560e +size 524304 diff --git a/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/output_1.pb b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/output_1.pb new file mode 100644 index 000000000..d8d5350a3 --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/output_1.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ec009fdc25a78ee3d6d7e0da303f60ff5312a87c6f1f2d4cc1dd144303b15ab +size 524304 diff --git a/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/output_2.pb b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/output_2.pb new file mode 100644 index 000000000..04771af3c --- /dev/null +++ b/testdata/unittest_models/sdpa_yi_4_38_0/test_data_set_0/output_2.pb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73e21e6c015872857614dd6f3682a6abcda3680e1bfebea41b95f7b47eb4ca6e +size 4194319 diff --git a/onnxscript/tests/README.md b/tests/README.md similarity index 100% rename from onnxscript/tests/README.md rename to tests/README.md diff --git a/onnxscript/tests/__init__.py b/tests/__init__.py similarity index 100% rename from onnxscript/tests/__init__.py rename to tests/__init__.py diff --git a/onnxscript/tests/common/__init__.py b/tests/common/__init__.py similarity index 100% rename from onnxscript/tests/common/__init__.py rename to tests/common/__init__.py diff --git a/onnxscript/tests/common/onnx_script_test_case.py b/tests/common/onnx_script_test_case.py similarity index 100% rename from onnxscript/tests/common/onnx_script_test_case.py rename to tests/common/onnx_script_test_case.py diff --git a/tests/common/testutils.py b/tests/common/testutils.py new file mode 100644 index 000000000..c0dafbff1 --- /dev/null +++ b/tests/common/testutils.py @@ -0,0 +1,117 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import functools +import os +import pathlib +import unittest + +import numpy as np +import onnx +import onnxruntime + +from onnxscript import optimizer +from onnxscript._legacy_ir import visitor +from onnxscript.rewriter import onnxruntime as ort_rewriter +from onnxscript.utils import evaluation_utils + + +class TestBase(unittest.TestCase): + """The base class for testing ONNX Script functions for internal use.""" + + def validate(self, fn): + """Validate script function translation.""" + return fn.to_function_proto() + + +def skip_if_no_cuda(reason: str): + def skip_dec(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if not onnxruntime.get_device() == "GPU": + raise unittest.SkipTest(f"GPU is not available. {reason}") + return func(self, *args, **kwargs) + + return wrapper + + return skip_dec + + +class OpTypeAnalysisVisitor(visitor.ProtoVisitorCore): + def __init__(self): + super().__init__() + self.op_types = set() + + def visit_model(self, model: onnx.ModelProto): + self.op_types = set() + super().visit_model(model) + + def process_node(self, node: onnx.NodeProto): + self.op_types.add((node.domain, node.op_type, getattr(node, "overload", ""))) + return super().process_node(node) + + +def test_onnxruntime_rewrite( + model_basename: str, + model_count: int, + expected_optypes: set[tuple[str, str, str]], + rtol: float = 1e-2, + atol: float = 1e-2, +): + dir_path = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) + unittest_root_dir = dir_path.parent.parent / "testdata" / "unittest_models" + for model_index in range(model_count): + model_name = f"{model_basename}_{model_index}" + model_dir = unittest_root_dir / f"{model_name}" + model_path = model_dir / f"{model_name}.onnx" + model = onnx.load(model_path) + + # TODO: Parity issue with randomly generated data. Need investigation. + # inputs = generate_random_input(model) + inputs, expected_outputs = evaluation_utils.load_test_data( + model_dir, [i.name for i in model.graph.input] + ) + + optimized = optimizer.optimize( + model, + onnx_shape_inference=False, + num_iterations=2, + ) + rewritten = ort_rewriter.rewrite(optimized) + # NOTE: uncomment this to save the optimized model. + # onnx.save(rewritten, model_dir / f"{model_name}_opt.onnx") + + # Check expected operator is found. + optype_analysis = OpTypeAnalysisVisitor() + optype_analysis.visit_model(rewritten) + for domain, op_type, overload in expected_optypes: + if (domain, op_type, overload) not in optype_analysis.op_types: + raise AssertionError( + f"Expected op type {domain}:{op_type}:{overload} not found in rewritten model." + ) + + # Run baseline model + providers = ["CUDAExecutionProvider"] + + # Run optimized model + optimized_session = onnxruntime.InferenceSession( + rewritten.SerializeToString(), providers=providers + ) + optimized_outputs = optimized_session.run(None, inputs) + + for i, (baseline_output, optimized_output) in enumerate( + zip(expected_outputs, optimized_outputs) + ): + try: + np.testing.assert_equal(baseline_output.shape, optimized_output.shape) + np.testing.assert_allclose( + baseline_output, optimized_output, rtol=rtol, atol=atol + ) + except AssertionError as e: + print( + f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}" + ) + raise diff --git a/onnxscript/tests/eager_mode_test.py b/tests/eager_mode_test.py similarity index 100% rename from onnxscript/tests/eager_mode_test.py rename to tests/eager_mode_test.py diff --git a/onnxscript/tests/eager_test.py b/tests/eager_test.py similarity index 99% rename from onnxscript/tests/eager_test.py rename to tests/eager_test.py index 78eb2d1ad..ffed8be5f 100644 --- a/onnxscript/tests/eager_test.py +++ b/tests/eager_test.py @@ -7,8 +7,8 @@ import numpy as np import parameterized -from onnxscript.tests.common import onnx_script_test_case -from onnxscript.tests.models import signal_dft +from tests.common import onnx_script_test_case +from tests.models import signal_dft def _fft(x, fft_length, axis=-1): diff --git a/onnxscript/tests/external_tensor_test.py b/tests/external_tensor_test.py similarity index 100% rename from onnxscript/tests/external_tensor_test.py rename to tests/external_tensor_test.py diff --git a/onnxscript/tests/function_libs/torch_lib/README.md b/tests/function_libs/torch_lib/README.md similarity index 100% rename from onnxscript/tests/function_libs/torch_lib/README.md rename to tests/function_libs/torch_lib/README.md diff --git a/onnxscript/tests/function_libs/torch_lib/error_reproduction.py b/tests/function_libs/torch_lib/error_reproduction.py similarity index 100% rename from onnxscript/tests/function_libs/torch_lib/error_reproduction.py rename to tests/function_libs/torch_lib/error_reproduction.py diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py similarity index 100% rename from onnxscript/tests/function_libs/torch_lib/extra_opinfo.py rename to tests/function_libs/torch_lib/extra_opinfo.py diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py similarity index 99% rename from onnxscript/tests/function_libs/torch_lib/ops_test.py rename to tests/function_libs/torch_lib/ops_test.py index 8a060098d..cf29a8b80 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -38,7 +38,7 @@ import onnxscript import onnxscript.evaluator -from onnxscript.tests.function_libs.torch_lib import ( +from tests.function_libs.torch_lib import ( error_reproduction, ops_test_common, ops_test_data, diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py b/tests/function_libs/torch_lib/ops_test_common.py similarity index 99% rename from onnxscript/tests/function_libs/torch_lib/ops_test_common.py rename to tests/function_libs/torch_lib/ops_test_common.py index f1c95ffc3..a218777b3 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_common.py +++ b/tests/function_libs/torch_lib/ops_test_common.py @@ -32,7 +32,7 @@ import onnxscript import onnxscript.evaluator from onnxscript.function_libs.torch_lib import graph_building -from onnxscript.tests.function_libs.torch_lib import error_reproduction +from tests.function_libs.torch_lib import error_reproduction T = TypeVar("T") diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py similarity index 99% rename from onnxscript/tests/function_libs/torch_lib/ops_test_data.py rename to tests/function_libs/torch_lib/ops_test_data.py index 1b4855d6e..2ff6bf4dc 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -54,7 +54,7 @@ from onnxscript.function_libs.torch_lib.ops import nn as nn_ops from onnxscript.function_libs.torch_lib.ops import special as special_ops from onnxscript.function_libs.torch_lib.ops import vision as vision_ops -from onnxscript.tests.function_libs.torch_lib import extra_opinfo, ops_test_common +from tests.function_libs.torch_lib import extra_opinfo, ops_test_common # Create a copy of the op_db to modify OPS_DB = copy.deepcopy(common_methods_invocations.op_db) diff --git a/onnxscript/tests/functions/attr_test.py b/tests/functions/attr_test.py similarity index 95% rename from onnxscript/tests/functions/attr_test.py rename to tests/functions/attr_test.py index 803bf7364..778c7314c 100644 --- a/onnxscript/tests/functions/attr_test.py +++ b/tests/functions/attr_test.py @@ -9,7 +9,7 @@ from onnxscript import script from onnxscript.onnx_opset import opset17 as op -from onnxscript.tests.common.onnx_script_test_case import FunctionTestParams as Test +from tests.common.onnx_script_test_case import FunctionTestParams as Test @script() diff --git a/onnxscript/tests/functions/gemmgelu.py b/tests/functions/gemmgelu.py similarity index 100% rename from onnxscript/tests/functions/gemmgelu.py rename to tests/functions/gemmgelu.py diff --git a/onnxscript/tests/functions/gemmgelu_test.py b/tests/functions/gemmgelu_test.py similarity index 94% rename from onnxscript/tests/functions/gemmgelu_test.py rename to tests/functions/gemmgelu_test.py index e2ffc4155..3b38e6023 100644 --- a/onnxscript/tests/functions/gemmgelu_test.py +++ b/tests/functions/gemmgelu_test.py @@ -7,8 +7,8 @@ import numpy as np -from onnxscript.tests.common import onnx_script_test_case -from onnxscript.tests.functions import gemmgelu +from tests.common import onnx_script_test_case +from tests.functions import gemmgelu class TestGemmGelu(onnx_script_test_case.OnnxScriptTestCase): diff --git a/onnxscript/tests/functions/if_test.py b/tests/functions/if_test.py similarity index 92% rename from onnxscript/tests/functions/if_test.py rename to tests/functions/if_test.py index 799f7e463..bc80179ca 100644 --- a/onnxscript/tests/functions/if_test.py +++ b/tests/functions/if_test.py @@ -7,8 +7,8 @@ import numpy as np -from onnxscript.tests.common import onnx_script_test_case -from onnxscript.tests.models import if_statement +from tests.common import onnx_script_test_case +from tests.models import if_statement class TestOnnxIf(onnx_script_test_case.OnnxScriptTestCase): diff --git a/onnxscript/tests/functions/onnxfns1A_test.py b/tests/functions/onnxfns1A_test.py similarity index 93% rename from onnxscript/tests/functions/onnxfns1A_test.py rename to tests/functions/onnxfns1A_test.py index 09302a634..7f19ebaf7 100644 --- a/onnxscript/tests/functions/onnxfns1A_test.py +++ b/tests/functions/onnxfns1A_test.py @@ -2,8 +2,8 @@ import pytest -from onnxscript.tests.common import onnx_script_test_case -from onnxscript.tests.models import onnxfns1A +from tests.common import onnx_script_test_case +from tests.models import onnxfns1A class TestOnnxFns(onnx_script_test_case.OnnxScriptTestCase): diff --git a/onnxscript/tests/functions/onnxfns2_test.py b/tests/functions/onnxfns2_test.py similarity index 95% rename from onnxscript/tests/functions/onnxfns2_test.py rename to tests/functions/onnxfns2_test.py index f9dc5c9f0..3cf067dbd 100644 --- a/onnxscript/tests/functions/onnxfns2_test.py +++ b/tests/functions/onnxfns2_test.py @@ -1,7 +1,7 @@ import unittest -from onnxscript.tests.common import onnx_script_test_case -from onnxscript.tests.models import onnxfns2 +from tests.common import onnx_script_test_case +from tests.models import onnxfns2 class TestOnnxFns(onnx_script_test_case.OnnxScriptTestCase): diff --git a/onnxscript/tests/functions/onnxfns_test.py b/tests/functions/onnxfns_test.py similarity index 95% rename from onnxscript/tests/functions/onnxfns_test.py rename to tests/functions/onnxfns_test.py index 68ae2e3b8..105721459 100644 --- a/onnxscript/tests/functions/onnxfns_test.py +++ b/tests/functions/onnxfns_test.py @@ -5,8 +5,8 @@ import unittest -from onnxscript.tests.common import onnx_script_test_case -from onnxscript.tests.models import onnxfns1 +from tests.common import onnx_script_test_case +from tests.models import onnxfns1 class TestOnnxFns(onnx_script_test_case.OnnxScriptTestCase): diff --git a/onnxscript/tests/functions/ort_custom_ops.py b/tests/functions/ort_custom_ops.py similarity index 100% rename from onnxscript/tests/functions/ort_custom_ops.py rename to tests/functions/ort_custom_ops.py diff --git a/onnxscript/tests/if_test.py b/tests/if_test.py similarity index 97% rename from onnxscript/tests/if_test.py rename to tests/if_test.py index b048a290a..346334c09 100644 --- a/onnxscript/tests/if_test.py +++ b/tests/if_test.py @@ -8,7 +8,7 @@ import onnxscript.testing from onnxscript import script from onnxscript.onnx_opset import opset15 as op -from onnxscript.tests.common import testutils +from tests.common import testutils class IfOpTest(testutils.TestBase): diff --git a/tests/ir/serde_test.py b/tests/ir/serde_test.py new file mode 100644 index 000000000..3376451d1 --- /dev/null +++ b/tests/ir/serde_test.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import pathlib +import unittest + +import onnx +import parameterized + +import onnxscript.testing +from onnxscript import ir + +model_folder_path = pathlib.Path(__file__).resolve().parent.parent.parent / "testdata" + +model_paths = list(model_folder_path.rglob("*.onnx")) +test_args = [(model_path.name, model_path) for model_path in model_paths] + + +class SerdeTest(unittest.TestCase): + @parameterized.parameterized.expand(test_args) + def test_serialization_deserialization_produces_same_model( + self, _: str, model_path: pathlib.Path + ) -> None: + model = onnx.load(model_path) + ir_model = ir.serde.deserialize_model(model) + serialized = ir.serde.serialize_model(ir_model) + onnxscript.testing.assert_onnx_proto_equal(serialized, model) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/tests/loop_test.py b/tests/loop_test.py similarity index 96% rename from onnxscript/tests/loop_test.py rename to tests/loop_test.py index 4b36db25a..0be895c08 100644 --- a/onnxscript/tests/loop_test.py +++ b/tests/loop_test.py @@ -5,7 +5,7 @@ from onnxscript import script from onnxscript.onnx_opset import opset15 as op from onnxscript.onnx_types import FLOAT, INT64 -from onnxscript.tests.common import testutils +from tests.common import testutils class LoopOpTest(testutils.TestBase): diff --git a/onnxscript/tests/models/__init__.py b/tests/models/__init__.py similarity index 100% rename from onnxscript/tests/models/__init__.py rename to tests/models/__init__.py diff --git a/onnxscript/tests/models/attrref.py b/tests/models/attrref.py similarity index 100% rename from onnxscript/tests/models/attrref.py rename to tests/models/attrref.py diff --git a/onnxscript/tests/models/cast_like.py b/tests/models/cast_like.py similarity index 100% rename from onnxscript/tests/models/cast_like.py rename to tests/models/cast_like.py diff --git a/onnxscript/tests/models/different_opset.py b/tests/models/different_opset.py similarity index 100% rename from onnxscript/tests/models/different_opset.py rename to tests/models/different_opset.py diff --git a/onnxscript/tests/models/dropout.py b/tests/models/dropout.py similarity index 100% rename from onnxscript/tests/models/dropout.py rename to tests/models/dropout.py diff --git a/onnxscript/tests/models/eager_op.py b/tests/models/eager_op.py similarity index 100% rename from onnxscript/tests/models/eager_op.py rename to tests/models/eager_op.py diff --git a/onnxscript/tests/models/eg1.py b/tests/models/eg1.py similarity index 100% rename from onnxscript/tests/models/eg1.py rename to tests/models/eg1.py diff --git a/onnxscript/tests/models/getitem.py b/tests/models/getitem.py similarity index 98% rename from onnxscript/tests/models/getitem.py rename to tests/models/getitem.py index b0d23ffc9..ae7da8270 100644 --- a/onnxscript/tests/models/getitem.py +++ b/tests/models/getitem.py @@ -14,7 +14,7 @@ from onnxscript import script from onnxscript.onnx_opset import opset15 as op from onnxscript.onnx_types import INT32, INT64 -from onnxscript.tests.common.onnx_script_test_case import FunctionTestParams +from tests.common.onnx_script_test_case import FunctionTestParams x = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=np.int32) zero = np.array(0, dtype=np.int64) diff --git a/onnxscript/tests/models/graph_attr.py b/tests/models/graph_attr.py similarity index 100% rename from onnxscript/tests/models/graph_attr.py rename to tests/models/graph_attr.py diff --git a/onnxscript/tests/models/identity.py b/tests/models/identity.py similarity index 100% rename from onnxscript/tests/models/identity.py rename to tests/models/identity.py diff --git a/onnxscript/tests/models/if_statement.py b/tests/models/if_statement.py similarity index 100% rename from onnxscript/tests/models/if_statement.py rename to tests/models/if_statement.py diff --git a/onnxscript/tests/models/loops_break.py b/tests/models/loops_break.py similarity index 100% rename from onnxscript/tests/models/loops_break.py rename to tests/models/loops_break.py diff --git a/onnxscript/tests/models/loops_while.py b/tests/models/loops_while.py similarity index 100% rename from onnxscript/tests/models/loops_while.py rename to tests/models/loops_while.py diff --git a/onnxscript/tests/models/m1.py b/tests/models/m1.py similarity index 100% rename from onnxscript/tests/models/m1.py rename to tests/models/m1.py diff --git a/onnxscript/tests/models/multi.py b/tests/models/multi.py similarity index 100% rename from onnxscript/tests/models/multi.py rename to tests/models/multi.py diff --git a/onnxscript/tests/models/onnxfns1.py b/tests/models/onnxfns1.py similarity index 100% rename from onnxscript/tests/models/onnxfns1.py rename to tests/models/onnxfns1.py diff --git a/onnxscript/tests/models/onnxfns1A.py b/tests/models/onnxfns1A.py similarity index 100% rename from onnxscript/tests/models/onnxfns1A.py rename to tests/models/onnxfns1A.py diff --git a/onnxscript/tests/models/onnxfns2.py b/tests/models/onnxfns2.py similarity index 100% rename from onnxscript/tests/models/onnxfns2.py rename to tests/models/onnxfns2.py diff --git a/onnxscript/tests/models/opt_input.py b/tests/models/opt_input.py similarity index 100% rename from onnxscript/tests/models/opt_input.py rename to tests/models/opt_input.py diff --git a/onnxscript/tests/models/opt_output.py b/tests/models/opt_output.py similarity index 100% rename from onnxscript/tests/models/opt_output.py rename to tests/models/opt_output.py diff --git a/onnxscript/tests/models/renaming.py b/tests/models/renaming.py similarity index 100% rename from onnxscript/tests/models/renaming.py rename to tests/models/renaming.py diff --git a/onnxscript/tests/models/sequences.py b/tests/models/sequences.py similarity index 100% rename from onnxscript/tests/models/sequences.py rename to tests/models/sequences.py diff --git a/onnxscript/tests/models/signal_dft.py b/tests/models/signal_dft.py similarity index 100% rename from onnxscript/tests/models/signal_dft.py rename to tests/models/signal_dft.py diff --git a/onnxscript/tests/models/subfunction.py b/tests/models/subfunction.py similarity index 100% rename from onnxscript/tests/models/subfunction.py rename to tests/models/subfunction.py diff --git a/onnxscript/tests/models/type_double.py b/tests/models/type_double.py similarity index 100% rename from onnxscript/tests/models/type_double.py rename to tests/models/type_double.py diff --git a/onnxscript/tests/onnx_types_test.py b/tests/onnx_types_test.py similarity index 100% rename from onnxscript/tests/onnx_types_test.py rename to tests/onnx_types_test.py diff --git a/onnxscript/tests/operator_test.py b/tests/operator_test.py similarity index 100% rename from onnxscript/tests/operator_test.py rename to tests/operator_test.py diff --git a/tests/optimizer/test_models.py b/tests/optimizer/test_models.py new file mode 100644 index 000000000..29843a375 --- /dev/null +++ b/tests/optimizer/test_models.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import pathlib +import tempfile +import unittest + +import numpy as np +import onnx +import onnxruntime +import parameterized + +from onnxscript import optimizer +from onnxscript.utils import evaluation_utils + +_SKIP_TABLE = {} + +model_folder_path = ( + pathlib.Path(__file__).resolve().parent.parent.parent / "testdata" / "e2e_models" +) + +# List all entries in the directory and filter for directories +model_names = [entry.name for entry in model_folder_path.iterdir() if entry.is_dir()] + + +class ModelTest(unittest.TestCase): + @parameterized.parameterized.expand(model_names) + def test_model_runs_and_matches_accuracy_after_optimization(self, model_name): + test_id = model_name # This can be expanded in the future with more parameters, e.g. optimization options + if (skip_reason := _SKIP_TABLE.get(test_id)) is not None: + self.skipTest(skip_reason) + + model_dir = f"{model_folder_path}/{model_name}/dynamo" + model = onnx.load(f"{model_dir}/{model_name}_dynamo.onnx") + model = optimizer.optimize( + model, + onnx_shape_inference=False, + ) + + with tempfile.TemporaryDirectory() as tmp_folder: + optimized_model_path = f"{tmp_folder}/{model_name}_opt.onnx" + onnx.save( + model, + optimized_model_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + ) + + sess = onnxruntime.InferenceSession( + optimized_model_path, + providers=["CPUExecutionProvider"], + ) + + inputs, expected_outputs = evaluation_utils.load_test_data( + model_dir, [i.name for i in model.graph.input] + ) + + input_names = [i.name for i in sess.get_inputs()] + assert set(input_names) == set(inputs.keys()) + + outputs = sess.run(None, inputs) + + for output, expected_output in zip(outputs, expected_outputs): + np.testing.assert_allclose(output, expected_output, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + unittest.main() From 667ab08bc752f4246c420a999d1fc16d05db2771 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 4 Apr 2024 17:23:31 -0700 Subject: [PATCH 2/3] Attempt to fix linter on "Migrate onnxrewriter" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Squashed of the following steps: - #1328 - #1329 - #1330 - #1331 - #1332 - #1333 - #1343 - #1345 Co-authored-by: Shubham Bhokare <32080845+shubhambhokare1users.noreply.github.com> Co-authored-by: Justin Chu Co-authored-by: Xavier Dupré Co-authored-by: "G. Ramalingam" Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnaviusers.noreply.github.com> Co-authored-by: Ti-Tai Wang [ghstack-poisoned] --- .lintrunner.toml | 4 +++- pyproject.toml | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index e86109025..184a1a4ac 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -61,7 +61,8 @@ exclude_patterns = [ 'onnxscript/_legacy_ir/protobuilder.py', # FIXME 'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME 'onnxscript/ir/serde.py', # FIXME - 'onnxrewriter/rewriter/pattern/generic_pattern_test.py', # FIXME + 'onnxrewriter/rewriter/generic_pattern_test.py', # FIXME + 'onnxrewriter/rewriter/generic_pattern.py', # FIXME ] command = [ 'python', @@ -118,6 +119,7 @@ include_patterns = [ '**/*.py', ] exclude_patterns = [ + 'examples/**', # TODO: Merge with docs/examples 'docs/examples/**', 'docs/tutorial/examples/**', 'onnxscript/converter_test.py', diff --git a/pyproject.toml b/pyproject.toml index 57b607caa..10933171b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -196,7 +196,8 @@ ignore-init-module-imports = true [tool.ruff.per-file-ignores] "__init__.py" = ["TID252"] # Allow relative imports in init files -"**/{examples,tests,docs,tools,utils}/*" = ["TID251"] # pathlib is allowed in supporting code +"setup.py" = ["TID251"] # pathlib is allowed in supporting code +"**/{examples,tests,docs,tools,utils,opgen}/*" = ["TID251"] # pathlib is allowed in supporting code "**/*_test.py" = ["TID251"] # pathlib is allowed in tests "**/generic_pattern.py" = ["FBT003", "UP037"] # inline ignoring fails "**/generic_pattern_test.py" = ["ARG001", "ARG002", "PLR2004"] From 596f3af8b189f20439df5a577e63d6b60635e695 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Fri, 5 Apr 2024 10:57:15 -0700 Subject: [PATCH 3/3] Update on "Migrate onnxrewriter" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Squashed of the following steps: - #1328 - #1329 - #1330 - #1331 - #1332 - #1333 - #1343 - #1345 Co-authored-by: Shubham Bhokare <32080845+shubhambhokare1users.noreply.github.com> Co-authored-by: Justin Chu Co-authored-by: Xavier Dupré Co-authored-by: "G. Ramalingam" Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnaviusers.noreply.github.com> Co-authored-by: Ti-Tai Wang [ghstack-poisoned] --- .lintrunner.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 184a1a4ac..8e9639c22 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -61,8 +61,8 @@ exclude_patterns = [ 'onnxscript/_legacy_ir/protobuilder.py', # FIXME 'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME 'onnxscript/ir/serde.py', # FIXME - 'onnxrewriter/rewriter/generic_pattern_test.py', # FIXME - 'onnxrewriter/rewriter/generic_pattern.py', # FIXME + 'onnxscript/rewriter/generic_pattern_test.py', # FIXME + 'onnxscript/rewriter/generic_pattern.py', # FIXME ] command = [ 'python',