Skip to content

Commit

Permalink
ATen Op Supports Int Return Type and CPU Tensor Arguments (#19773)
Browse files Browse the repository at this point in the history
This PR:
- add support for int as return type, will create a CPU scalar tensor
for it.
- add attributes to specify which arguments or returns are CPU tensors.
- adjust ATen efficient attn to match latest PyTorch native function.
- a Triton codegen bugfix by the way.
  • Loading branch information
centwang committed Mar 6, 2024
1 parent d102569 commit 1bfc266
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 198 deletions.
16 changes: 8 additions & 8 deletions onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace onnxruntime {
namespace contrib {
namespace aten_ops {

typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input);
typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input);
typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size,
DLManagedTensor** dlpack_inputs, size_t output_size,
DLManagedTensor** dlpack_outputs);
Expand All @@ -22,17 +22,17 @@ class ATenOperatorExecutor {
return instance;
}

void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) {
ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw);
p_is_cpu_argument_func_ = reinterpret_cast<IsCpuArgumentFunc>(p_is_cpu_argument_func_raw);
void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) {
ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw);
p_is_tensor_argument_func_ = reinterpret_cast<IsTensorArgumentFunc>(p_is_tensor_argument_func_raw);
p_execute_aten_op_func_ = reinterpret_cast<ExecuteATenOperatorFunc>(p_execute_aten_op_func_raw);
}

bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; }

bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) {
ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized.");
return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input);
bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) {
ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized.");
return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input);
}

void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size,
Expand All @@ -43,7 +43,7 @@ class ATenOperatorExecutor {
}

private:
IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr;
IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr;
ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr;
};

Expand Down
24 changes: 22 additions & 2 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1015,17 +1015,27 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index)
}

#ifdef ENABLE_ATEN
// For ATen node, we assume that all tensor inputs are on device, all non-tensor inputs are on CPU,
// except those specified in attribute cpu_input_args;
if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" &&
node.Domain() == kPytorchAtenDomain) {
const auto& attrs = node.GetAttributes();
if (auto entry = attrs.find("cpu_input_args"); entry != attrs.end()) {
const auto& attr = entry->second;
if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(),
[index](int64_t arg) { return static_cast<int64_t>(index) == arg; })) {
return true;
}
}

ORT_ENFORCE(utils::HasString(attrs.at("operator")));
std::string op_name = attrs.at("operator").s();
std::string overload_name = "";
if (attrs.find("overload_name") != attrs.end() && utils::HasString(attrs.at("overload_name"))) {
overload_name = attrs.at("overload_name").s();
}

return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true);
return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, true);
}
#else
ORT_UNUSED_PARAMETER(node);
Expand All @@ -1040,17 +1050,27 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index
}

#ifdef ENABLE_ATEN
// For ATen node, we assume that all tensor outputs are on device, all non-tensor outputs are on CPU,
// except those specified in attribute cpu_output_args;
if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" &&
node.Domain() == kPytorchAtenDomain) {
const auto& attrs = node.GetAttributes();
if (auto entry = attrs.find("cpu_output_args"); entry != attrs.end()) {
const auto& attr = entry->second;
if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(),
[index](int64_t arg) { return static_cast<int64_t>(index) == arg; })) {
return true;
}
}

ORT_ENFORCE(utils::HasString(attrs.at("operator")));
std::string op_name = attrs.at("operator").s();
std::string overload_name = "";
if (attrs.find("overload_name") != attrs.end() && utils::HasString(attrs.at("overload_name"))) {
overload_name = attrs.at("overload_name").s();
}

return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false);
return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, false);
}
#else
ORT_UNUSED_PARAMETER(node);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3474,6 +3474,8 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4
/*min_arity*/ 1)
.Attr("operator", "Name of ATen operator.", AttributeProto::STRING)
.Attr("overload_name", "Overload name of ATen operator.", AttributeProto::STRING, false)
.Attr("cpu_input_args", "CPU input argument indices.", AttributeProto::INTS, false)
.Attr("cpu_output_args", "CPU output argument indices.", AttributeProto::INTS, false)
.TypeConstraint("T", OpSchema::all_tensor_types_ir4(),
"Allow inputs and outputs to be any kind of tensor.");
#endif
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1327,14 +1327,14 @@ void addGlobalMethods(py::module& m) {

#ifdef ENABLE_ATEN
m.def("register_aten_op_executor",
[](const std::string& is_cpu_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
size_t is_cpu_argument_address_int, aten_op_executor_address_int;
[](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void {
size_t is_tensor_argument_address_int, aten_op_executor_address_int;
ORT_THROW_IF_ERROR(
ParseStringWithClassicLocale(is_cpu_argument_address_str, is_cpu_argument_address_int));
ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int));
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int));
void* p_is_cpu_argument = reinterpret_cast<void*>(is_cpu_argument_address_int);
void* p_is_tensor_argument = reinterpret_cast<void*>(is_tensor_argument_address_int);
void* p_aten_op_executor = reinterpret_cast<void*>(aten_op_executor_address_int);
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_cpu_argument, p_aten_op_executor);
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor);
});
#endif
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ def load_aten_op_executor_cpp_extension():
from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor

_C.register_aten_op_executor(
str(aten_op_executor.is_cpu_argument_address()), str(aten_op_executor.execute_aten_operator_address())
str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address())
)
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,23 @@ struct ATenOperator {
std::vector<bool> is_optional_arguments;
std::vector<c10::optional<c10::IValue>> default_values;
size_t return_size;
std::vector<c10::TypeKind> ret_kinds;

c10::IValue ToIValueArgument(const DLManagedTensor* dlpack, size_t index) const {
TORCH_INTERNAL_ASSERT(index < argument_size);
bool is_optional = is_optional_arguments[index];
TORCH_INTERNAL_ASSERT(dlpack || is_optional || default_values[index]);
TORCH_INTERNAL_ASSERT(dlpack || is_optional || default_values[index] ||
elem_kinds[index] == c10::TypeKind::TensorType);
if (!dlpack) {
if (is_optional) {
// Optional argument always has no default value.
return c10::IValue(c10::nullopt);
}

return *default_values[index];
if (default_values[index]) {
return *default_values[index];
}
// Fow bw func, it's possible that input is an undefined tensor from fw outputs, dlpack is nullptr for such case.
return c10::IValue(at::Tensor());
}

bool is_list = is_list_arguments[index];
Expand Down Expand Up @@ -142,7 +147,10 @@ class ATenOperatorCache {
}
aten_op.return_size = schema.returns().size();
for (const auto& ret : schema.returns()) {
TORCH_INTERNAL_ASSERT(ret.type()->kind() == c10::TypeKind::TensorType);
c10::TypeKind ret_type = ret.type()->kind();
// Support tensor or int only for now.
TORCH_INTERNAL_ASSERT(ret_type == c10::TypeKind::TensorType || ret_type == c10::TypeKind::IntType);
aten_op.ret_kinds.emplace_back(ret_type);
}
ops_.emplace(key, aten_op);
}
Expand All @@ -154,32 +162,15 @@ class ATenOperatorCache {
std::unordered_map<std::pair<std::string, std::string>, ATenOperator, PairHash> ops_;
};

const std::unordered_map<std::string, std::unordered_set<size_t>> kCpuTensorInputsMap = {
{"_efficient_attention_forward", {4, 5, 11, 12}}, {"_efficient_attention_backward", {6, 7, 12, 13}}};

const std::unordered_map<std::string, std::unordered_set<size_t>> kCpuTensorOutputsMap = {
{"_efficient_attention_forward", {2, 3}}};

// Backend uses this function to check if an argument is CPU input or not.
bool IsCpuArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) {
// Backend uses this function to check if an argument is tensor type or not.
bool IsTensorArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) {
const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name);
if (is_input) {
// If the argument is non-tensor type, it's CPU argument.
const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name);
TORCH_INTERNAL_ASSERT(index < aten_op.argument_size);
if (aten_op.elem_kinds[index] != c10::TypeKind::TensorType) {
return true;
}
}

std::string full_name = std::string(op_name);
std::string overload_name_str = std::string(overload_name);
if (overload_name_str != "") {
full_name += ("." + overload_name_str);
return aten_op.elem_kinds[index] == c10::TypeKind::TensorType;
}

const auto& cpu_tensors_map = is_input ? kCpuTensorInputsMap : kCpuTensorOutputsMap;
return cpu_tensors_map.find(full_name) != cpu_tensors_map.end() &&
cpu_tensors_map.at(full_name).find(index) != cpu_tensors_map.at(full_name).end();
TORCH_INTERNAL_ASSERT(index < aten_op.return_size);
return aten_op.ret_kinds[index] == c10::TypeKind::TensorType;
}

void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t input_size,
Expand Down Expand Up @@ -216,16 +207,23 @@ void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t
TORCH_INTERNAL_ASSERT(output_size == aten_op.return_size);
size_t output_index = 0;
for (const auto& ret : torch::jit::pop(stack, output_size)) {
const auto& tensor = ret.toTensor();
dlpack_outputs[output_index++] =
tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr;
if (ret.isTensor()) {
const auto& tensor = ret.toTensor();
dlpack_outputs[output_index++] =
tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr;
} else if (ret.isInt()) {
at::Tensor scalar = at::scalar_to_tensor(at::Scalar(ret.toInt()));
dlpack_outputs[output_index++] = at::toDLPack(scalar);
} else {
TORCH_INTERNAL_ASSERT(false);
}
}
}

size_t is_cpu_argument_address() { return reinterpret_cast<size_t>(&IsCpuArgument); }
size_t is_tensor_argument_address() { return reinterpret_cast<size_t>(&IsTensorArgument); }
size_t execute_aten_operator_address() { return reinterpret_cast<size_t>(&ExecuteATenOperator); }

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("is_cpu_argument_address", &is_cpu_argument_address, "Address of tensor argument check.");
m.def("is_tensor_argument_address", &is_tensor_argument_address, "Address of tensor argument check.");
m.def("execute_aten_operator_address", &execute_aten_operator_address, "Address of Aten operator executor");
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from onnxruntime.capi import _pybind_state as _C

from .aten_op_executor import execute_aten_operator_address, is_cpu_argument_address
from .aten_op_executor import execute_aten_operator_address, is_tensor_argument_address


def run_once_aten_op_executor(f):
Expand All @@ -30,7 +30,7 @@ def aten_op_executor_wrapper(*args, **kwargs):

@run_once_aten_op_executor
def load_aten_op_executor_cpp_extension():
_C.register_aten_op_executor(str(is_cpu_argument_address()), str(execute_aten_operator_address()))
_C.register_aten_op_executor(str(is_tensor_argument_address()), str(execute_aten_operator_address()))


def init_aten_op_executor():
Expand Down
3 changes: 3 additions & 0 deletions orttraining/orttraining/python/training/ort_triton/_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,5 +392,8 @@ def __init__(
for ir_node in kernel.sub_nodes:
if isinstance(ir_node, DropoutNode):
ir_node.global_offset = running_offset
kernel.offset_calc.symbolic_shape_variables.update(
[symbol.name for symbol in running_offset.free_symbols]
)
running_offset = running_offset + sympy.prod(ir_node.outputs[0].shape)
self.has_dropout = True
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
if (
"ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ
and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1
and Version(torch.__version__) >= Version("2.1.1")
and Version(torch.__version__) >= Version("2.3.0")
):
from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401

Expand Down
Loading

0 comments on commit 1bfc266

Please sign in to comment.