Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][M4a] Mutator: Mutate Parallel #10096

Merged
merged 1 commit into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/tir/schedule/instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class InstructionKindNode : public runtime::Object {
// not visited: f_attrs_from_json
}

/*! \brief Checks if the instruction kind is EnterPostproc */
bool IsPostproc() const;

static constexpr const char* _type_key = "tir.InstructionKind";
TVM_DECLARE_FINAL_OBJECT_INFO(InstructionKindNode, runtime::Object);
};
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/mutator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
"""
from .mutator import Mutator, PyMutator
from .mutate_compute_location import MutateComputeLocation
from .mutate_parallel import MutateParallel
from .mutate_unroll import MutateUnroll
33 changes: 33 additions & 0 deletions python/tvm/meta_schedule/mutator/mutate_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Mutator that mutates the parallel extent"""
from tvm._ffi.registry import register_object

from .. import _ffi_api
from .mutator import Mutator


@register_object("meta_schedule.MutateParallel")
class MutateParallel(Mutator):
"""Mutator that mutates the parallel extent"""

def __init__(self, max_jobs_per_core: int) -> None:
"""Mutator that mutates the parallel extent"""
self.__init_handle_by_constructor__(
_ffi_api.MutatorMutateParallel, # type: ignore # pylint: disable=no-member
max_jobs_per_core,
)
312 changes: 312 additions & 0 deletions src/meta_schedule/mutator/mutate_parallel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <algorithm>
#include <unordered_map>

#include "../utils.h"

namespace tvm {
namespace tir {

/*!
* \brief Check if the instruction is annotation with `meta_schedule_parallel`
* \param inst The instruction to be checked
* \return Whether the instruction is annotation with `meta_schedule_parallel`
*/
bool IsAnnotateWithParallel(const Instruction& inst) {
static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
if (!inst->kind.same_as(inst_annotate)) {
return false;
}
ICHECK_EQ(inst->attrs.size(), 1);
String ann_key = Downcast<String>(inst->attrs[0]);
return ann_key == attr::meta_schedule_parallel;
}

/*!
* \brief Replace the annotation value
* \param inst The instruction to be replaced
* \param ann_val The new annotation value
* \return The replaced instruction
*/
Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) {
ICHECK_EQ(inst->inputs.size(), 2);
return Instruction(/*kind=*/inst->kind, //
/*inputs=*/{inst->inputs[0], Integer(ann_val)}, //
/*attrs=*/inst->attrs,
/*outputs=*/inst->outputs);
}

/*!
* \brief Get the output of the instruction Get-Block
* \param inst The instruction to be checked
* \return The output of the instruction Get-Block
*/
const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) {
static const InstructionKind& inst_get_block = InstructionKind::Get("GetBlock");
if (!inst->kind.same_as(inst_get_block)) {
return nullptr;
}
ICHECK_EQ(inst->outputs.size(), 1);
const BlockRVNode* block = TVM_TYPE_AS(block, inst->outputs[0], BlockRVNode);
return block;
}

/*!
* \brief Analyze the parallel structure
* \param self The schedule state
* \param block_name The name of the root block
* \param func_name The name of the PrimFunc
* \param limit The uplimit of the parallelism
* \return The parallel structure
*/
std::vector<std::vector<int64_t>> AnalyzeParallel(const ScheduleState& self,
const String& block_name, const String& func_name,
int64_t limit) {
Array<StmtSRef> block_srefs = tir::GetBlocks(self, block_name, func_name);
ICHECK_EQ(block_srefs.size(), 1);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]);
ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef<Block>(block));
std::vector<std::vector<int64_t>> results;
results.reserve(info.realizes.size());
for (const BlockRealize& realize : info.realizes) {
// Step 1. Extract static loop extents for spatial loops
std::vector<int64_t> loop_extents;
const ForNode* loop = nullptr;
for (const StmtSRefNode* loop_sref = self->stmt2ref.at(realize->block.get())->parent;
(loop = loop_sref->StmtAs<ForNode>()) != nullptr; //
loop_sref = loop_sref->parent) {
int64_t loop_extent = -1;
if (const auto* ext = GetLoopIntExtent(loop)) {
if (!info.non_spatial_vars.count(loop->loop_var.get())) {
loop_extent = *ext;
}
}
if (loop_extent != -1) {
loop_extents.push_back(loop_extent);
} else {
loop_extents.clear();
}
}
// Step 2. Take the prefix product of loop extents
if (!loop_extents.empty()) {
results.emplace_back();
std::vector<int64_t>& result = results.back();
result.reserve(loop_extents.size());
int64_t prod_extent = 1;
for (auto it = loop_extents.rbegin(); it != loop_extents.rend(); ++it) {
result.push_back(prod_extent *= *it);
if (prod_extent >= limit) {
break;
}
}
}
}
return results;
}

/*!
* \brief Get the number of parallelizable loops for each subtree
* \param loop_extent_prods The parallel structure for each subtree
* \param limit The uplimit of the parallelism
* \return The number of parallelizable loops for each subtree
*/
std::vector<int> GetNumFusedLoops(const std::vector<std::vector<int64_t>>& loop_extent_prods,
int64_t limit) {
std::vector<int> results;
results.reserve(loop_extent_prods.size());
for (const std::vector<int64_t>& prods : loop_extent_prods) {
int n = prods.size();
int i = std::upper_bound(prods.begin(), prods.end(), limit) - prods.begin();
if (i > 0 && prods[i - 1] == limit) {
--i;
}
if (i != n) {
++i;
}
results.push_back(i);
}
return results;
}

} // namespace tir
} // namespace tvm

namespace tvm {
namespace meta_schedule {

using tir::Instruction;
using tir::Trace;

/*! \brief Create a Mutator that mutates the parallel extent */
class MutateParallelNode : public MutatorNode {
public:
/*!
* \brief The maximum number of jobs to be launched per CPU core.
* It sets the uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`.
* Use -1 to disable parallelism.
*/
int64_t max_jobs_per_core;
/*! \brief The number of cores in CPU. */
int max_parallel_extent_;
/*! \brief JSON representation of the workload */
std::string json_mod_;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("max_jobs_per_core", &max_jobs_per_core);
// `max_parallel_extent_` is not visited.
// `json_mod` is not visited.
}

static constexpr const char* _type_key = "meta_schedule.MutateParallel";
TVM_DECLARE_FINAL_OBJECT_INFO(MutateParallelNode, MutatorNode);

public:
struct Candidate;
// Inherit from `MutatorNode`
void InitializeWithTuneContext(const TuneContext& context) final {
Target target = context->target.value();
this->max_parallel_extent_ = GetTargetNumCores(target) * this->max_jobs_per_core;
this->json_mod_ = SaveJSON(context->mod.value());
}
// Inherit from `MutatorNode`
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
};

/*! \brief The candidate to be mutated */
struct MutateParallelNode::Candidate {
/*! \brief The annotation instruction */
Instruction inst;
/*! \brief The current parallel extent */
int64_t parallel_extent;
/*! \brief The name of the root block */
String block_name;
/*! \brief The name of the PrimFunc */
String func_name;
};

/*!
* \brief Get an instruction that annotates the maximum parallel extent
* \param trace The trace to be mutated
* \param rand_state The random state
* \param candidate The candidate to be mutated
* \return Whether a decision is found
*/
bool FindParallelDecision(const Trace& trace, TRandState* rand_state,
MutateParallelNode::Candidate* candidate) {
using tir::BlockRVNode;
using tir::InstructionNode;
std::unordered_map<const BlockRVNode*, const InstructionNode*> get_block_insts;
std::vector<const InstructionNode*> ann_insts;
get_block_insts.reserve(trace->insts.size());
ann_insts.reserve(trace->insts.size());
for (const Instruction& inst : trace->insts) {
if (tir::IsAnnotateWithParallel(inst)) {
ann_insts.push_back(inst.get());
}
if (const BlockRVNode* block_rv = tir::GetInstGetBlockOutput(inst)) {
get_block_insts[block_rv] = inst.get();
}
}
int n_ann_insts = ann_insts.size();
if (n_ann_insts == 0) {
return false;
}
const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)];
ICHECK_EQ(ann_inst->inputs.size(), 2);
const InstructionNode* get_block_inst =
get_block_insts.at(Downcast<tir::BlockRV>(ann_inst->inputs[0]).get());
ICHECK_EQ(get_block_inst->attrs.size(), 2);
candidate->inst = GetRef<Instruction>(ann_inst);
candidate->parallel_extent = Downcast<IntImm>(ann_inst->inputs[1])->value;
candidate->block_name = Downcast<String>(get_block_inst->attrs[0]);
candidate->func_name = Downcast<String>(get_block_inst->attrs[1]);
return true;
}

Optional<Trace> MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) {
// Step 1. Find a parallel decision.
Candidate candidate;
if (!FindParallelDecision(trace, rand_state, &candidate)) {
return NullOpt;
}
// Step 2. Replay the instructions to recover loop extents
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)), //
/*rand_state=*/ForkSeed(rand_state), //
/*debug_mode=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
// Step 3. Find all possible parallel plans
std::vector<std::vector<int64_t>> loop_extent_prods = tir::AnalyzeParallel(
sch->state(), candidate.block_name, candidate.func_name, this->max_parallel_extent_);
std::unordered_map<int64_t, std::vector<int>> limit2plan;
std::map<std::vector<int>, int64_t> plan2limit;
for (const std::vector<int64_t>& prods : loop_extent_prods) {
for (int64_t limit : prods) {
if (limit <= this->max_parallel_extent_ && !limit2plan.count(limit)) {
std::vector<int> plan = tir::GetNumFusedLoops(loop_extent_prods, limit);
limit2plan[limit] = plan;
plan2limit[plan] = limit;
}
}
}
// Step 4. Remove the original plan and remove it
std::vector<int> original_plan =
tir::GetNumFusedLoops(loop_extent_prods, candidate.parallel_extent);
auto it = plan2limit.find(original_plan);
if (it != plan2limit.end()) {
plan2limit.erase(it);
}
// Step 5. Pick a new plan
int n_plans = plan2limit.size();
if (n_plans == 0) {
return NullOpt;
}
it = plan2limit.begin();
for (int i = 0, n = tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) {
++it;
}
int64_t limit = it->second;
// Step 6. Assemble a new trace
Array<Instruction> insts;
insts.reserve(trace->insts.size());
for (const Instruction& inst : trace->insts) {
if (inst.same_as(candidate.inst)) {
insts.push_back(tir::ReplaceAnnValue(candidate.inst, limit));
} else if (inst->kind->IsPostproc()) {
break;
} else {
insts.push_back(inst);
}
}
return Trace(insts, trace->decisions);
}

Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) {
ObjectPtr<MutateParallelNode> n = make_object<MutateParallelNode>();
n->max_jobs_per_core = max_jobs_per_core;
return Mutator(n);
}

TVM_REGISTER_NODE_TYPE(MutateParallelNode);
TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel").set_body_typed(Mutator::MutateParallel);

} // namespace meta_schedule
} // namespace tvm
20 changes: 20 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,26 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline,
bool require_subtree_compact_dataflow);

/*!
* \brief The information of a block scope, including the leaf blocks,
* as well as the loop types (spatial, reduction) for each loop in the scope.
*/
struct ScopeBlockLoopInfo {
/*! \brief A list of the leaf blocks, from left to right */
std::vector<BlockRealize> realizes;
/*! \brief The loop vars bound to spatial block iters */
std::unordered_set<const VarNode*> spatial_vars;
/*! \brief The loop vars bound to non-spatial block iters */
std::unordered_set<const VarNode*> non_spatial_vars;
};

/*!
* \brief Inspect the scope of the given sref
* \param scope_block The root block of the scope
* \return The information of the scope
*/
ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block);

/*!
* \brief Checks whether the block is a complete block under the scope
* \param self The schedule state
Expand Down
Loading