Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Refine IR builder and throw methods #54396

Merged
merged 19 commits into from
Jun 9, 2023
605 changes: 411 additions & 194 deletions paddle/fluid/ir/dialect/op_gen.py

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions paddle/fluid/ir/dialect/pd_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,23 @@ phi::DataLayout DataLayoutAttribute::data() const {
return storage()->GetAsKey();
}

phi::Scalar ScalarAttribute::data() {
if (isa<ir::FloatAttribute>()) {
return phi::Scalar(dyn_cast<ir::FloatAttribute>().data());
} else if (isa<ir::DoubleAttribute>()) {
return phi::Scalar(dyn_cast<ir::DoubleAttribute>().data());
} else if (isa<ir::Int32_tAttribute>()) {
return phi::Scalar(dyn_cast<ir::Int32_tAttribute>().data());
} else if (isa<ir::Int64_tAttribute>()) {
return phi::Scalar(dyn_cast<ir::Int64_tAttribute>().data());
} else if (isa<ir::BoolAttribute>()) {
return phi::Scalar(dyn_cast<ir::BoolAttribute>().data());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir attribute when casting it into "
"phi scalar."));
}
}

} // namespace dialect
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/fluid/ir/dialect/pd_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "paddle/fluid/ir/dialect/pd_attribute_storage.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/enforce.h"

namespace paddle {
namespace dialect {
Expand Down Expand Up @@ -45,6 +47,8 @@ class ScalarAttribute : public ir::Attribute {
(val.type_id() == ir::Int32_tAttribute::type_id()) ||
(val.type_id() == ir::Int64_tAttribute::type_id());
}

phi::Scalar data();
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要内联


class DataTypeAttribute : public ir::Attribute {
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/ir/dialect/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,17 @@ struct OpInputInfo {
std::string type_name;
bool optional = false;
bool no_need_buffer = false;
bool is_mutable_attribute = false;
OpInputInfo(std::string name,
std::string type_name,
bool optional,
bool no_need_buffer)
bool no_need_buffer,
bool is_mutable_attribute)
: name(name),
type_name(type_name),
optional(optional),
no_need_buffer(no_need_buffer) {}
no_need_buffer(no_need_buffer),
is_mutable_attribute(is_mutable_attribute) {}
};

struct OpOutputInfo {
Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/core/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Builder {
template <typename OpTy, typename... Args>
OpTy Build(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::Build(argument, std::forward<Args>(args)...);
OpTy::Build(*this, argument, std::forward<Args>(args)...);
Operation *op = Build(std::move(argument));
return op->dyn_cast<OpTy>();
}
Expand Down
159 changes: 64 additions & 95 deletions paddle/ir/core/builtin_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,15 @@ void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
// Verify inputs type:
if (inputs.size() != 0) {
throw("The size of inputs must be equal to 0.");
}
IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");

// Verify if attributes contain attribute name in attributes_name:
auto iter = attributes.find("program");
if (iter == attributes.end() || !iter->second.isa<PointerAttribute>()) {
throw("Type of attribute: program is not right.");
}
IR_ENFORCE(iter != attributes.end() && iter->second.isa<PointerAttribute>(),
"Type of attribute: program is not right.");

// Verify outputs type:
if (outputs.size() != 0) {
throw("The size of outputs must be equal to 0.");
}
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
}

const char *GetParameterOp::attributes_name[attributes_num] = {
Expand All @@ -81,17 +76,15 @@ void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
// Verify inputs type:
if (inputs.size() != 0) {
throw("The size of inputs must be equal to 0.");
}
// Verify outputs type:
if (outputs.size() != 1) {
throw("The size of outputs must be equal to 1.");
}
IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");

// Verify if attributes contain attribute name in attributes_name:
if (!attributes.at("parameter_name").isa<StrAttribute>()) {
throw("Type of attribute: parameter_name is not right.");
}
auto iter = attributes.find("parameter_name");
IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
"Type of attribute: parameter_name is not right.");

// Verify outputs type:
IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
}

const char *SetParameterOp::attributes_name[attributes_num] = {
Expand All @@ -102,54 +95,45 @@ void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
// Verify inputs type:
if (inputs.size() != 1) {
throw("The size of inputs must be equal to 1.");
}
// Verify outputs type:
if (outputs.size() != 0) {
throw("The size of outputs must be equal to 0.");
}
IR_ENFORCE(inputs.size() == 1, "The size of outputs must be equal to 1.");

// Verify if attributes contain attribute name in attributes_name:
if (!attributes.at("parameter_name").isa<StrAttribute>()) {
throw("Type of attribute: parameter_name is not right.");
}
auto iter = attributes.find("parameter_name");
IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
"Type of attribute: parameter_name is not right.");

// Verify outputs type:
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
}

void CombineOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
// outputs.size() == 1
PADDLE_ENFORCE_EQ(
outputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
IR_ENFORCE(outputs.size() == 1,
"The size %d of outputs must be equal to 1.",
outputs.size());

// outputs[0].type == Vector<Type>
PADDLE_ENFORCE(outputs[0].isa<ir::VectorType>(),
phi::errors::PreconditionNotMet(
"The type %s of outputs[0] must be equal to VectorType.",
outputs[0]));
IR_ENFORCE(outputs[0].isa<ir::VectorType>(),
"The type %s of outputs[0] must be equal to VectorType.",
outputs[0]);
ir::VectorType output_type = outputs[0].dyn_cast<ir::VectorType>();
// inputs.size() == outputs[0].size()
PADDLE_ENFORCE_EQ(
output_type.size(),
inputs.size(),
phi::errors::PreconditionNotMet(
"The size %d of outputs[0] must be equal to size %d of inputs.",
output_type.size(),
inputs.size()));
IR_ENFORCE(output_type.size() == inputs.size(),
"The size %d of outputs[0] must be equal to size %d of inputs.",
output_type.size(),
inputs.size());

// forall i in inputs.size(): inputs[i].type == outputs[0][i].type
for (size_t i = 0; i < inputs.size(); i++) {
PADDLE_ENFORCE_EQ(
output_type[i],
inputs[i].type(),
phi::errors::PreconditionNotMet("The type %s of outputs[0][%d] must be "
"equal to type %s of inputs[%d].",
output_type[i],
i,
inputs[i].type(),
i));
IR_ENFORCE(output_type[i] == inputs[i].type(),
"The type %s of outputs[0][%d] must be "
"equal to type %s of inputs[%d].",
output_type[i],
i,
inputs[i].type(),
i);
}
}

Expand All @@ -158,65 +142,50 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
// inputs.size() == 1
PADDLE_ENFORCE_EQ(
inputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 1.", inputs.size()));
IR_ENFORCE(inputs.size() == 1,
"The size %d of inputs must be equal to 1.",
inputs.size());

// inputs[0].type == Vector<Type>
PADDLE_ENFORCE(inputs[0].type().isa<ir::VectorType>(),
phi::errors::PreconditionNotMet(
"The type %s of inputs[0] must be equal to VectorType.",
inputs[0].type()));
IR_ENFORCE(inputs[0].type().isa<ir::VectorType>(),
"The type %s of inputs[0] must be equal to VectorType.",
inputs[0].type());
ir::VectorType input_type = inputs[0].type().dyn_cast<ir::VectorType>();

// outputs.size() == 1
PADDLE_ENFORCE_EQ(
outputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
IR_ENFORCE(outputs.size() == 1,
"The size %d of outputs must be equal to 1.",
outputs.size());

// attributes contains index: Int32
PADDLE_ENFORCE_NE(
attributes.count("index"),
0,
phi::errors::PreconditionNotMet("The attributes must contains index."));
IR_ENFORCE(attributes.count("index") != 0,
"The attributes must contains index.");
const ir::Attribute &attr = attributes.at("index");
PADDLE_ENFORCE(
attr.isa<ir::Int32_tAttribute>(),
phi::errors::PreconditionNotMet("The attribute index must be INT32."));
IR_ENFORCE(attr.isa<ir::Int32_tAttribute>(),
"The attribute index must be INT32.");
auto index = attr.dyn_cast<ir::Int32_tAttribute>().data();

// index >= 0 and < inputs[0].size()
PADDLE_ENFORCE_GE(
index,
0,
phi::errors::PreconditionNotMet(
"The index %d must be greater or equal than 0.", index));
PADDLE_ENFORCE_LT(
index,
input_type.size(),
phi::errors::PreconditionNotMet(
"The index %d must be less or equal than size %d of inputs[0].",
index,
input_type.size()));
IR_ENFORCE(
index >= 0, "The index %d must be greater or equal than 0.", index);
IR_ENFORCE(static_cast<size_t>(index) < input_type.size(),
"The index %d must be less or equal than size %d of inputs[0].",
index,
input_type.size());

// inputs[index].type == outputs[0].type
PADDLE_ENFORCE_EQ(
IR_ENFORCE(
input_type[index] == outputs[0],
"The type %s of inputs[%d] must be equal to type %s of outputs[0].",
input_type[index],
outputs[0],
phi::errors::PreconditionNotMet(
"The type %s of inputs[%d] must be equal to type %s of outputs[0].",
input_type[index],
index,
outputs[0]));
index,
outputs[0]);
}

const char *ConstantOp::attributes_name[attributes_num] = {"value"};

void ConstantOp::Build(OperationArgument &argument,
void ConstantOp::Build(Builder &builder,
OperationArgument &argument,
Attribute value,
Type output_type) {
argument.AddAttribute("value", value);
Expand Down
4 changes: 3 additions & 1 deletion paddle/ir/core/builtin_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class CombineOp : public ir::Op<CombineOp> {
static constexpr uint32_t attributes_num = 0;

static constexpr const char **attributes_name = nullptr;

static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
Expand Down Expand Up @@ -125,7 +126,8 @@ class ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];

static void Build(OperationArgument &argument, // NOLINT
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Attribute value,
Type output_type);

Expand Down
9 changes: 5 additions & 4 deletions paddle/ir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/ir/core/block.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
Expand Down Expand Up @@ -85,7 +86,7 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
base_ptr += sizeof(Operation);
// 3.3. Construct OpOperands.
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
throw("The address of OpOperandImpl must be divisible by 8.");
IR_THROW("The address of OpOperandImpl must be divisible by 8.");
}
for (size_t idx = 0; idx < num_operands; idx++) {
new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
Expand Down Expand Up @@ -147,7 +148,7 @@ void Operation::Destroy() {
// 2.2. Deconstruct Operation.
if (reinterpret_cast<uintptr_t>(base_ptr) !=
reinterpret_cast<uintptr_t>(this)) {
throw("Operation address error");
IR_THROW("Operation address error");
}
reinterpret_cast<Operation *>(base_ptr)->~Operation();
base_ptr += sizeof(Operation);
Expand Down Expand Up @@ -178,7 +179,7 @@ Operation::Operation(const AttributeMap &attributes,

ir::OpResult Operation::GetResultByIndex(uint32_t index) const {
if (index >= num_results_) {
throw("index exceeds OP output range.");
IR_THROW("index exceeds OP output range.");
}
uint32_t max_inline_idx = detail::OpResultImpl::GetMaxInlineResultIndex();
const char *ptr =
Expand All @@ -199,7 +200,7 @@ ir::OpResult Operation::GetResultByIndex(uint32_t index) const {

ir::OpOperand Operation::GetOperandByIndex(uint32_t index) const {
if (index >= num_operands_) {
throw("index exceeds OP input range.");
IR_THROW("index exceeds OP input range.");
}
const char *ptr = reinterpret_cast<const char *>(this) + sizeof(Operation) +
(index) * sizeof(detail::OpOperandImpl);
Expand Down
8 changes: 5 additions & 3 deletions paddle/ir/core/storage_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <memory>
#include <unordered_map>

#include "paddle/ir/core/enforce.h"

namespace ir {
// This is a structure for creating, caching, and looking up Storage of
// parametric types.
Expand Down Expand Up @@ -76,7 +78,7 @@ StorageManager::StorageBase *StorageManager::GetParametricStorageImpl(
<< std::hash<ir::TypeId>()(type_id) << ", param_hash=" << hash_value
<< "].";
if (parametric_instance_.find(type_id) == parametric_instance_.end()) {
throw("The input data pointer is null.");
IR_THROW("The input data pointer is null.");
}
ParametricStorageManager &parametric_storage = *parametric_instance_[type_id];
return parametric_storage.GetOrCreate(hash_value, equal_func, constructor);
Expand All @@ -88,7 +90,7 @@ StorageManager::StorageBase *StorageManager::GetParameterlessStorageImpl(
VLOG(4) << "Try to get a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instance_.find(type_id) == parameterless_instance_.end())
throw("TypeId not found in IrContext.");
IR_THROW("TypeId not found in IrContext.");
StorageBase *parameterless_instance = parameterless_instance_[type_id];
return parameterless_instance;
}
Expand All @@ -107,7 +109,7 @@ void StorageManager::RegisterParameterlessStorageImpl(
VLOG(4) << "Register a parameterless storage of: [TypeId_hash="
<< std::hash<ir::TypeId>()(type_id) << "].";
if (parameterless_instance_.find(type_id) != parameterless_instance_.end())
throw("storage class already registered");
IR_THROW("storage class already registered");
parameterless_instance_.emplace(type_id, constructor());
}

Expand Down
Loading