Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/a2064968462/Paddle into …
Browse files Browse the repository at this point in the history
…typos_tools
  • Loading branch information
a2064968462 committed Nov 21, 2024
2 parents d3eff2b + 7ca7f2c commit e09d3f7
Show file tree
Hide file tree
Showing 67 changed files with 1,470 additions and 478 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ repos:
rev: v1.27.3
hooks:
- id: typos
args: []
args: [--force-exclude]
# For Python files
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0
Expand Down
8 changes: 6 additions & 2 deletions _typos.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
[files]
# The following files will be excluded from spell check during commits
extend-exclude = [
"test/dataset/imikolov_test.py"
]

[default.extend-words]
# PaddlePaddle specific words
lod = "lod"
Expand Down Expand Up @@ -154,7 +160,6 @@ CACH = 'CACH'
endianess = 'endianess'
VAILD = 'VAILD'
ues = 'ues'
aer = 'aer'
elemenents = 'elemenents'
CANN = 'CANN'
pathes = 'pathes'
Expand Down Expand Up @@ -764,7 +769,6 @@ distrubuted = 'distrubuted'
Localy = 'Localy'
PARM = 'PARM'
thi = 'thi'
Oll = 'Oll'
Infor = 'Infor'
statment = 'statment'
varn = 'varn'
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/pybind11.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ set(PYBIND_PATCH_COMMAND "")
if(LINUX
AND (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9)
set(PYBIND_TAG v2.12.0)
set(PYBIND_TAG v2.13.6)
file(TO_NATIVE_PATH
${PADDLE_SOURCE_DIR}/patches/pybind/detail/internals.h.patch native_dst)
# Note: [Why calling some `git` commands before `patch`?]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ class FuseSingleElementShapeOpsIntoGenerateShapeOpPattern
auto* user = iter->owner();
if (IsSingleElementShapeOp(user, &shape_analysis)) return false;
if (user->isa<cinn::dialect::GenerateShapeOp>()) return false;
if (user->isa<pir::ShadowOutputOp>()) return false;
}

return true;
Expand Down
95 changes: 75 additions & 20 deletions paddle/cinn/hlir/framework/pir/trivial_op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/cinn/hlir/framework/pir/trivial_op_util.h"

#include "paddle/cinn/common/dim_expr_converter.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/framework/compile_error.h"
#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h"
Expand Down Expand Up @@ -547,9 +548,6 @@ ExprTransformer RemoveVarInScheduleBlockRealize(const ir::Var& target_vars,
* remove it in axes.bind()
*/
const auto& f = [=](const ir::Expr& e) -> ir::Expr {
VLOG(4) << "Start RemoveVarInScheduleBlockRealize(" << target_vars << ", "
<< replaced_expr << ")";
VLOG(4) << " Input is " << e;
PADDLE_ENFORCE_NE(
e.As<ir::ScheduleBlockRealize>(),
nullptr,
Expand All @@ -562,22 +560,11 @@ ExprTransformer RemoveVarInScheduleBlockRealize(const ir::Var& target_vars,
auto block_bound_vars = copied_ir.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars;
for (const auto& i_var : schedule_block_iter_vars) {
PADDLE_ENFORCE_EQ(
i_var.is_var(),
true,
::common::errors::InvalidArgument("RemoveVarInScheduleBlockRealize: "
"axes.bind rhs is is not a Var."));
}
// find replace idx
int target_idx = -1;
for (int i = 0; i < schedule_block_iter_vars.size(); ++i) {
VLOG(4) << "RemoveVarInScheduleBlockRealize: compare with "
<< schedule_block_iter_vars[i] << " vs " << target_vars
<< ", and equality is: "
<< (schedule_block_iter_vars[i].as_var()->name ==
target_vars->name);
if (schedule_block_iter_vars[i].as_var()->name == target_vars->name) {
if (schedule_block_iter_vars[i].is_var() &&
schedule_block_iter_vars[i].as_var()->name == target_vars->name) {
target_idx = i;
}
}
Expand Down Expand Up @@ -688,8 +675,6 @@ ExprTransformer RemoveOneTransformer(int one) {
.GetSingle(copied);
const ir::Expr& target_block =
ExprSetFinderUtils::DirectlyFather(copied).GetSingle(target_for);
VLOG(4) << "RemoveOneTransformer: directly target_block of for is "
<< target_block;
if (target_block.As<ir::ScheduleBlockRealize>() != nullptr) {
VLOG(4) << "RemoveOneTransformer: father block is root realize";
ir::Expr shedule_block =
Expand All @@ -708,7 +693,6 @@ ExprTransformer RemoveOneTransformer(int one) {
shedule_block.As<ir::ScheduleBlock>()->body = for_body;
}
} else if (target_block.As<ir::Block>() != nullptr) {
VLOG(4) << "RemoveOneTransformer: father block is Block";
std::vector<ir::Expr> new_bodies;
for (const auto& expr : target_block.As<ir::Block>()->stmts) {
if (expr != target_for) {
Expand All @@ -728,7 +712,6 @@ ExprTransformer RemoveOneTransformer(int one) {
"RemoveOneTransformer: target for father should be a ir::Block or "
"ir::ScheduleBlockRealize."));
}
VLOG(4) << "Remove Var to 0 in ScheduleBlockRealizer: " << copied;
// Remove var to 0 in ScheduleBlockRealizer
InplaceMutateSingleExpr(
&copied,
Expand Down Expand Up @@ -949,6 +932,10 @@ std::vector<ir::Var> GetAllLoopVars(const ir::Expr& root) {

ir::Expr GetBodyBlock(const ir::Expr& root) {
const auto& iters = GetNonReduceLoopVars(root);
if (iters.empty()) {
return ir::Block::Make(
{ExprSetFinderUtils::ChildScheduleBlockRealizes.GetSingle(root)});
}
const size_t reduce_size =
std::count_if(iters.begin(), iters.end(), [](const ir::Var& v) {
return v->is_reduce_axis;
Expand All @@ -965,6 +952,74 @@ ir::Expr GetBodyBlock(const ir::Expr& root) {
->body;
}

ir::Expr ReshapeLoop(const ir::Expr& root,
const std::vector<symbol::DimExpr>& in_shape,
const std::vector<symbol::DimExpr>& out_shape) {
auto copied = ir::ir_utils::IRCopy(root);

ir::ModuleExpr mod_expr({copied});
ir::IRSchedule ir_sch(
mod_expr, -1, false, cinn::utils::ErrorMessageLevel::kGeneral, true);

const auto block_realize =
(ExprSetFinderUtils::ChildScheduleBlockRealizes).GetSingle(copied);
const auto block_name = block_realize.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
const auto shape_partion = fusion::PartionReshapeAxes(in_shape, out_shape);

for (int idx = shape_partion.size() - 1; idx > 0; --idx) {
const auto& in_s = shape_partion[idx - 1].first;
const auto& in_e = shape_partion[idx].first;
const auto& out_s = shape_partion[idx - 1].second;
const auto& out_e = shape_partion[idx].second;

std::vector<int> fuse_indices;
for (int i = in_e - 1; i >= in_s; --i) {
if (in_shape[i] != symbol::DimExpr(1)) {
fuse_indices.insert(fuse_indices.begin(), i);
} else {
VLOG(4) << "Remove index[" << i << "]: " << in_shape[i]
<< " for expr: \n"
<< copied;
copied = ExprTransformerUtils::RemoveOneTransformer(i)(copied);
ir_sch.SetExprs({copied});
for (auto& index : fuse_indices) {
index--;
}
}
}
if (fuse_indices.size() > 1) {
VLOG(4) << "fuse_indices: " << cinn::utils::Join(fuse_indices, ",");
ir_sch.Fuse(block_name, fuse_indices);
}

std::vector<ir::Expr> split_shapes;
for (int i = out_s; i < out_e; ++i) {
if (out_shape[i] != symbol::DimExpr(1)) {
split_shapes.push_back(
cinn::common::DimExprConverter().ConvertToIrExpr(out_shape[i]));
}
}
if (split_shapes.size() > 1) {
ir_sch.Split(ir_sch.GetLoops(block_name)[in_s], split_shapes)[0];
}
}

std::vector<int> insert_axis;
std::vector<ir::Var> ones_var;
for (int i = 0; i < out_shape.size(); ++i) {
if (out_shape[i] == symbol::DimExpr(1)) {
insert_axis.push_back(i);
ones_var.push_back(ir::Var(1, "one_" + std::to_string(ones_var.size())));
}
}
copied = ExprTransformerUtils::InsertForsTransformer(insert_axis,
ones_var)(copied);

return copied;
}

} // namespace trivial_fusion_detail
} // namespace pir
} // namespace framework
Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/hlir/framework/pir/trivial_op_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ std::vector<ir::Var> GetAllLoopVars(const ir::Expr& root);

ir::Expr GetBodyBlock(const ir::Expr& root);

ir::Expr ReshapeLoop(const ir::Expr& root,
const std::vector<symbol::DimExpr>& in_shape,
const std::vector<symbol::DimExpr>& out_shape);

} // namespace trivial_fusion_detail
} // namespace pir
} // namespace framework
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/ir/group_schedule/config/group_tile_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ bool GetCanApplyGridReduce(const std::vector<ir::Expr>& op_compute_bodies,
auto* block = expr_block.As<ir::ScheduleBlockRealize>();
auto& iter_vars = block->schedule_block.As<ir::ScheduleBlock>()->iter_vars;
for (int i = 0; i < iter_vars.size(); i++) {
ir::Var loop_var = block->iter_values[i].as_var_ref();
if (reduce_loop_vars.count(loop_var->name) > 0) {
if (block->iter_values[i].is_var() &&
reduce_loop_vars.count(block->iter_values[i].as_var()->name) > 0) {
reduce_iter_vars.insert(iter_vars[i]->name);
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/operator_fusion/fusion_tracker/expr_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ ir::Expr ApplyItersTransform::operator()(const TransposeItersTransform& trans) {

ir::Expr ApplyItersTransform::operator()(const RemoveOnesTransform& trans) {
VLOG(4) << "[ItersTransform] Before RemoveOnesTransform("
<< utils::Join(trans.ones_, ",") << "'): " << expr_;
<< utils::Join(trans.ones_, ",") << "): " << expr_;
auto result = RemoveOnesTransformer(trans.ones_)(expr_);
VLOG(4) << "[ItersTransform] After RemoveOnesTransform: " << result;
return result;
Expand Down
18 changes: 18 additions & 0 deletions paddle/cinn/operator_fusion/fusion_tracker/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,20 @@ void RunItersTransformInstr(const std::shared_ptr<ItersTransformInstr>& instr,
interpreter->scope[instr->target_] = new_pattern;
}

void RunReshapeAlignInstr(const std::shared_ptr<ReshapeAlignInstr>& instr,
FusionInterpreter* interpreter) {
const auto expr = std::visit(
FusibleOp2Expr(), interpreter->scope[instr->input_]->fusion_ops[0])[0];
VLOG(4) << "Before RunReshapeAlignInstr: \n" << expr;
auto result = cinn::hlir::framework::pir::trivial_fusion_detail::ReshapeLoop(
expr, instr->in_shape_, instr->out_shape_);

auto new_pattern = std::make_shared<ScopeElement>();
new_pattern->fusion_ops.emplace_back(TrivialOp(result));
interpreter->scope[instr->result_] = new_pattern;
VLOG(4) << "After ReshapeAlignInstr: \n" << result;
}

void RunPaddingInstr(const std::shared_ptr<PaddingInstr>& instr,
FusionInterpreter* interpreter) {
ScopeElementPtr new_pattern = std::make_shared<ScopeElement>();
Expand Down Expand Up @@ -229,6 +243,10 @@ std::vector<ir::Expr> FusionInterpreter::Run() {
RunItersTransformInstr(
dynamic_cast_instr_with_err<ItersTransformInstr>(instr), this);
break;
case T_ReshapeAlign:
RunReshapeAlignInstr(
dynamic_cast_instr_with_err<ReshapeAlignInstr>(instr), this);
break;
default:
PADDLE_THROW(
::common::errors::Unavailable("Unsupported Fusion Instrution"));
Expand Down
27 changes: 27 additions & 0 deletions paddle/cinn/operator_fusion/fusion_tracker/tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ enum InstructionType {
T_Return,
T_InitPattern,
T_TrivialInline,
T_ReshapeAlign,
T_TmpTransform,
T_TrivialLoopAlign,
T_ItersTransform,
Expand Down Expand Up @@ -143,6 +144,32 @@ struct TrivialInlineInstr : public FusionInstruction {
}
};

struct ReshapeAlignInstr : public FusionInstruction {
ReshapeAlignInstr(const std::string& input,
const std::vector<symbol::DimExpr>& in_shape,
const std::vector<symbol::DimExpr>& out_shape,
const std::string& result)
: input_(input),
in_shape_(in_shape),
out_shape_(out_shape),
result_(result) {}
virtual InstructionType type() const { return T_ReshapeAlign; }
virtual FusionInstrPtr Clone() {
return std::make_shared<ReshapeAlignInstr>(*this);
}

std::string input_;
std::vector<symbol::DimExpr> in_shape_;
std::vector<symbol::DimExpr> out_shape_;
std::string result_;

virtual std::string DebugStr() const {
return "ReshapeAlignInstr || " + input_ + "(" +
cinn::utils::Join(in_shape_, ",") + ") => " + result_ + "(" +
cinn::utils::Join(out_shape_, ",") + ")";
}
};

struct TmpTransformInstr : public FusionInstruction {
TmpTransformInstr(const std::string& upstream,
const std::string& downstream,
Expand Down
Loading

0 comments on commit e09d3f7

Please sign in to comment.