Skip to content

Commit

Permalink
Help clang produce fma instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Apr 21, 2024
1 parent 9d4d14c commit 723c087
Showing 1 changed file with 39 additions and 15 deletions.
54 changes: 39 additions & 15 deletions sgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,45 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC

////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED FUSED MULTIPLY ADD

/**
* Computes a * b + c.
*/
template <typename T, typename U>
inline U madd(T a, T b, U c) {
return add(mul(a, b), c);
}

#if defined(__FMA__)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <>
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
return _mm256_fmadd_ps(a, b, c);
}
#endif
#if defined(__AVX512F__)
template <>
inline __m512 madd(__m512 a, __m512 b, __m512 c) {
return _mm512_fmadd_ps(a, b, c);
}
#endif
#endif

#if defined(__ARM_FEATURE_FMA)
template <>
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
return vfmaq_f32(c, b, a);
}
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
template <>
inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
return vfmaq_f16(c, b, a);
}
#endif
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED HORIZONTAL SUM

Expand Down Expand Up @@ -198,21 +237,6 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
}
#endif // __AVX512F__

////////////////////////////////////////////////////////////////////////////////////////////////////
// ABSTRACTIONS

/**
* Computes a * b + c.
*
* This operation will become fused into a single arithmetic instruction
* if the hardware has support for this feature, e.g. Intel Haswell+ (c.
* 2013), AMD Bulldozer+ (c. 2011), etc.
*/
template <typename T, typename U>
inline U madd(T a, T b, U c) {
return add(mul(a, b), c);
}

////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT MATRIX MULTIPLICATION

Expand Down

0 comments on commit 723c087

Please sign in to comment.