From 361cdc5a98df23488f831436486770d03e08a692 Mon Sep 17 00:00:00 2001 From: ronny1996 Date: Wed, 19 Jul 2023 02:31:56 +0000 Subject: [PATCH 1/2] [PHI CAPI] Add support for registering a new operator, PART2 --- paddle/fluid/framework/custom_operator.cc | 113 +++++++++ paddle/phi/capi/all.h | 2 + paddle/phi/capi/capi.h | 2 + paddle/phi/capi/include/c_kernel_registry.h | 10 + paddle/phi/capi/include/kernel_registry.h | 241 +++++++++++++++++++- paddle/phi/capi/include/kernel_utils.h | 148 ++++++++++++ paddle/phi/capi/lib/CMakeLists.txt | 2 + paddle/phi/capi/lib/c_infer_meta_context.cc | 215 +++++++++++++++++ paddle/phi/capi/lib/c_kernel_registry.cc | 2 + paddle/phi/capi/lib/c_meta_tensor.cc | 150 ++++++++++++ 10 files changed, 884 insertions(+), 1 deletion(-) create mode 100644 paddle/phi/capi/lib/c_infer_meta_context.cc create mode 100644 paddle/phi/capi/lib/c_meta_tensor.cc diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 059524b21c6d61..84c06cea91cbc2 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -38,7 +38,11 @@ limitations under the License. */ #include "paddle/phi/core/tensor_utils.h" #include "paddle/utils/any.h" #ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/phi/backends/device_manager.h" +#include "paddle/phi/capi/include/c_infer_meta_context.h" +#include "paddle/phi/capi/include/c_kernel_registry.h" +#include "paddle/phi/capi/include/c_meta_tensor.h" #endif #include "paddle/phi/api/include/operants_manager.h" @@ -1226,3 +1230,112 @@ LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) { } // namespace framework } // namespace paddle + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +void PD_RegisterOperator(const char* kernel_name_cstr, + size_t in_nargs, + PD_KernelArgumentType* in_args_type, + size_t attr_nargs, + PD_KernelArgumentType* attr_args_type, + size_t out_nargs, + PD_KernelArgumentType* out_args_type, + void (*infer_shape_fn)(PD_InferMetaContext*)) { + std::string kernel_name(kernel_name_cstr); + if (infer_shape_fn && + !paddle::framework::OpInfoMap::Instance().Has(kernel_name)) { + VLOG(8) << "Registering a new operator: " << kernel_name; + + std::vector op_inputs, op_outputs, op_attrs; + + for (size_t i = 0; i < in_nargs; ++i) { + if (in_args_type[i] == PD_KernelArgumentType::PD_ARG_TYPE_TENSOR) { + op_inputs.push_back("Input_" + std::to_string(i)); + } else if (in_args_type[i] == + PD_KernelArgumentType::PD_ARG_TYPE_LIST_TENSOR) { + op_inputs.push_back("Input_" + std::to_string(i) + + paddle::kTensorVectorSuffix); + } else if (in_args_type[i] == + PD_KernelArgumentType::PD_ARG_TYPE_OPTIONAL_TENSOR) { + op_inputs.push_back("Input_" + std::to_string(i) + + paddle::kOptionalSuffix); + } else { + op_inputs.push_back("Input_unknown"); + } + } + for (size_t i = 0; i < out_nargs; ++i) { + if (out_args_type[i] == PD_KernelArgumentType::PD_ARG_TYPE_TENSOR) { + op_outputs.push_back("Output_" + std::to_string(i)); + } else if (out_args_type[i] == + PD_KernelArgumentType::PD_ARG_TYPE_LIST_TENSOR) { + op_outputs.push_back("Output_" + std::to_string(i) + + paddle::kTensorVectorSuffix); + } else { + op_outputs.push_back("Output_unknown"); + } + } + for (size_t i = 0; i < attr_nargs; ++i) { + auto attr_type = attr_args_type[i]; + if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_BOOL) { + op_attrs.push_back("Attr_" + std::to_string(i) + ":bool"); + } else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_INT32) { + op_attrs.push_back("Attr_" + std::to_string(i) + ":int"); + } else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_FLOAT32) { + op_attrs.push_back("Attr_" + std::to_string(i) + ":float"); + } else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_INT64) { + op_attrs.push_back("Attr_" + std::to_string(i) + ":int64_t"); + } else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_STRING) { + op_attrs.push_back("Attr_" + std::to_string(i) + ":std::string"); + } else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_LIST_INT32) { + op_attrs.push_back("Attr_" + std::to_string(i) + ":std::vector"); + } else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_LIST_FLOAT32) { + op_attrs.push_back("Attr_" + std::to_string(i) + ":std::vector"); + } else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_LIST_INT64) { + op_attrs.push_back("Attr_" + std::to_string(i) + + ":std::vector"); + } else if (attr_type == PD_KernelArgumentType::PD_ARG_TYPE_LIST_STRING) { + op_attrs.push_back("Attr_" + std::to_string(i) + + ":std::vector"); + } else { + op_attrs.push_back("Attr_unknown"); + } + } + + paddle::framework::OpInfo info; + // Op + info.creator_ = [](const std::string& op_name, + const paddle::framework::VariableNameMap& inputs, + const paddle::framework::VariableNameMap& outputs, + const paddle::framework::AttributeMap& attrs) { + return new paddle::framework::OperatorWithKernel( + op_name, inputs, outputs, attrs); + }; + + // OpMaker + info.proto_ = new paddle::framework::proto::OpProto; + info.proto_->set_type(kernel_name); + + info.checker_ = new paddle::framework::OpAttrChecker(); + + paddle::framework::CustomOpMaker custom_maker( + op_inputs, op_outputs, op_attrs); + custom_maker(info.proto_, info.checker_); + PADDLE_ENFORCE_EQ( + info.proto_->IsInitialized(), + true, + phi::errors::PreconditionNotMet( + "Fail to initialize %s's OpProto, because %s is not initialized.", + kernel_name, + info.proto_->InitializationErrorString())); + + info.infer_shape_ = [infer_shape_fn, kernel_name]( + paddle::framework::InferShapeContext* ctx) { + auto infer_meta_context = + paddle::framework::BuildInferMetaContext(ctx, kernel_name); + infer_shape_fn( + reinterpret_cast(&infer_meta_context)); + }; + + paddle::framework::OpInfoMap::Instance().Insert(kernel_name, info); + } +} +#endif diff --git a/paddle/phi/capi/all.h b/paddle/phi/capi/all.h index 4e3a133fda362c..5dc7e09f6536f3 100644 --- a/paddle/phi/capi/all.h +++ b/paddle/phi/capi/all.h @@ -17,10 +17,12 @@ #include "paddle/phi/capi/include/c_data_type.h" #include "paddle/phi/capi/include/c_device_context.h" +#include "paddle/phi/capi/include/c_infer_meta_context.h" #include "paddle/phi/capi/include/c_int_array.h" #include "paddle/phi/capi/include/c_kernel_context.h" #include "paddle/phi/capi/include/c_kernel_factory.h" #include "paddle/phi/capi/include/c_kernel_registry.h" +#include "paddle/phi/capi/include/c_meta_tensor.h" #include "paddle/phi/capi/include/c_place.h" #include "paddle/phi/capi/include/c_scalar.h" #include "paddle/phi/capi/include/c_tensor.h" diff --git a/paddle/phi/capi/capi.h b/paddle/phi/capi/capi.h index 85b09315597d55..7cc2f9568fdb6e 100644 --- a/paddle/phi/capi/capi.h +++ b/paddle/phi/capi/capi.h @@ -23,6 +23,8 @@ PD_DECLARE_CAPI(int_array); PD_DECLARE_CAPI(kernel_context); PD_DECLARE_CAPI(kernel_factory); PD_DECLARE_CAPI(kernel_registry); +PD_DECLARE_CAPI(infer_meta_context); +PD_DECLARE_CAPI(meta_tensor); PD_DECLARE_CAPI(place); PD_DECLARE_CAPI(scalar); PD_DECLARE_CAPI(tensor); diff --git a/paddle/phi/capi/include/c_kernel_registry.h b/paddle/phi/capi/include/c_kernel_registry.h index bb2842051f8159..801d65cdd03d00 100644 --- a/paddle/phi/capi/include/c_kernel_registry.h +++ b/paddle/phi/capi/include/c_kernel_registry.h @@ -19,6 +19,7 @@ #include #include "paddle/phi/capi/include/c_data_type.h" +#include "paddle/phi/capi/include/c_infer_meta_context.h" #include "paddle/phi/capi/include/c_kernel_context.h" #include "paddle/phi/capi/include/c_kernel_factory.h" @@ -71,6 +72,15 @@ void PD_RegisterPhiKernel(const char *kernel_name_cstr, void (*fn)(PD_KernelContext *), void *variadic_kernel_fn); +void PD_RegisterOperator(const char *kernel_name_cstr, + size_t in_nargs, + PD_KernelArgumentType *in_args_type, + size_t attr_nargs, + PD_KernelArgumentType *attr_args_type, + size_t out_nargs, + PD_KernelArgumentType *out_args_type, + void (*infer_shape_fn)(PD_InferMetaContext *)); + #ifdef __cplusplus } // extern "C" #endif diff --git a/paddle/phi/capi/include/kernel_registry.h b/paddle/phi/capi/include/kernel_registry.h index e626cd422770db..71fcc76c2c213f 100644 --- a/paddle/phi/capi/include/kernel_registry.h +++ b/paddle/phi/capi/include/kernel_registry.h @@ -187,6 +187,30 @@ inline std::vector PD_MultiOutputAt( return ret; } +inline std::vector PD_InferMetaMultiInputAt( + PD_InferMetaContext *ctx, size_t index) { + std::vector ret; + auto list = PD_InferMetaContextMultiInputAt(ctx, index); + auto data = reinterpret_cast(list.data); + for (size_t i = 0; i < list.size; ++i) { + ret.emplace_back(data[i]); + } + PD_DeletePointerList(list); + return ret; +} + +inline std::vector PD_InferMetaMultiOutputAt( + PD_InferMetaContext *ctx, size_t index) { + std::vector ret; + auto list = PD_InferMetaContextMultiOutputAt(ctx, index); + auto data = reinterpret_cast(list.data); + for (size_t i = 0; i < list.size; ++i) { + ret.emplace_back(data[i]); + } + PD_DeletePointerList(list); + return ret; +} + template inline std::vector PD_GetPointerVector(std::vector *vec) { std::vector ret; @@ -336,6 +360,152 @@ inline std::vector PD_AttrAt>(PD_KernelContext *ctx, return list; } +template +inline T PD_InferMetaAttrAt(PD_InferMetaContext *ctx, size_t index); + +template <> +inline bool PD_InferMetaAttrAt(PD_InferMetaContext *ctx, size_t index) { + return PD_InferMetaContextBoolAttrAt(ctx, index); +} + +template <> +inline int32_t PD_InferMetaAttrAt(PD_InferMetaContext *ctx, + size_t index) { + return PD_InferMetaContextInt32AttrAt(ctx, index); +} + +template <> +inline int64_t PD_InferMetaAttrAt(PD_InferMetaContext *ctx, + size_t index) { + return PD_InferMetaContextInt64AttrAt(ctx, index); +} + +template <> +inline float PD_InferMetaAttrAt(PD_InferMetaContext *ctx, size_t index) { + return PD_InferMetaContextFloatAttrAt(ctx, index); +} + +template <> +inline double PD_InferMetaAttrAt(PD_InferMetaContext *ctx, + size_t index) { + return PD_InferMetaContextDoubleAttrAt(ctx, index); +} + +template <> +inline std::string PD_InferMetaAttrAt(PD_InferMetaContext *ctx, + size_t index) { + return PD_InferMetaContextStringAttrAt(ctx, index); +} + +template <> +inline PD_DataType PD_InferMetaAttrAt(PD_InferMetaContext *ctx, + size_t index) { + return PD_InferMetaContextDataTypeAttrAt(ctx, index); +} + +template <> +inline PD_DataLayout PD_InferMetaAttrAt(PD_InferMetaContext *ctx, + size_t index) { + return PD_InferMetaContextDataLayoutAttrAt(ctx, index); +} + +template <> +inline std::vector PD_InferMetaAttrAt>( + PD_InferMetaContext *ctx, size_t index) { + auto list = PD_InferMetaContextListInt32AttrAt(ctx, index); + auto data = reinterpret_cast(list.data); + std::vector cc_list(data, data + list.size); + return cc_list; +} + +template <> +inline std::vector PD_InferMetaAttrAt>( + PD_InferMetaContext *ctx, size_t index) { + auto list = PD_InferMetaContextListInt64AttrAt(ctx, index); + auto data = reinterpret_cast(list.data); + std::vector cc_list(data, data + list.size); + return cc_list; +} + +template <> +inline std::vector PD_InferMetaAttrAt>( + PD_InferMetaContext *ctx, size_t index) { + auto list = PD_InferMetaContextListFloatAttrAt(ctx, index); + auto data = reinterpret_cast(list.data); + std::vector cc_list(data, data + list.size); + return cc_list; +} + +template <> +inline std::vector PD_InferMetaAttrAt>( + PD_InferMetaContext *ctx, size_t index) { + auto list = PD_InferMetaContextListDoubleAttrAt(ctx, index); + auto data = reinterpret_cast(list.data); + std::vector cc_list(data, data + list.size); + return cc_list; +} + +template <> +inline phi::capi::Scalar PD_InferMetaAttrAt( + PD_InferMetaContext *ctx, size_t index) { + auto scalar = PD_InferMetaContextScalarAttrAt(ctx, index); + return phi::capi::Scalar(scalar); +} + +template <> +inline phi::capi::IntArray PD_InferMetaAttrAt( + PD_InferMetaContext *ctx, size_t index) { + auto int_array = PD_InferMetaContextIntArrayAttrAt(ctx, index); + return phi::capi::IntArray(int_array); +} + +template <> +inline phi::capi::Place PD_InferMetaAttrAt( + PD_InferMetaContext *ctx, size_t index) { + auto place = PD_InferMetaContextPlaceAttrAt(ctx, index); + return phi::capi::Place(place); +} + +template <> +inline std::vector +PD_InferMetaAttrAt>(PD_InferMetaContext *ctx, + size_t index) { + auto c_list = PD_InferMetaContextListScalarAttrAt(ctx, index); + auto data = reinterpret_cast(c_list.data); + std::vector list; + for (size_t i = 0; i < c_list.size; ++i) { + list.emplace_back(data[i]); + } + PD_DeletePointerList(c_list); + return list; +} + +template <> +inline std::vector PD_InferMetaAttrAt>( + PD_InferMetaContext *ctx, size_t index) { + auto c_list = PD_InferMetaContextListStringAttrAt(ctx, index); + auto data = reinterpret_cast(c_list.data); + std::vector list; + for (size_t i = 0; i < c_list.size; ++i) { + list.emplace_back(data[i]); + } + PD_DeletePointerList(c_list); + return list; +} + +template <> +inline std::vector PD_InferMetaAttrAt>( + PD_InferMetaContext *ctx, size_t index) { + auto c_list = PD_InferMetaContextListBoolAttrAt(ctx, index); + std::vector list; + auto data = reinterpret_cast(c_list.data); + for (size_t i = 0; i < c_list.size; ++i) { + list[i] = static_cast(data[i]); + } + PD_DeleteUInt8List(c_list); + return list; +} + #define CPP_TYPE_TO_PD_ARG_TYPE_REGISTER(_) \ _(phi::capi::DenseTensor, ::PD_KernelArgumentType::PD_ARG_TYPE_TENSOR) \ _(phi::capi::DeviceContext, ::PD_KernelArgumentType::PD_ARG_TYPE_CONTEXT) \ @@ -391,13 +561,82 @@ using IntArray = capi::IntArray; using Place = capi::Place; using DataType = ::PD_DataType; using DataLayout = ::PD_DataLayout; - +using DenseTensor = capi::DenseTensor; +using MetaTensor = capi::MetaTensor; } // namespace phi #include "paddle/phi/capi/include/kernel_utils.h" // clang-format off +#define PD_BUILD_NEW_PHI_KERNEL(kernel_name, \ + backend, \ + layout, \ + meta_kernel_fn, \ + infer_shape_fn, \ + ...) \ + static void \ + __CUSTOM_adefs_CFN_##kernel_name##_##backend##_##layout( \ + const PD_KernelKey* kernel_key, PD_Kernel* kernel); \ + template \ + struct __##kernel_name##_##backend##_##layout##__ { \ + __##kernel_name##_##backend##_##layout##__() { \ + ::phi::capi::CustomKernelArgsParseFunctor)> \ + parser; \ + PD_RegisterOperator(#kernel_name, \ + parser.in_args_type.size(), \ + parser.in_args_type.data(), \ + parser.attr_args_type.size(), \ + parser.attr_args_type.data(), \ + parser.out_args_type.size(), \ + parser.out_args_type.data(), \ + PHI_CAPI_INFER_META(infer_shape_fn)); \ + PD_RegisterPhiKernel( \ + #kernel_name, \ + #backend, \ + ::phi::capi::CppTypeToPDType::Type(), \ + PD_DATALAYOUT(layout), \ + parser.in_args_type.size(), \ + parser.in_args_type.data(), \ + parser.attr_args_type.size(), \ + parser.attr_args_type.data(), \ + parser.out_args_type.size(), \ + parser.out_args_type.data(), \ + __CUSTOM_adefs_CFN_##kernel_name##_##backend##_##layout, \ + CUSTOM_PHI_KERNEL(meta_kernel_fn), \ + CUSTOM_PHI_VARIADIC_KERNEL( \ + meta_kernel_fn)); \ + } \ + static void Touch() {} \ + }; \ + PD_CUSTOM_PHI_KERNEL_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + CUSTOM_tp_ns_check_##kernel_name##_##backend##_##layout, \ + "PD_BUILD_KERNEL must be called in global namespace."); \ + static void \ + __CUSTOM_adefs_FN_##kernel_name##_##backend##_##layout( \ + const ::phi::capi::KernelKey &kernel_key, \ + ::phi::capi::Kernel* kernel); \ + _PD_BUILD_PHI_KERNEL(__##kernel_name##_##backend##_##layout##__, \ + kernel_name, \ + backend, \ + layout, \ + meta_kernel_fn, \ + __VA_ARGS__) \ + void \ + __CUSTOM_adefs_CFN_##kernel_name##_##backend##_##layout( \ + const PD_KernelKey* kernel_key, PD_Kernel* kernel) { \ + auto cc_kernel = ::phi::capi::Kernel(kernel); \ + __CUSTOM_adefs_FN_##kernel_name##_##backend##_##layout( \ + ::phi::capi::KernelKey( \ + const_cast(kernel_key)), \ + &cc_kernel); \ + } \ + void \ + __CUSTOM_adefs_FN_##kernel_name##_##backend##_##layout( \ + const ::phi::capi::KernelKey &kernel_key, \ + ::phi::capi::Kernel* kernel) + #define PD_BUILD_PHI_KERNEL(kernel_name, \ backend, \ layout, \ diff --git a/paddle/phi/capi/include/kernel_utils.h b/paddle/phi/capi/include/kernel_utils.h index 3822aaa32a8db8..3e052240cbbb29 100644 --- a/paddle/phi/capi/include/kernel_utils.h +++ b/paddle/phi/capi/include/kernel_utils.h @@ -24,6 +24,9 @@ namespace capi { #define CUSTOM_PHI_KERNEL(...) \ ::phi::capi::CustomKernelImpl::Compute +#define PHI_CAPI_INFER_META(...) \ + ::phi::capi::InferMetaFnImpl::Call + #define CUSTOM_PHI_VARIADIC_KERNEL(...) \ reinterpret_cast( \ &::phi::capi::CustomKernelImpl { } }; +#define PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(attr_type) \ + template \ + struct InferMetaFnCallHelper { \ + template \ + static void Call(PD_InferMetaContext *ctx, PreviousArgs &...pargs) { \ + static_assert(out_idx == 0, \ + "InferMeta's Attributes should appear before Outputs."); \ + attr_type arg = PD_InferMetaAttrAt(ctx, attr_idx); \ + InferMetaFnCallHelper< \ + Tail...>::template Call(ctx, \ + pargs..., \ + arg); \ + } \ + } + +#define PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( \ + attr_type) \ + template \ + struct InferMetaFnCallHelper { \ + template \ + static void Call(PD_InferMetaContext *ctx, PreviousArgs &...pargs) { \ + static_assert(out_idx == 0, \ + "InferMeta's Attributes should appear before Outputs."); \ + attr_type arg = PD_InferMetaAttrAt(ctx, attr_idx); \ + InferMetaFnCallHelper< \ + Tail...>::template Call(ctx, \ + pargs..., \ + arg); \ + } \ + } + +template +struct InferMetaTypeTag {}; + +template +struct InferMetaFnImpl; + +template +struct InferMetaFnImpl { + static void Call(PD_InferMetaContext *ctx) { + InferMetaFnCallHelper>::template Call<0, 0, 0>(ctx); + } + + private: + template + struct InferMetaFnCallHelper; + + template + struct InferMetaFnCallHelper { + template + static void Call(PD_InferMetaContext *ctx, PreviousArgs &...pargs) { + static_assert(attr_idx == 0, + "InferMeta's Input should appear before Attributes."); + static_assert(out_idx == 0, + "InferMeta's Input should appear before Outputs."); + auto arg = MetaTensor(PD_InferMetaContextInputAt(ctx, in_idx)); + InferMetaFnCallHelper< + Tail...>::template Call(ctx, + pargs..., + arg); + } + }; + + template + struct InferMetaFnCallHelper &, + Tail...> { + template + static void Call(PD_InferMetaContext *ctx, PreviousArgs &...pargs) { + static_assert(attr_idx == 0, + "InferMeta's Input should appear before Attributes."); + static_assert(out_idx == 0, + "InferMeta's Input should appear before Outputs."); + auto arg = PD_InferMetaMultiInputAt(ctx, in_idx); + std::vector tensor_ptr_vec; + for (auto &tensor : arg) { + tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); + } + InferMetaFnCallHelper:: + template Call( + ctx, pargs..., tensor_ptr_vec); + } + }; + + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(float); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType); + // PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + PD_SPECIALIZE_CAPI_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF( + std::vector); + + template + struct InferMetaFnCallHelper { + template + static void Call(PD_InferMetaContext *ctx, PreviousArgs &...pargs) { + auto arg = MetaTensor(PD_InferMetaContextOutputAt(ctx, out_idx)); + auto *arg_ptr = &arg; + InferMetaFnCallHelper< + Tail...>::template Call(ctx, + pargs..., + arg_ptr); + } + }; + + template + struct InferMetaFnCallHelper, Tail...> { + template + static void Call(PD_InferMetaContext *ctx, PreviousArgs &...pargs) { + auto arg = PD_InferMetaMultiOutputAt(ctx, out_idx); + std::vector tensor_ptr_vec; + for (auto &tensor : arg) { + tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); + } + InferMetaFnCallHelper:: + template Call( + ctx, pargs..., tensor_ptr_vec); + } + }; + + /* End case */ + template + struct InferMetaFnCallHelper> { + template + static void Call(PD_InferMetaContext *ctx, Args &...args) { + return infer_meta_fn(args...); + } + }; +}; + } // namespace capi } // namespace phi diff --git a/paddle/phi/capi/lib/CMakeLists.txt b/paddle/phi/capi/lib/CMakeLists.txt index 8cf3c9caf8ece6..251b0ec9c51dc5 100644 --- a/paddle/phi/capi/lib/CMakeLists.txt +++ b/paddle/phi/capi/lib/CMakeLists.txt @@ -7,6 +7,8 @@ collect_srcs( c_kernel_context.cc c_kernel_factory.cc c_kernel_registry.cc + c_infer_meta_context.cc + c_meta_tensor.cc c_place.cc c_scalar.cc c_tensor.cc) diff --git a/paddle/phi/capi/lib/c_infer_meta_context.cc b/paddle/phi/capi/lib/c_infer_meta_context.cc new file mode 100644 index 00000000000000..e11121c0d8caaf --- /dev/null +++ b/paddle/phi/capi/lib/c_infer_meta_context.cc @@ -0,0 +1,215 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/capi/include/c_infer_meta_context.h" + +#include "paddle/phi/capi/include/common.h" +#include "paddle/phi/capi/include/type_utils.h" +#include "paddle/phi/core/infermeta_utils.h" + +PD_MetaTensor* PD_InferMetaContextInputAt(PD_InferMetaContext* ctx, + size_t index) { + auto* meta_ctx = reinterpret_cast(ctx); + const std::pair range = meta_ctx->InputRangeAt(index); + const phi::MetaTensor& arg = meta_ctx->InputAt(range.first); + return reinterpret_cast(const_cast(&arg)); +} + +PD_List PD_InferMetaContextMultiInputAt(PD_InferMetaContext* ctx, + size_t index) { + auto* meta_ctx = reinterpret_cast(ctx); + const std::pair range = meta_ctx->InputRangeAt(index); + std::vector tensor_vec = + meta_ctx->InputsBetween(range.first, range.second); + PD_List list; + list.size = tensor_vec.size(); + list.data = new void*[list.size]; + for (size_t i = 0; i < list.size; ++i) { + (reinterpret_cast(list.data))[i] = + reinterpret_cast(const_cast(tensor_vec[i])); + } + return list; +} + +PD_MetaTensor* PD_InferMetaContextOutputAt(PD_InferMetaContext* ctx, + size_t index) { + auto* meta_ctx = reinterpret_cast(ctx); + const std::pair range = meta_ctx->OutputRangeAt(index); + phi::MetaTensor* arg = meta_ctx->MutableOutputAt(range.first); + return reinterpret_cast(arg); +} + +PD_List PD_InferMetaContextMultiOutputAt(PD_InferMetaContext* ctx, + size_t index) { + auto* meta_ctx = reinterpret_cast(ctx); + const std::pair range = meta_ctx->OutputRangeAt(index); + std::vector tensor_vec = + meta_ctx->MutableOutputBetween(range.first, range.second); + PD_List list; + list.size = tensor_vec.size(); + list.data = new void*[list.size]; + for (size_t i = 0; i < list.size; ++i) { + (reinterpret_cast(list.data))[i] = + reinterpret_cast(tensor_vec[i]); + } + return list; +} + +bool PD_InferMetaContextBoolAttrAt(PD_InferMetaContext* ctx, size_t index) { + auto meta_ctx = reinterpret_cast(ctx); + return meta_ctx->AttrAt(index); +} + +int32_t PD_InferMetaContextInt32AttrAt(PD_InferMetaContext* ctx, size_t index) { + auto meta_ctx = reinterpret_cast(ctx); + return meta_ctx->AttrAt(index); +} + +int64_t PD_InferMetaContextInt64AttrAt(PD_InferMetaContext* ctx, size_t index) { + auto meta_ctx = reinterpret_cast(ctx); + return meta_ctx->AttrAt(index); +} + +float PD_InferMetaContextFloatAttrAt(PD_InferMetaContext* ctx, size_t index) { + auto meta_ctx = reinterpret_cast(ctx); + return meta_ctx->AttrAt(index); +} + +double PD_InferMetaContextDoubleAttrAt(PD_InferMetaContext* ctx, size_t index) { + auto meta_ctx = reinterpret_cast(ctx); + return meta_ctx->AttrAt(index); +} + +PD_Scalar* PD_InferMetaContextScalarAttrAt(PD_InferMetaContext* ctx, + size_t index) { + auto meta_ctx = reinterpret_cast(ctx); + return reinterpret_cast( + const_cast(&meta_ctx->AttrAt(index))); +} + +PD_IntArray* PD_InferMetaContextIntArrayAttrAt(PD_InferMetaContext* ctx, + size_t index) { + auto meta_ctx = reinterpret_cast(ctx); + return reinterpret_cast( + const_cast(&meta_ctx->AttrAt(index))); +} + +PD_List PD_InferMetaContextListBoolAttrAt(PD_InferMetaContext* ctx, + size_t index) { + PD_List list; + auto meta_ctx = reinterpret_cast(ctx); + const auto& cc_list = meta_ctx->AttrAt>(index); + list.size = cc_list.size(); + auto data = reinterpret_cast(new uint8_t[cc_list.size()]); + for (size_t i = 0; i < cc_list.size(); ++i) { + data[i] = static_cast(cc_list[i]); + } + list.data = data; + return list; +} + +PD_List PD_InferMetaContextListInt32AttrAt(PD_InferMetaContext* ctx, + size_t index) { + PD_List list; + auto meta_ctx = reinterpret_cast(ctx); + const auto& cc_list = meta_ctx->AttrAt>(index); + list.size = cc_list.size(); + list.data = const_cast(cc_list.data()); + return list; +} + +PD_List PD_InferMetaContextListInt64AttrAt(PD_InferMetaContext* ctx, + size_t index) { + PD_List list; + auto meta_ctx = reinterpret_cast(ctx); + const auto& cc_list = meta_ctx->AttrAt>(index); + list.size = cc_list.size(); + list.data = const_cast(cc_list.data()); + return list; +} + +PD_List PD_InferMetaContextListFloatAttrAt(PD_InferMetaContext* ctx, + size_t index) { + PD_List list; + auto meta_ctx = reinterpret_cast(ctx); + const auto& cc_list = meta_ctx->AttrAt>(index); + list.size = cc_list.size(); + list.data = const_cast(cc_list.data()); + return list; +} + +PD_List PD_InferMetaContextListDoubleAttrAt(PD_InferMetaContext* ctx, + size_t index) { + PD_List list; + auto meta_ctx = reinterpret_cast(ctx); + const auto& cc_list = meta_ctx->AttrAt>(index); + list.size = cc_list.size(); + list.data = const_cast(cc_list.data()); + return list; +} + +char* PD_InferMetaContextStringAttrAt(PD_InferMetaContext* ctx, size_t index) { + auto meta_ctx = reinterpret_cast(ctx); + return const_cast(meta_ctx->AttrAt(index).data()); +} + +PD_List PD_InferMetaContextListStringAttrAt(PD_InferMetaContext* ctx, + size_t index) { + PD_List list; + auto meta_ctx = reinterpret_cast(ctx); + const auto& cc_list = meta_ctx->AttrAt>(index); + list.size = cc_list.size(); + auto data = new char*[list.size]; + for (size_t i = 0; i < list.size; ++i) { + data[i] = const_cast(cc_list[i].data()); + } + list.data = reinterpret_cast(data); + return list; +} + +PD_List PD_InferMetaContextListScalarAttrAt(PD_InferMetaContext* ctx, + size_t index) { + PD_List list; + auto meta_ctx = reinterpret_cast(ctx); + const auto& cc_list = meta_ctx->AttrAt>(index); + list.size = cc_list.size(); + auto data = new PD_Scalar*[list.size]; + for (size_t i = 0; i < list.size; ++i) { + data[i] = + const_cast(reinterpret_cast(&cc_list[i])); + } + list.data = data; + return list; +} + +PD_Place* PD_InferMetaContextPlaceAttrAt(PD_InferMetaContext* ctx, + size_t index) { + auto meta_ctx = reinterpret_cast(ctx); + return reinterpret_cast( + const_cast(&meta_ctx->AttrAt(index))); +} + +PD_DataType PD_InferMetaContextDataTypeAttrAt(PD_InferMetaContext* ctx, + size_t index) { + auto meta_ctx = reinterpret_cast(ctx); + return phi::capi::ToPDDataType(meta_ctx->AttrAt(index)); +} + +PD_DataLayout PD_InferMetaContextDataLayoutAttrAt(PD_InferMetaContext* ctx, + size_t index) { + auto meta_ctx = reinterpret_cast(ctx); + return phi::capi::ToPDDataLayout(meta_ctx->AttrAt(index)); +} + +PD_REGISTER_CAPI(infer_meta_context); diff --git a/paddle/phi/capi/lib/c_kernel_registry.cc b/paddle/phi/capi/lib/c_kernel_registry.cc index 6cf6208856bfad..a85e3cd5d20a49 100644 --- a/paddle/phi/capi/lib/c_kernel_registry.cc +++ b/paddle/phi/capi/lib/c_kernel_registry.cc @@ -18,6 +18,8 @@ #include "paddle/phi/capi/include/type_utils.h" #include "paddle/phi/core/kernel_registry.h" +#include "glog/logging.h" + void PD_KernelArgsParseFn(const phi::KernelKey& default_key, phi::KernelArgsDef* args_def, size_t in_nargs, diff --git a/paddle/phi/capi/lib/c_meta_tensor.cc b/paddle/phi/capi/lib/c_meta_tensor.cc new file mode 100644 index 00000000000000..d2493058081584 --- /dev/null +++ b/paddle/phi/capi/lib/c_meta_tensor.cc @@ -0,0 +1,150 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/capi/include/c_meta_tensor.h" + +#include "paddle/phi/capi/include/common.h" +#include "paddle/phi/capi/include/type_utils.h" +#include "paddle/phi/core/meta_tensor.h" + +PD_DataType PD_MetaTensorGetPDDataType(const PD_MetaTensor *tensor, + PD_Status *status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return PD_DataType::UNDEFINED; + } + *status = C_SUCCESS; + } + auto cc_tensor = reinterpret_cast(tensor); + return phi::capi::ToPDDataType(cc_tensor->dtype()); +} + +PD_DataLayout PD_MetaTensorGetDataLayout(const PD_MetaTensor *tensor, + PD_Status *status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return PD_DataLayout::ALL_LAYOUT; + } + *status = C_SUCCESS; + } + auto cc_tensor = reinterpret_cast(tensor); + return phi::capi::ToPDDataLayout(cc_tensor->layout()); +} + +int64_t PD_MetaTensorGetElementCount(const PD_MetaTensor *tensor, + PD_Status *status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return 0; + } + *status = C_SUCCESS; + } + + auto cc_tensor = reinterpret_cast(tensor); + return cc_tensor->numel(); +} + +int64_t PD_MetaTensorGetNumDims(const PD_MetaTensor *tensor, + PD_Status *status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return 0; + } + *status = C_SUCCESS; + } + + auto cc_tensor = reinterpret_cast(tensor); + return cc_tensor->dims().size(); +} + +int64_t PD_MetaTensorGetDim(const PD_MetaTensor *tensor, + size_t index, + PD_Status *status) { + auto cc_tensor = reinterpret_cast(tensor); + + if (status) { + if (!tensor || index >= static_cast(cc_tensor->dims().size())) { + *status = C_FAILED; + return 0; + } + *status = C_SUCCESS; + } + + return cc_tensor->dims()[index]; +} + +bool PD_MetaTensorIsValid(const PD_MetaTensor *tensor, PD_Status *status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return false; + } + *status = C_SUCCESS; + } + + auto cc_tensor = reinterpret_cast(tensor); + return cc_tensor->initialized(); +} + +void PD_MetaTensorSetDims(PD_MetaTensor *tensor, + int64_t ndims, + const int64_t *dims, + PD_Status *status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return; + } + *status = C_SUCCESS; + } + auto cc_tensor = reinterpret_cast(tensor); + std::vector shape(dims, dims + ndims); + cc_tensor->set_dims(phi::make_ddim(shape)); +} + +void PD_MetaTensorSetDataType(PD_MetaTensor *tensor, + PD_DataType dtype, + PD_Status *status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return; + } + *status = C_SUCCESS; + } + + auto cc_tensor = reinterpret_cast(tensor); + cc_tensor->set_dtype(phi::capi::ToPhiDataType(dtype)); +} + +void PD_MetaTensorSetDataLayout(PD_MetaTensor *tensor, + PD_DataLayout layout, + PD_Status *status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return; + } + *status = C_SUCCESS; + } + + auto cc_tensor = reinterpret_cast(tensor); + cc_tensor->set_layout(phi::capi::ToPhiDataLayout(layout)); +} + +PD_REGISTER_CAPI(meta_tensor); From 0df304b06db3894badfaba5e1c0d2ed1a5932c43 Mon Sep 17 00:00:00 2001 From: ronny1996 Date: Fri, 21 Jul 2023 02:31:10 +0000 Subject: [PATCH 2/2] update --- paddle/phi/capi/lib/c_kernel_registry.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/phi/capi/lib/c_kernel_registry.cc b/paddle/phi/capi/lib/c_kernel_registry.cc index a85e3cd5d20a49..6cf6208856bfad 100644 --- a/paddle/phi/capi/lib/c_kernel_registry.cc +++ b/paddle/phi/capi/lib/c_kernel_registry.cc @@ -18,8 +18,6 @@ #include "paddle/phi/capi/include/type_utils.h" #include "paddle/phi/core/kernel_registry.h" -#include "glog/logging.h" - void PD_KernelArgsParseFn(const phi::KernelKey& default_key, phi::KernelArgsDef* args_def, size_t in_nargs,