diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 97cf467cca07..e69de29bb2d1 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,158 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Github code owners file -# This file is used as a convenient tool to map -# committers' areas of expertise and faciliate the review process. -# -# This may not be the non-comprehensive list and is meant to be -# updated over time. - -# Per ASF policy, committer have global write permission. -# We normally recommend committers to shepherd code in their area of expertise. -* @apache/tvm-committers - -# Order is important; the last matching pattern takes the most precedence. -# The sub modules should be ordered first by depth. -# Making sure we append new sub-module rules after exisiting modules rules. - -############################## -# Top-level Fallbacks -############################## -include/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics -src/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics -apps/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics -python/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics - -# Thirdparty license audit -3rdparty/** @tqchen @jroesch -licenses/** @tqchen @jroesch - -# JVM language -jvm/** @yzhliu - -# Golang -golang/** @srkreddy1238 - -# WASM -web/** @tqchen @jroesch - -# Docker -docker/** @areusch @leandron @jroesch - -# Conda -conda/** @tqchen @junrushao1994 @comaniac - -# CMake -cmake/** @jroesch @tqchen @areusch @junrushao1994 @comaniac - -# rust bindings -rust/** @jroesch @nhynes @nhynes - -# vta -vta/** @tmoreau89 @vegaluisjose - -# docs -docs/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon -tutorials/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon - -# tests -tests/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon - -############################## -# Specific modules -############################## - -# automation related -src/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 @Hzfengsy -include/tvm/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 @Hzfengsy -python/tvm/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 @Hzfengsy - -python/tvm/autotvm/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 - -# node system and reflection -src/node/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac -include/tvm/node/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac - -# ir: Common IR -src/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac -include/tvm/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac -python/tvm/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac - -# tir -src/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were @Hzfengsy -include/tvm/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were @Hzfengsy -python/tvm/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were @Hzfengsy - -# te -src/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were -include/tvm/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were -python/tvm/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were - -# target -src/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi -include/tvm/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi -python/tvm/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi - -# arith: Arithmetic module and simplifiers -src/arith/** @tqchen @junrushao1994 @vinx13 -include/tvm/arith/** @tqchen @junrushao1994 @vinx13 -python/tvm/arith/** @tqchen @junrushao1994 @vinx13 - -# parser -src/parser/** @jroesch @slyubomirsky - -# runtime -src/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 -include/tvm/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 -python/tvm/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 - -# runtime/micro -src/runtime/micro/** @areusch @liangfu @tmoreau89 @manupa-arm -src/runtime/crt/** @areusch @liangfu @tmoreau89 @manupa-arm -include/tvm/runtime/crt/** @areusch @liangfu @tmoreau89 @manupa-arm -include/tvm/runtime/micro/** @areusch @liangfu @tmoreau89 @manupa-arm -python/tvm/micro/** @areusch @liangfu @tmoreau89 @manupa-arm - -# relay -src/relay/** @jroesch @slyubomirsky @icemelon @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 -include/tvm/relay/** @jroesch @slyubomirsky @icemelon @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 -python/tvm/relay/** @jroesch @slyubomirsky @icemelon @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 - - -# relay/qnn -src/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang -inlcude/tvm/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang -python/tvm/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang - -# relay/backend/contrib: BYOC -src/relay/backend/contrib/** @zhiics @trevor-m @comaniac @mbaret @manupa-arm - -# relay/frontends -python/tvm/relay/frontend/** @jwfromm @mbrookhart @srkreddy1238 @siju-samuel @Huyuwei @hlu1 @kazum @PariksheetPinjari909 - -# topi: Operator definitions -src/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 -include/tvm/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 -python/tvm/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 - - -# tvm/driver/ -python/tvm/driver/** @leandron @jwfromm @tqchen @jroesch - -# tvm/driver/tvmc -python/tvm/driver/tvmc/** @leandron @jwfromm diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py index 85b48b35f621..718b25437281 100644 --- a/python/tvm/meta_schedule/testing/__init__.py +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -18,3 +18,4 @@ from .byoc_trt import relay_build_with_tensorrt from .local_rpc import LocalRPC from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, get_torch_model +from .te_workload import create_te_workload diff --git a/python/tvm/meta_schedule/testing/e2e.py b/python/tvm/meta_schedule/testing/e2e.py new file mode 100644 index 000000000000..2b884701339a --- /dev/null +++ b/python/tvm/meta_schedule/testing/e2e.py @@ -0,0 +1,239 @@ +import multiprocessing +import os +import pickle +from typing import Any, Dict, List, Optional, Tuple + +import tvm +import tvm.relay.testing +from tvm import relay +from tvm.ir import IRModule +from tvm.meta_schedule.integration import ExtractedTask, extract_task_from_relay +from tvm.runtime import NDArray, load_param_dict, save_param_dict +from tvm.target import Target + +SUPPORTED = [ + # TorchVision + "resnet_18", + "resnet_50", + "mobilenet_v2", + "mobilenet_v3", + "wide_resnet_50", + "resnext_50", + "resnet3d_18", + "inception_v3", + "densenet_121", + "vgg_16", + # Transformer + "bert_tiny", + "bert_base", + "bert_medium", + "bert_large", + # Relay testing + "dcgan", +] + + +def _get_network( + args: Tuple[str, List[int]] +) -> Tuple[IRModule, bytearray, Tuple[str, List[int], str]]: + name: str + input_shape: List[int] + name, input_shape = args + + mod: IRModule + + if name in [ + "resnet_18", + "resnet_50", + "wide_resnet_50", + "resnext_50", + "mobilenet_v2", + "mobilenet_v3", + "inception_v3", + "densenet_121", + "resnet3d_18", + "vgg_16", + ]: + # torchvision>=0.9.0 + import torch # type: ignore + import torchvision.models as models # type: ignore + + if name in ["resnet_18", "resnet_50"]: + model = getattr(models, name.replace("_", ""))(pretrained=False) + elif name == "wide_resnet_50": + model = getattr(models, "wide_resnet50_2")(pretrained=False) + elif name == "resnext_50": + model = getattr(models, "resnext50_32x4d")(pretrained=False) + elif name == "mobilenet_v2": + model = getattr(models, name)(pretrained=False) + elif name == "mobilenet_v3": + model = getattr(models, name + "_large")(pretrained=False) + elif name == "inception_v3": + model = getattr(models, name)(pretrained=False, aux_logits=False) + elif name == "densenet_121": + model = getattr(models, name.replace("_", ""))(pretrained=False) + elif name == "resnet3d_18": + model = models.video.r3d_18(pretrained=False) + elif name == "vgg_16": + model = getattr(models, name.replace("_", ""))(pretrained=False) + + dtype = "float32" + input_data = torch.randn(input_shape).type( + { + "float32": torch.float32, + }[dtype] + ) + scripted_model = torch.jit.trace(model, input_data).eval() + input_name = "input0" + shape_list = [(input_name, input_shape)] + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + with tvm.transform.PassContext(opt_level=3): + mod = tvm.transform.Sequential( + [ + relay.transform.RemoveUnusedFunctions(), + relay.transform.ConvertLayout( + { + "nn.conv2d": ["NHWC", "default"], + "nn.conv3d": ["NDHWC", "default"], + "nn.max_pool2d": ["NHWC", "default"], + "nn.avg_pool2d": ["NHWC", "default"], + } + ), + ] + )(mod) + inputs = (input_name, input_shape, dtype) + elif name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + # pip3 install transformers==3.5 torch==1.7 + import torch # type: ignore + import transformers # type: ignore + + config_dict = { + "bert_tiny": transformers.BertConfig( + num_hidden_layers=6, + hidden_size=512, + intermediate_size=2048, + num_attention_heads=8, + return_dict=False, + ), + "bert_base": transformers.BertConfig( + num_hidden_layers=12, + hidden_size=768, + intermediate_size=3072, + num_attention_heads=12, + return_dict=False, + ), + "bert_medium": transformers.BertConfig( + num_hidden_layers=12, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + return_dict=False, + ), + "bert_large": transformers.BertConfig( + num_hidden_layers=24, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + return_dict=False, + ), + } + configuration = config_dict[name] + model = transformers.BertModel(configuration) + input_name = "input_ids" + input_dtype = "int64" + A = torch.randint(10000, input_shape) + model.eval() + scripted_model = torch.jit.trace(model, [A], strict=False) + input_name = "input_ids" + shape_list = [(input_name, input_shape)] + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + mod = relay.transform.FastMath()(mod) + mod = relay.transform.CombineParallelBatchMatmul()(mod) + inputs = (input_name, input_shape, input_dtype) + elif name == "dcgan": + output_shape = input_shape + batch_size = output_shape[0] + oshape = output_shape[1:] + mod, params = relay.testing.dcgan.get_workload( + batch_size=batch_size, + oshape=oshape, + layout="NHWC", + ) + inputs = ("data", [100], "float32") + else: + raise ValueError("Invalid name: " + name) + + params_bytearray: bytearray = save_param_dict(params) + return mod, params_bytearray, inputs + + +def _load_cache(cache_dir: Optional[str], filename: str) -> Optional[List[Any]]: + if cache_dir is None: + return None + path = os.path.join(os.path.expanduser(cache_dir), filename) + if not os.path.exists(path): + return None + print(f"Load from cache: {path}") + with open(path, "rb") as i_f: + return pickle.load(i_f) + + +def _save_cache(cache_dir: Optional[str], filename: str, objects: List[Any]) -> None: + if cache_dir is None: + return + path = os.path.join(os.path.expanduser(cache_dir), filename) + with open(path, "wb") as o_f: + pickle.dump(objects, o_f) + + +def get_network( + name: str, + input_shape: List[int], + *, + cache_dir: Optional[str] = None, +) -> Tuple[IRModule, Dict[str, NDArray], Tuple[str, List[int], str]]: + mod: IRModule + params: Dict[str, NDArray] + inputs: Tuple[str, List[int], str] + params_bytearray: bytearray + + filename = f'{name}-{",".join(str(i) for i in input_shape)}.json' + cached = _load_cache(cache_dir, filename) + if cached is None: + with multiprocessing.Pool(processes=1) as pool: + result = pool.map(_get_network, [(name, input_shape)]) + ((mod, params_bytearray, inputs),) = result + cached = [mod, params_bytearray, inputs] + _save_cache(cache_dir, filename, cached) + mod, params_bytearray, inputs = cached + params = load_param_dict(params_bytearray) + return mod, params, inputs + + +def extract( + filename: str, + mod: IRModule, + target: Target, + params: Optional[Dict[str, NDArray]] = None, + *, + cache_dir: Optional[str] = None, + opt_level: int = 3, + pass_config: Dict[str, Any] = { + "relay.backend.use_meta_schedule": True, + }, + disabled_pass: List[str] = [], +) -> List[ExtractedTask]: + extracted_tasks = _load_cache(cache_dir, filename) + if extracted_tasks is None: + extracted_tasks = extract_task_from_relay( + mod=mod, + target=target, + params=params, + opt_level=opt_level, + pass_config=pass_config, + disabled_pass=disabled_pass, + ) + extracted_tasks = list(extracted_tasks) + _save_cache(cache_dir, filename, extracted_tasks) + return extracted_tasks diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index 2f1ffdd407fa..bf9287a8eb18 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -18,7 +18,6 @@ from enum import Enum from typing import Dict, Tuple -import tvm.relay.testing # pylint: disable=unused-import from tvm import relay from tvm.ir import IRModule from tvm.runtime import NDArray @@ -34,9 +33,74 @@ class MODEL_TYPE(Enum): # pylint: disable=invalid-name # Specify the type of each model MODEL_TYPES = { + # Image classification models "resnet18": MODEL_TYPE.IMAGE_CLASSIFICATION, + "resnet50": MODEL_TYPE.IMAGE_CLASSIFICATION, + "alexnet": MODEL_TYPE.IMAGE_CLASSIFICATION, + "vgg16": MODEL_TYPE.IMAGE_CLASSIFICATION, + "squeezenet1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet121": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet161": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet169": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet201": MODEL_TYPE.IMAGE_CLASSIFICATION, + "inception_v3": MODEL_TYPE.IMAGE_CLASSIFICATION, + "googlenet": MODEL_TYPE.IMAGE_CLASSIFICATION, + "shufflenet_v2_x1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, "mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v3_large": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v3_small": MODEL_TYPE.IMAGE_CLASSIFICATION, + "resnext50_32x4d": MODEL_TYPE.IMAGE_CLASSIFICATION, + "wide_resnet50_2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mnasnet1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b1": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b3": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b4": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b5": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b6": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b7": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_400mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_800mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_1_6gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_3_2gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_8gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_16gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_32gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_400mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_800mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_1_6gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_3_2gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_8gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_16gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_32gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + # Semantic Segmentation models + "fcn_resnet50": MODEL_TYPE.SEGMENTATION, + "fcn_resnet101": MODEL_TYPE.SEGMENTATION, + "deeplabv3_resnet50": MODEL_TYPE.SEGMENTATION, + "deeplabv3_resnet101": MODEL_TYPE.SEGMENTATION, + "deeplabv3_mobilenet_v3_large": MODEL_TYPE.SEGMENTATION, + "lraspp_mobilenet_v3_large": MODEL_TYPE.SEGMENTATION, + # Object detection models + # @Sung: Following networks are not runnable since Torch frontend cannot handle aten::remainder. + # "retinanet_resnet50_fpn", "keypointrcnn_resnet50_fpn", + "fasterrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "fasterrcnn_mobilenet_v3_large_fpn": MODEL_TYPE.OBJECT_DETECTION, + "fasterrcnn_mobilenet_v3_large_320_fpn": MODEL_TYPE.OBJECT_DETECTION, + "retinanet_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "maskrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "keypointrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "ssd300_vgg16": MODEL_TYPE.OBJECT_DETECTION, + "ssdlite320_mobilenet_v3_large": MODEL_TYPE.OBJECT_DETECTION, + # Video classification + "r3d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + "mc3_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + "r2plus1d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + # Text classification + "bert_tiny": MODEL_TYPE.TEXT_CLASSIFICATION, "bert_base": MODEL_TYPE.TEXT_CLASSIFICATION, + "bert_medium": MODEL_TYPE.TEXT_CLASSIFICATION, + "bert_large": MODEL_TYPE.TEXT_CLASSIFICATION, } @@ -73,31 +137,104 @@ def do_trace(model, inp): return model_trace # Load model from torchvision - if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION: + if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: + model = getattr(models, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + model = getattr(models.segmentation, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + model = getattr(models.detection, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: + model = getattr(models.video, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION: os.environ["TOKENIZERS_PARALLELISM"] = "false" - model = transformers.BertModel( - transformers.BertConfig( + config_dict = { + "bert_tiny": transformers.BertConfig( + num_hidden_layers=6, + hidden_size=512, + intermediate_size=2048, + num_attention_heads=8, + return_dict=False, + ), + "bert_base": transformers.BertConfig( num_hidden_layers=12, hidden_size=768, intermediate_size=3072, num_attention_heads=12, return_dict=False, - ) - ) + ), + "bert_medium": transformers.BertConfig( + num_hidden_layers=12, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + return_dict=False, + ), + "bert_large": transformers.BertConfig( + num_hidden_layers=24, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + return_dict=False, + ), + } + configuration = config_dict[model_name] + model = transformers.BertModel(configuration) + A = torch.randint(10000, input_shape) + model.eval() - input_data = torch.randint(10000, input_shape) + scripted_model = torch.jit.trace(model, [A], strict=False) + shape_list = [("input_ids", input_shape)] - scripted_model = torch.jit.trace(model, [input_data], strict=False) - elif MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: - model = getattr(models, model_name)() - # Setup input - input_data = torch.randn(input_shape).type(torch.float32) - shape_list = [("input0", input_shape)] - # Get trace. Depending on the model type, wrapper may be necessary. - scripted_model = do_trace(model, input_data) + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + return mod, params else: raise ValueError("Unsupported model in Torch model zoo.") + # Setup input + input_data = torch.randn(input_shape).type(torch.float32) + shape_list = [("input0", input_shape)] + + # Get trace. Depending on the model type, wrapper may be necessary. + if MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + + class TraceWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return out["out"] + + wrapped_model = TraceWrapper(model) + wrapped_model.eval() + with torch.no_grad(): + scripted_model = do_trace(wrapped_model, input_data) + + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + + def dict_to_tuple(out_dict): + if "masks" in out_dict.keys(): + return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"] + return out_dict["boxes"], out_dict["scores"], out_dict["labels"] + + class TraceWrapper(torch.nn.Module): # type: ignore + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return dict_to_tuple(out[0]) + + wrapped_model = TraceWrapper(model) + wrapped_model.eval() + with torch.no_grad(): + _ = wrapped_model(input_data) + scripted_model = do_trace(wrapped_model, input_data) + else: + scripted_model = do_trace(model, input_data) + # Convert torch model to relay module mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) return mod, params @@ -110,6 +247,8 @@ def get_network( dtype: str = "float32", ) -> Tuple[IRModule, Dict[str, NDArray], Tuple[int, int, int, int], Tuple[int, int]]: """Get the symbol definition and random weight of a network""" + import tvm.relay.testing # pylint: disable=import-outside-toplevel,unused-import + # meta-schedule prefers NHWC layout if layout == "NHWC": image_shape = (224, 224, 3) diff --git a/python/tvm/meta_schedule/testing/run_ansor.sh b/python/tvm/meta_schedule/testing/run_ansor.sh new file mode 100644 index 000000000000..d5ea9df34485 --- /dev/null +++ b/python/tvm/meta_schedule/testing/run_ansor.sh @@ -0,0 +1,40 @@ +set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="raspi4b-aarch64" +TARGET="raspberry-pi/4b-64" +NUM_TRIALS=800 +LOG_DIR=$HOME/logs/ansor-cpu/ + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_ansor_cpu.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials "$NUM_TRIALS" \ + --log-dir $LOG_DIR \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +# Single op +run C1D +run C2D +run C3D +run CAP +run DEP +run DIL +run GMM +run GRP +run NRM +run T2D +# Subgraph +run C2d-BN-RELU +run TBG + diff --git a/python/tvm/meta_schedule/testing/run_meta_schedule.sh b/python/tvm/meta_schedule/testing/run_meta_schedule.sh new file mode 100644 index 000000000000..fa0c7ca42562 --- /dev/null +++ b/python/tvm/meta_schedule/testing/run_meta_schedule.sh @@ -0,0 +1,38 @@ +# set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="raspi4b-aarch64" +TARGET="raspberry-pi/4b-64" +LOG_DIR=$HOME/logs/ms-cpu/ + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_tune_te_cpu.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials 5000 \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +# Single op +run C1D +run C2D +# run C3D +run CAP +run DEP +run DIL +run GMM +run GRP +# run NRM +run T2D +# Subgraph +run C2d-BN-RELU +run TBG + diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index b149f20c52e3..93af4febaf09 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Default schedule rules""" +from typing import List from tvm.meta_schedule.schedule_rule import ( AddRFactor, AutoInline, @@ -28,6 +29,26 @@ from tvm.target import Target +def get(target: Target) -> List[ScheduleRule]: + """Default schedule rules""" + if target.kind.name == "llvm": + return [ + auto_inline(target), + add_rfactor(target), + multi_level_tiling(target), + parallel_vectorize_unroll(target), + random_compute_location(target), + ] + if target.kind.name == "cuda": + return [ + multi_level_tiling(target), + auto_inline_after_tiling(target), + cross_thread_reduction(target), + parallel_vectorize_unroll(target), + ] + raise NotImplementedError(f"{target.kind.name} is not supported") + + def auto_inline(target: Target) -> ScheduleRule: """Default schedule rules for auto inline""" if target.kind.name == "llvm": @@ -53,6 +74,31 @@ def auto_inline(target: Target) -> ScheduleRule: raise NotImplementedError(f"{target.kind.name} is not supported") +def auto_inline_after_tiling(target: Target) -> ScheduleRule: + """Default schedule rules for auto inline after tiling""" + if target.kind.name == "llvm": + return AutoInline( + into_producer=True, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ) + if target.kind.name == "cuda": + return AutoInline( + into_producer=True, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + def add_rfactor(target: Target) -> ScheduleRule: """Default schedule rules for with add_rfactor""" if target.kind.name == "llvm": @@ -109,6 +155,29 @@ def random_compute_location(target: Target) -> ScheduleRule: raise NotImplementedError(f"{target.kind.name} is not supported") +def multi_level_tiling_tensor_core(target: Target) -> ScheduleRule: + """Default schedule rules for with multi-level tiling with tensor core and reuse""" + if target.kind.name == "cuda": + return MultiLevelTiling( + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"], + use_tensor_core=True, + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=ReuseType( + req="must", + levels=[3], + scope="local", + ), + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + def parallel_vectorize_unroll(target: Target) -> ScheduleRule: """Default schedule rules for with parallel-vectorize-unroll""" if target.kind.name == "llvm": @@ -126,3 +195,17 @@ def parallel_vectorize_unroll(target: Target) -> ScheduleRule: unroll_explicit=True, ) raise NotImplementedError(f"{target.kind.name} is not supported") + + +def add_rfactor(target: Target) -> ScheduleRule: + """Default schedule rules for with add_rfactor""" + if target.kind.name == "llvm": + return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def cross_thread_reduction(target: Target) -> ScheduleRule: + """Default schedule rules for with cross-thread reduction""" + if target.kind.name == "cuda": + return CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]) + raise NotImplementedError(f"{target.kind.name} is not supported") diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py index 10e31e7213cb..4abf090ddf95 100644 --- a/python/tvm/meta_schedule/testing/space_generation.py +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -15,11 +15,31 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring -from typing import List +from typing import List, Union -from tvm.tir import Schedule +from tvm.ir import IRModule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.space_generator import PostOrderApply +from tvm.target import Target +from tvm.tir import PrimFunc, Schedule from tvm.tir.schedule import Trace +from . import schedule_rule as sch_rule + + +def create_context(mod: Union[IRModule, PrimFunc], target: Target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=sch_rule.get(target), + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for rule in ctx.sch_rules: + rule.initialize_with_tune_context(ctx) + return ctx + def check_trace(spaces: List[Schedule], expected: List[List[str]]): expected_traces = {"\n".join(t) for t in expected} @@ -31,3 +51,15 @@ def check_trace(spaces: List[Schedule], expected: List[List[str]]): actual_traces.add(str_trace) assert str_trace in expected_traces, "\n" + str_trace assert len(expected_traces) == len(actual_traces) + + +def debug_print_spaces(spaces: List[Schedule], trace_as_list: bool) -> None: + for i, space in enumerate(spaces): + print(f"##### Space {i}") + print(space.mod.script()) + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + if trace_as_list: + print(str(trace).strip().splitlines()) + else: + print(trace) diff --git a/python/tvm/meta_schedule/testing/test_ansor_cpu.py b/python/tvm/meta_schedule/testing/test_ansor_cpu.py new file mode 100644 index 000000000000..36e42c2ab636 --- /dev/null +++ b/python/tvm/meta_schedule/testing/test_ansor_cpu.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import os + +import tvm +from tvm import auto_scheduler +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.te_workload import CONFIGS + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + args.add_argument( + "--log-dir", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=60, + ) + parsed.rpc_workers = rpc_config.count_num_servers(allow_missing=False) + return parsed + + +ARGS = _parse_args() + + +def main(): + log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json") + workload_func, params = CONFIGS[ARGS.workload] + params = params[0] + workload_func = auto_scheduler.register_workload(workload_func) + task = auto_scheduler.SearchTask( + func=workload_func, + args=params, + target=ARGS.target, + hardware_params=auto_scheduler.HardwareParams( + num_cores=int(ARGS.target.attrs["num-cores"]), + target=ARGS.target, + ), + ) + runner = auto_scheduler.RPCRunner( + key=ARGS.rpc_key, + host=ARGS.rpc_host, + port=ARGS.rpc_port, + n_parallel=ARGS.rpc_workers, + ) + + # Inspect the computational graph + print("Computational DAG:") + print(task.compute_dag) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=ARGS.num_trials, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, + runner=runner, + ) + print("Running AutoTuning:") + task.tune(tune_option) + print("History Best:") + print(task.print_best(log_file)) + sch, args = task.apply_best(log_file) + print("Lowered TIR:") + print(tvm.lower(sch, args, simple_mode=True)) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/meta_schedule/testing/test_tune_te_cpu.py b/python/tvm/meta_schedule/testing/test_tune_te_cpu.py new file mode 100644 index 000000000000..b48fc4f9a04c --- /dev/null +++ b/python/tvm/meta_schedule/testing/test_tune_te_cpu.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import logging + +import tvm +from tvm import meta_schedule as ms +from tvm import tir +from tvm.meta_schedule.testing import create_te_workload + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=60, + ) + parsed.rpc_workers = parsed.rpc_config.count_num_servers(allow_missing=False) + return parsed + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +ARGS = _parse_args() + + +def main(): + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, + alloc_repeat=3, + max_workers=ARGS.rpc_workers, + ) + sch: tir.Schedule = ms.tune_tir( + mod=create_te_workload(ARGS.workload, 0), + target=ARGS.target, + config=ms.ReplayTraceConfig( + num_trials_per_iter=64, + num_trials_total=ARGS.num_trials, + ), + runner=runner, + task_name=ARGS.workload, + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/meta_schedule/testing/tir_tensor_intrin.py b/python/tvm/meta_schedule/testing/tir_tensor_intrin.py new file mode 100644 index 000000000000..6aad875e445c --- /dev/null +++ b/python/tvm/meta_schedule/testing/tir_tensor_intrin.py @@ -0,0 +1,337 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A collection of TIR tensor intrinsics""" +# pylint: disable=missing-function-docstring +from tvm import tir +from tvm.script import tir as T + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks + + +@T.prim_func +def tensorcore_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + +@T.prim_func +def tensorcore_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + T.reads([C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]]) + T.writes(C[0:16, 0:16]) + T.evaluate( + T.tvm_mma_sync( + C.data, + C.elem_offset // 256, + A.data, + A.elem_offset // 256, + B.data, + B.elem_offset // 256, + C.data, + C.elem_offset // 256, + dtype="handle", + ) + ) + + +@T.prim_func +def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + C = T.match_buffer(c, (1,)) + + with T.block("root"): + for i in range(0, 4): + with T.block("update"): + vi = T.axis.R(4, i) + C[0] = C[0] + A[vi] * B[vi] + + +@T.prim_func +def dot_product_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + C = T.match_buffer(c, (1,)) + + with T.block("root"): + T.reads([C[0:1], A[0:4], B[0:4]]) + T.writes([C[0:1]]) + T.evaluate( + T.call_extern( # pylint: disable=redundant-keyword-arg + "vec4add", + C.data, + C.elem_offset, + A.data, + A.elem_offset, + B.data, + B.elem_offset, + dtype="int32", + ) + ) + + +@T.prim_func +def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_a") + B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_b") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=1, scope="wmma.accumulator") + + with T.block("root"): + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], "float32") * T.cast( + B[vkk, vjj], "float32" + ) + + +@T.prim_func +def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a") + B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + C = T.match_buffer( + c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" + ) + + with T.block("root"): + T.reads( + [ + C[0:16, 0:16], + A[0:16, 0:16], + B[0:16, 0:16], + ] + ) + T.writes(C[0:16, 0:16]) + T.evaluate( + T.tvm_mma_sync( + C.data, + C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), + A.data, + A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16), + B.data, + B.elem_offset // 256 + T.floordiv(T.floormod(B.elem_offset, 256), 16), + C.data, + C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), + dtype="handle", + ) + ) + + +@T.prim_func +def wmma_load_a_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a") + + with T.block("root"): + for i, j in T.grid(16, 16): + with T.block("load"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = A[vii, vjj] + + +@T.prim_func +def wmma_load_a_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer( + a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0] + ) + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a") + + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + T.evaluate( + T.tvm_load_matrix_sync( + C.data, + 16, + 16, + 16, + C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), + A.access_ptr("r"), + s1, + "row_major", + dtype="handle", + ) + ) + + +@T.prim_func +def wmma_load_b_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + with T.block("root"): + for i, j in T.grid(16, 16): + with T.block("load"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = A[vii, vjj] + + +@T.prim_func +def wmma_load_b_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer( + a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0] + ) + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + T.evaluate( + T.tvm_load_matrix_sync( + C.data, + 16, + 16, + 16, + C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), + A.access_ptr("r"), + s1, + "row_major", + dtype="handle", + ) + ) + + +@T.prim_func +def wmma_fill_desc(c: T.handle) -> None: + C = T.match_buffer( + c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" + ) + with T.block("root"): + for i, j in T.grid(16, 16): + with T.block("init"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = T.float32(0) + + +@T.prim_func +def wmma_fill_impl(c: T.handle) -> None: + C = T.match_buffer( + c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" + ) + with T.block("root"): + T.reads([]) + T.writes(C[0:16, 0:16]) + T.evaluate( + T.tvm_fill_fragment( + C.data, + 16, + 16, + 16, + C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), + T.float32(0), + dtype="handle", + ) + ) + + +@T.prim_func +def wmma_store_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" + ) + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global") + with T.block("root"): + for i, j in T.grid(16, 16): + with T.block("store"): + vii, vjj = T.axis.remap("SS", [i, j]) + C[vii, vjj] = A[vii, vjj] + + +@T.prim_func +def wmma_store_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer( + a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator" + ) + C = T.match_buffer( + c, (16, 16), "float32", align=128, offset_factor=16, scope="global", strides=[s1, s0] + ) + with T.block("root"): + T.reads(A[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + T.evaluate( + T.tvm_store_matrix_sync( + A.data, + 16, + 16, + 16, + A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16), + C.access_ptr("w"), + s1, + "row_major", + dtype="handle", + ) + ) + + +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks + +TENSORCORE_WMMA = tir.TensorIntrin.register( + "test.tensorcore.wmma", + tensorcore_desc, + tensorcore_impl, +) + +NEON_DOT = tir.TensorIntrin.register( + "test.neon.dot", + dot_product_desc, + dot_product_impl, +) + +WMMA_SYNC = tir.TensorIntrin.register( + "wmma_sync", + wmma_sync_desc, + wmma_sync_impl, +) + +WMMA_LOAD_A = tir.TensorIntrin.register( + "wmma_load_a", + wmma_load_a_desc, + wmma_load_a_impl, +) + +WMMA_LOAD_B = tir.TensorIntrin.register( + "wmma_load_b", + wmma_load_b_desc, + wmma_load_b_impl, +) + +WMMA_FILL = tir.TensorIntrin.register( + "wmma_fill", + wmma_fill_desc, + wmma_fill_impl, +) + +WMMA_FILL = tir.TensorIntrin.register( + "wmma_store", + wmma_store_desc, + wmma_store_impl, +) diff --git a/src/tir/transforms/memhammer_coalesce.cc b/src/tir/transforms/memhammer_coalesce.cc new file mode 100644 index 000000000000..7925f4e090c4 --- /dev/null +++ b/src/tir/transforms/memhammer_coalesce.cc @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../../runtime/thread_storage_scope.h" +#include "./memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Fuse consecutive loops + * \param body the outer-most loop + * \return the fused loop + */ +Stmt FuseNestLoops(Stmt body) { + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + std::string suffix; + int n = loops.size(); + for (int i = 1; i < n; i++) { + suffix += "_" + loops[i]->loop_var->name_hint; + } + suffix += "_fused"; + Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); + Map subst_map; + PrimExpr tot = fused_var; + for (int i = n - 1; i >= 0; i--) { + subst_map.Set(loops[i]->loop_var, floormod(tot, loops[i]->extent)); + tot = floordiv(tot, loops[i]->extent); + } + auto f_substitute = [&](const Var& v) -> Optional { + return subst_map.Get(v).value_or(v); + }; + PrimExpr fused_extent = 1; + for (int i = 0; i < n; i++) { + fused_extent *= loops[i]->extent; + } + return For(fused_var, 0, fused_extent, ForKind::kSerial, + Substitute(std::move(body), f_substitute)); +} + +/*! + * \brief a combination of split, bind, vectorize, + * a helper function to perform coalesced load/store + * \param stmt the stmt to do transformation + * \param constraints The constraints, including thread extents, vector bytes, and data bits. + * \return The stmt after transformation + */ +Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { + const ForNode* loop = TVM_TYPE_AS(loop, stmt, ForNode); + int loop_extent = Downcast(loop->extent)->value; + int vector_bytes = constraints.vector_bytes; + int data_bits = constraints.data_bits; + int vector_len = std::max(1, vector_bytes * 8 / data_bits); + int tot_threads = 1; + // generate thread binding loops + std::vector factors{-1}; + std::vector thread_axis; + if (Optional o_t = constraints.thread_extent.Get("threadIdx.z")) { + int t = o_t.value()->value; + tot_threads *= t; + factors.push_back(t); + thread_axis.push_back("threadIdx.z"); + } + if (Optional o_t = constraints.thread_extent.Get("threadIdx.y")) { + int t = o_t.value()->value; + tot_threads *= t; + factors.push_back(t); + thread_axis.push_back("threadIdx.y"); + } + if (Optional o_t = constraints.thread_extent.Get("threadIdx.x")) { + int t = o_t.value()->value; + tot_threads *= t; + factors.push_back(t); + thread_axis.push_back("threadIdx.x"); + } + // generate vectorized loop + factors.push_back(vector_len); + // generate outer loop + ICHECK_EQ(loop_extent % (tot_threads * vector_len), 0); + factors[0] = loop_extent / (tot_threads * vector_len); + // create new loop vars + int n = factors.size(); + std::vector new_loop_vars; + new_loop_vars.reserve(n); + for (int i = 0; i < n; i++) { + new_loop_vars.push_back(loop->loop_var.copy_with_suffix("_" + std::to_string(i))); + } + // substitute fused loop var with new loop vars + PrimExpr substitute_value = 0; + for (int i = 0; i < n; i++) { + substitute_value *= factors[i]; + substitute_value += new_loop_vars[i]; + } + // Construct the new loop nest + Stmt body = Substitute(loop->body, [&](const Var& v) -> Optional { + if (v.same_as(loop->loop_var)) { + return substitute_value; + } else { + return NullOpt; + } + }); + body = For(new_loop_vars.back(), 0, vector_len, ForKind::kVectorized, std::move(body)); + for (int i = n - 2; i >= 1; i--) { + body = For(new_loop_vars[i], 0, factors[i], ForKind::kThreadBinding, std::move(body), + IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, thread_axis[i - 1])); + } + return For(new_loop_vars[0], 0, factors[0], ForKind::kSerial, std::move(body)); +} + +Stmt CoalescedAccess::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_fuse = FuseNestLoops(stmt); + Stmt after_split = SplitBindVectorize(std::move(after_fuse), constraints); + return after_split; +} + +/*! + * \brief Get the index mapping of a specific stmt. + * The stmt is like: + * for i0: + * ... + * for in: + * A[f(i0, ..., in])] = B[i0, ..., in], + * where f is the index mapping we want to get. + * \param constraints The constraints, including the write region that is required to calculate + * the index mapping + * \return The mapping in the form of j0, ..., jm, where j0, ... jm = f(i0, ..., in) + */ +Array GetMapping(const Stmt& stmt, const ConstraintSet& constraints) { + Stmt body = stmt; + while (const ForNode* loop = body.as()) { + body = loop->body; + } + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + BufferRegion write_region = constraints.write_region; + const Array& write_index = buf_store->indices; + ICHECK(write_region->region.size() == write_index.size() && + write_region->buffer.same_as(buf_store->buffer)); + Array result; + arith::Analyzer analyzer; + for (int i = 0; i < static_cast(write_region->region.size()); i++) { + PrimExpr pattern = analyzer.Simplify(write_index[i] - write_region->region[i]->min); + if (!is_zero(pattern)) { + result.push_back(pattern); + } + } + return result; +} + +Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body = stmt; + Map var_range; + Array loop_vars; + // Step 1. Get index mapping + Array mapping_pattern = GetMapping(stmt, constraints); + while (const ForNode* loop = body.as()) { + var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + loop_vars.push_back(loop->loop_var); + body = loop->body; + } + // Step 2. Get Inverse mapping + arith::Analyzer analyzer; + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); + Array iter_map = + arith::DetectIterMap(mapping_pattern, var_range, Bool(true), true, &analyzer, diag_ctx); + CHECK_EQ(iter_map.size(), loop_vars.size()); + Map inverse_mapping = arith::InverseAffineIterMap(iter_map, loop_vars); + // Step 3. Generate new body + BufferRegion read_region = constraints.read_region; + BufferRegion write_region = constraints.write_region; + Array write_index; + Array read_index; + Array new_loop_vars; + Map substitute_map; + // Step 3.1 construct target buffer indices + for (int i = 0, j = 0; i < static_cast(write_region->region.size()); i++) { + if (is_one(write_region->region[i]->extent)) { + write_index.push_back(write_region->region[i]->min); + } else { + Var var = runtime::Downcast(loop_vars[j]).copy_with_suffix("_inverse"); + new_loop_vars.push_back(var); + substitute_map.Set(runtime::Downcast(loop_vars[j++]), var); + write_index.push_back(write_region->region[i]->min + var); + } + } + // Step 3.2 construct source buffer indices + for (int i = 0, j = 0; i < static_cast(read_region->region.size()); i++) { + if (is_one(read_region->region[i]->extent)) { + read_index.push_back(read_region->region[i]->min); + } else { + read_index.push_back( + read_region->region[i]->min + + Substitute(inverse_mapping[Downcast(loop_vars[j++])], substitute_map)); + } + } + BufferLoad new_buf_load = BufferLoad(read_region->buffer, read_index); + BufferStore new_buf_store = BufferStore(write_region->buffer, new_buf_load, write_index); + Stmt ret = new_buf_store; + // Step 3.3 construct loop body + for (int i = static_cast(new_loop_vars.size()) - 1; i >= 0; i--) { + PrimExpr extent = write_region->region[i]->extent; + ret = For(new_loop_vars[i], 0, extent, ForKind::kSerial, std::move(ret)); + } + return ret; +} +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_intermediate_stage.cc b/src/tir/transforms/memhammer_intermediate_stage.cc new file mode 100644 index 000000000000..4ffffc9fdeab --- /dev/null +++ b/src/tir/transforms/memhammer_intermediate_stage.cc @@ -0,0 +1,428 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +Stmt CopyLoopChain(const std::vector loops, const Stmt& inner_body, int ith = -1, + Stmt* ith_loop = nullptr) { + Stmt ret = inner_body; + for (int i = static_cast(loops.size() - 1); i >= 0; i--) { + ObjectPtr new_loop = make_object(*loops[i]); + new_loop->body = ret; + ret = For(new_loop); + if (ith == i) { + *ith_loop = ret; + } + } + return ret; +} + +/*! + * \brief lift all the thread binding loops + * \param stmt the top loop + * \return a pair. The first is the transformed stmt. + * The second is the lowest thread binding loop. + */ +std::pair> LiftThreadBindingLoops(Stmt stmt) { + std::vector normal_loops; + std::vector thread_binding_loops; + Stmt body = stmt; + while (const ForNode* loop = body.as()) { + if (loop->kind == ForKind::kThreadBinding) { + thread_binding_loops.push_back(loop); + } else { + normal_loops.push_back(loop); + } + body = loop->body; + } + body = CopyLoopChain(normal_loops, std::move(body)); + For compute_location{nullptr}; + body = CopyLoopChain(thread_binding_loops, std::move(body), + static_cast(thread_binding_loops.size()) - 1, &compute_location); + return std::make_pair(body, compute_location); +} + +/*! + * \brief Analyze the access pattern for buffer rank promotion. + * Rank promotion is a transformation that reshapes the buffer + * but doesn't change its underlying data layout. + * After the reshape, we expect that all dimensions of the access indices + * will be in the form of floormod(floordiv(x, a), b). + * Rank promotion removes strided access, thus enabling further buffer compacting + */ +class IndexPatternFinder : public ExprVisitor { + public: + IndexPatternFinder(const Map& var_range, Array* resulting_index) + : var_range_(var_range), resulting_index_(resulting_index) {} + + /*! + * \brief Calculate the new buffer shape after rank promotion. + * For each dimension of original shape, it will be split into multiple parts. + * The inner array represents the multiple parts of one original dimension, + * and the outer array represents the original dimensions + * For example, original shape [4, 8] may be split into [[2, 2], [2, 4]] + * \param indices The access indices of the buffer + * \param var_range The iter range of the vars in the indices + * \param rewrite_indices The access indices after rank promotion + * \return The new buffer shape after rank promotion. + */ + static Array> GetRankPromotedShape(Array indices, + const Map& var_range, + Array* rewrite_indices) { + Map var_dom = AsIntSet(var_range); + Array> new_shape; + for (const PrimExpr& expr : indices) { + IndexPatternFinder extractor(var_range, rewrite_indices); + arith::IntSet intset = arith::EvalSet(expr, var_dom); + extractor.mod_ = intset.max() + 1; + extractor.div_ = 1; + extractor.offset_ = 0; + extractor(expr); + Array access_shape = extractor.access_shape_; + for (int i = static_cast(access_shape.size()) - 1; i >= 1; i--) { + if (!is_zero(floormod(extractor.offset_, access_shape[i]))) { + return {}; + } else { + extractor.offset_ = floordiv(extractor.offset_, access_shape[i]); + } + } + access_shape.Set(0, extractor.offset_ + access_shape[0]); + new_shape.push_back(access_shape); + } + return new_shape; + } + + private: + void VisitExpr_(const VarNode* op) final { + arith::Analyzer analyzer; + PrimExpr extent = var_range_[GetRef(op)]->extent; + PrimExpr access_iter_range = min(mod_, (max(1, floordiv(extent, div_)))); + if (!analyzer.CanProveEqual(1, access_iter_range)) { + access_shape_.push_back(access_iter_range); + resulting_index_->push_back(floormod(floordiv(GetRef(op), div_), mod_)); + } + } + + void VisitExpr_(const FloorDivNode* op) final { + PrimExpr old_div = div_; + div_ *= op->b; + ExprVisitor::VisitExpr_(op); + div_ = old_div; + } + + void VisitExpr_(const FloorModNode* op) final { + PrimExpr old_mod = mod_; + mod_ = max(1, min(floordiv(op->b, div_), mod_)); + ExprVisitor::VisitExpr_(op); + mod_ = old_mod; + } + + void VisitExpr_(const MulNode* op) final { + PrimExpr old_mod = mod_; + PrimExpr old_div = div_; + div_ = max(1, floordiv(div_, op->b)); + mod_ = max(1, floordiv(mod_, floordiv(op->b, floordiv(old_div, div_)))); + ExprVisitor::VisitExpr_(op); + mod_ = old_mod; + div_ = old_div; + } + + void VisitExpr_(const AddNode* op) final { + if (is_const_int(op->b)) { + offset_ += floormod(floordiv(op->b, div_), mod_); + } + ExprVisitor::VisitExpr_(op); + } + + PrimExpr div_; + PrimExpr mod_; + PrimExpr offset_; + Map var_range_; + Array access_shape_; + Array* resulting_index_; +}; + +/*! + * \brief Utilities to perform rank promotion + */ +class RankPromoter : public StmtExprMutator { + public: + /*! + * \brief Flatten the buffer shape like performing inverse rank promotion. + * For example, [[i0, i1], [j0, j1]] to [i0 * i1, j0 * j1] + * \param new_shape The buffer shape in the special form as returned by GetRankPromotedShape + * \return The buffer shape after flatten + */ + static Array FlattenNewShape(const Array>& new_shape) { + Array ret; + ret.reserve(new_shape.size()); + for (int i = 0; i < static_cast(new_shape.size()); i++) { + PrimExpr prod = 1; + for (int j = 0; j < static_cast(new_shape[i].size()); j++) { + prod *= new_shape[i][j]; + } + ret.push_back(prod); + } + return ret; + } + /** + * \brief Rewrite the index given the shape after rank promotion + * \param indices The original indices + * \param new_shape The buffer shape after rank promotion + * \return The new indices + */ + static Array RewriteIndex(const Array& indices, + const Array>& new_shape) { + Array new_indices; + ICHECK_EQ(indices.size(), new_shape.size()); + for (int i = 0; i < static_cast(indices.size()); i++) { + PrimExpr index = indices[i]; + // The indices transformed from one original dimension + Array index_dim(new_shape[i].size(), 0); + for (int j = static_cast(new_shape[i].size()) - 1; j >= 0; j--) { + index_dim.Set(j, floormod(index, new_shape[i][j])); + index = floordiv(index, new_shape[i][j]); + } + for (int j = 0; j < static_cast(new_shape[i].size()); j++) { + new_indices.push_back(index_dim[j]); + } + } + return new_indices; + } + /*! + * \brief Rewrite the index after buffer flattening + * \param indices The original indices + * \param new_shape The shape before buffer flattening + * \return The indices after buffer flattening + */ + static Array RewriteBackIndex(const Array& indices, + const Array>& new_shape) { + Array new_indices; + int offset = 0; + for (int i = 0; i < static_cast(new_shape.size()); i++) { + PrimExpr index = 0; + for (int j = 0; j < static_cast(new_shape[i].size()); j++) { + index *= new_shape[i][j]; + index += indices[offset + j]; + } + new_indices.push_back(index); + offset += new_shape[i].size(); + } + return new_indices; + } + RankPromoter(const Buffer& src, const Buffer& dst, const Array>& new_shape, + const Array>& relaxed_new_shape, const Array& relaxed_region) + : src_(src), + dst_(dst), + new_shape_(new_shape), + relaxed_new_shape_(relaxed_new_shape), + relaxed_region_(relaxed_region) {} + + static Stmt RewriteBody(Stmt stmt, const Buffer& src, const Buffer& dst, + const Array>& new_shape, + const Array>& relaxed_new_shape, + const Array& relaxed_region) { + RankPromoter promoter(src, dst, new_shape, relaxed_new_shape, relaxed_region); + return promoter(stmt); + } + + private: + Stmt VisitStmt_(const BufferStoreNode* _store) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); + if (store->buffer.same_as(src_)) { + ObjectPtr new_store = make_object(*store.get()); + new_store->buffer = dst_; + new_store->indices = ConvertIndices(new_store->indices); + return BufferStore(new_store); + } + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + if (load->buffer.same_as(src_)) { + ObjectPtr new_load = make_object(*load.get()); + new_load->buffer = dst_; + new_load->indices = ConvertIndices(new_load->indices); + return BufferLoad(new_load); + } + return std::move(load); + } + + /*! + * \brief Rewrite the indices after performing buffer rank promotion + + * buffer compacting + buffer flattening. + * \param indices The original indices + * \return The indices after these transformations + */ + Array ConvertIndices(const Array& indices) { + Array rewrite_indices = RewriteIndex(indices, new_shape_); + arith::Analyzer analyzer; + for (int i = 0; i < static_cast(rewrite_indices.size()); i++) { + rewrite_indices.Set(i, analyzer.Simplify(rewrite_indices[i] - relaxed_region_[i]->min)); + } + return RewriteBackIndex(rewrite_indices, relaxed_new_shape_); + } + + const Buffer& src_; + const Buffer& dst_; + Array> new_shape_; + Array> relaxed_new_shape_; + Array relaxed_region_; +}; + +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, + Optional compute_location, + const Array& outer_loops, Buffer* alloc_buffer) { + Stmt body = stmt; + std::vector loops; + bool need_relax = !compute_location.defined(); + Map relax_var_range; + Map all_var_range; + PrimExpr vector_bytes = -1; + // Step 1. Perform rank promotion on the buffer access, turning a strided-changing dimension into + // several contiguous-changing dimensions + // Step 1.1 collect loop var range for rank promotion + while (const ForNode* loop = body.as()) { + if (need_relax) { + relax_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } else { + loops.push_back(loop); + } + all_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + if (loop == compute_location.get()) { + need_relax = true; + } + if (loop->kind == ForKind::kVectorized) { + vector_bytes = loop->extent; + } + body = loop->body; + } + for (const For& loop : outer_loops) { + if (loop->kind == ForKind::kThreadBinding) { + const String& thread_tag = loop->thread_binding.value()->thread_tag; + if (CanRelaxStorageUnderThread(runtime::StorageScope::Create(storage_scope), + runtime::ThreadScope::Create(thread_tag))) { + relax_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + } + all_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + // TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate + const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode); + Buffer orig_buffer = is_write_cache ? buf_store->buffer : buf_load->buffer; + Array indices = is_write_cache ? buf_store->indices : buf_load->indices; + // Step 1.2 get the new shape and new access indices after rank promotion + Array rewrite_indices; + Array> new_shape = + IndexPatternFinder::GetRankPromotedShape(indices, all_var_range, &rewrite_indices); + // Step 2. relax the access region after rank promotion + arith::Analyzer analyzer; + analyzer.Bind(all_var_range); + Array relaxed_region; + relaxed_region.reserve(rewrite_indices.size()); + { + Map relax_var_intset = AsIntSet(relax_var_range); + for (const PrimExpr& index : rewrite_indices) { + arith::IntSet int_set = arith::EvalSet(index, relax_var_intset); + relaxed_region.push_back(Range::FromMinExtent( + int_set.min(), analyzer.Simplify(int_set.max() - int_set.min() + 1))); + } + } + // Step 3. generate the data copy bodies + // preparation work + Array new_loop_vars; + Array orig_buf_indices, new_buf_indices; + Array> relaxed_new_shape; + for (int i = 0; i < static_cast(relaxed_region.size()); i++) { + Var new_loop_var = Var("ax" + std::to_string(i)); + new_loop_vars.push_back(new_loop_var); + orig_buf_indices.push_back(relaxed_region[i]->min + new_loop_var); + new_buf_indices.push_back(new_loop_var); + } + relaxed_new_shape.reserve(new_shape.size()); + for (int i = 0, ct = 0; i < static_cast(new_shape.size()); i++) { + Array layer; + for (int j = 0; j < static_cast(new_shape[i].size()); j++, ct++) { + layer.push_back(relaxed_region[ct]->extent); + } + relaxed_new_shape.push_back(layer); + } + // Step 3.1 create a buffer for the cache + Buffer new_buffer = WithScope(orig_buffer, storage_scope); + new_buffer.CopyOnWrite()->shape = RankPromoter::FlattenNewShape(relaxed_new_shape); + *alloc_buffer = new_buffer; + Array real_orig_buf_indices = + RankPromoter::RewriteBackIndex(orig_buf_indices, new_shape); + Array real_new_buf_indices = + RankPromoter::RewriteBackIndex(new_buf_indices, relaxed_new_shape); + // Step 3.2 generate a body that writes to the cache + Stmt generate_body = is_write_cache + ? BufferStore(orig_buffer, BufferLoad(new_buffer, real_new_buf_indices), + real_orig_buf_indices) + : BufferStore(new_buffer, BufferLoad(orig_buffer, real_orig_buf_indices), + real_new_buf_indices); + for (int i = static_cast(relaxed_region.size()) - 1; i >= 0; i--) { + if (i == static_cast(relaxed_region.size()) - 1 && !is_const_int(vector_bytes, -1)) { + ICHECK(analyzer.CanProve(vector_bytes == relaxed_region[i]->extent)); + generate_body = + For(new_loop_vars[i], 0, relaxed_region[i]->extent, ForKind::kVectorized, generate_body); + } else { + generate_body = + For(new_loop_vars[i], 0, relaxed_region[i]->extent, ForKind::kSerial, generate_body); + } + } + // Step 3.3 rewrite the original body to load from cache + Stmt rewrite_body; + if (compute_location.defined()) { + rewrite_body = compute_location.value()->body; + } else { + rewrite_body = stmt; + } + rewrite_body = RankPromoter::RewriteBody(rewrite_body, orig_buffer, new_buffer, new_shape, + relaxed_new_shape, relaxed_region); + SeqStmt insert_location; + if (is_write_cache) { + generate_body = insert_location = SeqStmt({rewrite_body, generate_body}); + } else { + generate_body = insert_location = SeqStmt({generate_body, rewrite_body}); + } + generate_body = CopyLoopChain(loops, generate_body); + return std::make_pair(generate_body, insert_location); +} + +Stmt CreateLocalStage::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body; + Optional compute_location; + std::tie(body, compute_location) = LiftThreadBindingLoops(std::move(stmt)); + Buffer cache_buffer; + Stmt after_caching = InsertCacheStage(body, false, "local", compute_location, + constraints.outer_loops, &cache_buffer) + .first; + output->alloc_buffer.push_back(cache_buffer); + return after_caching; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_lower_auto_copy.cc b/src/tir/transforms/memhammer_lower_auto_copy.cc new file mode 100644 index 000000000000..a0103aab380b --- /dev/null +++ b/src/tir/transforms/memhammer_lower_auto_copy.cc @@ -0,0 +1,763 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "../../runtime/thread_storage_scope.h" +#include "../schedule/utils.h" +#include "./ir_utils.h" +#include "./memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +// rewrite rules +static InverseMapping inverse_mapping; +static CoalescedAccess coalesced_access; +static CreateLocalStage create_local_stage; +static SharedToWmma shared_to_wmma; +static WmmaToGlobal wmma_to_global; +static WmmaToShared wmma_to_shared; + +/*! + * \brief A class to perform auto padding. + * + * One simple way to perform auto padding is to fix each padding size for each dimension at the + * same time, calculate the precise access index and the bank conflict, + * and choose the one with minimal conflict. However, this algorithm has exponential complexity. + * Suppose we have d dimensions and the padding size is 0-31, we need to calculate bank + * conflict for 32^{d-1} times. + * We propose a fast incremental algorithm that works for affine inputs, and it only calculate + * bank conflict for 32*{d-1} times. To be specific, we first decide the optimal padding size for + * dimension d-2, then for dimension d-3, ..., finally for dimension 0. It involves 2 steps. + * + * First, we analyze how a typical warp accesses the shared memory banks. + * A typical warp means setting all irrelevant loop vars to 0, and only keeps the threads in a warp. + * For each dimension, the access index is represented by + * x_1 * scale_1 + ... + x_n * scale_n (x_i is loop var) + * Note: The affine property guarantees that {x_i} must be independent, + * otherwise the algorithm is wrong. + * We will use this information to keep a list for each dimension called "iteration space" that + * records the resulting index as x_i takes each possible value. + * + * For example, the index is [outer*2+ty, tx*4+vec], where tx is threadIdx.x, and ty is threadIdx.y. + * tx is in [0, 16), and ty is in [0, 2). + * We will first get a warp access [ty, tx*4] because outer and vec are irrelevant loop vars. + * It's obvious that ty, tx*4 are both in the form of x_1 * scale_1 + ... + x_n * scale_n. + * In this case, we will keep lists {{0, 1}, {0, 4, ..., 60}} + * + * Next, we choose a padding size that has minimal conflict from the last dimension to first one. + * To calculate the conflict, we calculate the Cartesian product of the iteration space of all + * dimensions not higher than this. Each single point of product space represents access index + * of a particular thread, by which we can calculate the accessed memory bank. The conflict is + * the highest access frequency among the banks. + * + */ +class AutoPadder { + public: + /** + * \brief Do padding to the given buffers in shard memory + * \param buffers the given buffers + * \return the list of new padded buffers + */ + Array PadSharedMemory(const Array& buffers) { + Array result; + + for (const Buffer& buffer : buffers) { + runtime::StorageScope scope = runtime::StorageScope::Create(buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + auto iter_spaces = iter_spaces_[buffer.get()]; + if (iter_spaces.empty()) { + result.push_back(buffer); + continue; + } + // The access index represented by points in the cartesian product of lower dimension + // iteration spaces + std::vector> low_dim_iter_space(iter_spaces.size(), std::vector()); + + int n = buffer->shape.size(); + int data_bits = buffer->dtype.bits(); + // Step 1. initialize `low_dim_iter_space` with the iteration space of the last dim + for (int i = 0; i < static_cast(iter_spaces.size()); i++) { + auto last_dim_iter_space = iter_spaces[i][n - 1]; + low_dim_iter_space[i] = last_dim_iter_space; + } + PrimExpr stride = 1; + Array reverse_strides; + int pad_min = padding_min_.Get(buffer).value_or(Integer(1)); + // Step 2. For each dimension, select a padding that has minimal bank conflict + for (int k = n - 2; k >= 0; k--) { // dims + int max_pad_size = std::min( + int(max_pad_factor_ * (stride * buffer->shape[k + 1]).as()->value), + 32 * 32 / data_bits); + int min_conflict = INT32_MAX; + int min_conflict_pad = -1; + for (int pad = 0; pad <= max_pad_size; pad += pad_min) { // select padding + int padded_stride = ((stride * buffer->shape[k + 1]).as()->value + pad) % + (32 * 32 / data_bits); + int conflict = 0; + for (int i = 0; i < static_cast(iter_spaces.size()); i++) { // accesses + auto iter_space = iter_spaces[i][k]; + int bank[32]{0}; + for (int v1 : iter_space) { + for (int v2 : low_dim_iter_space[i]) { + int comb = (v1 * padded_stride + v2) * data_bits / 32 % 32; + bank[comb]++; + } + } + for (int j = 0; j < 32; j++) { + conflict = std::max(conflict, bank[j]); + } + } + if (conflict < min_conflict) { + min_conflict = conflict; + min_conflict_pad = pad; + } + } + // update low_dim_iter_space with + for (int i = 0; i < static_cast(iter_spaces.size()); i++) { // accesses + auto iter_space = iter_spaces[i][k]; + if (!iter_space.empty()) { + int padded_stride = + ((stride * buffer->shape[k + 1]).as()->value + min_conflict_pad) % + (32 * 32 / data_bits); + std::vector span; + for (int v1 : iter_space) { + for (int v2 : low_dim_iter_space[i]) { + span.push_back(((v1 * padded_stride + v2) * data_bits) % (32 * 32 / data_bits)); + } + } + low_dim_iter_space[i] = span; + } else { + ICHECK(min_conflict_pad == 0); + } + } + stride = stride * buffer->shape[k + 1] + min_conflict_pad; + reverse_strides.push_back(stride); + } + // Step 3. create the new padded buffer + ObjectPtr b = make_object(*buffer.get()); + Array strides; + for (int i = static_cast(reverse_strides.size()) - 1; i >= 0; i--) { + strides.push_back(reverse_strides[i]); + } + strides.push_back(1); + b->strides = strides; + Buffer new_buffer(b); + result.push_back(new_buffer); + padded_buffer_map_.Set(buffer, new_buffer); + } else { + result.push_back(buffer); + } + } + return result; + } + + /** + * \brief Replace all occurrence of the old buffer with the new buffer in the stmt + * \param stmt the stmt to do replacement + * \return the stmt after replacement + */ + Stmt RewriteBufferAccess(const Stmt& stmt) { + class Rewriter : public StmtExprMutator { + public: + Rewriter(const Map& buffer_map) : buffer_map_(buffer_map) {} + + private: + PrimExpr VisitExpr_(const BufferLoadNode* _op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); + BufferLoadNode* op = load.CopyOnWrite(); + if (buffer_map_.count(op->buffer)) { + op->buffer = buffer_map_[op->buffer]; + } + return std::move(load); + } + + Stmt VisitStmt_(const BufferStoreNode* _op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); + BufferStoreNode* op = store.CopyOnWrite(); + if (buffer_map_.count(op->buffer)) { + op->buffer = buffer_map_[op->buffer]; + } + return std::move(store); + } + + Stmt VisitStmt_(const BlockNode* op) final { + // To reduce the number of blocks in block sref reuse map, we check whether the block is + // really mutated (i.e., the old buffer appears in the block). If so, we return the block + // after mutation. Otherwise we just return the original block. + bool changed = false; + // Step 1. Mutate the read region. + Array reads; + for (const BufferRegion& read : op->reads) { + if (buffer_map_.count(read->buffer)) { + changed = true; + reads.push_back(BufferRegion(buffer_map_[read->buffer], read->region)); + } else { + reads.push_back(read); + } + } + // Step 2. Mutate the write region. + Array writes; + for (const BufferRegion& write : op->writes) { + if (buffer_map_.count(write->buffer)) { + changed = true; + writes.push_back(BufferRegion(buffer_map_[write->buffer], write->region)); + } else { + writes.push_back(write); + } + } + // Step 4. Mutate `match_buffers`. If an old buffer appears as a source of + // MatchBufferRegion, the storage scope of the target buffer also needs to be set. + Array match_buffers; + for (const MatchBufferRegion& match_buffer : op->match_buffers) { + if (buffer_map_.count(match_buffer->source->buffer)) { + changed = true; + Buffer new_buffer = buffer_map_[match_buffer->source->buffer]; + match_buffers.push_back(MatchBufferRegion( + match_buffer->buffer, BufferRegion(new_buffer, match_buffer->source->region))); + } else { + match_buffers.push_back(match_buffer); + } + } + // Step 5. Recursively mutate the block. + Stmt res = StmtMutator::VisitStmt_(op); + if (res.get() != op) { + changed = true; + } + + if (changed) { + ObjectPtr block = CopyOnWrite(res.as()); + block->reads = std::move(reads); + block->writes = std::move(writes); + block->match_buffers = std::move(match_buffers); + return Stmt(block); + } else { + return GetRef(op); + } + } + const Map& buffer_map_; + }; + Rewriter rewriter(padded_buffer_map_); + return rewriter(stmt); + } + + /** + * \brief an equivalent of scale * loop_var with loop_var: {min=0, extent=extent} + */ + struct Pattern { + int extent; + int scale; + }; + + /** + * \brief Collect pattern from indices + */ + class PatternCollector : public StmtExprVisitor { + void VisitExpr_(const VarNode* op) final { + if (!success_) { + return; + } + int extent = var_range_[GetRef(op)]->extent.as()->value; + if (extent > 1) { + stack_.push({{extent, 1}}); + } else { + stack_.push({}); + } + } + + void VisitExpr_(const AddNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector merged_patterns; + std::vector r = stack_.top(); + stack_.pop(); + std::vector l = stack_.top(); + stack_.pop(); + for (const Pattern& pattern : l) { + merged_patterns.push_back(pattern); + } + for (const Pattern& pattern : r) { + merged_patterns.push_back(pattern); + } + if (merged_patterns.empty()) { + stack_.push({}); + return; + } + std::vector ret; + ret.push_back(merged_patterns[0]); + for (int i = 0; i < static_cast(merged_patterns.size()); i++) { + Pattern prev_pattern = ret.back(); + if (merged_patterns[i].extent * merged_patterns[i].scale == prev_pattern.scale) { + ret.pop_back(); + ret.push_back( + {prev_pattern.extent * merged_patterns[i].extent, merged_patterns[i].scale}); + } + } + stack_.push(ret); + } + + void VisitExpr_(const FloorDivNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector inner = stack_.top(); + stack_.pop(); + int lower_factor = op->b.as()->value; + std::vector ret; + for (const Pattern& pattern : inner) { + if (pattern.scale >= lower_factor) { + if (pattern.scale % lower_factor == 0) { + ret.push_back({pattern.extent, pattern.scale / lower_factor}); + } else { + success_ = false; + } + } else if (pattern.scale * pattern.extent > lower_factor) { + if ((pattern.scale * pattern.extent) % lower_factor == 0) { + ret.push_back({pattern.extent * pattern.scale / lower_factor, 1}); + } else { + success_ = false; + } + } + } + stack_.push(ret); + } + + void VisitExpr_(const FloorModNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector inner = stack_.top(); + stack_.pop(); + int extent = op->b.as()->value; + std::vector ret; + for (const Pattern& pattern : inner) { + if (pattern.scale < extent) { + if (extent % pattern.scale == 0) { + if (extent / pattern.scale < pattern.extent) { + ret.push_back({extent / pattern.scale, pattern.scale}); + } else { + ret.push_back({pattern.extent, pattern.scale}); + } + } else { + success_ = false; + } + } + } + stack_.push(ret); + } + + void VisitExpr_(const MulNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector inner = stack_.top(); + stack_.pop(); + int scale = op->b.as()->value; + std::vector ret; + for (const Pattern& pattern : inner) { + ret.push_back({pattern.extent, pattern.scale * scale}); + } + stack_.push(ret); + } + + public: + PatternCollector(const Map& var_range) : var_range_(var_range) {} + + /*! + * \brief Collect the iteration space for given indices. The iteration space is the possible + * values that an index can take (do not remove duplicate). + * For example, the input is [ty, tx*4], where tx is in [0, 16), and ty is in [0, 2). + * The output would be {{0, 1}, {0, 4, ..., 60}} + * \param indices The indices to analyze + * \param var_range The range of loop variables + * \param data_bits The size of dtype in bits + * \return The iteration space. The first array represents dimensions, and the second array + * represents the iteration space of one dimension + */ + static std::vector> CollectIterationSpace(const Array& indices, + const Map& var_range, + int data_bits) { + PatternCollector collector(var_range); + std::vector> ret; + for (int i = 0; i < static_cast(indices.size()); i++) { + collector(indices[i]); + if (collector.success_ && collector.stack_.size() == 1) { + auto patterns = collector.stack_.top(); + int extent_prod = 1; + for (const Pattern& p : patterns) { + extent_prod *= p.extent; + } + std::vector iter_space; + for (int thread_id = 0; thread_id < extent_prod; thread_id++) { + int index = 0; + int n = thread_id; + for (int j = static_cast(patterns.size()) - 1; j >= 0; j--) { + int val = n % patterns[j].extent; + index += val * patterns[j].scale; + n /= patterns[j].extent; + } + iter_space.push_back(index); + } + + ret.push_back(iter_space); + collector.stack_.pop(); + } else { + ret.push_back({}); + } + } + return ret; + } + + std::stack> stack_; + const Map& var_range_; + bool success_ = true; + }; + + /*! A utility class for calling CollectIterationSpace to each buffer access*/ + class IterSpaceAnalyzer : public StmtExprVisitor { + public: + IterSpaceAnalyzer(const Map& substitute_map, AutoPadder* self, int data_bits, + const Map warp_thread_extent) + : substitute_map_(substitute_map), + self(self), + data_bits_(data_bits), + warp_thread_extent_(warp_thread_extent) {} + + private: + bool CheckVarContiguous(PrimExpr e, Var var) { + PrimExpr e1 = Substitute(e, [var](const Var& v) -> Optional { + if (v.same_as(var)) { + return Integer(0); + } else { + return v; + } + }); + PrimExpr e2 = Substitute(e, [var](const Var& v) -> Optional { + if (v.same_as(var)) { + return Integer(1); + } else { + return v; + } + }); + arith::Analyzer analyzer; + return analyzer.CanProve(e2 - e1 == 1); + } + + void VisitStmt_(const ForNode* op) final { + if (op->kind != ForKind::kThreadBinding) { + substitute_map_.Set(op->loop_var, op->min); + } else { + Integer extent = + warp_thread_extent_.Get(op->thread_binding.value()->thread_tag).value_or(1); + var_range_.Set(op->loop_var, Range::FromMinExtent(op->min, extent)); + } + if (op->kind == ForKind::kVectorized) { + vector_var = op->loop_var; + vector_length_ = op->extent.as()->value; + } + StmtExprVisitor::VisitStmt_(op); + if (op->kind == ForKind::kVectorized) { + vector_length_ = -1; + } + if (op->kind != ForKind::kThreadBinding) { + substitute_map_.erase(op->loop_var); + } + } + /*! + * \brief Take a typical warp and collect the iteration space for buffer store + * For example, the access is A[outer*2+ty, tx*4+vec] = xxx, where tx is threadIdx.x, and ty is + * threadIdx.y. tx is in [0, 16), and ty is in [0, 2). + * The iteration space would be {{0, 1}, {0, 4, ..., 60}}. + * \param op the buffer store + */ + void VisitStmt_(const BufferStoreNode* op) final { + runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + Array substitued_indices; + arith::Analyzer analyzer; + for (const PrimExpr& e : op->indices) { + substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + } + std::vector> iter_space = + PatternCollector::CollectIterationSpace(substitued_indices, var_range_, data_bits_); + if (!iter_space.empty()) { + self->iter_spaces_[op->buffer.get()].push_back(iter_space); + } + if (vector_length_ != -1 && CheckVarContiguous(substitued_indices.back(), vector_var)) { + Integer m = self->padding_min_.Get(op->buffer).value_or(1); + self->padding_min_.Set(op->buffer, Downcast(max(vector_length_, m))); + } + } + StmtExprVisitor::VisitStmt_(op); + } + /*! + * \brief Take a typical warp and collect the iteration space for buffer load + * For example, the access is xxx = A[outer*2+ty, tx*4+vec], where tx is threadIdx.x, and ty is + * threadIdx.y. tx is in [0, 16), and ty is in [0, 2). + * The iteration space would be {{0, 1}, {0, 4, ..., 60}}. + * \param op the buffer load + */ + void VisitExpr_(const BufferLoadNode* op) final { + runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + Array substitued_indices; + arith::Analyzer analyzer; + for (const PrimExpr& e : op->indices) { + substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + } + std::vector> iter_space = + PatternCollector::CollectIterationSpace(substitued_indices, var_range_, data_bits_); + if (!iter_space.empty()) { + self->iter_spaces_[op->buffer.get()].push_back(iter_space); + } + if (vector_length_ != -1 && CheckVarContiguous(substitued_indices.back(), vector_var)) { + Integer m = self->padding_min_.Get(op->buffer).value_or(1); + self->padding_min_.Set(op->buffer, Downcast(max(vector_length_, m))); + } + } + StmtExprVisitor::VisitExpr_(op); + } + + /*! + * \brief Take a typical warp and collect the iteration space for load_matrix_sync and + * store_matrix_sync + * For example, the access region is A[y*16+16, x*16+16], where y and x are not bound to + * threadIdx. The iteration space would be {{0, 1, ..., 15}, {0, 1, ..., 15}}. + * \param op the call node + */ + void VisitStmt_(const BlockNode* op) final { + if (const auto* eval = op->body.as()) { + if (const auto* call = eval->value.as()) { + if (call->op == builtin::tvm_load_matrix_sync() || + call->op == builtin::tvm_store_matrix_sync()) { + for (const MatchBufferRegion& r : op->match_buffers) { + Buffer src_buffer = r->source->buffer; + runtime::StorageScope scope = runtime::StorageScope::Create(src_buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + Region region = r->source->region; + Array indices; + for (int i = 0; i < static_cast(region.size()); i++) { + Var var("region" + std::to_string(i)); + indices.push_back(region[i]->min + var); + var_range_.Set(var, Range::FromMinExtent(0, region[i]->extent)); + } + Array substitued_indices; + arith::Analyzer analyzer; + for (const PrimExpr& e : indices) { + substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + } + std::vector> iter_space = PatternCollector::CollectIterationSpace( + substitued_indices, var_range_, data_bits_); + if (!iter_space.empty()) { + self->iter_spaces_[src_buffer.get()].push_back(iter_space); + } + } + } + } + } + } + } + + Map substitute_map_; + AutoPadder* self; + int data_bits_; + Map warp_thread_extent_; + Map var_range_; + int vector_length_ = -1; + Var vector_var; + }; + + /*! + * \brief Analyze the shared memory access + * \param stmt The data copy + * \param outer_loops The outer loops of the stmt + * \param data_bits The length of dtype in bits + * \param thread_extent The extents of all thread binding loops + */ + void AnalyzeSharedMemoryAccess(const Stmt& stmt, const Array& outer_loops, int data_bits, + const Map& thread_extent) { + Map warp_thread_extent; + Integer prod = 1; + Array thread_tags{"threadIdx.x", "threadIdx.y", "threadIdx.z"}; + arith::Analyzer analyzer; + for (int i = 0; i < 3; i++) { + Integer extent = thread_extent.Get(thread_tags[i]).value_or(1); + if (analyzer.CanProve(prod * extent >= 32)) { + warp_thread_extent.Set(thread_tags[i], Downcast(floordiv(32, prod))); + prod *= floordiv(32, prod); + break; + } else { + warp_thread_extent.Set(thread_tags[i], Downcast(extent)); + prod *= extent; + } + } + Map substitute_map; + for (const For& loop : outer_loops) { + substitute_map.Set(loop->loop_var, loop->min); + } + IterSpaceAnalyzer iter_space_analyzer(substitute_map, this, data_bits, warp_thread_extent); + iter_space_analyzer(stmt); + } + + private: + /*! \brief A map from the old buffers to the new padded buffers */ + Map padded_buffer_map_; + /*! \brief A map from each buffer to the iteration spaces of the accesses*/ + std::unordered_map>>> iter_spaces_; + /*! \brief A map from each buffer to their minimal padding size */ + Map padding_min_; + /*! \brief max padding size in relative to the original shape*/ + const double max_pad_factor_ = 0.25; + + friend class AutoCopyMutator; +}; + +class AutoCopyMutator : public StmtExprMutator { + public: + explicit AutoCopyMutator(Map thread_extent) : thread_extent_(thread_extent) {} + /** + * \brief Replace old buffers with padded buffers in the stmt + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ + Stmt RewritePaddingBody(const Stmt& stmt) { return padder.RewriteBufferAccess(stmt); } + + private: + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtMutator::VisitStmt_(op)); + // only rewrite the block annotated with "auto_copy" + if (GetAnn(op, "auto_copy").value_or(0)->value == 0) { + BlockNode* n = block.CopyOnWrite(); + n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers)); + return std::move(block); + } + ICHECK_EQ(block->reads.size(), 1); + ICHECK_EQ(block->writes.size(), 1); + int data_bits = block->reads[0]->buffer->dtype.bits(); + ConstraintSet constraints(this->thread_extent_, // + this->outer_loops_, // + block->reads[0], // + block->writes[0], // + data_bits, // + block->annotations); + BlockNode* n = block.CopyOnWrite(); + OutputSet outputs; + for (RewriteRule* rule : rules) { + n->body = rule->Apply(std::move(n->body), constraints, &outputs); + } + for (const Buffer& buffer : outputs.alloc_buffer) { + n->alloc_buffers.push_back(buffer); + } + for (const auto& p : outputs.padding_min) { + Integer m = padder.padding_min_.Get(p.first).value_or(1); + padder.padding_min_.Set(p.first, Downcast(max(p.second, m))); + } + padder.AnalyzeSharedMemoryAccess(block->body, outer_loops_, data_bits, thread_extent_); + n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers)); + return std::move(block); + } + + Stmt VisitStmt_(const ForNode* op) final { + outer_loops_.push_back(GetRef(op)); + Stmt stmt = StmtMutator::VisitStmt_(op); + outer_loops_.pop_back(); + return stmt; + } + + /*! \brief Thread extents collected. */ + Map thread_extent_; + /*! \brief The outer loops during recursive visit */ + Array outer_loops_; + /*! \brief Calculating optimal padding size */ + AutoPadder padder; + + /*! \brief All rewrite rules. */ + const std::array rules = { + &inverse_mapping, // + &coalesced_access, // + &create_local_stage, // + &shared_to_wmma, // + &wmma_to_global, // + &wmma_to_shared, + }; +}; + +/*! + * \brief Collect the extent for all thread binding loops. + */ +class ThreadExtentCollector : public StmtVisitor { + public: + static Map CollectThreadExtent(const Stmt& stmt) { + ThreadExtentCollector collector; + collector(stmt); + return collector.thread_extent_; + } + + private: + void VisitStmt_(const BlockNode* op) final { + if (Optional warp_execution = GetAnn(op, "warp_execution")) { + if (warp_execution.value()->value != 0) { + thread_extent_.Set("threadIdx.x", Integer(32)); + } + } + StmtVisitor::VisitStmt_(op); + } + void VisitStmt_(const ForNode* op) final { + if (op->thread_binding.defined() && op->thread_binding.value()->iter_type == kThreadIndex) { + thread_extent_.Set(op->thread_binding.value()->thread_tag, Downcast(op->extent)); + } + StmtVisitor::VisitStmt_(op); + } + + /*! \brief the map from thread tag to its extent */ + Map thread_extent_; +}; + +namespace transform { + +Pass LowerAutoCopy() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + AutoCopyMutator mutator(ThreadExtentCollector::CollectThreadExtent(n->body)); + n->body = mutator(std::move(n->body)); + n->body = mutator.RewritePaddingBody(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerAutoCopy", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerAutoCopy").set_body_typed(LowerAutoCopy); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_rewrite_rule.h b/src/tir/transforms/memhammer_rewrite_rule.h new file mode 100644 index 000000000000..1cb0ea496a03 --- /dev/null +++ b/src/tir/transforms/memhammer_rewrite_rule.h @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "../schedule/utils.h" + +namespace tvm { +namespace tir { + +/*! \brief The set containing all possible constraints of a data copy */ +struct ConstraintSet { + /*! \brief The extents of the thread binding loops */ + Map thread_extent; + /*! \brief The outer loops surrounding the data copy */ + Array outer_loops; + /*! \brief The read region of the data copy */ + BufferRegion read_region; + /*! \brief The write region of the data copy */ + BufferRegion write_region; + /*! \brief The dtype size in bits */ + int data_bits; + /*! \brief Whether to insert a local stage in the data copy */ + int add_local_stage = 0; + /*! \brief The vectorization length in bytes */ + int vector_bytes = 1; + + explicit ConstraintSet(Map thread_extent, // + Array outer_loops, // + BufferRegion read_region, // + BufferRegion write_region, // + int data_bits, // + const Map& ann) + : thread_extent(thread_extent), + outer_loops(outer_loops), + read_region(read_region), + write_region(write_region), + data_bits(data_bits) { + if (Optional add_local_stage = ann.Get("local_stage")) { + this->add_local_stage = Downcast(add_local_stage.value())->value; + } + if (Optional vector_bytes = ann.Get("vector_bytes")) { + this->vector_bytes = Downcast(vector_bytes.value())->value; + } + } +}; + +/*! \brief The set containing all possible outputs of a rewrite rule */ +struct OutputSet { + /*! \brief New buffers allocated after rewrite */ + Array alloc_buffer; + /*! \brief The minimal padding size of a buffer in base 2 logarithm */ + Map padding_min; +}; + +/*! + * \brief Rules to rewrite a data copy. + */ +class RewriteRule { + protected: + /* RewriteRule() = default; */ + /*! + * \brief Rewrite the stmt under certain constraints + * \param stmt The stmt + * \param constraints The constraints of the rewrite + * \param output Some additional information that the rewrite rule produces. (including the new + * buffer to be allocated, etc.) + * \return the stmt after rewrite + */ + virtual Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const = 0; + /*! + * \brief Whether the rewrite rule can be applied to the stmt under certain constraints + * \param stmt The stmt + * \param constraints The constraints of the rewrite + * \return A boolean flag indicating whether the rule can be applied + */ + virtual bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const { return true; } + + public: + inline Stmt Apply(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { + if (CanApply(stmt, constraints)) { + return Rewrite(stmt, constraints, output); + } else { + return stmt; + } + } +}; + +inline bool IsCopyBetweenScope(const Buffer& src_buffer, const Buffer& tgt_buffer, + runtime::StorageRank src_rank, runtime::StorageRank tgt_rank) { + runtime::StorageScope src_scope = runtime::StorageScope::Create(src_buffer.scope()); + runtime::StorageScope tgt_scope = runtime::StorageScope::Create(tgt_buffer.scope()); + return src_scope.rank == src_rank && tgt_scope.rank == tgt_rank; +} + +/*! + * \brief Coalesce and vectorize memory access. + */ +class CoalescedAccess : public RewriteRule { + public: + CoalescedAccess() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kGlobal, + runtime::StorageRank::kShared) || + IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Transform from A[f(i,j)] = B[i,j] to A[i,j] = B[f^{-1}(i,j)] + */ +class InverseMapping : public RewriteRule { + public: + InverseMapping() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Create a local stage when loading from global memory to shared memory. + */ +class CreateLocalStage : public RewriteRule { + public: + CreateLocalStage() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kGlobal, + runtime::StorageRank::kShared) && + is_one(constraints.add_local_stage); + } +}; + +/*! + * \brief Add a cache stage in shared memory. Perform tensor core rewrite for wmma->shared, and + * perform coalescing and vectorizing for shared->global. + */ +class WmmaToGlobal : public RewriteRule { + public: + WmmaToGlobal() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kWMMAAccumulator, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Rewrite shared->wmma data copy with load_matrix_sync + */ +class SharedToWmma : public RewriteRule { + public: + SharedToWmma() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kWMMAMatrixA) || + IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kWMMAMatrixB); + } +}; + +/*! + * \brief Rewrite wmma->shared data copy with store_matrix_sync + */ +class WmmaToShared : public RewriteRule { + public: + WmmaToShared() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kWMMAAccumulator, + runtime::StorageRank::kShared); + } +}; + +/*! + * \brief Insert a cache stage to the compute location + * \param stmt the stmt + * \param is_write_cache whether to write a read cache or write cache + * \param storage_scope the storage scope of the new cache + * \param compute_location the compute location. + * \param outer_loops the outer loops of this stmt + * \param alloc_buffer the new cache block + * \return a pair. The first is the stmt after transformation. + * The second is the SeqStmt that contains 2 stages (one original and another inserted). + */ +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, + Optional compute_location, + const Array& outer_loops, Buffer* alloc_buffer); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/tir/transforms/memhammer_tensorcore_rewrite.cc new file mode 100644 index 000000000000..6e880146d618 --- /dev/null +++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Tile the 2 innermost loops to extent=16. This helps further tensor core rewrite. + * \param stmt The stmt + * \return A pair. The first is the stmt after transformation. + * The second is the compute location where we may add write cache. + */ +std::pair> TileWmmaBlock(Stmt stmt) { + Stmt body = stmt; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + int n = loops.size(); + PrimExpr extent_last1 = loops[n - 1]->extent; + PrimExpr extent_last2 = loops[n - 2]->extent; + { + arith::Analyzer analyzer; + if (!analyzer.CanProveEqual(floormod(extent_last1, 16), 0) || + !analyzer.CanProveEqual(floormod(extent_last2, 16), 0)) { + return std::make_pair(stmt, NullOpt); + } + } + Var new_loop_vars[4] = { + /*0:*/ loops[n - 2]->loop_var.copy_with_suffix("_0"), + /*1:*/ loops[n - 1]->loop_var.copy_with_suffix("_0"), + /*2:*/ loops[n - 2]->loop_var.copy_with_suffix("_1"), + /*3:*/ loops[n - 1]->loop_var.copy_with_suffix("_1"), + }; + body = Substitute(std::move(body), + Map{ + {loops[n - 2]->loop_var, new_loop_vars[0] * 16 + new_loop_vars[2]}, + {loops[n - 1]->loop_var, new_loop_vars[1] * 16 + new_loop_vars[3]}, + }); + { + PrimExpr factor[4] = { + /*0:*/ floordiv(extent_last2, 16), // + /*1:*/ floordiv(extent_last1, 16), // + /*3:*/ 16, // + /*4:*/ 16, // + }; + body = For(new_loop_vars[3], 0, factor[3], ForKind::kSerial, std::move(body)); + body = For(new_loop_vars[2], 0, factor[2], ForKind::kSerial, std::move(body)); + body = For(new_loop_vars[1], 0, factor[1], ForKind::kSerial, std::move(body)); + body = For(new_loop_vars[0], 0, factor[0], ForKind::kSerial, std::move(body)); + } + For compute_location = Downcast(body); + for (int i = n - 3; i >= 0; i--) { + body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, std::move(body), + loops[i]->thread_binding, loops[i]->annotations); + } + return {body, compute_location}; +} + +Array RelaxIndices(const Array& indices, const Array& shape, + const Map& var_dom) { + Array int_set = arith::EvalSet(indices, var_dom); + int ndim = int_set.size(); + Array region; + region.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + region.push_back(int_set[i].CoverRange(Range::FromMinExtent(0, shape[i]))); + }; + return region; +} + +/*! + * \brief Rewrite the data copy that stores to wmma fragment with wmma::load_matrix_sync + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ +Stmt RewriteWmmaLoad(Stmt stmt) { + using arith::IntSet; + const DataType dtype = DataType::Float(16); + const DataType int32 = DataType::Int(32); + + Stmt body = stmt; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + int n = loops.size(); + + Map var_dom{ + {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, + {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, + }; + // TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode); + Buffer src_buffer = buf_load->buffer; + Buffer tgt_buffer = buf_store->buffer; + + Buffer new_src_buffer( + /*data=*/Var("src", PointerType(PrimType(dtype), src_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{Var("s1", int32), Var("s0", int32)}, + /*elem_offset=*/Var("src_elem_offset", int32), + /*name=*/"src", + /*data_alignment=*/128, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + Buffer new_tgt_buffer( + /*data=*/Var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{}, + /*elem_offset=*/Var("tgt_elem_offset", int32), + /*name=*/"tgt", + /*data_alignment=*/128, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + Stmt wmma_body = BlockRealize( + /*iter_values=*/{}, + /*predicate=*/Bool(true), + Block( + /*iter_vars=*/{}, + /*reads=*/{BufferRegion(src_buffer, read_region)}, + /*writes=*/{BufferRegion(tgt_buffer, write_region)}, + /*name_hint=*/"wmma_load", + /*body=*/ + Evaluate(Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_load_matrix_sync(), + { + /*0:*/ new_tgt_buffer->data, + /*1:*/ 16, + /*2:*/ 16, + /*3:*/ 16, + /*4:*/ floordiv(new_tgt_buffer->elem_offset, 256) + + floordiv(floormod(new_tgt_buffer->elem_offset, 256), 16), + /*5:*/ + Call( + /*dtype=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_access_ptr(), + /*args=*/ + { + /*0:*/ TypeAnnotation(new_src_buffer->dtype), + /*1:*/ new_src_buffer->data, + /*2:*/ new_src_buffer->elem_offset, + /*3:*/ new_src_buffer->strides[new_src_buffer->strides.size() - 2] * 16, + /*4:*/ 1, + }), + /*6:*/ new_src_buffer->strides[new_src_buffer->strides.size() - 2], + /*7:*/ StringImm("row_major"), + })), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/ + { + /*0:*/ MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)), + /*1:*/ MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)), + }, + /*annotations=*/{})); + for (int i = n - 3; i >= 0; i--) { + wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, + std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations); + } + return wmma_body; +} + +/*! + * \brief Rewrite the data copy that loads from wmma fragment with wmma::store_matrix_sync + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ +Stmt RewriteWmmaStore(Stmt stmt) { + using arith::IntSet; + const DataType dtype = DataType::Float(32); + const DataType int32 = DataType::Int(32); + + Stmt body = stmt; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + int n = loops.size(); + + Map var_dom{ + {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, + {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, + }; + // TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode); + Buffer src_buffer = buf_load->buffer; + Buffer tgt_buffer = buf_store->buffer; + + Buffer new_src_buffer(/*data=*/Var("src", PointerType(PrimType(dtype), src_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{}, + /*elem_offset=*/Var("src_elem_offset", int32), + /*name=*/"src", + /*data_alignment=*/128, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + Buffer new_tgt_buffer(/*data=*/Var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{Var("s1", int32), Var("s0", int32)}, + /*elem_offset=*/Var("tgt_elem_offset", int32), + /*name=*/"tgt", + /*data_alignment=*/128, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + + Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + + Stmt wmma_body = BlockRealize( + /*iter_values=*/{}, // + /*predicate=*/Bool(true), + Block(/*iter_vars=*/{}, + /*reads=*/{BufferRegion(src_buffer, read_region)}, + /*writes=*/{BufferRegion(tgt_buffer, write_region)}, + /*name_hint=*/"wmma_store", + Evaluate(Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_store_matrix_sync(), + {/*0:*/ new_src_buffer->data, + /*1:*/ 16, + /*2:*/ 16, + /*3:*/ 16, + /*4:*/ floordiv(new_src_buffer->elem_offset, 256) + + floordiv(floormod(new_src_buffer->elem_offset, 256), 16), + /*5:*/ + Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_access_ptr(), + { + /*0:*/ TypeAnnotation(new_tgt_buffer->dtype), + /*1:*/ new_tgt_buffer->data, + /*2:*/ new_tgt_buffer->elem_offset, + /*3:*/ new_tgt_buffer->strides[0] * 16, + /*4:*/ 2, + }), + /*6:*/ new_tgt_buffer->strides[0], + /*7:*/ StringImm("row_major")})), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/ + { + MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)), + MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)), + }, + /*annotations=*/{})); + for (int i = n - 3; i >= 0; i--) { + wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, + std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations); + } + return wmma_body; +} + +Stmt SharedToWmma::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_tiling = TileWmmaBlock(stmt).first; + output->padding_min.Set(constraints.read_region->buffer, 8); + return RewriteWmmaLoad(after_tiling); +} + +Stmt WmmaToShared::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_tiling = TileWmmaBlock(stmt).first; + output->padding_min.Set(constraints.write_region->buffer, 8); + return RewriteWmmaStore(after_tiling); +} + +class WmmaToGlobalRewriter : public StmtExprMutator { + public: + WmmaToGlobalRewriter(const SeqStmtNode* tgt_stmt, const ConstraintSet& constraints) + : tgt_stmt_(tgt_stmt), constraints_(constraints) {} + + private: + Stmt VisitStmt_(const SeqStmtNode* op) final { + if (op == tgt_stmt_) { + ICHECK_EQ(op->seq.size(), 2); + Stmt wmma_to_shared = RewriteWmmaStore(op->seq[0]); + Stmt shared_to_global = CoalescedAccess().Rewrite(op->seq[1], constraints_, nullptr); + return SeqStmt({wmma_to_shared, shared_to_global}); + } else { + return StmtMutator::VisitStmt_(op); + } + } + + const SeqStmtNode* tgt_stmt_; + const ConstraintSet& constraints_; +}; + +Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body{nullptr}; + Optional compute_location{nullptr}; + std::tie(body, compute_location) = TileWmmaBlock(stmt); + SeqStmt seq{nullptr}; + Buffer cache_buffer; + // Step 1. add a shared memory cache + std::tie(body, seq) = InsertCacheStage(std::move(body), true, "shared.dyn", compute_location, + constraints.outer_loops, &cache_buffer); + output->alloc_buffer.push_back(cache_buffer); + output->padding_min.Set(cache_buffer, 8); + // Step 2. do coalesced rewrite and tensor core rewrite respectively for 2 parts + WmmaToGlobalRewriter rewriter(seq.get(), constraints); + return rewriter(body); +} + +} // namespace tir +} // namespace tvm diff --git a/tests/python/meta_schedule/run_ansor_cpu.sh b/tests/python/meta_schedule/run_ansor_cpu.sh new file mode 100644 index 000000000000..a080ded8fdd9 --- /dev/null +++ b/tests/python/meta_schedule/run_ansor_cpu.sh @@ -0,0 +1,41 @@ +set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="raspi4b-aarch64" +TARGET="raspberry-pi/4b-64" +NUM_TRIALS=800 +LOG_DIR=$HOME/logs/ansor-cpu/ + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_ansor.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials "$NUM_TRIALS" \ + --log-dir $LOG_DIR \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +# Single op +run C1D +run C2D +run C3D +run CAP +run DEP +run DIL +run GMM +run GRP +run NRM +run SFM +run T2D +# Subgraph +run C2d-BN-RELU +run TBG + diff --git a/tests/python/meta_schedule/run_ansor_cuda.sh b/tests/python/meta_schedule/run_ansor_cuda.sh new file mode 100644 index 000000000000..6eda12fe119c --- /dev/null +++ b/tests/python/meta_schedule/run_ansor_cuda.sh @@ -0,0 +1,39 @@ +# set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="jetson-agx-xavier" +TARGET="nvidia/jetson-agx-xavier" +LOG_DIR=$HOME/logs/ansor-cuda/ +NUM_TRIALS=2000 + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_ansor.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials "$NUM_TRIALS" \ + --log-dir $LOG_DIR \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +run C1D +run C2D +run CAP +run DEP +run DIL +run GMM +run GRP +run T2D +run C2d-BN-RELU +run TBG + +run C3D +run NRM +run SFM diff --git a/tests/python/meta_schedule/run_meta_schedule_cpu.sh b/tests/python/meta_schedule/run_meta_schedule_cpu.sh new file mode 100644 index 000000000000..87bc17f9e8b6 --- /dev/null +++ b/tests/python/meta_schedule/run_meta_schedule_cpu.sh @@ -0,0 +1,40 @@ +set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="raspi4b-aarch64" +TARGET="raspberry-pi/4b-64" +LOG_DIR=$HOME/logs/ms-cpu/ +NUM_TRIALS=2000 + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_meta_schedule.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials $NUM_TRIALS \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +# Single op +run C1D +run C2D +run C3D +run CAP +run DEP +run DIL +run GMM +run GRP +run NRM +run SFM +run T2D +# Subgraph +run C2d-BN-RELU +run TBG + diff --git a/tests/python/meta_schedule/run_meta_schedule_cuda.sh b/tests/python/meta_schedule/run_meta_schedule_cuda.sh new file mode 100644 index 000000000000..28132a05045a --- /dev/null +++ b/tests/python/meta_schedule/run_meta_schedule_cuda.sh @@ -0,0 +1,41 @@ +set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="jetson-agx-xavier" +TARGET="nvidia/jetson-agx-xavier" +LOG_DIR=$HOME/logs/ms-cuda/ +NUM_TRIALS=2000 + +mkdir -p $LOG_DIR + +run () { + name=$1 + work_dir=$LOG_DIR/$name/ + mkdir -p $work_dir + echo "Running workload $name" + python tests/python/meta_schedule/test_meta_schedule.py \ + --workload "$name" \ + --target "$TARGET" \ + --work-dir "$work_dir" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials $NUM_TRIALS \ + 2>&1 | tee "$work_dir/$name.log" +} + +run C1D +run C2D +run CAP +run DEP +run DIL +run GMM +run GRP +run T2D +run C2d-BN-RELU +run TBG + +run C3D +run NRM +run SFM diff --git a/tests/python/meta_schedule/test_ansor.py b/tests/python/meta_schedule/test_ansor.py new file mode 100644 index 000000000000..1e548c49afa3 --- /dev/null +++ b/tests/python/meta_schedule/test_ansor.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import os + +import tvm +from tvm import auto_scheduler +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.te_workload import CONFIGS + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + args.add_argument( + "--log-dir", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=60, + ) + parsed.rpc_workers = rpc_config.count_num_servers(allow_missing=False) + return parsed + + +ARGS = _parse_args() + + +def main(): + log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json") + workload_func, params = CONFIGS[ARGS.workload] + params = params[0] + workload_func = auto_scheduler.register_workload(workload_func) + + if ARGS.target.device_name == "cpu": + hardware_params = auto_scheduler.HardwareParams( + num_cores=int(ARGS.target.attrs["num-cores"]), + target=ARGS.target, + ) + else: + hardware_params = auto_scheduler.HardwareParams( + num_cores=-1, + vector_unit_bytes=16, + cache_line_bytes=64, + max_shared_memory_per_block=int(ARGS.target.attrs["shared_memory_per_block"]), + max_local_memory_per_block=int(ARGS.target.attrs["registers_per_block"]), + max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]), + max_vthread_extent=8, + warp_size=32, + ) + task = auto_scheduler.SearchTask( + func=workload_func, + args=params, + target=ARGS.target, + hardware_params=hardware_params, + ) + runner = auto_scheduler.RPCRunner( + key=ARGS.rpc_key, + host=ARGS.rpc_host, + port=ARGS.rpc_port, + n_parallel=ARGS.rpc_workers, + ) + + # Inspect the computational graph + print("Computational DAG:") + print(task.compute_dag) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=ARGS.num_trials, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, + runner=runner, + ) + print("Running AutoTuning:") + task.tune(tune_option) + print("History Best:") + print(task.print_best(log_file)) + sch, args = task.apply_best(log_file) + print("Lowered TIR:") + print(tvm.lower(sch, args, simple_mode=True)) + + +if __name__ == "__main__": + main() diff --git a/tests/python/meta_schedule/test_debug_ansor.py b/tests/python/meta_schedule/test_debug_ansor.py new file mode 100644 index 000000000000..be562963a1a0 --- /dev/null +++ b/tests/python/meta_schedule/test_debug_ansor.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from typing import Tuple + +import tvm +from tvm import te, topi + + +TARGET = tvm.target.Target("nvidia/jetson-agx-xavier") + +@tvm.register_func +def tvm_callback_cuda_postproc(code): + import os + if not os.path.exists("/tmp/perf"): + os.mkdir("/tmp/perf") + with open("/tmp/perf/te.cu", "w") as f: + f.write(code) + return code + + +def func( # pylint: disable=invalid-name,missing-docstring + N: int, + L: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, L, CI), name="inputs") + weight = te.placeholder((kernel_size, CI // groups, CO), name="weight") + + batch_size, in_len, _ = inputs.shape + k_len, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1 + rc = te.reduce_axis((0, channel_per_group), name="rc") + rl = te.reduce_axis((0, k_len), name="rl") + + padded = topi.nn.pad(inputs, [0, padding, 0]) + output = te.compute( + (batch_size, out_len, out_channel), + lambda n, l, co: te.sum( + ( + padded[ + n, + l * stride + rl * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rl, rc, co] + ), + axis=[rl, rc], + ), + name="conv1d_nlc", + ) + return (inputs, weight, padded, output) + + +def main(): + inputs, weight, PadInput, conv1d_nlc = func(1, 256, 64, 128, 3, 2, 1) + s = te.create_schedule(conv1d_nlc.op) + # fmt: off + PadInput_i0, PadInput_i1, PadInput_i2 = tuple(PadInput.op.axis) + tuple(PadInput.op.reduce_axis) + conv1d_nlc_n, conv1d_nlc_l, conv1d_nlc_co, conv1d_nlc_rl, conv1d_nlc_rc = tuple(conv1d_nlc.op.axis) + tuple(conv1d_nlc.op.reduce_axis) + conv1d_nlc_local, = s.cache_write([conv1d_nlc], "local") + conv1d_nlc_local_n_c, conv1d_nlc_local_l_c, conv1d_nlc_local_co_c, conv1d_nlc_local_rl, conv1d_nlc_local_rc = tuple(conv1d_nlc_local.op.axis) + tuple(conv1d_nlc_local.op.reduce_axis) + conv1d_nlc_local_n_c_o_i, conv1d_nlc_local_n_c_i = s[conv1d_nlc_local].split(conv1d_nlc_local_n_c, factor=1) + conv1d_nlc_local_n_c_o_o_i, conv1d_nlc_local_n_c_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_n_c_o_i, factor=1) + conv1d_nlc_local_n_c_o_o_o_i, conv1d_nlc_local_n_c_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_n_c_o_o_i, factor=1) + conv1d_nlc_local_n_c_o_o_o_o, conv1d_nlc_local_n_c_o_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_n_c_o_o_o_i, factor=1) + conv1d_nlc_local_l_c_o_i, conv1d_nlc_local_l_c_i = s[conv1d_nlc_local].split(conv1d_nlc_local_l_c, factor=1) + conv1d_nlc_local_l_c_o_o_i, conv1d_nlc_local_l_c_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_l_c_o_i, factor=4) + conv1d_nlc_local_l_c_o_o_o_i, conv1d_nlc_local_l_c_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_l_c_o_o_i, factor=8) + conv1d_nlc_local_l_c_o_o_o_o, conv1d_nlc_local_l_c_o_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_l_c_o_o_o_i, factor=1) + conv1d_nlc_local_co_c_o_i, conv1d_nlc_local_co_c_i = s[conv1d_nlc_local].split(conv1d_nlc_local_co_c, factor=2) + conv1d_nlc_local_co_c_o_o_i, conv1d_nlc_local_co_c_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_co_c_o_i, factor=1) + conv1d_nlc_local_co_c_o_o_o_i, conv1d_nlc_local_co_c_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_co_c_o_o_i, factor=16) + conv1d_nlc_local_co_c_o_o_o_o, conv1d_nlc_local_co_c_o_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_co_c_o_o_o_i, factor=1) + conv1d_nlc_local_rl_o_i, conv1d_nlc_local_rl_i = s[conv1d_nlc_local].split(conv1d_nlc_local_rl, factor=3) + conv1d_nlc_local_rl_o_o, conv1d_nlc_local_rl_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_rl_o_i, factor=1) + conv1d_nlc_local_rc_o_i, conv1d_nlc_local_rc_i = s[conv1d_nlc_local].split(conv1d_nlc_local_rc, factor=2) + conv1d_nlc_local_rc_o_o, conv1d_nlc_local_rc_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_rc_o_i, factor=8) + s[conv1d_nlc_local].reorder(conv1d_nlc_local_n_c_o_o_o_o, conv1d_nlc_local_l_c_o_o_o_o, conv1d_nlc_local_co_c_o_o_o_o, conv1d_nlc_local_n_c_o_o_o_i, conv1d_nlc_local_l_c_o_o_o_i, conv1d_nlc_local_co_c_o_o_o_i, conv1d_nlc_local_n_c_o_o_i, conv1d_nlc_local_l_c_o_o_i, conv1d_nlc_local_co_c_o_o_i, conv1d_nlc_local_rl_o_o, conv1d_nlc_local_rc_o_o, conv1d_nlc_local_rl_o_i, conv1d_nlc_local_rc_o_i, conv1d_nlc_local_n_c_o_i, conv1d_nlc_local_l_c_o_i, conv1d_nlc_local_co_c_o_i, conv1d_nlc_local_rl_i, conv1d_nlc_local_rc_i, conv1d_nlc_local_n_c_i, + conv1d_nlc_local_l_c_i, conv1d_nlc_local_co_c_i) + conv1d_nlc_n_o_i, conv1d_nlc_n_i = s[conv1d_nlc].split(conv1d_nlc_n, factor=1) + conv1d_nlc_n_o_o_i, conv1d_nlc_n_o_i = s[conv1d_nlc].split(conv1d_nlc_n_o_i, factor=1) + conv1d_nlc_n_o_o_o, conv1d_nlc_n_o_o_i = s[conv1d_nlc].split(conv1d_nlc_n_o_o_i, factor=1) + conv1d_nlc_l_o_i, conv1d_nlc_l_i = s[conv1d_nlc].split(conv1d_nlc_l, factor=4) + conv1d_nlc_l_o_o_i, conv1d_nlc_l_o_i = s[conv1d_nlc].split(conv1d_nlc_l_o_i, factor=8) + conv1d_nlc_l_o_o_o, conv1d_nlc_l_o_o_i = s[conv1d_nlc].split(conv1d_nlc_l_o_o_i, factor=1) + conv1d_nlc_co_o_i, conv1d_nlc_co_i = s[conv1d_nlc].split(conv1d_nlc_co, factor=2) + conv1d_nlc_co_o_o_i, conv1d_nlc_co_o_i = s[conv1d_nlc].split(conv1d_nlc_co_o_i, factor=16) + conv1d_nlc_co_o_o_o, conv1d_nlc_co_o_o_i = s[conv1d_nlc].split(conv1d_nlc_co_o_o_i, factor=1) + s[conv1d_nlc].reorder(conv1d_nlc_n_o_o_o, conv1d_nlc_l_o_o_o, conv1d_nlc_co_o_o_o, conv1d_nlc_n_o_o_i, conv1d_nlc_l_o_o_i, conv1d_nlc_co_o_o_i, conv1d_nlc_n_o_i, conv1d_nlc_l_o_i, conv1d_nlc_co_o_i, conv1d_nlc_n_i, conv1d_nlc_l_i, conv1d_nlc_co_i) + s[conv1d_nlc_local].compute_at(s[conv1d_nlc], conv1d_nlc_co_o_i) + weight_shared = s.cache_read(weight, "shared", [conv1d_nlc_local]) + weight_shared_ax0, weight_shared_ax1, weight_shared_ax2 = tuple(weight_shared.op.axis) + s[weight_shared].compute_at(s[conv1d_nlc_local], conv1d_nlc_local_rc_o_o) + PadInput_shared = s.cache_read(PadInput, "shared", [conv1d_nlc_local]) + PadInput_shared_ax0, PadInput_shared_ax1, PadInput_shared_ax2 = tuple(PadInput_shared.op.axis) + s[PadInput_shared].compute_at(s[conv1d_nlc_local], conv1d_nlc_local_rc_o_o) + s[PadInput].compute_inline() + conv1d_nlc_n_o_o_o_l_o_o_o_fused_co_o_o_o_fused = s[conv1d_nlc].fuse(conv1d_nlc_n_o_o_o, conv1d_nlc_l_o_o_o, conv1d_nlc_co_o_o_o) + s[conv1d_nlc].bind(conv1d_nlc_n_o_o_o_l_o_o_o_fused_co_o_o_o_fused, te.thread_axis("blockIdx.x")) + conv1d_nlc_n_o_o_i_l_o_o_i_fused_co_o_o_i_fused = s[conv1d_nlc].fuse(conv1d_nlc_n_o_o_i, conv1d_nlc_l_o_o_i, conv1d_nlc_co_o_o_i) + s[conv1d_nlc].bind(conv1d_nlc_n_o_o_i_l_o_o_i_fused_co_o_o_i_fused, te.thread_axis("vthread")) + conv1d_nlc_n_o_i_l_o_i_fused_co_o_i_fused = s[conv1d_nlc].fuse(conv1d_nlc_n_o_i, conv1d_nlc_l_o_i, conv1d_nlc_co_o_i) + s[conv1d_nlc].bind(conv1d_nlc_n_o_i_l_o_i_fused_co_o_i_fused, te.thread_axis("threadIdx.x")) + weight_shared_ax0_ax1_fused_ax2_fused = s[weight_shared].fuse(weight_shared_ax0, weight_shared_ax1, weight_shared_ax2) + weight_shared_ax0_ax1_fused_ax2_fused_o, weight_shared_ax0_ax1_fused_ax2_fused_i = s[weight_shared].split(weight_shared_ax0_ax1_fused_ax2_fused, factor=1) + s[weight_shared].vectorize(weight_shared_ax0_ax1_fused_ax2_fused_i) + weight_shared_ax0_ax1_fused_ax2_fused_o_o, weight_shared_ax0_ax1_fused_ax2_fused_o_i = s[weight_shared].split(weight_shared_ax0_ax1_fused_ax2_fused_o, factor=128) + s[weight_shared].bind(weight_shared_ax0_ax1_fused_ax2_fused_o_i, te.thread_axis("threadIdx.x")) + PadInput_shared_ax0_ax1_fused_ax2_fused = s[PadInput_shared].fuse(PadInput_shared_ax0, PadInput_shared_ax1, PadInput_shared_ax2) + PadInput_shared_ax0_ax1_fused_ax2_fused_o, PadInput_shared_ax0_ax1_fused_ax2_fused_i = s[PadInput_shared].split(PadInput_shared_ax0_ax1_fused_ax2_fused, factor=1) + s[PadInput_shared].vectorize(PadInput_shared_ax0_ax1_fused_ax2_fused_i) + PadInput_shared_ax0_ax1_fused_ax2_fused_o_o, PadInput_shared_ax0_ax1_fused_ax2_fused_o_i = s[PadInput_shared].split(PadInput_shared_ax0_ax1_fused_ax2_fused_o, factor=128) + s[PadInput_shared].bind(PadInput_shared_ax0_ax1_fused_ax2_fused_o_i, te.thread_axis("threadIdx.x")) + # s[conv1d_nlc_local].pragma(conv1d_nlc_local_n_c_o_o_o_o, "auto_unroll_max_step", 1024) + # s[conv1d_nlc_local].pragma(conv1d_nlc_local_n_c_o_o_o_o, "unroll_explicit", True) + # fmt: off + print(tvm.lower(s, [inputs, weight, conv1d_nlc]).script()) + tvm.build(s, [inputs, weight, conv1d_nlc], target=TARGET) + + +if __name__ == "__main__": + main() diff --git a/tests/python/meta_schedule/test_debug_meta_schedule.py b/tests/python/meta_schedule/test_debug_meta_schedule.py new file mode 100644 index 000000000000..b93a01dae737 --- /dev/null +++ b/tests/python/meta_schedule/test_debug_meta_schedule.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +from typing import List + +import tvm +from tvm import meta_schedule as ms +from tvm.ir import IRModule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import Postproc +from tvm.meta_schedule.testing import create_te_workload +from tvm.meta_schedule.tune import DefaultCUDA, DefaultLLVM +from tvm.meta_schedule.utils import remove_build_dir +from tvm.target import Target +from tvm.tir import Schedule + + +RPC_HOST = "192.168.6.66" +RPC_PORT = 4445 +RPC_KEY = "jetson-agx-xavier" +TARGET = Target("nvidia/jetson-agx-xavier") +WORKLOAD = "C1D" +POSTPROCS: List[Postproc] = DefaultCUDA._postproc() # pylint: disable=protected-access + +TARGET = tvm.target.Target("nvidia/jetson-agx-xavier") + + +@tvm.register_func +def tvm_callback_cuda_postproc(code): + import os + + if not os.path.exists("/tmp/perf"): + os.mkdir("/tmp/perf") + with open("/tmp/perf/tir.cu", "w") as f: + f.write(code) + return code + + +def schedule_fn(sch: Schedule): + # pylint: disable=invalid-name,line-too-long,unused-variable + # fmt: off + b0 = sch.get_block(name="PadInput", func_name="main") + b1 = sch.get_block(name="conv1d_nlc", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + b3 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + l4, l5, l6, l7, l8 = sch.get_loops(block=b1) + v9, v10, v11, v12, v13 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l14, l15, l16, l17, l18 = sch.split(loop=l4, factors=[v9, v10, v11, v12, v13]) + v19, v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[4, 1, 8, 4, 1]) + l24, l25, l26, l27, l28 = sch.split(loop=l5, factors=[v19, v20, v21, v22, v23]) + v29, v30, v31, v32, v33 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[4, 1, 16, 1, 2]) + l34, l35, l36, l37, l38 = sch.split(loop=l6, factors=[v29, v30, v31, v32, v33]) + v39, v40, v41 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64, decision=[1, 1, 3]) + l42, l43, l44 = sch.split(loop=l7, factors=[v39, v40, v41]) + v45, v46, v47 = sch.sample_perfect_tile(loop=l8, n=3, max_innermost_factor=64, decision=[4, 8, 2]) + l48, l49, l50 = sch.split(loop=l8, factors=[v45, v46, v47]) + sch.reorder(l14, l24, l34, l15, l25, l35, l16, l26, l36, l42, l48, l43, l49, l17, l27, l37, l44, l50, l18, l28, l38) + l51 = sch.fuse(l14, l24, l34) + sch.bind(loop=l51, thread_axis="blockIdx.x") + l52 = sch.fuse(l15, l25, l35) + sch.bind(loop=l52, thread_axis="vthread.x") + l53 = sch.fuse(l16, l26, l36) + sch.bind(loop=l53, thread_axis="threadIdx.x") + + b54 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="shared") + sch.compute_at(block=b54, loop=l48, preserve_unit_loops=True) + l55, l56, l57, l58, l59, l60, l61, l62 = sch.get_loops(block=b54) + l63 = sch.fuse(l60, l61, l62) + v64, v65 = sch.sample_perfect_tile(loop=l63, n=2, max_innermost_factor=4, decision=[1040, 1]) + sch.annotate(block_or_loop=b54, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) + + b66 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared") + sch.compute_at(block=b66, loop=l48, preserve_unit_loops=True) + l67, l68, l69, l70, l71, l72, l73, l74 = sch.get_loops(block=b66) + l75 = sch.fuse(l72, l73, l74) + v76, v77 = sch.sample_perfect_tile(loop=l75, n=2, max_innermost_factor=4, decision=[1536, 1]) + sch.annotate(block_or_loop=b66, ann_key="meta_schedule.cooperative_fetch", ann_val=v77) + + sch.reverse_compute_at(block=b3, loop=l53, preserve_unit_loops=True) + sch.compute_inline(block=b0) + # v78 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=4) + # sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v78) + # fmt: on + return sch + + +def _make_sch() -> Schedule: + prim_func = create_te_workload(WORKLOAD, 0) + prim_func = prim_func.with_attr("global_symbol", "main") + prim_func = prim_func.with_attr("tir.noalias", True) + mod = IRModule({"main": prim_func}) + return Schedule(mod, debug_mask="all") + + +def _apply_postproc(sch: Schedule): + sch.enter_postproc() + ctx = TuneContext(target=TARGET) + for p in POSTPROCS: + p.initialize_with_tune_context(ctx) + assert p.apply(sch) + + +def run_sch(sch: Schedule): + print(sch.mod.script()) + print(sch.trace) + print(tvm.lower(sch.mod).script()) + tvm.build(sch.mod, target=TARGET) + builder = ms.builder.LocalBuilder() + runner = ms.runner.RPCRunner( + rpc_config=ms.runner.RPCConfig( + tracker_host=RPC_HOST, + tracker_port=RPC_PORT, + tracker_key=RPC_KEY, + session_timeout_sec=60, + ), + alloc_repeat=3, + max_workers=5, + ) + (builder_result,) = builder.build( # pylint: disable=unbalanced-tuple-unpacking + [ms.builder.BuilderInput(sch.mod, TARGET)] + ) + if builder_result.error_msg is not None: + print(builder_result.error_msg) + return + try: + runner_input = ms.runner.RunnerInput( + builder_result.artifact_path, + device_type=TARGET.kind.name, + args_info=ms.arg_info.ArgInfo.from_prim_func(sch.mod["main"]), + ) + (runner_future,) = runner.run([runner_input]) # pylint: disable=unbalanced-tuple-unpacking + runner_result = runner_future.result() + if runner_result.error_msg is not None: + print(runner_result.error_msg) + else: + print([float(x) * 1000.0 for x in runner_result.run_secs]) + finally: + remove_build_dir(builder_result.artifact_path) + + +def main(): + sch = schedule_fn(_make_sch()) + _apply_postproc(sch) + run_sch(sch) + + +if __name__ == "__main__": + main() diff --git a/tests/python/meta_schedule/test_e2e.py b/tests/python/meta_schedule/test_e2e.py new file mode 100644 index 000000000000..0f1dcae0cfac --- /dev/null +++ b/tests/python/meta_schedule/test_e2e.py @@ -0,0 +1,78 @@ +from typing import List, Tuple + +from tvm.meta_schedule.testing.e2e import extract, get_network +from tvm.target import Target + +MODEL_CACHE_DIR = "~/dataset/relay-models" +TASK_CACHE_DIR = "~/dataset/tasks-{target_kind}" + + +def _build_dataset() -> List[Tuple[str, List[int]]]: + network_keys = [] + for name in [ + "resnet_18", + "resnet_50", + "mobilenet_v2", + "mobilenet_v3", + "wide_resnet_50", + "resnext_50", + "densenet_121", + ]: + for batch_size in [1, 4, 8]: + for image_size in [224, 240, 256]: + network_keys.append((name, [batch_size, 3, image_size, image_size])) + # inception-v3 + for name in ["inception_v3"]: + for batch_size in [1, 2, 4]: + for image_size in [299]: + network_keys.append((name, [batch_size, 3, image_size, image_size])) + # resnet3d + for name in ["resnet3d_18"]: + for batch_size in [1, 2, 4]: + for image_size in [112, 128, 144]: + network_keys.append((name, [batch_size, 3, image_size, image_size, 16])) + # bert + for name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]: + for batch_size in [1, 2, 4]: + for seq_length in [64, 128, 256]: + network_keys.append((name, [batch_size, seq_length])) + # dcgan + for name in ["dcgan"]: + for batch_size in [1, 4, 8]: + for image_size in [64]: + network_keys.append((name, [batch_size, 3, image_size, image_size])) + + return network_keys + + +def test_import(): + network_keys = _build_dataset() + for i, (name, input_shape) in enumerate(network_keys, 1): + print(f"[{i} / {len(network_keys)}] Import {name}, input_shape = {input_shape}") + get_network(name, input_shape, cache_dir=MODEL_CACHE_DIR) + + +def test_extract(): + network_keys = _build_dataset() + for target_kind in ["llvm", "cuda"]: + for i, (name, input_shape) in enumerate(network_keys, 1): + print( + f"[{i} / {len(network_keys)}] Extract {name} @ {target_kind}, input_shape = {input_shape}" + ) + if name == "resnext_50" and target_kind == "cuda": + continue + mod, params, _ = get_network(name, input_shape, cache_dir=MODEL_CACHE_DIR) + filename = f'{name}-{",".join(str(i) for i in input_shape)}-{target_kind}.json' + extracted_tasks = extract( + filename=filename, + mod=mod, + target=Target(target_kind), + params=params, + cache_dir=TASK_CACHE_DIR.format(target_kind=target_kind), + ) + print(f"{len(extracted_tasks)} task(s) extracted") + + +if __name__ == "__main__": + test_import() + test_extract() diff --git a/tests/python/meta_schedule/test_meta_schedule.py b/tests/python/meta_schedule/test_meta_schedule.py new file mode 100644 index 000000000000..64890f426791 --- /dev/null +++ b/tests/python/meta_schedule/test_meta_schedule.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import logging +from os import cpu_count + +import tvm +from tvm import meta_schedule as ms +from tvm import tir +from tvm.meta_schedule.testing import create_te_workload + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--work-dir", + type=str, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu": + parsed.alloc_repeat = 3 + else: + parsed.alloc_repeat = 1 + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=30, + ) + parsed.rpc_workers = parsed.rpc_config.count_num_servers(allow_missing=False) + return parsed + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +ARGS = _parse_args() + + +def main(): + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, + alloc_repeat=3, + max_workers=ARGS.rpc_workers, + ) + sch: tir.Schedule = ms.tune_tir( + mod=create_te_workload(ARGS.workload, 0), + target=ARGS.target, + config=ms.EvolutionarySearchConfig( + num_trials_per_iter=64, + num_trials_total=ARGS.num_trials, + init_min_unmeasured=50 + ), + runner=runner, + task_name=ARGS.workload, + work_dir=ARGS.work_dir, + num_threads=cpu_count(), + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + +if __name__ == "__main__": + main() diff --git a/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py b/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py new file mode 100644 index 000000000000..5a2ede204769 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py @@ -0,0 +1,398 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import te +from tvm.script import tir as T +import sys +import pytest + + +@tvm.script.ir_module +class Transpose: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([16, 128], dtype="float32", scope="shared.dyn") + with T.block("A_shared"): + T.block_attr({"auto_copy": 1}) + for ax0, ax1 in T.grid(128, 16): + A_shared_dyn[ax1, ax0] = A[ax0, ax1] + with T.block("B"): + for ax1, ax0 in T.grid(16, 128): + T.block_attr({"auto_copy": 1}) + B[ax1, ax0] = A_shared_dyn[ax1, ax0] + + +@tvm.script.ir_module +class GlobalToShared: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn") + with T.block("A_shared"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax0, ax1] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class SharedToGlobal: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn") + with T.block("A_shared"): + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax1, ax0] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for ax1, ax0 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax1, ax0] + + +@tvm.script.ir_module +class GlobalToSharedWithLocalStage: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn") + with T.block("A_shared"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16, "local_stage": True}) + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax0, ax1] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class SharedToWmma: + @T.prim_func + def main() -> None: + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float16", scope="shared.dyn") + A_wmma = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") + with T.block("A_wmma"): + T.block_attr({"auto_copy": 1}) + for ax0, ax1 in T.grid(128, 128): + A_wmma[ax0, ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class WmmaToShared: + @T.prim_func + def main() -> None: + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + C_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn") + with T.block("C_shared"): + T.block_attr({"auto_copy": 1}) + for ax0, ax1 in T.grid(128, 128): + C_shared[ax0, ax1] = C_accum[ax0, ax1] + + +@tvm.script.ir_module +class WmmaToGlobal: + @T.prim_func + def main(c: T.handle) -> None: + C = T.match_buffer(c, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + with T.block("C_global"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for ax0, ax1 in T.grid(128, 128): + C[bx * 128 + ax0, by * 128 + ax1] = C_accum[ax0, ax1] + +@tvm.script.ir_module +class TransformedGlobalToShared: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", strides=[128, 1], scope="shared.dyn") + with T.block("A_shared"): + T.block_attr({"auto_copy":1, "vector_bytes":16}) + for outer in T.serial(16): + for ty_1 in T.thread_binding(8, thread="threadIdx.y"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + for vec in T.vectorized(4): + A_shared_dyn[(((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] = A[bx * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, by * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + +@tvm.script.ir_module +class TransformedSharedToGlobal: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", strides=[129, 1], scope="shared.dyn") + with T.block("A_shared"): + T.reads(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.writes(A_shared_dyn[0 : 128, 0 : 128]) + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax1, ax0] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + T.block_attr({"auto_copy":1, "vector_bytes":16}) + for outer in T.serial(16): + for ty_1 in T.thread_binding(8, thread="threadIdx.y"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + for vec in T.vectorized(4): + B[bx * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, by * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] = A_shared_dyn[(((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128, (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128] + +@tvm.script.ir_module +class TransformedGlobalToSharedWithLocalStage: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", strides=[128, 1], scope="shared.dyn") + with T.block("A_shared"): + T.reads(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.writes(A_shared_dyn[0 : 128, 0 : 128]) + T.block_attr({"auto_copy":1, "local_stage":True, "vector_bytes":16}) + A_local = T.alloc_buffer([16, 4], dtype="float32", scope="local") + for ty_1 in T.thread_binding(8, thread="threadIdx.y"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 16, 1, 1, 1): + for vec in T.vectorized(4): + A_local[ax0 * 16 + ax1 + ax2, (ax3 + ax4) * 4 + vec] = A[((bx % 8 + ax0) * 16 + ax1) * 8 + (ty_1 % 128 + ax2), ((by % 8 + ax3) * 32 + (tx % 32 + ax4)) * 4 + vec] + for serial in T.serial(16): + for vec in T.vectorized(4): + A_shared_dyn[(((serial * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, (((serial * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] = A_local[(serial * 8 + (tx * 4 + vec) // 128 + ty_1) % 128 // 8 + (((tx * 4 + vec) // 128 + ty_1) % 8 - ty_1 % 128), ((tx * 4 + vec) % 128 // 4 - tx % 32) * 4 + vec % 4] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + +@tvm.script.ir_module +class TransformedSharedToWmma: + @T.prim_func + def main() -> None: + s0 = T.var("int32") + s1 = T.var("int32") + # body + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float16", strides=[136, 1], scope="shared.dyn") + A_wmma = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") + with T.block("C_shared"): + T.reads(A_shared_dyn[0 : 128, 0 : 128]) + T.writes(A_wmma[0 : 128, 0 : 128]) + T.block_attr({"auto_copy":1}) + for ax00, ax10 in T.grid(8, 8): + with T.block("wmma_load"): + T.reads(A_shared_dyn[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16]) + T.writes(A_wmma[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16]) + src = T.match_buffer(A_shared_dyn[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16], [16, 16], dtype="float16", strides=[s1, s0], scope="shared.dyn", offset_factor=16) + tgt = T.match_buffer(A_wmma[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_a", offset_factor=16) + T.evaluate(T.tvm_load_matrix_sync(tgt.data, 16, 16, 16, tgt.elem_offset // 256 + tgt.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float16"), src.data, src.elem_offset, s1 * 16, 1, dtype="handle"), s1, "row_major", dtype="handle")) + +@tvm.script.ir_module +class TransformedWmmaToShared: + @T.prim_func + def main() -> None: + s0 = T.var("int32") + s1 = T.var("int32") + # body + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + C_shared = T.alloc_buffer([128, 128], dtype="float32", strides=[136, 1], scope="shared.dyn") + with T.block("A_wmma"): + T.reads(C_accum[0 : 128, 0 : 128]) + T.writes(C_shared[0 : 128, 0 : 128]) + T.block_attr({"auto_copy":1}) + for ax00, ax10 in T.grid(8, 8): + with T.block("wmma_store"): + T.reads(C_accum[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16]) + T.writes(C_shared[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16]) + src = T.match_buffer(C_accum[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16) + tgt = T.match_buffer(C_shared[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16], [16, 16], dtype="float32", strides=[s1, s0], scope="shared.dyn", offset_factor=16) + T.evaluate(T.tvm_store_matrix_sync(src.data, 16, 16, 16, src.elem_offset // 256 + src.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float32"), tgt.data, tgt.elem_offset, s1 * 16, 2, dtype="handle"), s1, "row_major", dtype="handle")) + +@tvm.script.ir_module +class TransformedWmmaToGlobal: + @T.prim_func + def main(C: T.Buffer[(1024, 1024), "float32"]) -> None: + s0 = T.var("int32") + s1 = T.var("int32") + # body + with T.block("root"): + T.reads() + T.writes(C[0 : 1024, 0 : 1024]) + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + T.reads() + T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + C_accum = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + with T.block("C_global"): + T.reads(C_accum[0 : 128, 0 : 128]) + T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.block_attr({"auto_copy":1, "vector_bytes":16}) + C_shared_dyn = T.alloc_buffer([16, 128], dtype="float32", strides=[136, 1], scope="shared.dyn") + for ax0_0 in T.serial(8): + for ax1_0 in T.serial(8): + with T.block("wmma_store"): + T.reads(C_accum[ax0_0 * 16 : ax0_0 * 16 + 16, ax1_0 * 16 : ax1_0 * 16 + 16]) + T.writes(C_shared_dyn[(ax0_0 // 8 + bx) % 8 * 16 + ax0_0 % 8 * 16 - ax0_0 % 64 * 16 - bx % 8 * 16 : (ax0_0 // 8 + bx) % 8 * 16 + ax0_0 % 8 * 16 - ax0_0 % 64 * 16 - bx % 8 * 16 + 16, (ax1_0 // 8 + by) % 8 * 128 + ax1_0 % 8 * 16 - by % 8 * 128 : (ax1_0 // 8 + by) % 8 * 128 + ax1_0 % 8 * 16 - by % 8 * 128 + 16]) + src = T.match_buffer(C_accum[ax0_0 * 16 : ax0_0 * 16 + 16, ax1_0 * 16 : ax1_0 * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16) + tgt = T.match_buffer(C_shared_dyn[(ax0_0 // 8 + bx) % 8 * 16 + ax0_0 % 8 * 16 - ax0_0 % 64 * 16 - bx % 8 * 16 : (ax0_0 // 8 + bx) % 8 * 16 + ax0_0 % 8 * 16 - ax0_0 % 64 * 16 - bx % 8 * 16 + 16, (ax1_0 // 8 + by) % 8 * 128 + ax1_0 % 8 * 16 - by % 8 * 128 : (ax1_0 // 8 + by) % 8 * 128 + ax1_0 % 8 * 16 - by % 8 * 128 + 16], [16, 16], dtype="float32", strides=[s1, s0], scope="shared.dyn", offset_factor=16) + T.evaluate(T.tvm_store_matrix_sync(src.data, 16, 16, 16, src.elem_offset // 256 + src.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float32"), tgt.data, tgt.elem_offset, s1 * 16, 2, dtype="handle"), s1, "row_major", dtype="handle")) + for ax0_ax1_ax2_ax3_ax4_ax5_fused_0 in T.serial(2): + for ax0_ax1_ax2_ax3_ax4_ax5_fused_1 in T.thread_binding(8, thread="threadIdx.y"): + for ax0_ax1_ax2_ax3_ax4_ax5_fused_2 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_ax4_ax5_fused_3 in T.vectorized(4): + C[((bx % 8 + 0) * 8 + (ax0_0 % 64 + 0)) * 16 + (((ax0_ax1_ax2_ax3_ax4_ax5_fused_0 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused_1) * 32 + ax0_ax1_ax2_ax3_ax4_ax5_fused_2) * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused_3) // 16 // 8 % 16, ((by % 8 + 0) * 8 + (((ax0_ax1_ax2_ax3_ax4_ax5_fused_0 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused_1) * 32 + ax0_ax1_ax2_ax3_ax4_ax5_fused_2) * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused_3) // 16 % 8) * 16 + (((ax0_ax1_ax2_ax3_ax4_ax5_fused_0 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused_1) * 32 + ax0_ax1_ax2_ax3_ax4_ax5_fused_2) * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused_3) % 16] = C_shared_dyn[(0 + 0) * 16 + (((ax0_ax1_ax2_ax3_ax4_ax5_fused_0 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused_1) * 32 + ax0_ax1_ax2_ax3_ax4_ax5_fused_2) * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused_3) // 16 // 8 % 16, (0 * 8 + (((ax0_ax1_ax2_ax3_ax4_ax5_fused_0 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused_1) * 32 + ax0_ax1_ax2_ax3_ax4_ax5_fused_2) * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused_3) // 16 % 8) * 16 + (((ax0_ax1_ax2_ax3_ax4_ax5_fused_0 * 8 + ax0_ax1_ax2_ax3_ax4_ax5_fused_1) * 32 + ax0_ax1_ax2_ax3_ax4_ax5_fused_2) * 4 + ax0_ax1_ax2_ax3_ax4_ax5_fused_3) % 16] + + +def _check(original, transformed): + mod = tvm.tir.transform.LowerAutoCopy()(original) + tvm.ir.assert_structural_equal(mod, transformed, True) + + +def test_coalesce_vectorize(): + _check(GlobalToShared, TransformedGlobalToShared) + + +def test_inverse(): + _check(SharedToGlobal, TransformedSharedToGlobal) + + +def test_local_stage(): + _check(GlobalToSharedWithLocalStage, TransformedGlobalToSharedWithLocalStage) + + +def test_rewrite_shared_to_wmma(): + _check(SharedToWmma, TransformedSharedToWmma) + + +def test_rewrite_wmma_to_shared(): + _check(WmmaToShared, TransformedWmmaToShared) + + +def test_rewrite_wmma_to_global(): + _check(WmmaToGlobal, TransformedWmmaToGlobal) + + +def verify_single_allocation(stmt, alloc_size=None): + num_alloc = [0] + alloc_extents = [] + + def verify(n): + if ( + isinstance(n, tvm.tir.Allocate) + and n.buffer_var.type_annotation.storage_scope == "shared.dyn" + ): + num_alloc[0] += 1 + alloc_extents.append(n.extents[0]) + + tvm.tir.stmt_functor.post_order_visit(stmt, verify) + assert num_alloc[0] == 1 + + if alloc_size: + assert alloc_extents[0] == alloc_size + + +def test_auto_padding(): + mod = tvm.tir.transform.LowerAutoCopy()(Transpose) + mod = tvm.tir.transform.FlattenBuffer()(mod) + verify_single_allocation(mod['main'].body, 16 * 130) + + +if __name__ == "__main__": + test_coalesce_vectorize() + test_inverse() + test_local_stage() + test_rewrite_shared_to_wmma() + test_rewrite_wmma_to_shared() + test_rewrite_wmma_to_global() + test_auto_padding()