Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rule RFactor #551

Merged
merged 3 commits into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Integer> vector_load_max_len, //
Optional<Map<String, ObjectRef>> reuse_read, //
Optional<Map<String, ObjectRef>> reuse_write);
/*!
* \brief Create a rule: add-rfactor to some blocks if needed
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
* uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable
* parallelism.
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, //
Optional<Integer> max_innermost_factor);
/*!
* \brief A rule that randomly select a compute-at location for a free block
* \return The rule created
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Meta Schedule schedule rules are used for modification of
blocks in a schedule. See also PostOrderApply.
"""
from .add_rfactor import AddRFactor
from .auto_inline import AutoInline
from .multi_level_tiling import MultiLevelTiling, ReuseType
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
Expand Down
49 changes: 49 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/add_rfactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Add-rfactor Rule that add-rfactor to some blocks if needed"""
from typing import Optional

from tvm._ffi import register_object

from .. import _ffi_api
from .schedule_rule import ScheduleRule


@register_object("meta_schedule.AddRFactor")
class AddRFactor(ScheduleRule):
"""Rules for add-rfactor to some blocks if needed.

Parameters
----------
max_jobs_per_core: int
The maximum number of jobs to be launched per CPU core. It sets the uplimit of CPU
parallelism, i.e. `num_cores * max_jobs_per_core`.
Use -1 to disable parallelism.
max_innermost_factor: Optional[int] = None
The maximum size of the innermost factor. NullOpt means no limit.
"""

def __init__(
self,
max_jobs_per_core: int = 16,
max_innermost_factor: Optional[int] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleAddRFactor, # type: ignore # pylint: disable=no-member
max_jobs_per_core,
max_innermost_factor,
)
9 changes: 9 additions & 0 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import List

from tvm.meta_schedule.schedule_rule import (
AddRFactor,
AutoInline,
MultiLevelTiling,
ParallelizeVectorizeUnroll,
Expand All @@ -33,6 +34,7 @@ def get(target: Target) -> List[ScheduleRule]:
if target.kind.name == "llvm":
return [
auto_inline(target),
add_rfactor(target),
multi_level_tiling(target),
parallel_vectorize_unroll(target),
]
Expand Down Expand Up @@ -182,3 +184,10 @@ def random_compute_location(target: Target) -> ScheduleRule:
if target.kind.name == "llvm":
return RandomComputeLocation()
raise NotImplementedError(f"{target.kind.name} is not supported")


def add_rfactor(target: Target) -> ScheduleRule:
"""Default schedule rules for with add_rfactor"""
if target.kind.name == "llvm":
return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64)
raise NotImplementedError(f"{target.kind.name} is not supported")
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def _sch_rules() -> List[ScheduleRule]:
require_ordered=True,
disallow_op=["tir.exp"],
),
M.AddRFactor(max_job_per_core=16, max_inner_most_factor=64),
M.MultiLevelTiling(
structure="SSRSRS",
tile_binds=None,
Expand Down
115 changes: 115 additions & 0 deletions src/meta_schedule/schedule_rule/add_rfactor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

class AddRFactorNode : public ScheduleRuleNode {
public:
// Inherited from ScheduleRuleNode
void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(context->target.defined());
Target target = context->target.value();
this->max_parallel_basic_ = GetTargetNumCores(target);
if (this->max_jobs_per_core != -1) {
this->max_parallel_extent_ = max_parallel_basic_ * max_jobs_per_core;
}
}

// Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv);

public:
/*!
* \brief The maximum number of jobs to be launched per core.
* It sets the uplimit of parallelism, i.e. `num_cores * max_jobs_per_core`.
* Use -1 to disable parallelism.
*/
int max_jobs_per_core;
/*! \brief The maximum size of the innermost factor */
int max_innermost_factor;
/*! \brief The number of uplimit of parallelism. */
int max_parallel_extent_;
/*! \brief The number of cores. */
int max_parallel_basic_;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("max_jobs_per_core", &max_jobs_per_core);
v->Visit("max_innermost_factor", &max_innermost_factor);
// `max_parallel_extent_` is not visited
// `max_parallel_basic_` is not visited
}

static constexpr const char* _type_key = "meta_schedule.AddRFactor";
TVM_DECLARE_FINAL_OBJECT_INFO(AddRFactorNode, ScheduleRuleNode);
};

ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core,
Optional<Integer> max_innermost_factor) {
ObjectPtr<AddRFactorNode> n = make_object<AddRFactorNode>();
n->max_jobs_per_core = max_jobs_per_core;
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
n->max_parallel_extent_ = -1;
n->max_parallel_basic_ = -1;
return ScheduleRule(n);
}

Array<tir::Schedule> AddRFactorNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_,
max_parallel_basic_)) {
return {sch};
}

// Make a copy of the original schedule.
tir::Schedule ori_sch = sch->Copy();
ori_sch->Seed(sch->ForkSeed());

// Reorder the loop axes if reduction loops are not innermost.
// After the reordering, fuse all the reduction loops.
size_t num_spatial_loops;
tir::LoopRV fused_reduce_loop;
ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops);

// Split the fused reduction loop.
Array<tir::ExprRV> factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor);
const Array<tir::LoopRV>& split_loops =
sch->Split(fused_reduce_loop, {factors.begin(), factors.end()});

Array<tir::Schedule> res;
for (const tir::LoopRV& split_loop : split_loops) {
tir::Schedule sch_tmp = sch->Copy();
sch_tmp->Seed(sch->ForkSeed());
const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops);
Array<tir::LoopRV> axes = sch_tmp->GetLoops(block_rf);
ICHECK_GT(axes.size(), num_spatial_loops);
res.push_back(sch_tmp);
}

res.push_back(ori_sch);
return res;
}

TVM_REGISTER_NODE_TYPE(AddRFactorNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor")
.set_body_typed(ScheduleRule::AddRFactor);

} // namespace meta_schedule
} // namespace tvm
2 changes: 1 addition & 1 deletion src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ std::vector<State> SubRule(std::vector<State> states, FLambda sub_rule) {
std::vector<State> results;
for (auto&& state : states) {
std::vector<State> next = sub_rule(std::move(state));
results.insert(results.end(),
results.insert(results.end(), //
std::make_move_iterator(next.begin()), //
std::make_move_iterator(next.end()));
}
Expand Down
54 changes: 54 additions & 0 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,60 @@ struct ThreadedTraceApply {
Item* items_;
};

/********** Helper Functions for RuleAddRFactor and RuleCrossThreadReduction **********/

/*!
* \brief Reorder the reduction loops to innermost positions if needed.
* \param sch The schedule
* \param block_rv The block where to apply the reorder
* \param fused_reduce_loop The fusion-generated loop to return.
* \param num_spatial_loops The number of spatial loops to return.
* \note Before invoking this helper function, make sure that the block has only spatial and
* reduction loop axes.
*/
inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::BlockRV& block_rv,
tir::LoopRV* fused_reduce_loop,
size_t* num_spatial_loops) {
Array<tir::LoopRV> loops = sch->GetLoops(block_rv);
Array<tir::StmtSRef> loop_srefs;
for (const tir::LoopRV& loop_rv : loops) {
loop_srefs.push_back(sch->GetSRef(loop_rv));
}

Array<tir::LoopRV> new_order;
// Step 1. Add spatial loops.
*num_spatial_loops = 0;
for (size_t i = 0; i < loops.size(); ++i) {
if (GetLoopIterType(loop_srefs[i]) == tir::kDataPar) {
new_order.push_back(loops[i]);
(*num_spatial_loops)++;
}
}
// Step 2. Add reduction loops.
Array<tir::LoopRV> reduction_loops;
for (size_t i = 0; i < loops.size(); ++i) {
if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) {
new_order.push_back(loops[i]);
reduction_loops.push_back(loops[i]);
}
}
// Step 3. Apply reordering if new_order differs from the original order.
ICHECK_EQ(new_order.size(), loops.size());
for (size_t i = 0; i < loops.size(); ++i) {
if (!new_order[i].same_as(loops[i])) {
sch->Reorder(new_order);
break;
}
}
// Step 4. Fuse all the reduction loops if there are multiple reduction loops.
CHECK(!reduction_loops.empty()) << "ValueError: There should be at least one reduction loop";
if (reduction_loops.size() > 1) {
*fused_reduce_loop = sch->Fuse(reduction_loops);
} else {
*fused_reduce_loop = reduction_loops[0];
}
}

} // namespace meta_schedule
} // namespace tvm

Expand Down
13 changes: 13 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,19 @@ AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& wri
*/
bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref);

/*!
* \brief Checks if the rfactor or cross thread reduction is beneficial to the given block.
* \param self The schedule state.
* \param block_sref The block to be checked.
* \param max_parallel_extent The maximum parallel jobs on the target.
* \param max_parallel_extent The maximum cores on the target.
* \return A boolean indicating whether the operation is beneficial.
*/
bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
const tir::StmtSRef& block_sref, //
int64_t max_parallel_extent, //
int64_t max_parallel_basic);

/*!
* \brief Checks if the given AST contains the specific operators
* \param stmt The AST to be checked
Expand Down
Loading