Skip to content

Commit

Permalink
Optimize Apply Updates (#453)
Browse files Browse the repository at this point in the history
Bench: 4873539

https://godbolt.org/z/PcMxTWf7c

STC

ELO   | 12.00 +- 5.56 (95%)
SPRT  | 10.0+0.10s Threads=1 Hash=8MB
LLR   | 2.94 (-2.94, 2.94) [0.00, 3.00]
GAMES | N: 7128 W: 1821 L: 1575 D: 3732
  • Loading branch information
jhonnold authored Jan 23, 2023
1 parent 7f804cd commit 181229c
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 45 deletions.
225 changes: 180 additions & 45 deletions src/nn.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,155 @@ int16_t OUTPUT_WEIGHTS[2 * N_HIDDEN] ALIGN;
int32_t OUTPUT_BIAS;

#if defined(__AVX512F__)
#define UNROLL 512
#define UNROLL 512
#define regi_t __m512i
#define regi_load _mm512_load_si512
#define regi_sub _mm512_sub_epi16
#define regi_add _mm512_add_epi16
#define regi_store _mm512_store_si512
#elif defined(__AVX2__)
#define UNROLL 256
#define UNROLL 256
#define regi_t __m256i
#define regi_load _mm256_load_si256
#define regi_sub _mm256_sub_epi16
#define regi_add _mm256_add_epi16
#define regi_store _mm256_store_si256
#else
#define UNROLL 128
#define UNROLL 128
#define regi_t __m128i
#define regi_load _mm_load_si128
#define regi_sub _mm_sub_epi16
#define regi_add _mm_add_epi16
#define regi_store _mm_store_si128
#endif

INLINE void ApplyFeature(acc_t* dest, acc_t* src, int f, const int add) {
for (size_t c = 0; c < N_HIDDEN; c += UNROLL)
for (size_t i = 0; i < UNROLL; i++)
dest[i + c] = src[i + c] + (2 * add - 1) * INPUT_WEIGHTS[f * N_HIDDEN + i + c];
#define NUM_REGS 16

INLINE void ApplyDelta(acc_t* dest, acc_t* src, Delta* delta) {
regi_t regs[NUM_REGS];

for (size_t c = 0; c < N_HIDDEN / UNROLL; ++c) {
const size_t unrollOffset = c * UNROLL;

const regi_t* inputs = (regi_t*) &src[unrollOffset];
regi_t* outputs = (regi_t*) &dest[unrollOffset];

for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_load(&inputs[i]);

for (size_t r = 0; r < delta->r; r++) {
const size_t offset = delta->rem[r] * N_HIDDEN + unrollOffset;
const regi_t* weights = (regi_t*) &INPUT_WEIGHTS[offset];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_sub(regs[i], weights[i]);
}

for (size_t a = 0; a < delta->a; a++) {
const size_t offset = delta->add[a] * N_HIDDEN + unrollOffset;
const regi_t* weights = (regi_t*) &INPUT_WEIGHTS[offset];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_add(regs[i], weights[i]);
}

for (size_t i = 0; i < NUM_REGS; i++)
regi_store(&outputs[i], regs[i]);
}
}

INLINE void ApplySubAdd(acc_t* dest, acc_t* src, int f1, int f2) {
regi_t regs[NUM_REGS];

for (size_t c = 0; c < N_HIDDEN / UNROLL; ++c) {
const size_t unrollOffset = c * UNROLL;

const regi_t* inputs = (regi_t*) &src[unrollOffset];
regi_t* outputs = (regi_t*) &dest[unrollOffset];

for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_load(&inputs[i]);

const size_t o1 = f1 * N_HIDDEN + unrollOffset;
const regi_t* w1 = (regi_t*) &INPUT_WEIGHTS[o1];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_sub(regs[i], w1[i]);

const size_t o2 = f2 * N_HIDDEN + unrollOffset;
const regi_t* w2 = (regi_t*) &INPUT_WEIGHTS[o2];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_add(regs[i], w2[i]);

for (size_t i = 0; i < NUM_REGS; i++)
regi_store(&outputs[i], regs[i]);
}
}

INLINE void ApplySubSubAdd(acc_t* dest, acc_t* src, int f1, int f2, int f3) {
regi_t regs[NUM_REGS];

for (size_t c = 0; c < N_HIDDEN / UNROLL; ++c) {
const size_t unrollOffset = c * UNROLL;

const regi_t* inputs = (regi_t*) &src[unrollOffset];
regi_t* outputs = (regi_t*) &dest[unrollOffset];

for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_load(&inputs[i]);

const size_t o1 = f1 * N_HIDDEN + unrollOffset;
const regi_t* w1 = (regi_t*) &INPUT_WEIGHTS[o1];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_sub(regs[i], w1[i]);

const size_t o2 = f2 * N_HIDDEN + unrollOffset;
const regi_t* w2 = (regi_t*) &INPUT_WEIGHTS[o2];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_sub(regs[i], w2[i]);

const size_t o3 = f3 * N_HIDDEN + unrollOffset;
const regi_t* w3 = (regi_t*) &INPUT_WEIGHTS[o3];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_add(regs[i], w3[i]);

for (size_t i = 0; i < NUM_REGS; i++)
regi_store(&outputs[i], regs[i]);
}
}

INLINE void ApplySubSubAddAdd(acc_t* dest, acc_t* src, int f1, int f2, int f3, int f4) {
regi_t regs[NUM_REGS];

for (size_t c = 0; c < N_HIDDEN / UNROLL; ++c) {
const size_t unrollOffset = c * UNROLL;

const regi_t* inputs = (regi_t*) &src[unrollOffset];
regi_t* outputs = (regi_t*) &dest[unrollOffset];

for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_load(&inputs[i]);

const size_t o1 = f1 * N_HIDDEN + unrollOffset;
const regi_t* w1 = (regi_t*) &INPUT_WEIGHTS[o1];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_sub(regs[i], w1[i]);

const size_t o2 = f2 * N_HIDDEN + unrollOffset;
const regi_t* w2 = (regi_t*) &INPUT_WEIGHTS[o2];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_sub(regs[i], w2[i]);

const size_t o3 = f3 * N_HIDDEN + unrollOffset;
const regi_t* w3 = (regi_t*) &INPUT_WEIGHTS[o3];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_add(regs[i], w3[i]);

const size_t o4 = f4 * N_HIDDEN + unrollOffset;
const regi_t* w4 = (regi_t*) &INPUT_WEIGHTS[o4];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_add(regs[i], w4[i]);

for (size_t i = 0; i < NUM_REGS; i++)
regi_store(&outputs[i], regs[i]);
}
}

int OutputLayer(acc_t* stm, acc_t* xstm) {
Expand Down Expand Up @@ -89,6 +227,9 @@ void ResetRefreshTable(AccumulatorKingState* refreshTable) {
// Refreshes an accumulator using a diff from the last known board state
// with proper king bucketing
void RefreshAccumulator(Accumulator* dest, Board* board, const int perspective) {
Delta delta[1];
delta->r = delta->a = 0;

int kingSq = LSB(PieceBB(KING, perspective));
int pBucket = perspective == WHITE ? 0 : 2 * N_KING_BUCKETS;
int kingBucket = KING_BUCKETS[kingSq ^ (56 * perspective)] + N_KING_BUCKETS * (File(kingSq) > 3);
Expand All @@ -103,71 +244,65 @@ void RefreshAccumulator(Accumulator* dest, Board* board, const int perspective)
BitBoard add = curr & ~prev;

while (rem) {
int sq = PopLSB(&rem);
ApplyFeature(state->values, state->values, FeatureIdx(pc, sq, kingSq, perspective), SUB);
int sq = PopLSB(&rem);
delta->rem[delta->r++] = FeatureIdx(pc, sq, kingSq, perspective);
}

while (add) {
int sq = PopLSB(&add);
ApplyFeature(state->values, state->values, FeatureIdx(pc, sq, kingSq, perspective), ADD);
int sq = PopLSB(&add);
delta->add[delta->a++] = FeatureIdx(pc, sq, kingSq, perspective);
}

state->pcs[pc] = curr;
}

ApplyDelta(state->values, state->values, delta);

// Copy in state
memcpy(dest->values[perspective], state->values, sizeof(acc_t) * N_HIDDEN);
}

// Resets an accumulator from pieces on the board
void ResetAccumulator(Accumulator* dest, Board* board, const int perspective) {
int kingSq = LSB(PieceBB(KING, perspective));
acc_t* values = dest->values[perspective];
Delta delta[1];
delta->r = delta->a = 0;

memcpy(values, INPUT_BIASES, sizeof(acc_t) * N_HIDDEN);
int kingSq = LSB(PieceBB(KING, perspective));

BitBoard occ = OccBB(BOTH);
while (occ) {
int sq = PopLSB(&occ);
int pc = board->squares[sq];
int feature = FeatureIdx(pc, sq, kingSq, perspective);

ApplyFeature(values, values, feature, ADD);
int sq = PopLSB(&occ);
int pc = board->squares[sq];
delta->add[delta->a++] = FeatureIdx(pc, sq, kingSq, perspective);
}
}

void ApplyUpdates(Board* board, Move move, int captured, const int view) {
int16_t* output = board->accumulators->values[view];
int16_t* prev = (board->accumulators - 1)->values[view];

const int king = LSB(PieceBB(KING, view));

int f = FeatureIdx(Moving(move), From(move), king, view);
ApplyFeature(output, prev, f, SUB);
acc_t* values = dest->values[perspective];
memcpy(values, INPUT_BIASES, sizeof(acc_t) * N_HIDDEN);
ApplyDelta(values, values, delta);
}

int endPc = !Promo(move) ? Moving(move) : Promo(move);
f = FeatureIdx(endPc, To(move), king, view);
ApplyFeature(output, output, f, ADD);
void ApplyUpdates(Board* board, const Move move, const int captured, const int view) {
acc_t* output = board->accumulators->values[view];
acc_t* prev = (board->accumulators - 1)->values[view];

if (IsCap(move)) {
int movingSide = Moving(move) & 1;
int capturedSq = IsEP(move) ? To(move) - PawnDir(movingSide) : To(move);
f = FeatureIdx(captured, capturedSq, king, view);
const int king = LSB(PieceBB(KING, view));
const int movingSide = Moving(move) & 1;

ApplyFeature(output, output, f, SUB);
}
int from = FeatureIdx(Moving(move), From(move), king, view);
int to = FeatureIdx(Promo(move) ?: Moving(move), To(move), king, view);

if (IsCas(move)) {
int movingSide = Moving(move) & 1;
int rook = Piece(ROOK, movingSide);
int rookFrom = FeatureIdx(Piece(ROOK, movingSide), board->cr[CASTLING_ROOK[To(move)]], king, view);
int rookTo = FeatureIdx(Piece(ROOK, movingSide), CASTLE_ROOK_DEST[To(move)], king, view);

int rookFrom = board->cr[CASTLING_ROOK[To(move)]];
int rookTo = CASTLE_ROOK_DEST[To(move)];
ApplySubSubAddAdd(output, prev, from, rookFrom, to, rookTo);
} else if (IsCap(move)) {
int capSq = IsEP(move) ? To(move) - PawnDir(movingSide) : To(move);
int capturedTo = FeatureIdx(captured, capSq, king, view);

f = FeatureIdx(rook, rookFrom, king, view);
ApplyFeature(output, output, f, SUB);
f = FeatureIdx(rook, rookTo, king, view);
ApplyFeature(output, output, f, ADD);
ApplySubSubAdd(output, prev, from, capturedTo, to);
} else {
ApplySubAdd(output, prev, from, to);
}
}

Expand Down
8 changes: 8 additions & 0 deletions src/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

#include <immintrin.h>

#include "board.h"
#include "types.h"
#include "util.h"
Expand All @@ -30,3 +32,9 @@ void ApplyUpdates(Board* board, Move move, int captured, const int view);

void LoadDefaultNN();
int LoadNetwork(char* path);

typedef struct {
uint8_t r, a;
int rem[32];
int add[32];
} Delta;

0 comments on commit 181229c

Please sign in to comment.