Skip to content

Commit

Permalink
[MetaSchedule] Sample-Perfect-Tile (#501)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Nov 5, 2021
1 parent 7469824 commit fc902d3
Show file tree
Hide file tree
Showing 12 changed files with 171 additions and 110 deletions.
38 changes: 5 additions & 33 deletions python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TaskScheduler(Object):

def tune(self) -> None:
"""Auto-tuning."""
_ffi_api.TaskSchedulerTune(self) # type: ignore # pylint: disable=no-member
_ffi_api.TaskSchedulerTune(self) # pylint: disable=no-member

def next_task_id(self) -> int:
"""Fetch the next task id.
Expand Down Expand Up @@ -86,7 +86,7 @@ def _set_task_stopped(self, task_id: int) -> None:
task_id : int
The task id to be stopped.
"""
_ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # type: ignore # pylint: disable=no-member
_ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # pylint: disable=no-member

def _is_task_running(self, task_id: int) -> bool:
"""Check whether the task is running.
Expand All @@ -101,7 +101,7 @@ def _is_task_running(self, task_id: int) -> bool:
bool
Whether the task is running.
"""
return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member
return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # pylint: disable=no-member

def _join_running_task(self, task_id: int) -> None:
"""Wait until the task is finished.
Expand All @@ -111,17 +111,7 @@ def _join_running_task(self, task_id: int) -> None:
task_id : int
The task id to be joined.
"""
_ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member

def _next_task_id(self) -> int:
"""Fetch the next task id.
Returns
-------
int
The next task id.
"""
return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member
_ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # pylint: disable=no-member


@register_object("meta_schedule.PyTaskScheduler")
Expand Down Expand Up @@ -185,7 +175,7 @@ def f_join_running_task(task_id: int) -> None:
self._join_running_task(task_id)

self.__init_handle_by_constructor__(
_ffi_api.TaskSchedulerPyTaskScheduler, # type: ignore # pylint: disable=no-member
_ffi_api.TaskSchedulerPyTaskScheduler, # pylint: disable=no-member
tasks,
builder,
runner,
Expand All @@ -198,21 +188,3 @@ def f_join_running_task(task_id: int) -> None:
f_join_running_task,
f_next_task_id,
)

def tune(self) -> None:
raise NotImplementedError()

def _initialize_task(self, task_id: int) -> None:
raise _ffi_api.TaskSchedulerInitializeTask(self, task_id)

def _set_task_stopped(self, task_id: int) -> None:
_ffi_api.TaskSchedulerSetTaskStopped(self, task_id) # type: ignore # pylint: disable=no-member

def _is_task_running(self, task_id: int) -> bool:
return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member

def _join_running_task(self, task_id: int) -> None:
_ffi_api.TaskSchedulerJoinRunningTask(self, task_id) # type: ignore # pylint: disable=no-member

def _next_task_id(self) -> int:
return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member
67 changes: 67 additions & 0 deletions python/tvm/meta_schedule/testing/relay_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,70 @@ def forward(self, inp):
# Convert torch model to relay module
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
return mod, params


def get_network(
name: str,
batch_size: int,
layout: str = "NHWC",
dtype: str = "float32",
) -> Tuple[IRModule, Dict[str, NDArray], Tuple[int, int, int, int], Tuple[int, int]]:
"""Get the symbol definition and random weight of a network"""
import tvm.relay.testing # pylint: disable=import-outside-toplevel

# meta-schedule prefers NHWC layout
if layout == "NHWC":
image_shape = (224, 224, 3)
elif layout == "NCHW":
image_shape = (3, 224, 224)
else:
raise ValueError("Invalid layout: " + layout)

input_shape: Tuple[int, int, int, int] = (batch_size,) + image_shape
output_shape: Tuple[int, int] = (batch_size, 1000)

if name.startswith("resnet-"):
n_layer = int(name.split("-")[1])
mod, params = relay.testing.resnet.get_workload(
num_layers=n_layer,
batch_size=batch_size,
layout=layout,
dtype=dtype,
image_shape=image_shape,
)
elif name.startswith("resnet3d-"):
n_layer = int(name.split("-")[1])
mod, params = relay.testing.resnet.get_workload(
num_layers=n_layer,
batch_size=batch_size,
layout=layout,
dtype=dtype,
image_shape=image_shape,
)
elif name == "mobilenet":
mod, params = relay.testing.mobilenet.get_workload(
batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape
)
elif name == "squeezenet_v1.1":
assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout"
mod, params = relay.testing.squeezenet.get_workload(
version="1.1",
batch_size=batch_size,
dtype=dtype,
image_shape=image_shape,
)
elif name == "inception_v3":
input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3)
mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
elif name == "mxnet":
from mxnet.gluon.model_zoo.vision import get_model # type: ignore # pylint: disable=import-outside-toplevel

assert layout == "NCHW"
block = get_model("resnet50_v1", pretrained=True)
mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
net = mod["main"]
net = relay.Function(
net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
)
mod = IRModule.from_expr(net)
return mod, params, input_shape, output_shape
4 changes: 0 additions & 4 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,11 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r
*/
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
const String& storage_scope);

/******** Schedule: Data movement ********/

TVM_DLL StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref,
int read_buffer_index, const String& storage_scope);

TVM_DLL StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref,
int write_buffer_index, const String& storage_scope);

/******** Schedule: Compute location ********/
/*!
* \brief Move a producer block under the specific loop, and regenerate the
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/primitive/sampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <random>

#include "../utils.h"
#include "tvm/support/random_engine.h"

namespace tvm {
namespace tir {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
MetaScheduleContext,
TaskExtraction,
)
from tvm.meta_schedule.testing import get_network
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.script import tir as T

# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking
Expand Down
10 changes: 6 additions & 4 deletions tests/python/unittest/test_meta_schedule_measure_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
Expand Down
10 changes: 6 additions & 4 deletions tests/python/unittest/test_meta_schedule_mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
Expand Down
60 changes: 38 additions & 22 deletions tests/python/unittest/test_meta_schedule_post_order_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,13 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@tvm.script.ir_module
class DuplicateMatmul:
Expand All @@ -60,12 +63,17 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]


@tvm.script.ir_module
class TrinityMatmul:
Expand All @@ -76,12 +84,19 @@ def main(a: T.handle, d: T.handle) -> None:
B = T.alloc_buffer((1024, 1024), "float32")
C = T.alloc_buffer((1024, 1024), "float32")
D = T.match_buffer(d, (1024, 1024), "float32")
with T.block([1024, 1024], "A") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
with T.block([1024, 1024], "B") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 3.0
with T.block([1024, 1024], "C") as [vi, vj]:
D[vi, vj] = C[vi, vj] * 5.0
for i, j in T.grid(1024, 1024):
with T.block("A"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(1024, 1024):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 3.0
for i, j in T.grid(1024, 1024):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
D[vi, vj] = C[vi, vj] * 5.0


@tvm.script.ir_module
class TrinityMatmulProcessedForReference:
Expand All @@ -95,20 +110,21 @@ def main(a: T.handle, d: T.handle) -> None:
# with tir.block("root")
B = T.alloc_buffer([1024, 1024], dtype="float32")
for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16):
with T.block([1024, 1024], "A") as [vi, vj]:
T.bind(vi, i0_0 * 64 + i0_1)
T.bind(vj, i1_0 * 16 + i1_1)
with T.block("A"):
vi = T.axis.S(1024, i0_0 * 64 + i0_1)
vj = T.axis.S(1024, i1_0 * 16 + i1_1)
T.reads([A[vi, vj]])
T.writes([B[vi, vj]])
B[vi, vj] = A[vi, vj] * T.float32(2)
for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16):
with T.block([1024, 1024], "C") as [vi, vj]:
T.bind(vi, i0_0 * 64 + i0_1)
T.bind(vj, i1_0 * 16 + i1_1)
with T.block("C"):
vi = T.axis.S(1024, i0_0 * 64 + i0_1)
vj = T.axis.S(1024, i1_0 * 16 + i1_1)
T.reads([B[vi, vj]])
T.writes([D[vi, vj]])
D[vi, vj] = (B[vi, vj] + T.float32(3)) * T.float32(5)


# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument

Expand Down
10 changes: 6 additions & 4 deletions tests/python/unittest/test_meta_schedule_postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
Expand Down
10 changes: 6 additions & 4 deletions tests/python/unittest/test_meta_schedule_schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

# fmt: on
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tvm.runtime import NDArray


@pytest.mark.skip("Skip because it runs too slowly as a unittest")
@pytest.mark.parametrize(
"model_name",
[
Expand Down
Loading

0 comments on commit fc902d3

Please sign in to comment.