Skip to content

Commit

Permalink
[Fix] Fix topi.rms_norm with float32 upscale (#16091)
Browse files Browse the repository at this point in the history
This PR fixes the `topi.rms_norm` with upscale to float32, for large reduction dimension of computation on float16.
  • Loading branch information
cyx-6 authored Nov 9, 2023
1 parent 99336c3 commit 42de91f
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 34 deletions.
28 changes: 12 additions & 16 deletions include/tvm/topi/nn/rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,31 @@ using namespace tvm::te;
* \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
* \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
* d_{axis_k} == r_k
* \param bias Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where
* d_{axis_k} == r_k
* \param axis The axis to normalize over.
* \param epsilon The epsilon value to avoid division by zero.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
* \return The normalized tensor, with the same shape as data.
*/
inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Tensor& bias,
const Array<Integer>& axis, double epsilon, std::string name = "T_rms_norm",
inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array<Integer>& axis,
double epsilon, std::string name = "T_rms_norm",
std::string tag = kInjective) {
const auto& data_type = data->dtype;
const auto& weight_type = weight.defined() ? weight->dtype : data_type;
ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type";
const auto& bias_type = bias.defined() ? bias->dtype : data_type;
ICHECK(data_type == bias_type) << "rms_norm: data and bias must have the same type";

auto square = multiply(data, data);
const auto& data_fp32 = cast(data, DataType::Float(32));
const auto& weight_fp32 = cast(weight, DataType::Float(32));

auto square = multiply(data_fp32, data_fp32);
auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);

auto ndim = data->shape.size();
auto ndim = data_fp32->shape.size();
ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto reduce_extent = make_const(data->dtype, 1);
auto reduce_extent = make_const(data_fp32->dtype, 1);
for (int i : real_axis) {
reduce_extent *= data->shape[i];
reduce_extent *= data_fp32->shape[i];
}
auto rms_norm_func = [&](const Array<Var>& indices) {
Array<Var> reduce_indices, non_reduce_indices;
Expand All @@ -78,15 +77,12 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Tensor& b
}
}
auto output =
data(indices) * weight(reduce_indices) *
data_fp32(indices) * weight_fp32(reduce_indices) *
tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon));
if (bias.defined()) {
output += bias(reduce_indices);
}
return output;
};
auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag);
return rms_norm;
auto rms_norm = tvm::te::compute(data_fp32->shape, rms_norm_func, name, tag);
return cast(rms_norm, data_type);
}

} // namespace nn
Expand Down
7 changes: 2 additions & 5 deletions python/tvm/topi/nn/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .. import cpp


def rms_norm(data, weight, bias, axis, epsilon=1e-5):
def rms_norm(data, weight, axis, epsilon=1e-5):
"""Root mean square normalization operator. The output will have the same data type as input.
Parameters
Expand All @@ -29,9 +29,6 @@ def rms_norm(data, weight, bias, axis, epsilon=1e-5):
weight: tvm.te.Tensor
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
bias: tvm.te.Tensor
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k
axis : list of int
Axis over the normalization applied
Expand All @@ -43,4 +40,4 @@ def rms_norm(data, weight, bias, axis, epsilon=1e-5):
result : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
return cpp.nn.rms_norm(data, weight, bias, axis, epsilon)
return cpp.nn.rms_norm(data, weight, axis, epsilon)
9 changes: 5 additions & 4 deletions python/tvm/topi/testing/rms_norm_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


def rms_norm_python(data, weight, bias, axis, epsilon=1e-5):
def rms_norm_python(data, weight, axis, epsilon=1e-5):
"""Root mean square normalization operator in Python.
Parameters
Expand All @@ -44,8 +44,9 @@ def rms_norm_python(data, weight, bias, axis, epsilon=1e-5):
result : np.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
dtype = data.dtype
data = data.astype("float32")
weight = weight.astype("float32")
square_mean = np.mean(np.square(data), axis, keepdims=True)
result = data * weight / np.sqrt(square_mean + epsilon)
if bias is not None:
result += bias
return result
return result.astype(dtype)
2 changes: 1 addition & 1 deletion src/topi/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetVal

/* Ops from nn/rms_norm.h */
TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::rms_norm(args[0], args[1], args[2], args[3], static_cast<double>(args[4]));
*rv = nn::rms_norm(args[0], args[1], args[2], static_cast<double>(args[3]));
});

} // namespace topi
Expand Down
14 changes: 6 additions & 8 deletions tests/python/topi/python/test_topi_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,31 @@
# only test on llvm because schedule is missing
@tvm.testing.parametrize_targets("llvm")
@pytest.mark.parametrize(
"shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,))]
"shape,axis",
[([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,)), ([2, 8192], (1,))],
)
@pytest.mark.parametrize("dtype", ["float32", "float16"])
def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3, atol=1e-4):
shape_te = [te.var(v[0]) if isinstance(v, tuple) else v for v in shape]
scale_shape_te = [shape_te[dim] for dim in axis]
data = te.placeholder(shape_te, dtype=dtype, name="data")
weight = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
bias = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
B = topi.nn.rms_norm(data, weight, bias, axis, episilon)
B = topi.nn.rms_norm(data, weight, axis, episilon)

shape_np = [v[1] if isinstance(v, tuple) else v for v in shape]
scale_shape_np = [shape_np[dim] for dim in axis]
data_np = np.random.uniform(size=shape_np).astype(dtype)
weight_np = np.random.uniform(size=scale_shape_np).astype(dtype)
bias_np = np.random.uniform(size=scale_shape_np).astype(dtype)
b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, bias_np, axis, episilon)
b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, axis, episilon)

with tvm.target.Target(target):
s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule)
s = s_func([B])
data_tvm = tvm.nd.array(data_np, dev)
weight_tvm = tvm.nd.array(weight_np, dev)
bias_tvm = tvm.nd.array(bias_np, dev)
b_tvm = tvm.nd.array(np.zeros(shape_np, dtype=dtype), dev)
f = tvm.build(s, [data, weight, bias, B], target)
f(data_tvm, weight_tvm, bias_tvm, b_tvm)
f = tvm.build(s, [data, weight, B], target)
f(data_tvm, weight_tvm, b_tvm)
tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)


Expand Down

0 comments on commit 42de91f

Please sign in to comment.