Skip to content

Commit

Permalink
clean up using namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
1 parent eb05d25 commit d8b2aa3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 24 deletions.
13 changes: 9 additions & 4 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ std::vector<int> GetReadBufferNDims(const StmtSRef& block_sref) {
namespace tvm {
namespace meta_schedule {

using tir::BlockRV;
using tir::IterVarType;
using tir::LoopRV;
using tir::Schedule;

// Do nothing; Inherited from ScheduleRuleNode
void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) {
if (Optional<Integer> v = context->target.value()->GetAttr<Integer>("max_threads_per_block")) {
Expand Down Expand Up @@ -163,12 +168,12 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
}
// Do the split
int n_tiles = idx->size();
Array<ExprRV> factors = sch->SamplePerfectTile(
Array<tir::ExprRV> factors = sch->SamplePerfectTile(
/*loop=*/loop,
/*n=*/n_tiles,
/*max_innermost_factor=*/max_innermost_factor);
Array<LoopRV> splits = sch->Split(/*loop=*/loop,
/*factors=*/{factors.begin(), factors.end()});
Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop,
/*factors=*/{factors.begin(), factors.end()});
// Put every tile to its slot
for (int j = 0; j < n_tiles; ++j) {
tiles[idx->at(j)].push_back(splits[j]);
Expand Down Expand Up @@ -230,7 +235,7 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
if (!vector_load_lens.empty()) {
int n = vector_load_lens.size();
double prob = 1.0 / n;
ExprRV vector_load_len =
tir::ExprRV vector_load_len =
sch->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch,
Expand Down
19 changes: 7 additions & 12 deletions src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@
namespace tvm {
namespace meta_schedule {

using tir::BlockRV;
using tir::ExprRV;
using tir::IterVarType;
using tir::LoopRV;
using tir::Schedule;

/*!
* \brief Configuration of data reuse type:
* 0) kNoReuse: no reuse is allowed, then no cache_read/write is performed.
Expand Down Expand Up @@ -83,15 +77,16 @@ struct ReuseConfig {
/*! \brief The state of auto scheduling for the multi-level tiling rule */
struct State {
/*! \brief The schedule to date */
Schedule sch;
tir::Schedule sch;
/*! \brief The block to be tiled */
BlockRV block_rv;
tir::BlockRV block_rv;
/*! \brief The loop tiles */
Array<Array<LoopRV>> tiles;
Array<Array<tir::LoopRV>> tiles;

/*! \brief Default constructor */
explicit State(Schedule sch, BlockRV block_rv, Optional<BlockRV> write_cache = NullOpt,
bool write_cache_is_added = false, Array<Array<LoopRV>> tiles = {})
explicit State(tir::Schedule sch, tir::BlockRV block_rv,
Optional<tir::BlockRV> write_cache = NullOpt, bool write_cache_is_added = false,
Array<Array<tir::LoopRV>> tiles = {})
: sch(sch), block_rv(block_rv), tiles(tiles) {}
};

Expand Down Expand Up @@ -131,7 +126,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
void InitializeWithTuneContext(const TuneContext& context) final;

// Entry of the mega rule; Inherited from ScheduleRuleNode
Array<Schedule> Apply(const Schedule& sch, const BlockRV& block_rv) final;
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final;

protected:
virtual std::vector<State> ApplySubRules(std::vector<State> states);
Expand Down
13 changes: 5 additions & 8 deletions src/meta_schedule/schedule_rule/multi_level_tiling_vnni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
namespace tvm {
namespace meta_schedule {

using tir::LoopRV;

/*! \brief Necessary information used for tensorization */
class TensorizeInfoNode : public Object {
public:
Expand Down Expand Up @@ -182,7 +184,7 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
return TensorizeInfo(ret);
}

Optional<LoopRV> TilingwithTensorIntrin(const Schedule& sch, const BlockRV& block_rv,
Optional<LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
const String& intrin_name) {
Optional<TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
Expand Down Expand Up @@ -244,15 +246,12 @@ Optional<LoopRV> TilingwithTensorIntrin(const Schedule& sch, const BlockRV& bloc
}

std::vector<State> TileForVNNI(State state) {
std::vector<State> result;
BlockRV block_rv = state.block_rv;
const std::string intrin_name = "dot_16x4_vnni";
Optional<LoopRV> tiled_loop_rv = TilingwithTensorIntrin(state.sch, block_rv, intrin_name);
Optional<LoopRV> tiled_loop_rv = TilingwithTensorIntrin(state.sch, state.block_rv, intrin_name);
ICHECK(tiled_loop_rv.defined());
state.block_rv = state.sch->Blockize(tiled_loop_rv.value());
state.sch->Annotate(state.block_rv, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));
result.push_back(state);
return result;
return {state};
}

class MultiLevelTilingVNNINode : public MultiLevelTilingNode {
Expand All @@ -267,8 +266,6 @@ class MultiLevelTilingVNNINode : public MultiLevelTilingNode {
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingVNNINode, MultiLevelTilingNode);
};

// Constructor

ScheduleRule ScheduleRule::MultiLevelTilingVNNI(String structure,
Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor,
Expand Down

0 comments on commit d8b2aa3

Please sign in to comment.