Skip to content

Commit

Permalink
Merge pull request #5 from listenlink/fp16
Browse files Browse the repository at this point in the history
fp16 blas implementation patch
  • Loading branch information
gongzg committed Feb 8, 2017
2 parents 4d46444 + e1c3c1a commit 996fb92
Show file tree
Hide file tree
Showing 17 changed files with 609 additions and 33 deletions.
447 changes: 447 additions & 0 deletions include/external/clBLAS.h

Large diffs are not rendered by default.

14 changes: 10 additions & 4 deletions include/isaac/common/numeric_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@

#include <stdexcept>
#include "isaac/exception/api.h"

namespace isaac
{

class half{
/* It is a incompleted class for compiling using*/
public:
half() {};
};
enum numeric_type
{
INVALID_NUMERIC_TYPE = 0,
Expand All @@ -41,7 +45,7 @@ enum numeric_type
UINT_TYPE,
LONG_TYPE,
ULONG_TYPE,
// HALF_TYPE,
HALF_TYPE,
FLOAT_TYPE,
DOUBLE_TYPE
};
Expand All @@ -59,7 +63,7 @@ inline std::string to_string(numeric_type const & type)
case UINT_TYPE: return "uint";
case LONG_TYPE: return "long";
case ULONG_TYPE: return "ulong";
// case HALF_TYPE : return "half";
case HALF_TYPE : return "half";
case FLOAT_TYPE : return "float";
case DOUBLE_TYPE : return "double";
default : throw unknown_datatype(type);
Expand All @@ -68,6 +72,7 @@ inline std::string to_string(numeric_type const & type)

inline numeric_type numeric_type_from_string(std::string const & name)
{
if(name=="float16") return HALF_TYPE;
if(name=="float32") return FLOAT_TYPE;
if(name=="float64") return DOUBLE_TYPE;
throw std::invalid_argument("Invalid datatype: " + name);
Expand All @@ -81,7 +86,7 @@ inline unsigned int size_of(numeric_type type)
case UCHAR_TYPE:
case CHAR_TYPE: return 1;

// case HALF_TYPE:
case HALF_TYPE:
case USHORT_TYPE:
case SHORT_TYPE: return 2;

Expand Down Expand Up @@ -128,6 +133,7 @@ template<> struct to_numeric_type<int> { static const numeric_type value = INT_T
template<> struct to_numeric_type<unsigned int> { static const numeric_type value = UINT_TYPE ; };
template<> struct to_numeric_type<long> { static const numeric_type value = LONG_TYPE ; };
template<> struct to_numeric_type<unsigned long> { static const numeric_type value = ULONG_TYPE ; };
template<> struct to_numeric_type<half> { static const numeric_type value = HALF_TYPE; };
template<> struct to_numeric_type<float> { static const numeric_type value = FLOAT_TYPE; };
template<> struct to_numeric_type<double> { static const numeric_type value = DOUBLE_TYPE; };

Expand Down
1 change: 1 addition & 0 deletions include/isaac/value_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class ISAACAPI value_scalar
ISAAC_INSTANTIATE(unsigned long)
ISAAC_INSTANTIATE(long long)
ISAAC_INSTANTIATE(unsigned long long)
ISAAC_INSTANTIATE(half)
ISAAC_INSTANTIATE(float)
ISAAC_INSTANTIATE(double)
#undef ISAAC_INSTANTIATE
Expand Down
102 changes: 101 additions & 1 deletion lib/api/blas/clBLAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,23 @@ extern "C"
return clblasSuccess; \
}

// MAKE_AXPY(H, sc::HALF_TYPE, cl_half)
MAKE_AXPY(S, sc::FLOAT_TYPE, cl_float)
MAKE_AXPY(D, sc::DOUBLE_TYPE, cl_double)

clblasStatus clblasHaxpy(size_t N, cl_float alpha,
const cl_mem mx, size_t offx, int incx,
cl_mem my, size_t offy, int incy,
cl_uint numCommandQueues, cl_command_queue *commandQueues,
cl_uint numEventsInWaitList, const cl_event *eventWaitList,
cl_event *events)
{
sc::array x((sc::int_t)N, sc::HALF_TYPE, sc::driver::Buffer(mx,false), (sc::int_t)offx, incx);
sc::array y((sc::int_t)N, sc::HALF_TYPE, sc::driver::Buffer(my,false), (sc::int_t)offy, incy);
execute(sc::assign(y, alpha*x + y), y.context(), numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);
return clblasSuccess;
}

//SCAL
#define MAKE_SCAL(TYPE_CHAR, TYPE_ISAAC, TYPE_CL) \
clblasStatus clblas ## TYPE_CHAR ## scal(size_t N, TYPE_CL alpha,\
Expand All @@ -98,9 +112,20 @@ extern "C"
return clblasSuccess;\
}

// MAKE_SCAL(H, sc::HALF_TYPE, cl_half)
MAKE_SCAL(S, sc::FLOAT_TYPE, cl_float)
MAKE_SCAL(D, sc::DOUBLE_TYPE, cl_double)

clblasStatus clblasHscal(size_t N, cl_float alpha,
cl_mem mx, size_t offx, int incx,
cl_uint numCommandQueues, cl_command_queue *commandQueues,
cl_uint numEventsInWaitList, const cl_event *eventWaitList, cl_event *events)
{
sc::array x((sc::int_t)N, sc::HALF_TYPE, sc::driver::Buffer(mx,false), (sc::int_t)offx, incx);
execute(sc::assign(x, alpha*x), x.context(), numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);
return clblasSuccess;
}

//COPY
#define MAKE_COPY(TYPE_CHAR, TYPE_ISAAC, TYPE_CL)\
clblasStatus clblas ## TYPE_CHAR ## copy(size_t N,\
Expand All @@ -115,6 +140,7 @@ extern "C"
return clblasSuccess;\
}

MAKE_COPY(H, sc::HALF_TYPE, cl_half)
MAKE_COPY(S, sc::FLOAT_TYPE, cl_float)
MAKE_COPY(D, sc::DOUBLE_TYPE, cl_double)

Expand All @@ -134,6 +160,7 @@ extern "C"
return clblasSuccess; \
}

MAKE_DOT(H, sc::HALF_TYPE, cl_half)
MAKE_DOT(S, sc::FLOAT_TYPE, cl_float)
MAKE_DOT(D, sc::DOUBLE_TYPE, cl_double)

Expand All @@ -155,10 +182,10 @@ extern "C"
return clblasSuccess;\
}

MAKE_ASUM(H, sc::HALF_TYPE, cl_half)
MAKE_ASUM(S, sc::FLOAT_TYPE, cl_float)
MAKE_ASUM(D, sc::DOUBLE_TYPE, cl_double)


//*****************
//BLAS2
//*****************
Expand Down Expand Up @@ -193,6 +220,33 @@ extern "C"
MAKE_GEMV(S, sc::FLOAT_TYPE, cl_float)
MAKE_GEMV(D, sc::DOUBLE_TYPE, cl_double)

clblasStatus clblasHgemv(clblasOrder order, clblasTranspose transA,
size_t M, size_t N,
cl_float alpha, const cl_mem mA, size_t offA, size_t lda,
const cl_mem mx, size_t offx, int incx,
cl_float beta, cl_mem my, size_t offy, int incy,
cl_uint numCommandQueues, cl_command_queue *commandQueues,
cl_uint numEventsInWaitList, const cl_event *eventWaitList, cl_event *events)
{
if(order==clblasRowMajor){
std::swap(M, N);
transA = (transA==clblasTrans||transA==clblasConjTrans)?clblasNoTrans:clblasTrans;
}
sc::array A((sc::int_t)M, (sc::int_t)N, sc::HALF_TYPE, sc::driver::Buffer(mA, false), (sc::int_t)offA, (sc::int_t)lda);

sc::int_t sx = (sc::int_t)N, sy = (sc::int_t)M;
if(transA) std::swap(sx, sy);
sc::array x(sx, sc::HALF_TYPE, sc::driver::Buffer(mx, false), (sc::int_t)offx, incx);
sc::array y(sy, sc::HALF_TYPE, sc::driver::Buffer(my, false), (sc::int_t)offy, incy);
\
sc::driver::Context const & context = A.context();
if(transA==clblasTrans||transA==clblasConjTrans)
execute(sc::assign(y, alpha*dot(A.T, x) + beta*y), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);
else\
execute(sc::assign(y, alpha*dot(A, x) + beta*y), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);
return clblasSuccess;
}

//*****************
//BLAS3
//*****************
Expand Down Expand Up @@ -246,6 +300,52 @@ extern "C"
MAKE_GEMM(S, sc::FLOAT_TYPE, cl_float)
MAKE_GEMM(D, sc::DOUBLE_TYPE, cl_double)

clblasStatus clblasHgemm(clblasOrder order, clblasTranspose transA, clblasTranspose transB,
size_t M, size_t N, size_t K,
cl_float alpha, const cl_mem cmA, size_t offA, size_t lda,
const cl_mem cmB, size_t offB, size_t ldb, cl_float beta,
cl_mem mC, size_t offC, size_t ldc,
cl_uint numCommandQueues, cl_command_queue *commandQueues,
cl_uint numEventsInWaitList, const cl_event *eventWaitList, cl_event *events)
{
cl_mem mA = cmA;
cl_mem mB = cmB;
if(order==clblasRowMajor){
std::swap(mA, mB);
std::swap(offA, offB);
std::swap(lda, ldb);
std::swap(M, N);
std::swap(transA, transB);
}
if(K==1 && M>1 && N>1){
sc::array A((sc::int_t)M, sc::HALF_TYPE, sc::driver::Buffer(mA, false), (sc::int_t)offA, transA==clblasNoTrans?1:lda);
sc::array B((sc::int_t)N, sc::HALF_TYPE, sc::driver::Buffer(mB, false), (sc::int_t)offB, transB==clblasTrans?1:ldb);
sc::array C((sc::int_t)M, (sc::int_t)N, sc::HALF_TYPE, sc::driver::Buffer(mC, false), (sc::int_t)offC, (sc::int_t)ldc);
execute(sc::assign(C, alpha*sc::outer(A, B) + beta*C), C.context(), numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);
return clblasSuccess;
}
sc::int_t As1 = (sc::int_t)M, As2 = (sc::int_t)K;
sc::int_t Bs1 = (sc::int_t)K, Bs2 = (sc::int_t)N;
if(transA==clblasTrans || transA==clblasConjTrans) std::swap(As1, As2);
if(transB==clblasTrans || transB==clblasConjTrans) std::swap(Bs1, Bs2);
/*Struct*/
sc::array A(As1, As2, sc::HALF_TYPE, sc::driver::Buffer(mA, false), (sc::int_t)offA, (sc::int_t)lda);
sc::array B(Bs1, Bs2, sc::HALF_TYPE, sc::driver::Buffer(mB, false), (sc::int_t)offB, (sc::int_t)ldb);
sc::array C((sc::int_t)M, (sc::int_t)N, sc::HALF_TYPE, sc::driver::Buffer(mC, false), (sc::int_t)offC, (sc::int_t)ldc);
sc::driver::Context const & context = C.context();
/*Operation*/
if((transA==clblasTrans || transA==clblasConjTrans) && (transB==clblasTrans || transB==clblasConjTrans))
execute(sc::assign(C, alpha*dot(A.T, B.T) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);
else if((transA==clblasTrans || transA==clblasConjTrans) && (transB==clblasNoTrans))
execute(sc::assign(C, alpha*dot(A.T, B) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);
else if((transA==clblasNoTrans) && (transB==clblasTrans || transB==clblasConjTrans))
execute(sc::assign(C, alpha*dot(A, B.T) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);
else
execute(sc::assign(C, alpha*dot(A, B) + beta*C), context, numCommandQueues, commandQueues, numEventsInWaitList, eventWaitList, events);
return clblasSuccess;
}


#undef DOT

}
9 changes: 6 additions & 3 deletions lib/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ INSTANTIATE(long);\
INSTANTIATE(unsigned long);\
INSTANTIATE(long long);\
INSTANTIATE(unsigned long long);\
INSTANTIATE(half);\
INSTANTIATE(float);\
INSTANTIATE(double);

Expand Down Expand Up @@ -212,6 +213,7 @@ INSTANTIATE(long);
INSTANTIATE(unsigned long);
INSTANTIATE(long long);
INSTANTIATE(unsigned long long);
INSTANTIATE(half);
INSTANTIATE(float);
INSTANTIATE(double);
#undef INSTANTIATE
Expand Down Expand Up @@ -418,6 +420,7 @@ TYPE scalar::cast() const
HANDLE_CASE(ULONG_TYPE, uint64);
HANDLE_CASE(FLOAT_TYPE, float32);
HANDLE_CASE(DOUBLE_TYPE, float64);

default: throw unknown_datatype(dtype_);
}
#undef HANDLE_CASE
Expand Down Expand Up @@ -479,7 +482,6 @@ std::ostream & operator<<(std::ostream & os, scalar const & s)
case UINT_TYPE: return os << static_cast<unsigned int>(s);
case LONG_TYPE: return os << static_cast<long>(s);
case ULONG_TYPE: return os << static_cast<unsigned long>(s);
// case HALF_TYPE: return os << static_cast<half>(s);
case FLOAT_TYPE: return os << static_cast<float>(s);
case DOUBLE_TYPE: return os << static_cast<double>(s);
default: throw unknown_datatype(s.dtype());
Expand Down Expand Up @@ -574,7 +576,7 @@ expression_tree OPNAME (array_base const & x) \
expression_tree OPNAME (expression_tree const & x) \
{ return expression_tree(x, invalid_node(), op_element(UNARY_ARITHMETIC, OP), &x.context(), x.dtype(), x.shape()); }

DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE)?FABS_TYPE:ABS_TYPE, abs)
DEFINE_ELEMENT_UNARY_OPERATOR((x.dtype()==FLOAT_TYPE || x.dtype()==DOUBLE_TYPE || x.dtype()==HALF_TYPE)?FABS_TYPE:ABS_TYPE, abs)
DEFINE_ELEMENT_UNARY_OPERATOR(ACOS_TYPE, acos)
DEFINE_ELEMENT_UNARY_OPERATOR(ASIN_TYPE, asin)
DEFINE_ELEMENT_UNARY_OPERATOR(ATAN_TYPE, atan)
Expand Down Expand Up @@ -647,7 +649,7 @@ inline operation_type casted(numeric_type dtype)
case UINT_TYPE: return CAST_UINT_TYPE;
case LONG_TYPE: return CAST_LONG_TYPE;
case ULONG_TYPE: return CAST_ULONG_TYPE;
// case FLOAT_TYPE: return CAST_HALF_TYPE;
case HALF_TYPE: return CAST_HALF_TYPE;
case FLOAT_TYPE: return CAST_FLOAT_TYPE;
case DOUBLE_TYPE: return CAST_DOUBLE_TYPE;
default: throw unknown_datatype(dtype);
Expand Down Expand Up @@ -1025,6 +1027,7 @@ INSTANTIATE(long);
INSTANTIATE(unsigned long);
INSTANTIATE(long long);
INSTANTIATE(unsigned long long);
INSTANTIATE(half);
INSTANTIATE(float);
INSTANTIATE(double);

Expand Down
1 change: 0 additions & 1 deletion lib/driver/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ void Kernel::setArg(unsigned int index, value_scalar const & scal)
case UINT_TYPE: setArg(index, scal.values().uint32); break;
case LONG_TYPE: setArg(index, scal.values().int64); break;
case ULONG_TYPE: setArg(index, scal.values().uint64); break;
//case HALF_TYPE: setArg(index, scal.values().float16); break;
case FLOAT_TYPE: setArg(index, scal.values().float32); break;
case DOUBLE_TYPE: setArg(index, scal.values().float64); break;
default: throw unknown_datatype(scal.dtype());
Expand Down
2 changes: 2 additions & 0 deletions lib/jit/generation/elementwise_1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ std::string elementwise_1d::generate_impl(std::string const & suffix, expression
case driver::CUDA:
stream << "#include \"vector.h\"" << std::endl; break;
case driver::OPENCL:
if(tree.dtype()==HALF_TYPE)
stream << "#pragma OPENCL EXTENSION cl_khr_fp16: enable" << std::endl;
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl; break;
}

Expand Down
2 changes: 2 additions & 0 deletions lib/jit/generation/elementwise_2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ std::string elementwise_2d::generate_impl(std::string const & suffix, expression
case driver::CUDA:
stream << "#include \"vector.h\"" << std::endl; break;
case driver::OPENCL:
if(tree.dtype()==HALF_TYPE)
stream << "#pragma OPENCL EXTENSION cl_khr_fp16: enable" << std::endl;
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl; break;
}

Expand Down
11 changes: 7 additions & 4 deletions lib/jit/generation/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ std::string gemm::generate_impl(std::string const & suffix, expression_tree cons
numeric_type dtype = tree.dtype();
std::string sdtype = to_string(dtype);
std::string vdtype = append_width(sdtype, vwidth_);
std::string abdtype = (sdtype == "half")? "float" : sdtype;

//////////////////
/// DECLARATIONS
Expand All @@ -231,18 +232,19 @@ std::string gemm::generate_impl(std::string const & suffix, expression_tree cons
switch(backend)
{
case driver::OPENCL:
if(tree.dtype()==HALF_TYPE)
stream << "#pragma OPENCL EXTENSION cl_khr_fp16: enable" << std::endl;
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl;
break;
default:
break;
}

stream << "$KERNEL void gemm" << suffix << "($SIZE_T M, $SIZE_T N, $SIZE_T K, "
<< "$GLOBAL " << sdtype << "* C, $SIZE_T ldc, $SIZE_T offc, $SIZE_T Cstride1, "
<< sdtype << " alpha,"
<< abdtype << " alpha,"
<< "$GLOBAL " << sdtype << "* A, $SIZE_T lda, $SIZE_T offa, $SIZE_T Astride1,"
<< "$GLOBAL " << sdtype << "* B, $SIZE_T ldb, $SIZE_T offb, $SIZE_T Bstride1,"
<< sdtype << " beta)"
<< abdtype << " beta)"
<< std::endl;
stream << "{" << std::endl;
stream.inc_tab();
Expand Down Expand Up @@ -625,8 +627,9 @@ std::string gemm::generate_impl(std::string const & suffix, expression_tree cons
stream << "$KERNEL void reduce" << suffix << "($SIZE_T M, $SIZE_T N, $SIZE_T D, "
<< "$GLOBAL " << sdtype << "* Z, $SIZE_T Zld,"
<< "$GLOBAL " << sdtype << "* C, $SIZE_T ldc, $SIZE_T Cstart, $SIZE_T Cstride,"
<< sdtype << " beta)"
<< abdtype << " beta)"
<< std::endl;

stream << "{" << std::endl;
stream.inc_tab();

Expand Down
3 changes: 3 additions & 0 deletions lib/jit/generation/reduce_1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
case driver::CUDA:
stream << "#include \"vector.h\"" << std::endl; break;
case driver::OPENCL:
if(tree.dtype()==HALF_TYPE)
stream << "#pragma OPENCL EXTENSION cl_khr_fp16: enable" << std::endl;
stream << " __attribute__((reqd_work_group_size(" << ls0_ << ",1,1)))" << std::endl; break;
}
stream << "$KERNEL void prod" << suffix << "($SIZE_T N, $GLOBAL char* tmp," << tools::join(kernel_arguments(device, symbols, tree), ", ") << ")" << std::endl;
Expand Down Expand Up @@ -244,6 +246,7 @@ std::string reduce_1d::generate_impl(std::string const & suffix, expression_tree
stream << "}" << std::endl;
stream.dec_tab();
stream << "}" << std::endl;
// std::cout<<"reduce 1d: "<<stream.str()<<std::endl;

return stream.str();
}
Expand Down
2 changes: 2 additions & 0 deletions lib/jit/generation/reduce_2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ std::string reduce_2d::generate_impl(std::string const & suffix, expression_tree
stream << "#include \"vector.h\"" << std::endl;
break;
case driver::OPENCL:
if(tree.dtype()==HALF_TYPE)
stream << "#pragma OPENCL EXTENSION cl_khr_fp16: enable" << std::endl;
stream << " __attribute__((reqd_work_group_size(" << ls0_ << "," << ls1_ << ",1)))" << std::endl;
break;
}
Expand Down
2 changes: 2 additions & 0 deletions lib/jit/syntax/expression/operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ std::string to_string(operation_type type)
case CAST_UINT_TYPE : return "(uint)";
case CAST_LONG_TYPE : return "(long)";
case CAST_ULONG_TYPE : return "(ulong)";
case CAST_HALF_TYPE: return "(half)";
case CAST_FLOAT_TYPE : return "(float)";
case CAST_DOUBLE_TYPE : return "(double)";

Expand Down Expand Up @@ -150,6 +151,7 @@ bool is_cast(operation_type op)
|| op == CAST_UINT_TYPE
|| op == CAST_LONG_TYPE
|| op == CAST_ULONG_TYPE
|| op == CAST_HALF_TYPE
|| op == CAST_FLOAT_TYPE
|| op == CAST_DOUBLE_TYPE
;
Expand Down
Loading

0 comments on commit 996fb92

Please sign in to comment.