Skip to content

Commit

Permalink
[MKL] Fix offloading of batch_matmul to MKL (#6752)
Browse files Browse the repository at this point in the history
* fix mkl offloading of batch matmul

* name fix and add doc

* add doc for lib arg

Co-authored-by: masa <masa@pop-os.localdomain>
  • Loading branch information
masahi and masa authored Oct 25, 2020
1 parent 1831c17 commit 2d8ac1d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
7 changes: 7 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,13 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
name="batch_matmul_cblas.x86",
plevel=15,
)
if "mkl" in target.libs:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.x86.batch_matmul_mkl),
wrap_topi_schedule(topi.x86.schedule_batch_matmul_mkl),
name="batch_matmul_mkl.x86",
plevel=15,
)
return strategy


Expand Down
30 changes: 25 additions & 5 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tvm import te
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas
from tvm.contrib import cblas, mkl
from .. import generic
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor

Expand Down Expand Up @@ -137,10 +137,9 @@ def _default_batch_matmul_config(cfg, M, N, K):
cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])


@autotvm.register_topi_compute("batch_matmul_cblas.x86")
def batch_matmul_cblas(cfg, x, y, out_shape=None):
def batch_matmul_blas_common(cfg, x, y, out_shape, lib):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
data in batch, using one of BLAS libraries.
Parameters
----------
Expand All @@ -152,6 +151,8 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None):
3-D with shape [batch, N, K]
out_shape : tuple or None
Shape of the output
lib : A contrib module which implements batch_matmul funtion
cblas and mkl are supported
Returns
-------
Expand All @@ -168,9 +169,28 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None):
assert out_shape[1] == M, "got invalid output shape"
assert out_shape[2] == N, "got invalid output shape"
cfg.add_flop(XB * M * N * XK * 2)
return cblas.batch_matmul(x, y, False, True)
return lib.batch_matmul(x, y, False, True)


@autotvm.register_topi_compute("batch_matmul_cblas.x86")
def batch_matmul_cblas(cfg, x, y, out_shape=None):
"""Compute batch_matmul using cblas"""
return batch_matmul_blas_common(cfg, x, y, out_shape, cblas)


@autotvm.register_topi_schedule("batch_matmul_cblas.x86")
def schedule_batch_matmul_cblas(_, outs):
"""Create schedule for batch_matmul_cblas"""
return generic.schedule_extern(outs)


@autotvm.register_topi_compute("batch_matmul_mkl.x86")
def batch_matmul_mkl(cfg, x, y, out_shape=None):
"""Compute batch_matmul using mkl"""
return batch_matmul_blas_common(cfg, x, y, out_shape, mkl)


@autotvm.register_topi_schedule("batch_matmul_mkl.x86")
def schedule_batch_matmul_mkl(_, outs):
"""Create schedule for batch_matmul_mul"""
return generic.schedule_extern(outs)

0 comments on commit 2d8ac1d

Please sign in to comment.