Skip to content

Commit

Permalink
better support. Just uses avx2
Browse files Browse the repository at this point in the history
  • Loading branch information
harrisonvanderbyl committed Dec 20, 2023
1 parent dfdcd11 commit aaccf9e
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 11 deletions.
4 changes: 2 additions & 2 deletions SCsub
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ if(env["PLATFORM"] != "win32"):
env_rwkv.add_source_files(env.modules_sources, "*.cpp")

# add -march=avx512
env_rwkv.Append(CCFLAGS=['-march=skylake-avx512'])
env_rwkv.Append(CCFLAGS=['-march=haswell'])

# -fopenmp -flto -fopenmp -funroll-loops -D_GLIBCXX_PARALLEL

Expand Down Expand Up @@ -46,7 +46,7 @@ else:
env_rwkv.Append(CPPPATH=[path + "/rwkv.hpp/include"])

# set avx512
env_rwkv.Append(CCFLAGS=['/arch:AVX512'])
env_rwkv.Append(CCFLAGS=['/arch:AVX2'])


env_rwkv.add_source_files(env.modules_sources, "*.cpp")
1 change: 0 additions & 1 deletion rwkv.hpp/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ fi
# build with intel compiler
# check if intel compiler installed
if [ "$(icpx --version)" ]; then
source /opt/intel/oneapi/2024.0/oneapi-vars.sh
icpx -m64 ./rwkv.cpp -I ./include/ -o ./build/rwkv -march=native -std=c++17 $vulk -ffast-math -O3 -pthread

else
Expand Down
15 changes: 10 additions & 5 deletions rwkv.hpp/include/hvml/intrinsics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
#include <cmath>
#define UINT8THREADALLOC 64

#ifdef __AVX512F__ // This macro is defined if AVX-512 is supported
#if defined(__AVX512F__) && defined(HVMLUSEAVX512) // This macro is defined if AVX-512 is supported
#include <immintrin.h>

#define SIMD_WIDTH 16
#define LOAD(x) _mm512_loadu_ps(x)
#define STORE(x, y) _mm512_storeu_ps(x, y)
#define LOAD(x) _mm512_load_ps(x)
#define STORE(x, y) _mm512_store_ps(x, y)
#define SET1(x) _mm512_set1_ps(x)
#define MULTIPLY(x, y) _mm512_mul_ps(x, y)
#define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z)
Expand Down Expand Up @@ -79,8 +79,8 @@ SIMDTYPE exp_ps_fill(SIMDTYPE x)
#ifdef __AVX2__
#include <immintrin.h>
#define SIMD_WIDTH 8
#define LOAD(x) _mm256_loadu_ps(x)
#define STORE(x, y) _mm256_storeu_ps(x, y)
#define LOAD(x) _mm256_load_ps(x)
#define STORE(x, y) _mm256_store_ps((float*)x, y)
#define SET1(x) _mm256_set1_ps(x)
#define MULTIPLY(x, y) _mm256_mul_ps(x, y)
#define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z)
Expand All @@ -93,6 +93,10 @@ SIMDTYPE exp_ps_fill(SIMDTYPE x)
#define MAX(x, y) _mm256_max_ps(x, y)
#define DIVIDE(x, y) _mm256_div_ps(x, y)
#define SIMDTYPE __m256
#if defined(__INTEL_LLVM_COMPILER)
#pragma message("AVX-2 exp is supported")
#define EXP(x) _mm256_exp_ps(x)
#else
#define EXP(x) exp_ps_fill(x)
SIMDTYPE exp_ps_fill(SIMDTYPE x)
{
Expand All @@ -103,6 +107,7 @@ SIMDTYPE exp_ps_fill(SIMDTYPE x)
}
return result;
}
#endif
// print out the SIMD width
#pragma message("AVX-2 is supported")

Expand Down
49 changes: 46 additions & 3 deletions rwkv.hpp/include/hvml/operations/avx512/matmul8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void dopartial(MatMulJob job) {
auto Ar = job.Ar;
auto bbt = job.bbt;
auto ii = job.ii;
#ifdef __AVX512F__
#if defined(__AVX512F__) && defined(HVMLUSEAVX512)
const auto Ario = _mm512_load_ps(Ar + ii);
const auto Aoioo = _mm512_div_ps(_mm512_load_ps(Ao + ii), Ario);
__m512 zz = _mm512_setzero_ps();
Expand All @@ -118,9 +118,52 @@ void dopartial(MatMulJob job) {
_mm512_store_ps(
(void *)(C + bbt * OUT + ii),
zz * Ario);
#endif
#elif defined(__AVX2__)
for (ulong b = 0; b < 16; b+= 8){
const auto Ario1 = LOAD(Ar + ii+b);
const auto Aoio1 = DIVIDE(LOAD(Ao + ii + b),Ario1);

auto zz1 = SET1(0.0);

for (uint32_t i = ii+b; i < ii + b+8; i += 1) {
auto Aoio = Aoio1[i&7];

const auto IAIN = A + i * IN;

auto sum1 = SET1(0.0);
auto sum2 = SET1(0.0);

for (uint32_t k = 0; k < IN; k += 16) {
// avx2
auto w = _mm256_cvtepu8_epi32(_mm_loadu_si128((__m128i *)(IAIN + k))); // Load the input uint8_t vector
// convert uint32_t to float32x8_t
auto u = _mm256_cvtepi32_ps(w)+Aoio; // Convert uint32_t to float32_t
// Load the input float vector
// Perform the multiplication with inp vector
sum1 = MULTADD(u, LOAD(B + bbt * IN + k),sum1);

auto w1 = _mm256_cvtepu8_epi32(_mm_loadu_si128((__m128i *)(IAIN + k + 8))); // Load the input uint8_t vector

auto u1 = _mm256_cvtepi32_ps(w1)+Aoio; // Convert uint32_t to float32_t

sum2 = MULTADD(u1, LOAD(B + bbt * IN + k + 8),sum2);

}

sum1 = sum1+sum2;

zz1[i&7]= REDUCE(sum1);


}


STORE(
(void *)(C + bbt * OUT + ii + b),
zz1 * Ario1);
}

#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
for (ulong b = 0; b < 16; b+= 4){
const auto Ario1 = LOAD(Ar + ii+b);
const auto Aoio1 = DIVIDE(LOAD(Ao + ii + b),Ario1);
Expand Down
1 change: 1 addition & 0 deletions rwkv.hpp/rwkv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "rwkv.hpp"
#include "sampler/sample.hpp"
#include "tokenizer/tokenizer.hpp"

int main( int argc, char** argv ){

std::cout << "Hello World" << std::endl;
Expand Down

0 comments on commit aaccf9e

Please sign in to comment.