diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 7173e5b8db63e..4fd4bbb444eb6 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -389,6 +389,33 @@ def matmul_strategy_cpu(attrs, inputs, out_type, target): name="matmul.generic", ) + same_type = inputs[0].dtype == inputs[1].dtype == out_type.dtype + dtype = inputs[0].dtype + u8s8s32 = dtype == "uint8" and inputs[1].dtype == "int8" and out_type.dtype == "int32" + if "cblas" in target.libs: + with SpecializedCondition(same_type and dtype in ["float32", "float64"]): + strategy.add_implementation( + wrap_compute_matmul(topi.x86.matmul_cblas), + wrap_topi_schedule(topi.x86.schedule_matmul_cblas), + name="matmul_cblas.x86", + plevel=13, + ) + if "mkl" in target.libs: + with SpecializedCondition(same_type and dtype in ["float32", "float64"] or u8s8s32): + strategy.add_implementation( + wrap_compute_matmul(topi.x86.matmul_mkl), + wrap_topi_schedule(topi.x86.schedule_matmul_mkl), + name="matmul_mkl.x86", + plevel=14, + ) + if "mkldnn" in target.libs: + with SpecializedCondition(same_type and dtype == "float32"): + strategy.add_implementation( + wrap_compute_matmul(topi.x86.matmul_mkldnn), + wrap_topi_schedule(topi.x86.schedule_matmul_mkldnn), + name="matmul_mkldnn.x86", + plevel=15, + ) return strategy diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 4fed4c16464ef..48a6440b36d11 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -281,8 +281,8 @@ def _callback(op): return s -def dense_blas_common(cfg, data, weight, bias, out_dtype, lib): - """Compute dense using a BLAS library""" +def matmul_blas_common(cfg, data, weight, bias, out_dtype, data_transposed, weight_transposed, lib): + """Compute matmul/dense using a BLAS library""" M, K = get_const_tuple(data.shape) N, _ = get_const_tuple(weight.shape) if isinstance(M, int) and isinstance(K, int) and isinstance(N, int): @@ -290,63 +290,110 @@ def dense_blas_common(cfg, data, weight, bias, out_dtype, lib): if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype == "int32": if not hasattr(lib, "matmul_u8s8s32"): raise NotImplementedError( - f"Dense with {lib.__name__} for {data.dtype} is not supported " + f"Matmul/Dense with {lib.__name__} for {data.dtype} is not supported " "(matmulu8s8s32 not imlemented)" ) - C = lib.matmul_u8s8s32(data, weight, False, True, dtype=out_dtype) + C = lib.matmul_u8s8s32(data, weight, data_transposed, weight_transposed, dtype=out_dtype) elif data.dtype == "float32" or data.dtype == "float64": - C = lib.matmul(data, weight, False, True) + C = lib.matmul(data, weight, data_transposed, weight_transposed) else: - raise NotImplementedError(f"Dense with {lib.__name__} for {data.dtype} is not supported") + raise NotImplementedError(f"Matmul/Dense with {lib.__name__} for {data.dtype} is not supported") if bias is not None: C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype), tag=tag.BROADCAST) return C +def schedule_matmul_blas_common(outs): + """Default matmul schedule for BLAS library""" + s = te.create_schedule([x.op for x in outs]) + te.schedule.AutoInlineInjective(s) + + for out in outs: + if "dense" not in out.op.tag and "matmul" not in out.op.tag: + schedule_injective_from_existing(s, out) + return s + + @autotvm.register_topi_compute("dense_cblas.x86") def dense_cblas(cfg, data, weight, bias=None, out_dtype=None): """Compute dense using a cblas""" - return dense_blas_common(cfg, data, weight, bias, out_dtype, cblas) + return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, cblas) @autotvm.register_topi_schedule("dense_cblas.x86") def schedule_dense_cblas(_, outs): """Create schedule for dense_cblas""" - return generic.schedule_extern(outs) + return schedule_matmul_blas_common(outs) @autotvm.register_topi_compute("dense_mkl.x86") def dense_mkl(cfg, data, weight, bias=None, out_dtype=None): """Compute dense using mkl""" - return dense_blas_common(cfg, data, weight, bias, out_dtype, mkl) + return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, mkl) @autotvm.register_topi_schedule("dense_mkl.x86") def schedule_dense_mkl(_, outs): """Create schedule for dense_mkl""" - # return generic.schedule_extern(outs) - s = te.create_schedule([x.op for x in outs]) - te.schedule.AutoInlineInjective(s) - - def _callback(op): - if "broadcast" in op.tag or "injective" in op.tag or "elemwise" in op.tag: - schedule_injective_from_existing(s, op.output(0)) - - # traverse_inline(s, outs[0].op, _callback) - for out in outs: - if "dense" not in out.op.name: - schedule_injective_from_existing(s, out) - return s + return schedule_matmul_blas_common(outs) @autotvm.register_topi_compute("dense_mkldnn.x86") def dense_mkldnn(cfg, data, weight, bias=None, out_dtype=None): """Compute dense using mkldnn""" - return dense_blas_common(cfg, data, weight, bias, out_dtype, mkldnn) + return matmul_blas_common(cfg, data, weight, bias, out_dtype, False, True, mkldnn) @autotvm.register_topi_schedule("dense_mkldnn.x86") def schedule_dense_mkldnn(_, outs): """Create schedule for dense_mkldnn""" - return generic.schedule_extern(outs) + return schedule_matmul_blas_common(outs) + + +@autotvm.register_topi_compute("matmul_cblas.x86") +def matmul_cblas( + cfg, data, weight, bias=None, out_dtype=None, data_transposed=False, weight_transposed=False +): + """Compute matmul using a cblas""" + return matmul_blas_common( + cfg, data, weight, bias, out_dtype, data_transposed, weight_transposed, cblas + ) + + +@autotvm.register_topi_schedule("matmul_cblas.x86") +def schedule_matmul_cblas(_, outs): + """Create schedule for matmul_cblas""" + return schedule_matmul_blas_common(outs) + + +@autotvm.register_topi_compute("matmul_mkl.x86") +def matmul_mkl( + cfg, data, weight, bias=None, out_dtype=None, data_transposed=False, weight_transposed=False +): + """Compute matmul using mkl""" + return matmul_blas_common( + cfg, data, weight, bias, out_dtype, data_transposed, weight_transposed, mkl + ) + + +@autotvm.register_topi_schedule("matmul_mkl.x86") +def schedule_matmul_mkl(_, outs): + """Create schedule for matmul_mkl""" + return schedule_matmul_blas_common(outs) + + +@autotvm.register_topi_compute("matmul_mkldnn.x86") +def matmul_mkldnn( + cfg, data, weight, bias=None, out_dtype=None, data_transposed=False, weight_transposed=False +): + """Compute matmul using mkldnn""" + return matmul_blas_common( + cfg, data, weight, bias, out_dtype, data_transposed, weight_transposed, mkldnn + ) + + +@autotvm.register_topi_schedule("matmul_mkldnn.x86") +def schedule_matmul_mkldnn(_, outs): + """Create schedule for matmul_mkldnn""" + return schedule_matmul_blas_common(outs)