Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Fix illegal pass order #243

Merged
merged 4 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 8847ba to 7b325a
5 changes: 4 additions & 1 deletion bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ def remove_tvm_path(path):
try_inline, # noqa: F401
try_inline_contiguous_spatial, # noqa: F401
)

from .relax import (
ApplyDefaultSchedule, # noqa: F401
ApplyFastTuning, # noqa: F401
)
from . import testing # noqa: F401
from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
Expand Down
40 changes: 14 additions & 26 deletions examples/relax_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,16 @@
import numpy as np
import os
from typing import Dict
import numpy as np # type: ignore
import time
import bitblas
from bitblas import tvm as tvm
import tvm
from tvm import relay, relax, runtime, transform
from tvm.ir.module import IRModule
from tvm.relax.testing import relay_translator, nn
from tvm.target.target import Target
from tvm import dlight as dl
from tvm import relay
import tvm.relay.testing
from tvm.ir.module import IRModule
from bitblas.relax import ApplyDefaultSchedule, ApplyFastTuning

fname = os.path.basename(__file__)
fname = os.path.splitext(fname)[0]
# get current file path
Expand All @@ -41,6 +37,7 @@

bitblas.set_log_level("Debug")


def write_code(code, path, fname):
global count
fname = str(count) + "." + fname
Expand Down Expand Up @@ -80,16 +77,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"):
input_shape = (batch_size,) + image_shape
output_shape = (batch_size, 1000)

if name.startswith("resnet-"):
n_layer = int(name.split("-")[1])
mod, params = relay.testing.resnet.get_workload(
num_layers=n_layer,
batch_size=batch_size,
layout=layout,
dtype=dtype,
image_shape=image_shape,
)
elif name.startswith("resnet3d-"):
if name.startswith("resnet-") or name.startswith("resnet3d-"):
n_layer = int(name.split("-")[1])
mod, params = relay.testing.resnet.get_workload(
num_layers=n_layer,
Expand All @@ -100,8 +88,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"):
)
elif name == "mobilenet":
mod, params = relay.testing.mobilenet.get_workload(
batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape
)
batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape)
elif name == "squeezenet_v1.1":
assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout"
mod, params = relay.testing.squeezenet.get_workload(
Expand All @@ -115,8 +102,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"):
mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
elif name == "mlp":
mod, params = relay.testing.mlp.get_workload(
batch_size=batch_size, image_shape=image_shape, dtype=dtype
)
batch_size=batch_size, image_shape=image_shape, dtype=dtype)

return mod, params, input_shape, output_shape

Expand All @@ -133,9 +119,8 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32"):
relay_mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype)


def apply_opt_before_tuning(
relay_mod: IRModule, params: Dict[str, runtime.NDArray], target: Target
):
def apply_opt_before_tuning(relay_mod: IRModule, params: Dict[str, runtime.NDArray],
target: Target):
with transform.PassContext(opt_level=3):
main_func = relay_mod["main"]
bind_main_func = relay.build_module.bind_params_by_name(main_func, params)
Expand All @@ -157,7 +142,12 @@ def apply_opt_before_tuning(
write_mod(relay_mod, log_path, "FoldConstant")

# opt_level=2 and select_impl_strategy are required for avoiding winograd lowering
relax_mod = relay_translator.from_relay(relay_mod["main"], opt_level=2, target=target, append_op_attrs=True, select_impl_strategy="first")
relax_mod = relay_translator.from_relay(
relay_mod["main"],
opt_level=2,
target=target,
append_op_attrs=True,
select_impl_strategy="first")
write_mod(relax_mod, log_path, "relay_translator_relax")
relax_mod = relax.transform.AnnotateTIROpPattern()(relax_mod)
write_mod(relax_mod, log_path, "AnnotateTIROpPattern")
Expand Down Expand Up @@ -199,7 +189,6 @@ def apply_opt_before_tuning(
ex = relax.build(relax_mod, target)
write_code(ex.mod.imported_modules[0].imported_modules[0].get_source(), log_path, "tmp.cu")


device = tvm.cuda(0)
vm = relax.VirtualMachine(ex, device)

Expand All @@ -218,10 +207,9 @@ def apply_opt_before_tuning(

start = time.time()

for i in range(10):
for _ in range(10):
vm["main"](*input_args)


device.sync()

end = time.time()
Expand Down
Loading