Skip to content

Commit

Permalink
rewrite_vnni -> rewrite_tensorize
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
1 parent 823797e commit 43e0b2f
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 18 deletions.
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/postproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
from .rewrite_reduction_block import RewriteReductionBlock
from .rewrite_unbound_block import RewriteUnboundBlock
from .verify_gpu_code import VerifyGPUCode
from .rewrite_vnni import RewriteVNNI
from .rewrite_tensorize import RewriteTensorize
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A postprocessor that tensorize VNNI related components."""
"""A postprocessor that tensorize related components."""

from tvm._ffi.registry import register_object
from .. import _ffi_api
from .postproc import Postproc
import tvm.tir.tensor_intrin


@register_object("meta_schedule.RewriteVNNI")
class RewriteVNNI(Postproc):
"""A postprocessor that tensorize VNNI related components."""
@register_object("meta_schedule.RewriteTensorize")
class RewriteTensorize(Postproc):
"""A postprocessor that tensorize related components."""

def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocRewriteVNNI, # type: ignore # pylint: disable=no-member
_ffi_api.PostprocRewriteTensorize, # type: ignore # pylint: disable=no-member
)
1 change: 0 additions & 1 deletion python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def _postproc() -> List[Postproc]:
M.DisallowDynamicLoop(),
M.RewriteParallelVectorizeUnroll(),
M.RewriteReductionBlock(),
M.RewriteVNNI(),
]

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ using tir::LoopRV;

using BlockPosition = std::tuple<String, String, String>;

class RewriteVNNINode : public PostprocNode {
class RewriteTensorizeNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {}
Expand All @@ -37,8 +37,8 @@ class RewriteVNNINode : public PostprocNode {

void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "meta_schedule.RewriteVNNI";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteVNNINode, PostprocNode);
static constexpr const char* _type_key = "meta_schedule.RewriteTensorize";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteTensorizeNode, PostprocNode);
};

void CollectPostProcTasks(const tir::Schedule& sch, const String& func_name,
Expand All @@ -60,7 +60,7 @@ void CollectPostProcTasks(const tir::Schedule& sch, const String& func_name,
});
}

bool RewriteVNNINode::Apply(const tir::Schedule& sch) {
bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) {
std::vector<BlockPosition> tasks;
for (const auto& kv : sch->mod()->functions) {
GlobalVar g_var = kv.first;
Expand Down Expand Up @@ -88,13 +88,13 @@ bool RewriteVNNINode::Apply(const tir::Schedule& sch) {
return true;
}

Postproc RewriteVNNI() {
ObjectPtr<RewriteVNNINode> n = make_object<RewriteVNNINode>();
Postproc RewriteTensorize() {
ObjectPtr<RewriteTensorizeNode> n = make_object<RewriteTensorizeNode>();
return Postproc(n);
}

TVM_REGISTER_NODE_TYPE(RewriteVNNINode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteVNNI").set_body_typed(RewriteVNNI);
TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode);
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize").set_body_typed(RewriteTensorize);

} // namespace meta_schedule
} // namespace tvm
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode {
}

public:
String intrin_name;

static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWithIntrin";
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWithIntrinNode, MultiLevelTilingNode);

String intrin_name;
};

ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin(
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2096,7 +2096,7 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
break;
}
}
if (block_loop == nullptr) {
if (desc_loop == nullptr) {
return NullOpt;
}
// Step 4.3. Check divisibility of loop extents
Expand Down

0 comments on commit 43e0b2f

Please sign in to comment.