Skip to content

Commit

Permalink
replace any by variant in infermeta (#42181)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Apr 25, 2022
1 parent de8aa07 commit e009809
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 96 deletions.
34 changes: 33 additions & 1 deletion paddle/phi/core/infermeta_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void InferMetaContext::EmplaceBackOutput(MetaTensor output) {
outputs_.emplace_back(std::move(output));
output_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void InferMetaContext::EmplaceBackAttr(paddle::any attr) {
void InferMetaContext::EmplaceBackAttr(Attribute attr) {
attrs_.emplace_back(std::move(attr));
}

Expand Down Expand Up @@ -120,6 +120,38 @@ std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
return result;
}

template <typename AttrType>
const AttrType& InferMetaContext::AttrAt(size_t idx) const {
try {
return paddle::get<AttrType>(attrs_.at(idx));
} catch (paddle::bad_variant_access const& e) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in InferMeta Context, the expected attribute "
"type is `%s`.",
std::type_index(typeid(AttrType)).name()));
}
}

template const bool& InferMetaContext::AttrAt(size_t idx) const;
template const int& InferMetaContext::AttrAt(size_t idx) const;
template const int64_t& InferMetaContext::AttrAt(size_t idx) const;
template const float& InferMetaContext::AttrAt(size_t idx) const;
template const double& InferMetaContext::AttrAt(size_t idx) const;
template const std::string& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<bool>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<int>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<int64_t>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<float>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<double>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<std::string>& InferMetaContext::AttrAt(
size_t idx) const;
template const Scalar& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<Scalar>& InferMetaContext::AttrAt(size_t idx) const;
template const IntArray& InferMetaContext::AttrAt(size_t idx) const;
template const DataType& InferMetaContext::AttrAt(size_t idx) const;
template const DataLayout& InferMetaContext::AttrAt(size_t idx) const;
template const Place& InferMetaContext::AttrAt(size_t idx) const;

MetaFnFactory& MetaFnFactory::Instance() {
static MetaFnFactory g_meta_fn_map;
return g_meta_fn_map;
Expand Down
60 changes: 33 additions & 27 deletions paddle/phi/core/infermeta_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */

#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/attribute.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/meta_tensor.h"
Expand All @@ -41,7 +42,7 @@ class InferMetaContext {

void EmplaceBackInput(MetaTensor input);
void EmplaceBackOutput(MetaTensor output);
void EmplaceBackAttr(paddle::any attr);
void EmplaceBackAttr(Attribute attr);

void EmplaceBackInputs(
paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs);
Expand All @@ -61,17 +62,7 @@ class InferMetaContext {
size_t end);

template <typename AttrType>
AttrType AttrAt(size_t idx) {
try {
return paddle::any_cast<AttrType>(attrs_.at(idx));
} catch (paddle::bad_any_cast& e) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in InferMeta Context, the expected attribute "
"type is `%s`, but actual attribute type is `%s`.",
std::type_index(typeid(AttrType)).name(),
std::type_index(attrs_.at(idx).type()).name()));
}
}
const AttrType& AttrAt(size_t idx) const;

const std::pair<int, int>& InputRangeAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const;
Expand All @@ -81,7 +72,7 @@ class InferMetaContext {
protected:
MetaConfig config_;

paddle::SmallVector<paddle::any, kAttrSmallVectorSize> attrs_;
paddle::SmallVector<Attribute, kAttrSmallVectorSize> attrs_;

paddle::SmallVector<std::pair<int, int>, phi::kInputSmallVectorSize>
input_range_;
Expand Down Expand Up @@ -111,6 +102,21 @@ class InferMetaContext {
} \
}

#define PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(attr_type) \
template <typename... Tail> \
struct InferMetaFnCallHelper<const attr_type&, Tail...> { \
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> \
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { \
static_assert(out_idx == 0, \
"InferMeta's Attributes should appear before Outputs."); \
const attr_type& arg = ctx->AttrAt<attr_type>(attr_idx); \
InferMetaFnCallHelper< \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(ctx, \
pargs..., \
arg); \
} \
}

template <typename T>
struct InferMetaTypeTag {};

Expand Down Expand Up @@ -201,27 +207,27 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
}
};

// TODO(chenweihang): support other attr type later
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::string&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<bool>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<int64_t>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<float>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<double>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<std::string>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const IntArray&);

// TODO(chenweihang): support vector<MetaTensor> input later
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<bool>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<int64_t>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<float>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<double>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<std::string>);

template <typename... Tail>
struct InferMetaFnCallHelper<MetaTensor*, Tail...> {
Expand Down
29 changes: 0 additions & 29 deletions paddle/phi/core/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,8 @@
#include <string>
#include <vector>

#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/scalar.h"

#include "paddle/utils/variant.h"

namespace phi {

class Place;

// NOTE: Add needed type in the future
using Attribute = paddle::variant<bool,
int,
int64_t,
float,
double,
std::string,
std::vector<bool>,
std::vector<int>,
std::vector<int64_t>,
std::vector<float>,
std::vector<double>,
std::vector<std::string>,
Scalar,
std::vector<Scalar>,
IntArray,
DataType,
DataLayout,
Place>;

class Kernel;
class KernelKey;
class KernelArgsDef;
Expand Down
9 changes: 1 addition & 8 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,6 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) {
out->set_dtype(x.dtype());
}

void CopyToInferMeta(const MetaTensor& x,
Backend backend,
bool blocking,
MetaTensor* out) {
UnchangedInferMeta(x, out);
}

void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype);
Expand Down Expand Up @@ -3002,5 +2995,5 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {

} // namespace phi

PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
PD_REGISTER_INFER_META_FN(split, phi::SplitInferMeta);
5 changes: 0 additions & 5 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);

void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);

void CopyToInferMeta(const MetaTensor& x,
Backend backend,
bool blocking,
MetaTensor* out);

void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);

void CumsumInferMeta(const MetaTensor& x,
Expand Down
26 changes: 0 additions & 26 deletions paddle/phi/tests/core/test_meta_fn_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,6 @@ TEST(MetaFnFactory, InferMetaFnExists) {
EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]);
}

TEST(MetaFnFactory, CopyInferMetaFn) {
phi::DenseTensor dense_x;
dense_x.Resize({3, 4});

phi::MetaTensor meta_x(&dense_x);
phi::DenseTensor dense_out1;
phi::MetaTensor meta_out(&dense_out1);
phi::UnchangedInferMeta(meta_x, &meta_out);

auto shared_meat_x = phi::MetaTensor(&dense_x);
phi::DenseTensor dense_out2;
auto shared_meta_out = phi::MetaTensor(&dense_out2);

phi::InferMetaContext ctx;
ctx.EmplaceBackInput(shared_meat_x);
ctx.EmplaceBackAttr(Backend::CPU);
ctx.EmplaceBackAttr(false);
ctx.EmplaceBackOutput(shared_meta_out);
ctx.SetMetaConfig({/*is_runtime =*/true, /*is_run_mkldnn_kernel=*/false});
phi::MetaFnFactory::Instance().Get("copy_to")(&ctx);

EXPECT_EQ(dense_out1.dims().size(), dense_out2.dims().size());
EXPECT_EQ(dense_out1.dims()[0], dense_out2.dims()[0]);
EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]);
}

TEST(MetaFnFactory, SplitInferMetaFn) {
phi::DenseTensor dense_x;
dense_x.Resize({4, 10});
Expand Down

1 comment on commit e009809

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.