Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoScheduler] New layout rewrite option: Weight pre-transpose #6750

Merged
merged 15 commits into from
Nov 2, 2020
30 changes: 24 additions & 6 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,22 @@ class ComputeDAGNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object);
};

/*!
* \brief Several options for applying layout rewrite.
* This is a optimization to rewrite the shape of input tensor according to the schedule we get.
*/
enum class LayoutRewriteOption : int {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enum class LayoutRewriteOption : uint8 should be enough.

/*! \brief Do not process layout rewrite. */
NoRewrite = 0,
/*!
* \brief Modify the placeholder to suit the schedule.
* \note This should be used along with the graph optimization in Relay.
*/
RewriteWithPlaceholder = 1,
/*! \brief Insert a pre-transpose stage between placeholer and compute op to suit the schedule. */
RewriteWithPreTranspose = 2
};

/*!
* \brief Managed reference to ComputeDAGNode.
* \sa ComputeDAGNode
Expand All @@ -214,8 +230,10 @@ class ComputeDAG : public ObjectRef {
* \brief Rewrite the layout of placeholder specified by attr `layout_free_placeholders`
* according to the loop nest derived with `transform_steps`.
* \param transform_steps Transform steps of a state.
* \param layout_rewrite Different options in layout rewrite.
* \return The updated ComputeDAG after layout rewrite.
*/
void RewriteLayout(const Array<Step>& transform_steps);
ComputeDAG RewriteLayout(Array<Step>* transform_steps, LayoutRewriteOption layout_rewrite) const;

/*!
* \brief Apply the history transform steps to get a TVM schedule.
Expand All @@ -225,14 +243,14 @@ class ComputeDAG : public ObjectRef {
* \param stage_to_axes The map that stores all axes for one stage.
* Pass a valid pointer if this information needs to be used outside this function.
* \param layout_rewrite Rewrite the layout of placeholders specified by
* attr `layout_free_placeholders`
* attr `layout_free_placeholders`.
* \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower`
* or `tvm.build`.
*/
std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(const Array<Step>& transform_steps,
Array<te::Stage>* stages = nullptr,
StageToAxesMap* stage_to_axes = nullptr,
bool layout_rewrite = false) const;
std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
StageToAxesMap* stage_to_axes = nullptr,
LayoutRewriteOption layout_rewrite = LayoutRewriteOption::NoRewrite) const;

/*!
* \brief Print transform steps as equivalent python schedule API.
Expand Down
46 changes: 31 additions & 15 deletions include/tvm/auto_scheduler/transform_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,23 @@ class StepNode : public Object {
*/
class Step : public ObjectRef {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode);
/*!
* \brief CopyOnWrite function for Step.
* This works almost the same as a normal ObjectRef.CopyOnWrite(), but can dispatch to different
* steps.
* \return A base StepNode pointer, need to cast to its real StepNode type before doing any
* modifies.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
* \code
*
* SplitStep ref;
* StepNode* mutable_ref = ref.CopyOnWrite();
* dynamic_cast<SplitStepNode*>(mutable_ref)->... = ...;
*
* \endcode
*/
StepNode* CopyOnWrite();

TVM_DEFINE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode);
};

// Forward declaration
Expand Down Expand Up @@ -267,7 +283,7 @@ class AnnotationStepNode : public StepNode {
static constexpr const char* record_prefix_str = "AN";

static constexpr const char* _type_key = "auto_scheduler.AnnotationStep";
TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -330,7 +346,7 @@ class FuseStepNode : public StepNode {
static constexpr const char* record_prefix_str = "FU";

static constexpr const char* _type_key = "auto_scheduler.FuseStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -390,7 +406,7 @@ class PragmaStepNode : public StepNode {
static constexpr const char* record_prefix_str = "PR";

static constexpr const char* _type_key = "auto_scheduler.PragmaStep";
TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -452,7 +468,7 @@ class ReorderStepNode : public StepNode {
static constexpr const char* record_prefix_str = "RE";

static constexpr const char* _type_key = "auto_scheduler.ReorderStep";
TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -527,7 +543,7 @@ class SplitStepNode : public StepNode {
static constexpr const char* record_prefix_str = "SP";

static constexpr const char* _type_key = "auto_scheduler.SplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -607,7 +623,7 @@ class FollowSplitStepNode : public StepNode {
static constexpr const char* record_prefix_str = "FSP";

static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -688,7 +704,7 @@ class FollowFusedSplitStepNode : public StepNode {
static constexpr const char* record_prefix_str = "FFSP";

static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -754,7 +770,7 @@ class StorageAlignStepNode : public StepNode {
static constexpr const char* record_prefix_str = "SA";

static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep";
TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -822,7 +838,7 @@ class ComputeAtStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CA";

static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -879,7 +895,7 @@ class ComputeInlineStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CI";

static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -938,7 +954,7 @@ class ComputeRootStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CR";

static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -1010,7 +1026,7 @@ class CacheReadStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CHR";

static constexpr const char* _type_key = "auto_scheduler.CacheReadStep";
TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -1081,7 +1097,7 @@ class CacheWriteStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CHW";

static constexpr const char* _type_key = "auto_scheduler.CacheWriteStep";
TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -1148,7 +1164,7 @@ class RfactorStepNode : public StepNode {
static constexpr const char* record_prefix_str = "RF";

static constexpr const char* _type_key = "auto_scheduler.RfactorStep";
TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, StepNode);
};

/*!
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class ComputeDAG(Object):
compute : Union[List[Tensor], str, Schedule]
Input/output tensors or workload key for a compute declaration.
"""
LAYOUT_REWRITE_TABLE = {
"NoRewrite": 0,
"RewriteWithPlaceholder": 1,
"RewriteWithPreTranspose": 2,
}
jcf94 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, compute_or_sche):
if isinstance(compute_or_sche, str):
Expand Down Expand Up @@ -81,7 +86,7 @@ def get_init_state(self):
"""
return State(self.init_state, self)

def apply_steps_from_state(self, state, layout_rewrite=False):
def apply_steps_from_state(self, state, layout_rewrite=LAYOUT_REWRITE_TABLE["NoRewrite"]):
"""
Apply the history transform steps from a State to get a TVM schedule.

Expand Down
Loading