Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] standardize the use of value[-3]. #57418

Merged
merged 1 commit into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def GenBuildInserFullForMutableAttribute(

def GenBuildInputs(op_input_name_list, op_mutable_attribute_name_list):
BUILD_INPUT_TEMPLATE = """ std::vector<pir::OpResult> argument_inputs = {{{inputs_args}}};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
argument.AddInputs(argument_inputs.begin(), argument_inputs.end());
"""
build_input_str = ' VLOG(4) << "Builder construction inputs";\n'
input_name_list = op_input_name_list + op_mutable_attribute_name_list
Expand Down
16 changes: 8 additions & 8 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ void AddNOp::Build(pir::Builder &builder, // NOLINT
pir::OpResult inputs) {
VLOG(4) << "Builder construction inputs";
std::vector<pir::OpResult> argument_inputs = {inputs};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
argument.AddInputs(argument_inputs.begin(), argument_inputs.end());

VLOG(4) << "Builder construction attributes";

Expand Down Expand Up @@ -179,7 +179,7 @@ void AddN_Op::Build(pir::Builder &builder,
pir::OpResult inputs_) {
VLOG(4) << "Builder construction inputs";
std::vector<pir::OpResult> argument_inputs = {inputs_};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
argument.AddInputs(argument_inputs.begin(), argument_inputs.end());

VLOG(4) << "Builder construction attributes";

Expand Down Expand Up @@ -307,7 +307,7 @@ void AddNWithKernelOp::Build(pir::Builder &builder,
pir::OpResult inputs_) {
VLOG(4) << "Builder construction inputs";
std::vector<pir::OpResult> argument_inputs = {inputs_};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
argument.AddInputs(argument_inputs.begin(), argument_inputs.end());

VLOG(4) << "Builder construction attributes";

Expand Down Expand Up @@ -477,7 +477,7 @@ void FusedGemmEpilogueOp::Build(pir::Builder &builder,

VLOG(4) << "Builder construction inputs";
std::vector<pir::OpResult> argument_inputs = {x_, y_, bias_};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
argument.AddInputs(argument_inputs.begin(), argument_inputs.end());

VLOG(4) << "Builder construction attributes";
pir::Attribute attr_trans_x =
Expand Down Expand Up @@ -732,7 +732,7 @@ void FusedGemmEpilogueGradOp::Build(pir::Builder &builder,
VLOG(4) << "Builder construction inputs";
std::vector<pir::OpResult> argument_inputs = {
x_, y_, reserve_space_, out_grad_};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
argument.AddInputs(argument_inputs.begin(), argument_inputs.end());

VLOG(4) << "Builder construction attributes";
pir::Attribute attr_trans_x =
Expand Down Expand Up @@ -916,7 +916,7 @@ void SplitGradOp::Build(pir::Builder &builder,

VLOG(4) << "Builder construction inputs";
std::vector<pir::OpResult> argument_inputs = {out_grad_, axis_};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
argument.AddInputs(argument_inputs.begin(), argument_inputs.end());

VLOG(4) << "Builder construction attributes";

Expand Down Expand Up @@ -974,7 +974,7 @@ void SplitGradOp::Build(pir::Builder &builder,
pir::OpResult axis_) {
VLOG(4) << "Builder construction inputs";
std::vector<pir::OpResult> argument_inputs = {out_grad_, axis_};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
argument.AddInputs(argument_inputs.begin(), argument_inputs.end());

VLOG(4) << "Builder construction attributes";

Expand Down Expand Up @@ -1095,7 +1095,7 @@ void IfOp::Build(pir::Builder &builder, // NOLINT
pir::OpResult cond,
std::vector<pir::Type> &&output_types) {
argument.num_regions = 2;
argument.AddOperand(cond);
argument.AddInput(cond);
argument.output_types.swap(output_types);
}
pir::Block *IfOp::true_block() {
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

namespace pir {
/// Create an operation given the fields represented as an OperationState.
Operation *Builder::Build(OperationArgument &&argument) {
return Insert(Operation::Create(std::move(argument)));
Operation *Builder::Build(const OperationArgument &argument) {
return Insert(Operation::Create(argument));
}

/// Creates an operation with the given fields.
Expand Down
6 changes: 3 additions & 3 deletions paddle/pir/core/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class Builder {
Block *block() const { return block_; }

/// Creates an operation given the fields represented as an OperationState.
IR_API Operation *Build(OperationArgument &&argument);
IR_API Operation *Build(const OperationArgument &argument);

/// Creates an operation with the given fields.
IR_API Operation *Build(const std::vector<pir::OpResult> &inputs,
Expand All @@ -107,8 +107,8 @@ class Builder {
OpTy Build(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::Build(*this, argument, std::forward<Args>(args)...);
Operation *op = Build(std::move(argument));
return op->dyn_cast<OpTy>();
Operation *op = Build(argument);
return OpTy(op);
}

IR_API UInt8Type uint8_type();
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/builtin_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ ModuleOp ModuleOp::Create(IrContext *context, Program *pointer) {
OperationArgument argument(info);
argument.num_regions = 1;
argument.AddAttribute("program", PointerAttribute::get(context, pointer));
Operation *op = Operation::Create(std::move(argument));
Operation *op = Operation::Create(argument);
op->region(0).emplace_back();
return ModuleOp(op);
}
Expand Down Expand Up @@ -140,7 +140,7 @@ void SetParameterOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
OpResult parameter,
const std::string &name) {
argument.AddOperand(parameter);
argument.AddInput(parameter);
argument.AddAttribute(attributes_name[0],
pir::StrAttribute::get(builder.ir_context(), name));
}
Expand Down
9 changes: 0 additions & 9 deletions paddle/pir/core/op_operand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,11 @@
CHECK_NULL_IMPL(OpOpernad, func_name)

namespace pir {

OpOperand::OpOperand(const detail::OpOperandImpl *impl)
: impl_(const_cast<detail::OpOperandImpl *>(impl)) {}

OpOperand &OpOperand::operator=(const OpOperand &rhs) {
impl_ = rhs.impl_;
return *this;
}

OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) {
if (this->impl_ == impl) return *this;
impl_ = const_cast<detail::OpOperandImpl *>(impl);
return *this;
}
OpOperand::operator bool() const { return impl_ && impl_->source(); }

OpOperand OpOperand::next_use() const {
Expand Down
4 changes: 1 addition & 3 deletions paddle/pir/core/op_operand.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@ class IR_API OpOperand {

OpOperand(const OpOperand &other) = default;

OpOperand(const detail::OpOperandImpl *impl); // NOLINT
OpOperand(detail::OpOperandImpl *impl) : impl_(impl) {} // NOLINT

OpOperand &operator=(const OpOperand &rhs);

OpOperand &operator=(const detail::OpOperandImpl *impl);

bool operator==(const OpOperand &other) const { return impl_ == other.impl_; }

bool operator!=(const OpOperand &other) const { return !operator==(other); }
Expand Down
14 changes: 0 additions & 14 deletions paddle/pir/core/operation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,10 @@ struct OperationArgument {
num_regions(num_regions),
successors(successors) {}

// Will be deleted in the next pr.
void AddOperand(OpResult operand) { inputs.emplace_back(operand); }

void AddInput(Value input) {
inputs.emplace_back(input.dyn_cast<OpResult>());
}

// Will be deleted in the next pr.
template <class InputIt>
void AddOperands(InputIt first, InputIt last);

template <class InputIt>
void AddInputs(InputIt first, InputIt last);

Expand Down Expand Up @@ -99,13 +92,6 @@ struct OperationArgument {
void AddSuccessor(Block* successor) { successors.emplace_back(successor); }
};

template <class InputIt>
void OperationArgument::AddOperands(InputIt first, InputIt last) {
while (first != last) {
inputs.emplace_back(*first++);
}
}

template <class InputIt>
void OperationArgument::AddInputs(InputIt first, InputIt last) {
while (first != last) {
Expand Down
4 changes: 0 additions & 4 deletions paddle/pir/core/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@
#define CHECK_VALUE_NULL_IMPL(func_name) CHECK_NULL_IMPL(Value, func_name)

namespace pir {

Value::Value(const detail::ValueImpl *impl)
: impl_(const_cast<detail::ValueImpl *>(impl)) {}

bool Value::operator==(const Value &other) const {
return impl_ == other.impl_;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class IR_API Value {
public:
Value() = default;

Value(const detail::ValueImpl *impl); // NOLINT
Value(detail::ValueImpl *impl) : impl_(impl) {} // NOLINT

Value(const Value &other) = default;

Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/dialect/control_flow/ir/cf_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace pir {
void YieldOp::Build(Builder &builder,
OperationArgument &argument,
std::vector<OpResult> &&inputs) {
argument.AddOperands(inputs.begin(), inputs.end());
argument.AddInputs(inputs.begin(), inputs.end());
}
} // namespace pir

Expand Down
3 changes: 1 addition & 2 deletions test/cpp/pir/core/ir_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,7 @@ TEST(op_test, region_test) {
argument.output_types = {pir::Float32Type::get(ctx)};
argument.num_regions = 1;

pir::Operation *op3 = pir::Operation::Create(std::move(argument));
// argument.regions.emplace_back(std::make_unique<pir::Region>());
pir::Operation *op3 = pir::Operation::Create(argument);

pir::Region &region = op3->region(0);
EXPECT_EQ(region.empty(), true);
Expand Down
12 changes: 6 additions & 6 deletions test/cpp/pir/core/ir_program_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class AddOp : public pir::Op<AddOp> {
void Verify();
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::OpResult l_operand,
pir::OpResult r_operand,
pir::Value l_operand,
pir::Value r_operand,
pir::Type sum_type);
};
void AddOp::Verify() {
Expand All @@ -58,11 +58,11 @@ void AddOp::Verify() {
}
void AddOp::Build(pir::Builder &,
pir::OperationArgument &argument,
pir::OpResult l_operand,
pir::OpResult r_operand,
pir::Value l_operand,
pir::Value r_operand,
pir::Type sum_type) {
argument.AddOperand(l_operand);
argument.AddOperand(r_operand);
argument.AddInput(l_operand);
argument.AddInput(r_operand);
argument.AddOutput(sum_type);
}
IR_DECLARE_EXPLICIT_TYPE_ID(AddOp)
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ TEST(kernel_dialect, legacy_op_test) {
"kernel_key",
kernel_key);

pir::Operation* op = pir::Operation::Create(std::move(argument));
pir::Operation* op = pir::Operation::Create(argument);
EXPECT_EQ("pd_op.kernel_op",
op->dyn_cast<paddle::dialect::LegacyKernelOp>().op_name());
EXPECT_EQ("kernel_op",
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/pir/pass/pass_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ void AddOp::Build(pir::Builder &,
pir::OpResult l_operand,
pir::OpResult r_operand,
pir::Type sum_type) {
argument.AddOperand(l_operand);
argument.AddOperand(r_operand);
argument.AddInput(l_operand);
argument.AddInput(r_operand);
argument.AddOutput(sum_type);
}
IR_DECLARE_EXPLICIT_TYPE_ID(AddOp)
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ void Conv2dFusionOpTest::Build(pir::Builder &builder,
VLOG(4) << "Builder construction inputs";
std::vector<pir::OpResult> argument_inputs = {
input_, filter_, bias_, residual_};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
argument.AddInputs(argument_inputs.begin(), argument_inputs.end());

VLOG(4) << "Builder construction attributes";
std::vector<pir::Attribute> vec_strides;
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/tools/test_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void BranchOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument,
const std::vector<pir::OpResult> &target_operands,
pir::Block *target) {
argument.AddOperands(target_operands.begin(), target_operands.end());
argument.AddInputs(target_operands.begin(), target_operands.end());
argument.AddSuccessor(target);
}

Expand Down