diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index e2a82d396b22..3c5735b17aa5 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -377,6 +377,13 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): name="batch_matmul_cblas.x86", plevel=15, ) + if "mkl" in target.libs: + strategy.add_implementation( + wrap_compute_batch_matmul(topi.x86.batch_matmul_mkl), + wrap_topi_schedule(topi.x86.schedule_batch_matmul_mkl), + name="batch_matmul_mkl.x86", + plevel=15, + ) return strategy diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 4e5f6efc815a..100bdf205165 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -19,7 +19,7 @@ from tvm import te from tvm import autotvm from tvm.autotvm.task.space import SplitEntity -from tvm.contrib import cblas +from tvm.contrib import cblas, mkl from .. import generic from ..util import traverse_inline, get_const_tuple, get_max_power2_factor @@ -137,10 +137,9 @@ def _default_batch_matmul_config(cfg, M, N, K): cfg["tile_y"] = SplitEntity([M // y_bn, y_bn]) -@autotvm.register_topi_compute("batch_matmul_cblas.x86") -def batch_matmul_cblas(cfg, x, y, out_shape=None): +def batch_matmul_blas_common(cfg, x, y, out_shape, lib): """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are - data in batch. + data in batch, using one of BLAS libraries. Parameters ---------- @@ -152,6 +151,8 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None): 3-D with shape [batch, N, K] out_shape : tuple or None Shape of the output + lib : A contrib module which implements batch_matmul funtion + cblas and mkl are supported Returns ------- @@ -168,9 +169,28 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None): assert out_shape[1] == M, "got invalid output shape" assert out_shape[2] == N, "got invalid output shape" cfg.add_flop(XB * M * N * XK * 2) - return cblas.batch_matmul(x, y, False, True) + return lib.batch_matmul(x, y, False, True) + + +@autotvm.register_topi_compute("batch_matmul_cblas.x86") +def batch_matmul_cblas(cfg, x, y, out_shape=None): + """Compute batch_matmul using cblas""" + return batch_matmul_blas_common(cfg, x, y, out_shape, cblas) @autotvm.register_topi_schedule("batch_matmul_cblas.x86") def schedule_batch_matmul_cblas(_, outs): + """Create schedule for batch_matmul_cblas""" + return generic.schedule_extern(outs) + + +@autotvm.register_topi_compute("batch_matmul_mkl.x86") +def batch_matmul_mkl(cfg, x, y, out_shape=None): + """Compute batch_matmul using mkl""" + return batch_matmul_blas_common(cfg, x, y, out_shape, mkl) + + +@autotvm.register_topi_schedule("batch_matmul_mkl.x86") +def schedule_batch_matmul_mkl(_, outs): + """Create schedule for batch_matmul_mul""" return generic.schedule_extern(outs)