Skip to content

Commit

Permalink
Merge pull request #14 from Fridge003/multi-down
Browse files Browse the repository at this point in the history
LiftToAnchorPattern Implementation
  • Loading branch information
feifei-111 authored Apr 29, 2024
2 parents d173c22 + ada835e commit b36e48d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/operator_fusion/graph_transformer/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct LiftToHorizontalFusionPatternOperation {
struct LiftToAnchorPatternOperation {
template <typename Phrase>
void operator()(PatternGraph<Phrase>* graph, PatternNodePtr<Phrase> node) {
// TODO(@wuzhanfei)
node->set_stmt_pattern(AnchorPattern<Phrase>(node->stmt_pattern()));
}
};

Expand Down
35 changes: 19 additions & 16 deletions paddle/cinn/operator_fusion/pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,24 @@ struct ReduceTreePlusTrivialPattern {
std::vector<size_t> fake_reduce_iter_idx;
};

template <typename T>
struct AnchorPattern {
explicit AnchorPattern(const StmtPattern<T>& pattern) : pattern_(pattern) {
ExtendVector(ops_, GetOpsInPattern(pattern));
// TODO(@wuzhanfei): initialize anchor_ and anchor_state using ops_ and
// pattern
}

StmtPattern<T> pattern_;
std::vector<pir::Operation*> ops_;
pir::Value anchor_; // Choose only one anchor
AnchorState<T> anchor_state;
std::vector<pir::Operation*> ops() const { return ops_; }
std::vector<pir::Value> outputs() const { return outputs_; }
pir::Value anchor() const { return anchor_; }
static std::string name() { return "AnchorPattern"; }
};

template <typename T>
class UnsupportPattern {};

Expand All @@ -90,6 +108,7 @@ using StmtPatternBase = std::variant<TrivialPattern<T>,
ReducePattern<T>,
ReduceTreePattern<T>,
ReduceTreePlusTrivialPattern<T>,
AnchorPattern<T>,
HorizontalFusionPattern<T>,
UnsupportPattern<T>>;

Expand All @@ -100,20 +119,4 @@ struct StmtPattern final : public StmtPatternBase<T> {
return static_cast<const StmtPatternBase<T>&>(*this);
}
};

template <typename T>
struct AnchorPattern {
explicit AnchorPattern(
const std::vector<pir::Operation*>& ops,
const pir::Value& anchor const AnchorState<T>& anchor_state)
: ops_(ops), anchor_(anchor), {}
std::vector<pir::Operation*> ops_;
pir::Value anchor_; // Choose only one anchor
AnchorState<T> anchor_state;
std::vector<pir::Operation*> ops() const { return ops_; }
std::vector<pir::Value> outputs() const { return outputs_; }
pir::Value anchor() const { return anchor_; }
static std::string name() { return "AnchorPattern"; }
};

} // namespace cinn::fusion

0 comments on commit b36e48d

Please sign in to comment.