Skip to content

Commit

Permalink
[cuDNN] Add support for log_softmax (apache#8369)
Browse files Browse the repository at this point in the history
* log_softmax strategy and cudnn impl

* add log_softmax cudnn test

* silence terrible pylint suggestion

* fix typo
  • Loading branch information
altanh authored and ylc committed Jan 13, 2022
1 parent 3bf580b commit 8f8d1f2
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 73 deletions.
26 changes: 26 additions & 0 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,3 +582,29 @@ def softmax(x, axis=-1):
),
name="y",
)


def log_softmax(x, axis=-1):
"""Compute log_softmax using CuDNN
Parameters
----------
x : tvm.te.Tensor
The input tensor
axis : int
The axis to compute log softmax over
Returns
-------
ret : tvm.te.Tensor
The result tensor
"""
return te.extern(
x.shape,
[x],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.cudnn.log_softmax.forward", ins[0], outs[0], axis
),
name="y",
)
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@


# log_softmax
reg.register_schedule("nn.log_softmax", strategy.schedule_log_softmax)
reg.register_strategy("nn.log_softmax", strategy.log_softmax_strategy)
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)


Expand Down
22 changes: 17 additions & 5 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,23 @@ def fast_softmax_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@schedule_log_softmax.register(["cuda", "gpu"])
def schedule_log_softmax_cuda(attrs, outs, target):
"""scheudle log_softmax for cuda"""
with target:
return topi.cuda.schedule_softmax(outs)
@log_softmax_strategy.register(["cuda", "gpu"])
def log_softmax_strategy_cuda(attrs, inputs, out_type, target):
"""log_softmax cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.log_softmax),
wrap_topi_schedule(topi.cuda.schedule_softmax),
name="log_softmax.cuda",
)
if target.kind.name == "cuda" and "cudnn" in target.libs:
strategy.add_implementation(
wrap_compute_softmax(topi.cuda.log_softmax_cudnn),
wrap_topi_schedule(topi.cuda.schedule_log_softmax_cudnn),
name="log_softmax.cudnn",
plevel=15,
)
return strategy


@schedule_lrn.register(["cuda", "gpu"])
Expand Down
15 changes: 9 additions & 6 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,15 @@ def fast_softmax_strategy(attrs, inputs, out_type, target):
return strategy


# log_softmax
@generic_func
def schedule_log_softmax(attrs, outs, target):
"""Schedule log_softmax op"""
with target:
return topi.generic.schedule_softmax(outs)
@override_native_generic_func("log_softmax_strategy")
def log_softmax_strategy(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.log_softmax),
wrap_topi_schedule(topi.generic.schedule_softmax),
name="log_softmax.generic",
)
return strategy


# lrn
Expand Down
15 changes: 10 additions & 5 deletions python/tvm/relay/op/strategy/hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,16 @@ def softmax_strategy_hls(attrs, inputs, out_type, target):
return strategy


@schedule_log_softmax.register("hls")
def schedule_log_softmax_hls(attrs, inputs, out_type, target):
"""schedule log_softmax for hls"""
with target:
return topi.hls.schedule_softmax(outs)
@log_softmax_strategy.register("hls")
def log_softmax_strategy_hls(attrs, inputs, out_type, target):
"""log_softmax hls strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.log_softmax),
wrap_topi_schedule(topi.hls.schedule_softmax),
name="log_softmax.hls",
)
return strategy


@override_native_generic_func("conv2d_strategy")
Expand Down
15 changes: 10 additions & 5 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,16 @@ def fast_softmax_strategy_cpu(attrs, inputs, out_type, target):
return strategy


@schedule_log_softmax.register("cpu")
def schedule_log_softmax_cpu(attrs, outs, target):
"""schedule log_softmax op for x86"""
with target:
return topi.x86.schedule_softmax(outs)
@log_softmax_strategy.register("cpu")
def log_softmax_strategy_cpu(attrs, inputs, out_type, target):
"""log_softmax x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.log_softmax),
wrap_topi_schedule(topi.x86.schedule_softmax),
name="log_softmax.x86",
)
return strategy


@conv2d_strategy.register("cpu")
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/topi/cuda/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,13 @@ def softmax_cudnn(x, axis=-1):
def schedule_softmax_cudnn(outs):
"""Schedule for softmax cudnn op"""
return generic.schedule_extern(outs)


def log_softmax_cudnn(x, axis=-1):
"""Perform log_softmax on the data using cudnn"""
return cudnn.log_softmax(x, axis)


def schedule_log_softmax_cudnn(outs):
"""Schedule for log_softmax cudnn op"""
return generic.schedule_extern(outs)
5 changes: 3 additions & 2 deletions python/tvm/topi/nn/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _normalize(exp, expsum, *indices):


@tvm.te.tag_scope(tag="log_softmax_output")
def log_softmax(x):
def log_softmax(x, axis=-1):
"""Perform log softmax activation on the data
Parameters
Expand All @@ -136,8 +136,9 @@ def log_softmax(x):
output : tvm.te.Tensor
2-D output with same shape
"""

assert len(x.shape) == 2, "only support 2-dim log softmax"
# pylint: disable=R1714
assert axis == -1 or axis == len(x.shape) - 1, "only support last axis log softmax"
m, n = x.shape
k = te.reduce_axis((0, n), name="k")
max_elem = te.compute((m,), lambda i: tvm.te.max(x[i, k], axis=k))
Expand Down
91 changes: 48 additions & 43 deletions src/runtime/contrib/cudnn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,54 +31,59 @@ namespace contrib {

using namespace runtime;

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* x = args[0];
DLTensor* y = args[1];
int axis = args[2];
int ndim = x->ndim;
int64_t* shape = x->shape;
if (axis < 0) axis += ndim;
ICHECK(axis >= 0 && axis < ndim);
void softmax_impl(cudnnSoftmaxAlgorithm_t alg, TVMArgs args, TVMRetValue* ret) {
DLTensor* x = args[0];
DLTensor* y = args[1];
int axis = args[2];
int ndim = x->ndim;
int64_t* shape = x->shape;
if (axis < 0) axis += ndim;
ICHECK(axis >= 0 && axis < ndim);

CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);

// Set mode and shape descriptor
if (axis == ndim - 1) {
int64_t N = 1;
for (int i = 0; i < ndim - 1; ++i) {
N *= shape[i];
}
entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE;
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc,
CUDNN_TENSOR_NCHW, entry_ptr->softmax_entry.data_type,
static_cast<int>(N),
static_cast<int>(shape[ndim - 1]), 1, 1));
} else {
int64_t pre_axis_dim = 1;
int64_t post_axis_dim = 1;
for (int i = 0; i < ndim; ++i) {
if (i < axis) {
pre_axis_dim *= shape[i];
} else if (i > axis) {
post_axis_dim *= shape[i];
}
}
entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL;
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW,
entry_ptr->softmax_entry.data_type, static_cast<int>(pre_axis_dim),
static_cast<int>(shape[axis]), static_cast<int>(post_axis_dim), 1));
// Set mode and shape descriptor
if (axis == ndim - 1) {
int64_t N = 1;
for (int i = 0; i < ndim - 1; ++i) {
N *= shape[i];
}
entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE;
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW,
entry_ptr->softmax_entry.data_type, static_cast<int>(N),
static_cast<int>(shape[ndim - 1]), 1, 1));
} else {
int64_t pre_axis_dim = 1;
int64_t post_axis_dim = 1;
for (int i = 0; i < ndim; ++i) {
if (i < axis) {
pre_axis_dim *= shape[i];
} else if (i > axis) {
post_axis_dim *= shape[i];
}
}
entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL;
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW, entry_ptr->softmax_entry.data_type,
static_cast<int>(pre_axis_dim), static_cast<int>(shape[axis]),
static_cast<int>(post_axis_dim), 1));
}

auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type);
auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type);
CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, alg, entry_ptr->softmax_entry.mode, alpha,
entry_ptr->softmax_entry.shape_desc, x->data, beta,
entry_ptr->softmax_entry.shape_desc, y->data));
}

auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type);
auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type);
CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, CUDNN_SOFTMAX_ACCURATE,
entry_ptr->softmax_entry.mode, alpha,
entry_ptr->softmax_entry.shape_desc, x->data, beta,
entry_ptr->softmax_entry.shape_desc, y->data));
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward")
.set_body([](TVMArgs args, TVMRetValue* ret) {
softmax_impl(CUDNN_SOFTMAX_ACCURATE, args, ret);
});

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.log_softmax.forward")
.set_body([](TVMArgs args, TVMRetValue* ret) { softmax_impl(CUDNN_SOFTMAX_LOG, args, ret); });

} // namespace contrib
} // namespace tvm
28 changes: 22 additions & 6 deletions tests/python/contrib/test_cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,30 +176,40 @@ def test_conv3d():
verify_conv3d("float32", "float32", tensor_format=0, groups=2)


def verify_softmax(shape, axis, dtype="float32"):
def verify_softmax(shape, axis, dtype="float32", log_softmax=False):
cudnn_op = cudnn.log_softmax if log_softmax else cudnn.softmax
testing_op = (
tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python
)

A = te.placeholder(shape, dtype=dtype, name="A")
B = cudnn.softmax(A, axis)
B = cudnn_op(A, axis)
s = te.create_schedule([B.op])

dev = tvm.cuda(0)
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = tvm.topi.testing.softmax_python(a_np)
b_np = testing_op(a_np)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
f = tvm.build(s, [A, B], target="cuda --host=llvm", name="softmax")
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3)


def verify_softmax_4d(shape, dtype="float32"):
def verify_softmax_4d(shape, dtype="float32", log_softmax=False):
cudnn_op = cudnn.log_softmax if log_softmax else cudnn.softmax
testing_op = (
tvm.topi.testing.log_softmax_python if log_softmax else tvm.topi.testing.softmax_python
)

A = te.placeholder(shape, dtype=dtype, name="A")
B = cudnn.softmax(A, axis=1)
B = cudnn_op(A, axis=1)
s = te.create_schedule([B.op])

dev = tvm.cuda(0)
n, c, h, w = shape
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = tvm.topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h * w, c))
b_np = testing_op(a_np.transpose(0, 2, 3, 1).reshape(h * w, c))
b_np = b_np.reshape(n, h, w, c).transpose(0, 3, 1, 2)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
Expand All @@ -217,6 +227,12 @@ def test_softmax():
verify_softmax_4d((1, 16, 256, 256))
verify_softmax_4d((1, 16, 256, 256), "float64")

verify_softmax((32, 10), -1, log_softmax=True)
verify_softmax((3, 4), -1, log_softmax=True)
verify_softmax((1, 5), -1, "float64", log_softmax=True)
verify_softmax_4d((1, 16, 256, 256), log_softmax=True)
verify_softmax_4d((1, 16, 256, 256), "float64", log_softmax=True)


test_kwargs_default_2d = {
"tensor_format": 0,
Expand Down

0 comments on commit 8f8d1f2

Please sign in to comment.