diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h index d72868cd8fa9..56c8e2911e28 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { namespace aten_ops { -typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); +typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size, DLManagedTensor** dlpack_inputs, size_t output_size, DLManagedTensor** dlpack_outputs); @@ -22,17 +22,17 @@ class ATenOperatorExecutor { return instance; } - void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) { - ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw); - p_is_cpu_argument_func_ = reinterpret_cast(p_is_cpu_argument_func_raw); + void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) { + ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw); + p_is_tensor_argument_func_ = reinterpret_cast(p_is_tensor_argument_func_raw); p_execute_aten_op_func_ = reinterpret_cast(p_execute_aten_op_func_raw); } bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; } - bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { - ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized."); - return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); + bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { + ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized."); + return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); } void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size, @@ -43,7 +43,7 @@ class ATenOperatorExecutor { } private: - IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr; + IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr; ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr; }; diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 23fe5e1cd3d9..b737d735b977 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -1015,9 +1015,19 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) } #ifdef ENABLE_ATEN + // For ATen node, we assume that all tensor inputs are on device, all non-tensor inputs are on CPU, + // except those specified in attribute cpu_input_args; if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && node.Domain() == kPytorchAtenDomain) { const auto& attrs = node.GetAttributes(); + if (auto entry = attrs.find("cpu_input_args"); entry != attrs.end()) { + const auto& attr = entry->second; + if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(), + [index](int64_t arg) { return static_cast(index) == arg; })) { + return true; + } + } + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); std::string op_name = attrs.at("operator").s(); std::string overload_name = ""; @@ -1025,7 +1035,7 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) overload_name = attrs.at("overload_name").s(); } - return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true); + return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, true); } #else ORT_UNUSED_PARAMETER(node); @@ -1040,9 +1050,19 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index } #ifdef ENABLE_ATEN + // For ATen node, we assume that all tensor outputs are on device, all non-tensor outputs are on CPU, + // except those specified in attribute cpu_output_args; if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && node.Domain() == kPytorchAtenDomain) { const auto& attrs = node.GetAttributes(); + if (auto entry = attrs.find("cpu_output_args"); entry != attrs.end()) { + const auto& attr = entry->second; + if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(), + [index](int64_t arg) { return static_cast(index) == arg; })) { + return true; + } + } + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); std::string op_name = attrs.at("operator").s(); std::string overload_name = ""; @@ -1050,7 +1070,7 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index overload_name = attrs.at("overload_name").s(); } - return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false); + return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, false); } #else ORT_UNUSED_PARAMETER(node); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index f06a3785f362..6709398c788f 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3474,6 +3474,8 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 /*min_arity*/ 1) .Attr("operator", "Name of ATen operator.", AttributeProto::STRING) .Attr("overload_name", "Overload name of ATen operator.", AttributeProto::STRING, false) + .Attr("cpu_input_args", "CPU input argument indices.", AttributeProto::INTS, false) + .Attr("cpu_output_args", "CPU output argument indices.", AttributeProto::INTS, false) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor."); #endif diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 9c36eb635ffc..e5e0e81cb7da 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1327,14 +1327,14 @@ void addGlobalMethods(py::module& m) { #ifdef ENABLE_ATEN m.def("register_aten_op_executor", - [](const std::string& is_cpu_argument_address_str, const std::string& aten_op_executor_address_str) -> void { - size_t is_cpu_argument_address_int, aten_op_executor_address_int; + [](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void { + size_t is_tensor_argument_address_int, aten_op_executor_address_int; ORT_THROW_IF_ERROR( - ParseStringWithClassicLocale(is_cpu_argument_address_str, is_cpu_argument_address_int)); + ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int)); ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int)); - void* p_is_cpu_argument = reinterpret_cast(is_cpu_argument_address_int); + void* p_is_tensor_argument = reinterpret_cast(is_tensor_argument_address_int); void* p_aten_op_executor = reinterpret_cast(aten_op_executor_address_int); - contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_cpu_argument, p_aten_op_executor); + contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor); }); #endif } diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py index 8bf7cbf80eb3..9dee6564509d 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py @@ -29,5 +29,5 @@ def load_aten_op_executor_cpp_extension(): from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor _C.register_aten_op_executor( - str(aten_op_executor.is_cpu_argument_address()), str(aten_op_executor.execute_aten_operator_address()) + str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address()) ) diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc index 903a394a06ef..e8be98cbfc0e 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc @@ -34,18 +34,23 @@ struct ATenOperator { std::vector is_optional_arguments; std::vector> default_values; size_t return_size; + std::vector ret_kinds; c10::IValue ToIValueArgument(const DLManagedTensor* dlpack, size_t index) const { TORCH_INTERNAL_ASSERT(index < argument_size); bool is_optional = is_optional_arguments[index]; - TORCH_INTERNAL_ASSERT(dlpack || is_optional || default_values[index]); + TORCH_INTERNAL_ASSERT(dlpack || is_optional || default_values[index] || + elem_kinds[index] == c10::TypeKind::TensorType); if (!dlpack) { if (is_optional) { // Optional argument always has no default value. return c10::IValue(c10::nullopt); } - - return *default_values[index]; + if (default_values[index]) { + return *default_values[index]; + } + // Fow bw func, it's possible that input is an undefined tensor from fw outputs, dlpack is nullptr for such case. + return c10::IValue(at::Tensor()); } bool is_list = is_list_arguments[index]; @@ -142,7 +147,10 @@ class ATenOperatorCache { } aten_op.return_size = schema.returns().size(); for (const auto& ret : schema.returns()) { - TORCH_INTERNAL_ASSERT(ret.type()->kind() == c10::TypeKind::TensorType); + c10::TypeKind ret_type = ret.type()->kind(); + // Support tensor or int only for now. + TORCH_INTERNAL_ASSERT(ret_type == c10::TypeKind::TensorType || ret_type == c10::TypeKind::IntType); + aten_op.ret_kinds.emplace_back(ret_type); } ops_.emplace(key, aten_op); } @@ -154,32 +162,15 @@ class ATenOperatorCache { std::unordered_map, ATenOperator, PairHash> ops_; }; -const std::unordered_map> kCpuTensorInputsMap = { - {"_efficient_attention_forward", {4, 5, 11, 12}}, {"_efficient_attention_backward", {6, 7, 12, 13}}}; - -const std::unordered_map> kCpuTensorOutputsMap = { - {"_efficient_attention_forward", {2, 3}}}; - -// Backend uses this function to check if an argument is CPU input or not. -bool IsCpuArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) { +// Backend uses this function to check if an argument is tensor type or not. +bool IsTensorArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) { + const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name); if (is_input) { - // If the argument is non-tensor type, it's CPU argument. - const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name); TORCH_INTERNAL_ASSERT(index < aten_op.argument_size); - if (aten_op.elem_kinds[index] != c10::TypeKind::TensorType) { - return true; - } - } - - std::string full_name = std::string(op_name); - std::string overload_name_str = std::string(overload_name); - if (overload_name_str != "") { - full_name += ("." + overload_name_str); + return aten_op.elem_kinds[index] == c10::TypeKind::TensorType; } - - const auto& cpu_tensors_map = is_input ? kCpuTensorInputsMap : kCpuTensorOutputsMap; - return cpu_tensors_map.find(full_name) != cpu_tensors_map.end() && - cpu_tensors_map.at(full_name).find(index) != cpu_tensors_map.at(full_name).end(); + TORCH_INTERNAL_ASSERT(index < aten_op.return_size); + return aten_op.ret_kinds[index] == c10::TypeKind::TensorType; } void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t input_size, @@ -216,16 +207,23 @@ void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t TORCH_INTERNAL_ASSERT(output_size == aten_op.return_size); size_t output_index = 0; for (const auto& ret : torch::jit::pop(stack, output_size)) { - const auto& tensor = ret.toTensor(); - dlpack_outputs[output_index++] = - tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr; + if (ret.isTensor()) { + const auto& tensor = ret.toTensor(); + dlpack_outputs[output_index++] = + tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr; + } else if (ret.isInt()) { + at::Tensor scalar = at::scalar_to_tensor(at::Scalar(ret.toInt())); + dlpack_outputs[output_index++] = at::toDLPack(scalar); + } else { + TORCH_INTERNAL_ASSERT(false); + } } } -size_t is_cpu_argument_address() { return reinterpret_cast(&IsCpuArgument); } +size_t is_tensor_argument_address() { return reinterpret_cast(&IsTensorArgument); } size_t execute_aten_operator_address() { return reinterpret_cast(&ExecuteATenOperator); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("is_cpu_argument_address", &is_cpu_argument_address, "Address of tensor argument check."); + m.def("is_tensor_argument_address", &is_tensor_argument_address, "Address of tensor argument check."); m.def("execute_aten_operator_address", &execute_aten_operator_address, "Address of Aten operator executor"); } diff --git a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py index 329fba5aa670..7d5716b85db3 100644 --- a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py +++ b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py @@ -5,7 +5,7 @@ from onnxruntime.capi import _pybind_state as _C -from .aten_op_executor import execute_aten_operator_address, is_cpu_argument_address +from .aten_op_executor import execute_aten_operator_address, is_tensor_argument_address def run_once_aten_op_executor(f): @@ -30,7 +30,7 @@ def aten_op_executor_wrapper(*args, **kwargs): @run_once_aten_op_executor def load_aten_op_executor_cpp_extension(): - _C.register_aten_op_executor(str(is_cpu_argument_address()), str(execute_aten_operator_address())) + _C.register_aten_op_executor(str(is_tensor_argument_address()), str(execute_aten_operator_address())) def init_aten_op_executor(): diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py index a2b8407645c4..a963d30a9e6e 100644 --- a/orttraining/orttraining/python/training/ort_triton/_ir.py +++ b/orttraining/orttraining/python/training/ort_triton/_ir.py @@ -392,5 +392,8 @@ def __init__( for ir_node in kernel.sub_nodes: if isinstance(ir_node, DropoutNode): ir_node.global_offset = running_offset + kernel.offset_calc.symbolic_shape_variables.update( + [symbol.name for symbol in running_offset.free_symbols] + ) running_offset = running_offset + sympy.prod(ir_node.outputs[0].shape) self.has_dropout = True diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py index 3d3538a62da6..368d1b238fd9 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py @@ -13,7 +13,7 @@ if ( "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1 - and Version(torch.__version__) >= Version("2.1.1") + and Version(torch.__version__) >= Version("2.3.0") ): from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401 diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py index b1e8809f03fc..c1fb6e68568f 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py @@ -5,9 +5,12 @@ """ PyTorch's _efficient_attention_forward/_efficient_attention_backward APIs is keep changing. Current implementation -is tested well on version 2.2.0.dev20231010+cu121, and should be run well since official version 2.2.0. If may fail to +is tested well on version 2.3.0.dev20240221+cu118, and should be run well since official version 2.3.0. If may fail to run is you are using PyTorch with older versions. +This file is more like an example of how to add a new graph optimizer. Ideally user can add graph optimizer according +to the specific model they are using on their own instead of putting every possible graph optimizer here. + PyTorch also has API for flash attention (currently doesn't support random attention mask or Dropout), we can add support if we want to try in the future. """ @@ -40,13 +43,14 @@ def _make_efficient_attention_nodes( scale_node = make_constant_node("scale_" + str(idx), TensorProto.FLOAT, [], [scale]) dropout_ratio_node = make_constant_node("dropout_ratio_" + str(idx), TensorProto.FLOAT, [], [dropout_ratio]) causal_node = make_constant_node("causal_" + str(idx), TensorProto.INT64, [], [1 if causal else 0]) - int_zero_node = make_constant_node("int_zero_" + str(idx), TensorProto.INT64, [], [0]) - true_node = make_constant_node("true_" + str(idx), TensorProto.BOOL, [], [True]) - false_node = make_constant_node("false_" + str(idx), TensorProto.BOOL, [], [False]) + one_node = make_constant_node("one_" + str(idx), TensorProto.INT64, [], [1]) + zero_node = make_constant_node("zero_" + str(idx), TensorProto.INT64, [], [0]) logsumexp = helper.make_tensor_value_info("logsumexp" + str(idx), TensorProto.FLOAT, []) seed = helper.make_tensor_value_info("seed" + str(idx), TensorProto.INT64, []) offset = helper.make_tensor_value_info("offset" + str(idx), TensorProto.INT64, []) - new_value_infos = [logsumexp, seed, offset] + msb_q = helper.make_tensor_value_info("msb_q_" + str(idx), TensorProto.INT64, []) + msb_k = helper.make_tensor_value_info("msb_k_" + str(idx), TensorProto.INT64, []) + new_value_infos = [logsumexp, seed, offset, msb_q, msb_k] if expand_bias: shape_0 = helper.make_node("Shape", [q], ["shape_0_" + str(idx)], start=0, end=1) shape_1 = helper.make_node("Shape", [q], ["shape_1_" + str(idx)], start=2, end=3) @@ -54,13 +58,13 @@ def _make_efficient_attention_nodes( shape_3 = helper.make_node("Shape", [k], ["shape_3_" + str(idx)], start=1, end=2) concat = helper.make_node( "Concat", - ["shape_0_" + str(idx), "shape_1_" + str(idx), "shape_2_" + str(idx), "shape_3_" + str(idx)], + [shape_0.output[0], shape_1.output[0], shape_2.output[0], shape_3.output[0]], ["concated_shape_" + str(idx)], axis=0, ) - expand = helper.make_node("Expand", [bias, "concated_shape_" + str(idx)], ["expanded_bias_" + str(idx)]) + expand = helper.make_node("Expand", [bias, concat.output[0]], ["expanded_bias_" + str(idx)]) nodes_to_add.extend([shape_0, shape_1, shape_2, shape_3, concat, expand]) - bias = "expanded_bias_" + str(idx) + bias = expand.output[0] fwd_node = helper.make_node( "ATen", [ @@ -71,18 +75,21 @@ def _make_efficient_attention_nodes( "", "", "", + "", dropout_ratio_node.output[0], causal_node.output[0], - true_node.output[0], + one_node.output[0], scale_node.output[0], "", "", ], - [y, logsumexp.name, seed.name, offset.name], + [y, logsumexp.name, seed.name, offset.name, msb_q.name, msb_k.name], "efficient_attention_forward_" + str(idx), None, "org.pytorch.aten", operator="_efficient_attention_forward", + cpu_input_args=[4, 5, 12, 13], + cpu_output_args=[2, 3, 4, 5], ) bwd_node = helper.make_node( "ATen", @@ -95,14 +102,14 @@ def _make_efficient_attention_nodes( y, "", "", - int_zero_node.output[0], - int_zero_node.output[0], + msb_q.name, + msb_k.name, logsumexp.name, dropout_ratio_node.output[0], seed.name, offset.name, causal_node.output[0], - false_node.output[0], + zero_node.output[0], scale_node.output[0], "", ], @@ -111,10 +118,9 @@ def _make_efficient_attention_nodes( None, "org.pytorch.aten", operator="_efficient_attention_backward", + cpu_input_args=[6, 7, 12, 13], ) - nodes_to_add.extend( - [scale_node, dropout_ratio_node, causal_node, int_zero_node, true_node, false_node, fwd_node, bwd_node] - ) + nodes_to_add.extend([scale_node, dropout_ratio_node, causal_node, one_node, zero_node, fwd_node, bwd_node]) return nodes_to_add, new_value_infos @@ -240,140 +246,9 @@ def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodePro return nodes, nodes_to_add, new_value_infos -# No causal mask, no attention mask, without Dropout. -_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ - ("MatMul", False, []), # 0 - ("Mul", True, [(0, 0, 0)]), # 1 - ("Mul", True, [(0, 0, 1)]), # 2 - ("Transpose", True, [(1, 0, 0)]), # 3 - ("Transpose", True, [(2, 0, 0)]), # 4 - ("Softmax", False, [(0, 0, 0)]), # 5 - ("MatMul", False, [(5, 0, 0)]), # 6 - ("Transpose", True, [(6, 0, 1)]), # 7 - ("Transpose", False, [(6, 0, 0)]), # 8 - ("FusedMatMul", False, [(7, 0, 1)]), # 9 - ("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10 - ("FusedMatMul", False, [(2, 0, 1), (10, 0, 0)]), # 11 - ("FusedMatMul", False, [(1, 0, 0), (10, 0, 1)]), # 12 - ("Mul", False, [(11, 0, 0)]), # 13 - ("Mul", False, [(12, 0, 0)]), # 14 - ("Identity", False, [(13, 0, 0)]), # 15 - ("Identity", False, [(14, 0, 0)]), # 16 - ("Transpose", False, [(15, 0, 0)]), # 17 - ("Transpose", False, [(16, 0, 0)]), # 18 - ("FusedMatMul", False, [(5, 0, 0)]), # 19 - ("Transpose", True, [(19, 0, 1)]), # 20 - ("Transpose", False, [(19, 0, 0)]), # 21 -] - - -def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): - # Check forward only as the backward is expected to be consistent if it's built correctly. - scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) - scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 - scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) - scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 - if not ( - check_attribute_value(nodes[3], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1]) - and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) - and scale_value_1 == scale_value_2 - ): - return [], [], [] - - nodes_to_add, new_value_infos = _make_efficient_attention_nodes( - idx, - nodes[3].input[0], - nodes[4].input[0], - nodes[7].input[0], - nodes[8].output[0], - nodes[20].input[0], - nodes[17].output[0], - nodes[18].output[0], - nodes[21].output[0], - "", - False, - scale_value_1, - 0.0, - False, - ) - return nodes, nodes_to_add, new_value_infos - - -# Has causal mask, no attention mask, without Dropout. -_PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ - ("MatMul", False, []), # 0 - ("Mul", True, [(0, 0, 0)]), # 1 - ("Mul", True, [(0, 0, 1)]), # 2 - ("Transpose", True, [(1, 0, 0)]), # 3 - ("Transpose", True, [(2, 0, 0)]), # 4 - ("Add", False, [(0, 0, 0)]), # 5 - ("Slice", True, [(5, 0, 1)]), # 6 - ("Slice", True, [(6, 0, 0)]), # 7 - ("Unsqueeze", True, [(6, 0, 2)]), # 8 - ("Gather", True, [(8, 0, 0)]), # 9 - ("Shape", True, [(9, 0, 0)]), # 10 - ("Softmax", False, [(5, 0, 0)]), # 11 - ("MatMul", False, [(11, 0, 0)]), # 12 - ("Transpose", True, [(12, 0, 1)]), # 13 - ("Transpose", False, [(12, 0, 0)]), # 14 - ("FusedMatMul", False, [(13, 0, 1)]), # 15 - ("SoftmaxGrad_13", False, [(15, 0, 0), (11, 0, 1)]), # 16 - ("Identity", False, [(16, 0, 0)]), # 17 - ("FusedMatMul", False, [(2, 0, 1), (17, 0, 0)]), # 18 - ("FusedMatMul", False, [(1, 0, 0), (17, 0, 1)]), # 19 - ("Mul", False, [(18, 0, 0)]), # 20 - ("Mul", False, [(19, 0, 0)]), # 21 - ("Identity", False, [(20, 0, 0)]), # 22 - ("Identity", False, [(21, 0, 0)]), # 23 - ("Transpose", False, [(22, 0, 0)]), # 24 - ("Transpose", False, [(23, 0, 0)]), # 25 - ("FusedMatMul", False, [(11, 0, 0)]), # 26 - ("Transpose", True, [(26, 0, 1)]), # 27 - ("Transpose", False, [(26, 0, 0)]), # 28 -] - - -def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): - # Check forward only as the backward is expected to be consistent if it's built correctly. - scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) - scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 - scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) - scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 - if not ( - check_attribute_value(nodes[3], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1]) - and check_attribute_value(nodes[13], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[14], "perm", [0, 2, 1, 3]) - and scale_value_1 == scale_value_2 - ): - return [], [], [] - - nodes_to_add, new_value_infos = _make_efficient_attention_nodes( - idx, - nodes[3].input[0], - nodes[4].input[0], - nodes[13].input[0], - nodes[14].output[0], - nodes[27].input[0], - nodes[24].output[0], - nodes[25].output[0], - nodes[28].output[0], - "", - False, - scale_value_1, - 0.0, - True, - ) - return nodes, nodes_to_add, new_value_infos - - _PATTERNS = [ (_PATTERN_0, _optimize_for_pattern_0), (_PATTERN_1, _optimize_for_pattern_1), - (_PATTERN_2, _optimize_for_pattern_2), - (_PATTERN_3, _optimize_for_pattern_3), ]