From 723c08750d7c362c00fc5623b7cc21ca2c125877 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Sun, 21 Apr 2024 12:22:39 -0700 Subject: [PATCH] Help clang produce fma instructions --- sgemm.cpp | 54 +++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/sgemm.cpp b/sgemm.cpp index 799723b2ffe789..531e12af361ccf 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -107,6 +107,45 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED FUSED MULTIPLY ADD + +/** + * Computes a * b + c. + */ +template +inline U madd(T a, T b, U c) { + return add(mul(a, b), c); +} + +#if defined(__FMA__) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +template <> +inline __m256 madd(__m256 a, __m256 b, __m256 c) { + return _mm256_fmadd_ps(a, b, c); +} +#endif +#if defined(__AVX512F__) +template <> +inline __m512 madd(__m512 a, __m512 b, __m512 c) { + return _mm512_fmadd_ps(a, b, c); +} +#endif +#endif + +#if defined(__ARM_FEATURE_FMA) +template <> +inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { + return vfmaq_f32(c, b, a); +} +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) +template <> +inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { + return vfmaq_f16(c, b, a); +} +#endif +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED HORIZONTAL SUM @@ -198,21 +237,6 @@ template <> inline __m512 load(const ggml_fp16_t *p) { } #endif // __AVX512F__ -//////////////////////////////////////////////////////////////////////////////////////////////////// -// ABSTRACTIONS - -/** - * Computes a * b + c. - * - * This operation will become fused into a single arithmetic instruction - * if the hardware has support for this feature, e.g. Intel Haswell+ (c. - * 2013), AMD Bulldozer+ (c. 2011), etc. - */ -template -inline U madd(T a, T b, U c) { - return add(mul(a, b), c); -} - //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION