Skip to content

Commit

Permalink
[METAL] Fix codegen for inf and erf (apache#8054)
Browse files Browse the repository at this point in the history
* [METAL] Fix codegen for inf and erf

Fixed Metal codegen with using `inf` constant. Constant `INFINITY` is
used now instead of `inf`.
Also, Metal doesn't have `erf` built-in function. So, we are using
`fast_erf` from tir. User will see warning message when we will
generate `fast_erf` instead of `erf`.

* Apply comments

* Fix clang-format

* Fix lint
  • Loading branch information
echuraev authored and trevor-m committed Jun 17, 2021
1 parent a39009e commit d02f9a0
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 42 deletions.
86 changes: 46 additions & 40 deletions include/tvm/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,54 +456,60 @@ inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp",
}

/*!
* \brief Fast_tanh_float implementation from Eigen
* \brief Fast_erf_float expression from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290
* \param arg The input expression.
* \param bits The number of bits in the type.
*/
inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string tag) {
auto plus_4 = make_const(DataType::Float(32), 4.f);
auto minus_4 = make_const(DataType::Float(32), -4.f);
inline PrimExpr fast_erf_float_expr(PrimExpr arg, int bits) {
auto plus_4 = make_const(DataType::Float(bits), 4.f);
auto minus_4 = make_const(DataType::Float(bits), -4.f);

// The monomial coefficients of the numerator polynomial (odd).
auto alpha_1 = make_const(DataType::Float(32), -1.60960333262415e-02f);
auto alpha_3 = make_const(DataType::Float(32), -2.95459980854025e-03f);
auto alpha_5 = make_const(DataType::Float(32), -7.34990630326855e-04f);
auto alpha_7 = make_const(DataType::Float(32), -5.69250639462346e-05f);
auto alpha_9 = make_const(DataType::Float(32), -2.10102402082508e-06f);
auto alpha_11 = make_const(DataType::Float(32), 2.77068142495902e-08f);
auto alpha_13 = make_const(DataType::Float(32), -2.72614225801306e-10f);
auto alpha_1 = make_const(DataType::Float(bits), -1.60960333262415e-02f);
auto alpha_3 = make_const(DataType::Float(bits), -2.95459980854025e-03f);
auto alpha_5 = make_const(DataType::Float(bits), -7.34990630326855e-04f);
auto alpha_7 = make_const(DataType::Float(bits), -5.69250639462346e-05f);
auto alpha_9 = make_const(DataType::Float(bits), -2.10102402082508e-06f);
auto alpha_11 = make_const(DataType::Float(bits), 2.77068142495902e-08f);
auto alpha_13 = make_const(DataType::Float(bits), -2.72614225801306e-10f);

// The monomial coefficients of the denominator polynomial (even).
auto beta_0 = make_const(DataType::Float(32), -1.42647390514189e-02f);
auto beta_2 = make_const(DataType::Float(32), -7.37332916720468e-03f);
auto beta_4 = make_const(DataType::Float(32), -1.68282697438203e-03f);
auto beta_6 = make_const(DataType::Float(32), -2.13374055278905e-04f);
auto beta_8 = make_const(DataType::Float(32), -1.45660718464996e-05f);
auto beta_0 = make_const(DataType::Float(bits), -1.42647390514189e-02f);
auto beta_2 = make_const(DataType::Float(bits), -7.37332916720468e-03f);
auto beta_4 = make_const(DataType::Float(bits), -1.68282697438203e-03f);
auto beta_6 = make_const(DataType::Float(bits), -2.13374055278905e-04f);
auto beta_8 = make_const(DataType::Float(bits), -1.45660718464996e-05f);

// clamp x
auto x = tvm::max(tvm::min(arg, plus_4), minus_4);
auto x2 = x * x;

// Evaluate the numerator polynomial p.
auto p = x2 * alpha_13 + alpha_11;
p = x2 * p + alpha_9;
p = x2 * p + alpha_7;
p = x2 * p + alpha_5;
p = x2 * p + alpha_3;
p = x2 * p + alpha_1;
p = x * p;

// Evaluate the denominator polynomial p.
auto q = x2 * beta_8 + beta_6;
q = x2 * q + beta_4;
q = x2 * q + beta_2;
q = x2 * q + beta_0;

return p / q;
}

/*!
* \brief Fast_erf_float expression from Eigen
*/
inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string tag) {
return compute(
data->shape,
[&](const Array<Var>& i) {
// clamp x
auto x = tvm::max(tvm::min(data(i), plus_4), minus_4);
auto x2 = x * x;

// Evaluate the numerator polynomial p.
auto p = x2 * alpha_13 + alpha_11;
p = x2 * p + alpha_9;
p = x2 * p + alpha_7;
p = x2 * p + alpha_5;
p = x2 * p + alpha_3;
p = x2 * p + alpha_1;
p = x * p;

// Evaluate the denominator polynomial p.
auto q = x2 * beta_8 + beta_6;
q = x2 * q + beta_4;
q = x2 * q + beta_2;
q = x2 * q + beta_0;

return p / q;
},
name, tag);
data->shape, [&](const Array<Var>& i) { return fast_erf_float_expr(data(i), 32); }, name,
tag);
}

/*!
Expand Down
20 changes: 20 additions & 0 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,26 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT
}
}

void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
std::ostringstream temp;
if (std::isinf(op->value)) {
if (op->value < 0) {
temp << "-";
}
temp << "INFINITY";
} else if (std::isnan(op->value)) {
temp << "NAN";
} else {
temp << std::scientific << op->value;
if (op->dtype.bits() == 32)
temp << 'f';
else if (op->dtype.bits() == 16)
temp << 'h';
}
MarkConst(temp.str());
os << temp.str();
}

runtime::Module BuildMetal(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
Expand Down
4 changes: 2 additions & 2 deletions src/target/source/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class CodeGenMetal final : public CodeGenC {
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
// overload visitor
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
// reuse parent's function.
using CodeGenC::PrintType;

Expand Down
19 changes: 19 additions & 0 deletions src/target/source/intrin_rule_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief Metal intrinsic rules.
*/
#include <tvm/tir/op_attr_types.h>
#include <tvm/topi/elemwise.h>

#include "../intrin_rule.h"

Expand Down Expand Up @@ -90,6 +91,24 @@ TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
TVM_REGISTER_OP("tir.cosh")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);

// There is no erf function in Metal. When erf is used, we use fast_erf instead
static PrimExpr DispatchFastErf(const PrimExpr& e) {
LOG(WARNING) << " Metal doesn't have built-in erf function. fast_erf will be used instead.";
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 1);
PrimExpr arg = call->args[0];
int bits = arg.dtype().bits();
bool isFloat = arg.dtype().is_float();
PrimExpr res;
if (isFloat && (bits == 16 || bits == 32))
res = topi::fast_erf_float_expr(arg, bits);
else
LOG(FATAL) << "Unsupported type in Metal fast_erf";
return res;
}
TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchFastErf);

} // namespace intrin
} // namespace codegen
} // namespace tvm
81 changes: 81 additions & 0 deletions tests/python/unittest/test_target_codegen_metal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import numpy as np
from tvm import topi
import unittest
from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
from tvm.contrib import nvcc
import tvm.testing

tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_metal_inf_nan():
target = "metal"

def check_inf_nan(dev, n, value, dtype):
A = te.placeholder((n,), name="A", dtype=dtype)
inf_value = tvm.tir.const(value, dtype=dtype)
C = te.compute((n,), lambda i: inf_value, name="C")
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tx)
fun = tvm.build(s, [A, C], target)
a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
# Only need to test compiling here
fun(a, c)

dev = tvm.device(target, 0)

check_inf_nan(dev, 1, -float("inf"), "float32")
check_inf_nan(dev, 1, -float("inf"), "float16")
check_inf_nan(dev, 1, float("inf"), "float32")
check_inf_nan(dev, 1, float("inf"), "float16")
check_inf_nan(dev, 1, float("nan"), "float32")
check_inf_nan(dev, 1, float("nan"), "float16")


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_metal_erf():
target = "metal"

def check_erf(dev, n, dtype):
A = te.placeholder((n,), name="A", dtype=dtype)
C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C")
s = te.create_schedule(C.op)
s[C].bind(s[C].op.axis[0], tx)
fun = tvm.build(s, [A, C], target)
a = tvm.nd.empty((n,), A.dtype, dev)
c = tvm.nd.empty((n,), A.dtype, dev)
# Only need to test compiling here
fun(a, c)

dev = tvm.device(target, 0)

check_erf(dev, 1, "float32")
check_erf(dev, 1, "float16")


if __name__ == "__main__":
test_metal_inf_nan()
test_metal_erf()

0 comments on commit d02f9a0

Please sign in to comment.