Skip to content

Commit

Permalink
Use unary primitive in utils op (#8466)
Browse files Browse the repository at this point in the history
* Use unary primitive in utils op

* delete unused code

* format

* delete half kernel

* address review

* address review

* rename xpu to primitive

* support half and bfloat16

* fix

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
EsdeathYZH and mergify[bot] authored Jun 29, 2022
1 parent 57869e9 commit 99b1eff
Show file tree
Hide file tree
Showing 20 changed files with 196 additions and 298 deletions.
4 changes: 4 additions & 0 deletions oneflow/core/ep/common/primitive/elementwise_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ namespace primitive {

#define UNARY_LOGICAL_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLogicalNot)

#define UNARY_UTILS_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsInf) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kIsNan)

} // namespace primitive
} // namespace ep
} // namespace oneflow
Expand Down
14 changes: 14 additions & 0 deletions oneflow/core/ep/common/primitive/unary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,20 @@ struct UnaryFunctor<device, UnaryOp::kLogicalNot, Dst, Src> {
OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(!src); }
};

template<DeviceType device, typename Src>
struct UnaryFunctor<device, UnaryOp::kIsInf, bool, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(Src src) const { return false; }
};

template<DeviceType device, typename Src>
struct UnaryFunctor<device, UnaryOp::kIsNan, bool, Src> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(Src src) const { return false; }
};

} // namespace primitive
} // namespace ep
} // namespace oneflow
Expand Down
7 changes: 6 additions & 1 deletion oneflow/core/ep/cpu/primitive/elementwise_unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#include "oneflow/core/ep/common/primitive/elementwise_unary.h"
#include "oneflow/core/common/scalar.h"
#include "oneflow/core/ep/cpu/primitive/unary_functor.h"
#include "oneflow/core/ep/cpu/primitive/type_seq.h"
#include "oneflow/core/ep/cpu/cpu_stream.h"
#include "oneflow/core/ep/cpu/cpu_device.h"

Expand Down Expand Up @@ -92,6 +91,12 @@ class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory {
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_FLOATING_MATH_OP_SEQ,
CPU_PRIMITIVE_FLOATING_TYPE_SEQ)

// For Utils OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_UTILS_OP_SEQ, UTIL_OPS_DATA_TYPE_SEQ,
CPU_PRIMITIVE_BOOL_TYPE_SEQ)

// For Logical OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_LOGICAL_OP_SEQ, CPU_PRIMITIVE_NATIVE_TYPE_SEQ,
Expand Down
8 changes: 8 additions & 0 deletions oneflow/core/ep/cpu/primitive/type_seq.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,12 @@ limitations under the License.
CPU_PRIMITIVE_FLOAT_TYPE_SEQ \
CPU_PRIMITIVE_DOUBLE_TYPE_SEQ

#define UTIL_OPS_DATA_TYPE_SEQ \
CPU_PRIMITIVE_INT8_TYPE_SEQ \
CPU_PRIMITIVE_UINT8_TYPE_SEQ \
CPU_PRIMITIVE_INT32_TYPE_SEQ \
CPU_PRIMITIVE_INT64_TYPE_SEQ \
CPU_PRIMITIVE_FLOAT_TYPE_SEQ \
CPU_PRIMITIVE_DOUBLE_TYPE_SEQ

#endif // ONEFLOW_CORE_EP_CPU_PRIMITIVE_TYPE_SEQ_H_
28 changes: 28 additions & 0 deletions oneflow/core/ep/cpu/primitive/unary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,34 @@ struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kTanh, Dst, Src> {
OF_DEVICE_FUNC Dst operator()(Src src) const { return std::tanh(src); }
};

template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsInf, bool, float> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(float src) const { return std::isinf(src); }
};

template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsInf, bool, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(double src) const { return std::isinf(src); }
};

template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsNan, bool, float> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(float src) const { return std::isnan(src); }
};

template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsNan, bool, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(double src) const { return std::isnan(src); }
};

} // namespace primitive
} // namespace ep
} // namespace oneflow
6 changes: 6 additions & 0 deletions oneflow/core/ep/cuda/primitive/elementwise_unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory {
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_FLOATING_MATH_OP_SEQ,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)

// For Utils OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_UTILS_OP_SEQ, UTIL_OPS_DATA_TYPE_SEQ,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ)

// For Logical OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ,
Expand Down
10 changes: 10 additions & 0 deletions oneflow/core/ep/cuda/primitive/type_seq.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ limitations under the License.
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ

#define UTIL_OPS_DATA_TYPE_SEQ \
CUDA_PRIMITIVE_INT8_TYPE_SEQ \
CUDA_PRIMITIVE_UINT8_TYPE_SEQ \
CUDA_PRIMITIVE_INT32_TYPE_SEQ \
CUDA_PRIMITIVE_INT64_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT_TYPE_SEQ \
CUDA_PRIMITIVE_DOUBLE_TYPE_SEQ \
CUDA_PRIMITIVE_FLOAT16_TYPE_SEQ \
CUDA_PRIMITIVE_BFLOAT16_TYPE_SEQ

#endif // WITH_CUDA

#endif // ONEFLOW_CORE_EP_CUDA_PRIMITIVE_TYPE_SEQ_H_
56 changes: 56 additions & 0 deletions oneflow/core/ep/cuda/primitive/unary_functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,48 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTanh, half, half> {
OF_DEVICE_FUNC half operator()(half src) const { return __float2half(tanhf(__half2float(src))); }
};

template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, half> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(half src) const { return isinf(__half2float(src)); }
};

template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, float> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(float src) const { return isinf(src); }
};

template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(double src) const { return isinf(src); }
};

template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, half> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(half src) const { return isnan(__half2float(src)); }
};

template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, float> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(float src) const { return isnan(src); }
};

template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(double src) const { return isnan(src); }
};

#define SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(op) \
template<> \
struct UnaryFunctor<DeviceType::kCUDA, op, half, half> { \
Expand Down Expand Up @@ -105,6 +147,20 @@ SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSoftPlus);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTanh);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kThreshold);

template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, nv_bfloat16> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isinf(__bfloat162float(src)); }
};

template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsNan, bool, nv_bfloat16> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC bool operator()(nv_bfloat16 src) const { return isnan(__bfloat162float(src)); }
};

#endif

} // namespace primitive
Expand Down
7 changes: 7 additions & 0 deletions oneflow/core/ep/include/primitive/unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace ep {
namespace primitive {

enum class UnaryOp {
// activation op
kElu,
kCelu,
kRelu,
Expand All @@ -39,7 +40,13 @@ enum class UnaryOp {
kSoftPlus,
kTanh,
kThreshold,

// logical op
kLogicalNot,

// utils op
kIsInf,
kIsNan,
};

}
Expand Down
29 changes: 1 addition & 28 deletions oneflow/user/kernels/activation_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,10 @@ limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/binary_op.h"
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/user/kernels/elementwise_xpu_kernel.h"
#include "oneflow/user/kernels/elementwise_primitive_kernel.h"

namespace oneflow {

namespace {
auto UnaryPrimitiveExists(ep::primitive::UnaryOp op, const std::string& output_name,
const std::string& input_name) {
return hob::make_custom(
"ElementwiseUnaryPrimitiveExists", [=](const user_op::KernelRegContext& ctx) {
const user_op::TensorDesc* src = ctx.TensorDesc4ArgNameAndIndex(input_name, 0);
const user_op::TensorDesc* dst = ctx.TensorDesc4ArgNameAndIndex(output_name, 0);
auto primitive = ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(
ctx.device_type(), op, src->data_type(), dst->data_type());
return primitive.operator bool();
});
}

auto BinaryPrimitiveExists(ep::primitive::BinaryOp op, const std::string& output_name,
const std::string& input_a_name) {
return hob::make_custom(
"BroadcastElementwiseBinaryPrimitiveExists", [=](const user_op::KernelRegContext& ctx) {
const user_op::TensorDesc* src0 = ctx.TensorDesc4ArgNameAndIndex(input_a_name, 0);
const user_op::TensorDesc* dst = ctx.TensorDesc4ArgNameAndIndex(output_name, 0);
auto primitive =
ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(
ctx.device_type(), op, src0->data_type(), dst->data_type(), 1 /*max_num_dims*/);
return primitive.operator bool();
});
}
} // namespace

REGISTER_USER_KERNEL("elu")
.SetCreateFn([]() {
return user_op::NewOpKernel<UnaryPrimitiveKernel>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,56 +25,6 @@ limitations under the License.
#include "oneflow/core/kernel/cuda_graph_support.h"

namespace oneflow {
template<DeviceType device_type, typename FunctorT, typename OutputT, typename InputA>
struct UnaryElemwiseXpuLauncher final {
void operator()(ep::Stream* stream, int64_t elem_cnt, OutputT* out, const InputA* input_a,
FunctorT functor);
};

template<typename FunctorT, typename OutputT, typename InputA>
struct UnaryElemwiseXpuLauncher<DeviceType::kCPU, FunctorT, OutputT, InputA> final {
void operator()(ep::Stream* stream, int64_t elem_cnt, OutputT* out, const InputA* input_a,
FunctorT functor) {
FOR_RANGE(int64_t, i, 0, elem_cnt) { out[i] = functor(input_a[i]); }
}
};

template<DeviceType device_type, typename FunctorT, typename OutputT, typename InputA>
class UnaryElemwiseXpuKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {
public:
OF_DISALLOW_COPY_AND_MOVE(UnaryElemwiseXpuKernel);
UnaryElemwiseXpuKernel() = default;
~UnaryElemwiseXpuKernel() = default;

UnaryElemwiseXpuKernel(
std::function<FunctorT(user_op::KernelComputeContext* ctx)> FunctorCreateFn,
const std::string& output_name, const std::string& input_a_name)
: FunctorCreateFn(FunctorCreateFn), output_name(output_name), input_a_name(input_a_name) {}

std::function<FunctorT(user_op::KernelComputeContext* ctx)> FunctorCreateFn; // The functor

private:
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* input_a_tensor = ctx->Tensor4ArgNameAndIndex(input_a_name, 0);
user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex(output_name, 0);

const ShapeView input_a_shape = input_a_tensor->shape_view();
const ShapeView out_shape = out_tensor->shape_view();
CHECK_EQ(input_a_shape, out_shape);

const InputA* input_a_ptr = input_a_tensor->dptr<InputA>();
OutputT* out_ptr = out_tensor->mut_dptr<OutputT>();
const int64_t elem_cnt = input_a_shape.elem_cnt();

UnaryElemwiseXpuLauncher<device_type, FunctorT, OutputT, InputA>()(
ctx->stream(), elem_cnt, out_ptr, input_a_ptr, FunctorCreateFn(ctx));
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }

std::string output_name;
std::string input_a_name;
};

class UnaryPrimitiveKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport {
public:
Expand Down Expand Up @@ -164,18 +114,32 @@ class BinaryPrimitiveKernel final : public user_op::OpKernel, public user_op::Cu
PrimitiveFactoryFuncType primitive_factory_func_;
};

#define REGISTER_UNARY_ELEMWISE_USER_KERNEL(device, kernel_name, functor, out_dtype, \
input_a_dtype, create_function, out_name, \
input_a_name) \
REGISTER_USER_KERNEL(kernel_name) \
.SetCreateFn([]() { \
return user_op::NewOpKernel< \
UnaryElemwiseXpuKernel<device, functor<out_dtype>, out_dtype, input_a_dtype>>( \
create_function, out_name, input_a_name); \
}) \
.SetIsMatchedHob( \
(user_op::HobDeviceType() == device) \
&& (user_op::HobDataType(input_a_name, 0) == GetDataType<out_dtype>::value));
namespace {
auto UnaryPrimitiveExists(ep::primitive::UnaryOp op, const std::string& output_name,
const std::string& input_name) {
return hob::make_custom(
"ElementwiseUnaryPrimitiveExists", [=](const user_op::KernelRegContext& ctx) {
const user_op::TensorDesc* src = ctx.TensorDesc4ArgNameAndIndex(input_name, 0);
const user_op::TensorDesc* dst = ctx.TensorDesc4ArgNameAndIndex(output_name, 0);
auto primitive = ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(
ctx.device_type(), op, src->data_type(), dst->data_type());
return primitive.operator bool();
});
}

auto BinaryPrimitiveExists(ep::primitive::BinaryOp op, const std::string& output_name,
const std::string& input_a_name) {
return hob::make_custom(
"BroadcastElementwiseBinaryPrimitiveExists", [=](const user_op::KernelRegContext& ctx) {
const user_op::TensorDesc* src0 = ctx.TensorDesc4ArgNameAndIndex(input_a_name, 0);
const user_op::TensorDesc* dst = ctx.TensorDesc4ArgNameAndIndex(output_name, 0);
auto primitive =
ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(
ctx.device_type(), op, src0->data_type(), dst->data_type(), 1 /*max_num_dims*/);
return primitive.operator bool();
});
}
} // namespace

} // namespace oneflow

Expand Down
34 changes: 0 additions & 34 deletions oneflow/user/kernels/elementwise_xpu_kernel.cuh

This file was deleted.

Loading

0 comments on commit 99b1eff

Please sign in to comment.