Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Apply Updates #453

Merged
merged 1 commit into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 180 additions & 45 deletions src/nn.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,155 @@ int16_t OUTPUT_WEIGHTS[2 * N_HIDDEN] ALIGN;
int32_t OUTPUT_BIAS;

#if defined(__AVX512F__)
#define UNROLL 512
#define UNROLL 512
#define regi_t __m512i
#define regi_load _mm512_load_si512
#define regi_sub _mm512_sub_epi16
#define regi_add _mm512_add_epi16
#define regi_store _mm512_store_si512
#elif defined(__AVX2__)
#define UNROLL 256
#define UNROLL 256
#define regi_t __m256i
#define regi_load _mm256_load_si256
#define regi_sub _mm256_sub_epi16
#define regi_add _mm256_add_epi16
#define regi_store _mm256_store_si256
#else
#define UNROLL 128
#define UNROLL 128
#define regi_t __m128i
#define regi_load _mm_load_si128
#define regi_sub _mm_sub_epi16
#define regi_add _mm_add_epi16
#define regi_store _mm_store_si128
#endif

INLINE void ApplyFeature(acc_t* dest, acc_t* src, int f, const int add) {
for (size_t c = 0; c < N_HIDDEN; c += UNROLL)
for (size_t i = 0; i < UNROLL; i++)
dest[i + c] = src[i + c] + (2 * add - 1) * INPUT_WEIGHTS[f * N_HIDDEN + i + c];
#define NUM_REGS 16

INLINE void ApplyDelta(acc_t* dest, acc_t* src, Delta* delta) {
regi_t regs[NUM_REGS];

for (size_t c = 0; c < N_HIDDEN / UNROLL; ++c) {
const size_t unrollOffset = c * UNROLL;

const regi_t* inputs = (regi_t*) &src[unrollOffset];
regi_t* outputs = (regi_t*) &dest[unrollOffset];

for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_load(&inputs[i]);

for (size_t r = 0; r < delta->r; r++) {
const size_t offset = delta->rem[r] * N_HIDDEN + unrollOffset;
const regi_t* weights = (regi_t*) &INPUT_WEIGHTS[offset];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_sub(regs[i], weights[i]);
}

for (size_t a = 0; a < delta->a; a++) {
const size_t offset = delta->add[a] * N_HIDDEN + unrollOffset;
const regi_t* weights = (regi_t*) &INPUT_WEIGHTS[offset];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_add(regs[i], weights[i]);
}

for (size_t i = 0; i < NUM_REGS; i++)
regi_store(&outputs[i], regs[i]);
}
}

INLINE void ApplySubAdd(acc_t* dest, acc_t* src, int f1, int f2) {
regi_t regs[NUM_REGS];

for (size_t c = 0; c < N_HIDDEN / UNROLL; ++c) {
const size_t unrollOffset = c * UNROLL;

const regi_t* inputs = (regi_t*) &src[unrollOffset];
regi_t* outputs = (regi_t*) &dest[unrollOffset];

for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_load(&inputs[i]);

const size_t o1 = f1 * N_HIDDEN + unrollOffset;
const regi_t* w1 = (regi_t*) &INPUT_WEIGHTS[o1];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_sub(regs[i], w1[i]);

const size_t o2 = f2 * N_HIDDEN + unrollOffset;
const regi_t* w2 = (regi_t*) &INPUT_WEIGHTS[o2];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_add(regs[i], w2[i]);

for (size_t i = 0; i < NUM_REGS; i++)
regi_store(&outputs[i], regs[i]);
}
}

INLINE void ApplySubSubAdd(acc_t* dest, acc_t* src, int f1, int f2, int f3) {
regi_t regs[NUM_REGS];

for (size_t c = 0; c < N_HIDDEN / UNROLL; ++c) {
const size_t unrollOffset = c * UNROLL;

const regi_t* inputs = (regi_t*) &src[unrollOffset];
regi_t* outputs = (regi_t*) &dest[unrollOffset];

for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_load(&inputs[i]);

const size_t o1 = f1 * N_HIDDEN + unrollOffset;
const regi_t* w1 = (regi_t*) &INPUT_WEIGHTS[o1];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_sub(regs[i], w1[i]);

const size_t o2 = f2 * N_HIDDEN + unrollOffset;
const regi_t* w2 = (regi_t*) &INPUT_WEIGHTS[o2];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_sub(regs[i], w2[i]);

const size_t o3 = f3 * N_HIDDEN + unrollOffset;
const regi_t* w3 = (regi_t*) &INPUT_WEIGHTS[o3];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_add(regs[i], w3[i]);

for (size_t i = 0; i < NUM_REGS; i++)
regi_store(&outputs[i], regs[i]);
}
}

INLINE void ApplySubSubAddAdd(acc_t* dest, acc_t* src, int f1, int f2, int f3, int f4) {
regi_t regs[NUM_REGS];

for (size_t c = 0; c < N_HIDDEN / UNROLL; ++c) {
const size_t unrollOffset = c * UNROLL;

const regi_t* inputs = (regi_t*) &src[unrollOffset];
regi_t* outputs = (regi_t*) &dest[unrollOffset];

for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_load(&inputs[i]);

const size_t o1 = f1 * N_HIDDEN + unrollOffset;
const regi_t* w1 = (regi_t*) &INPUT_WEIGHTS[o1];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_sub(regs[i], w1[i]);

const size_t o2 = f2 * N_HIDDEN + unrollOffset;
const regi_t* w2 = (regi_t*) &INPUT_WEIGHTS[o2];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_sub(regs[i], w2[i]);

const size_t o3 = f3 * N_HIDDEN + unrollOffset;
const regi_t* w3 = (regi_t*) &INPUT_WEIGHTS[o3];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_add(regs[i], w3[i]);

const size_t o4 = f4 * N_HIDDEN + unrollOffset;
const regi_t* w4 = (regi_t*) &INPUT_WEIGHTS[o4];
for (size_t i = 0; i < NUM_REGS; i++)
regs[i] = regi_add(regs[i], w4[i]);

for (size_t i = 0; i < NUM_REGS; i++)
regi_store(&outputs[i], regs[i]);
}
}

int OutputLayer(acc_t* stm, acc_t* xstm) {
Expand Down Expand Up @@ -89,6 +227,9 @@ void ResetRefreshTable(AccumulatorKingState* refreshTable) {
// Refreshes an accumulator using a diff from the last known board state
// with proper king bucketing
void RefreshAccumulator(Accumulator* dest, Board* board, const int perspective) {
Delta delta[1];
delta->r = delta->a = 0;

int kingSq = LSB(PieceBB(KING, perspective));
int pBucket = perspective == WHITE ? 0 : 2 * N_KING_BUCKETS;
int kingBucket = KING_BUCKETS[kingSq ^ (56 * perspective)] + N_KING_BUCKETS * (File(kingSq) > 3);
Expand All @@ -103,71 +244,65 @@ void RefreshAccumulator(Accumulator* dest, Board* board, const int perspective)
BitBoard add = curr & ~prev;

while (rem) {
int sq = PopLSB(&rem);
ApplyFeature(state->values, state->values, FeatureIdx(pc, sq, kingSq, perspective), SUB);
int sq = PopLSB(&rem);
delta->rem[delta->r++] = FeatureIdx(pc, sq, kingSq, perspective);
}

while (add) {
int sq = PopLSB(&add);
ApplyFeature(state->values, state->values, FeatureIdx(pc, sq, kingSq, perspective), ADD);
int sq = PopLSB(&add);
delta->add[delta->a++] = FeatureIdx(pc, sq, kingSq, perspective);
}

state->pcs[pc] = curr;
}

ApplyDelta(state->values, state->values, delta);

// Copy in state
memcpy(dest->values[perspective], state->values, sizeof(acc_t) * N_HIDDEN);
}

// Resets an accumulator from pieces on the board
void ResetAccumulator(Accumulator* dest, Board* board, const int perspective) {
int kingSq = LSB(PieceBB(KING, perspective));
acc_t* values = dest->values[perspective];
Delta delta[1];
delta->r = delta->a = 0;

memcpy(values, INPUT_BIASES, sizeof(acc_t) * N_HIDDEN);
int kingSq = LSB(PieceBB(KING, perspective));

BitBoard occ = OccBB(BOTH);
while (occ) {
int sq = PopLSB(&occ);
int pc = board->squares[sq];
int feature = FeatureIdx(pc, sq, kingSq, perspective);

ApplyFeature(values, values, feature, ADD);
int sq = PopLSB(&occ);
int pc = board->squares[sq];
delta->add[delta->a++] = FeatureIdx(pc, sq, kingSq, perspective);
}
}

void ApplyUpdates(Board* board, Move move, int captured, const int view) {
int16_t* output = board->accumulators->values[view];
int16_t* prev = (board->accumulators - 1)->values[view];

const int king = LSB(PieceBB(KING, view));

int f = FeatureIdx(Moving(move), From(move), king, view);
ApplyFeature(output, prev, f, SUB);
acc_t* values = dest->values[perspective];
memcpy(values, INPUT_BIASES, sizeof(acc_t) * N_HIDDEN);
ApplyDelta(values, values, delta);
}

int endPc = !Promo(move) ? Moving(move) : Promo(move);
f = FeatureIdx(endPc, To(move), king, view);
ApplyFeature(output, output, f, ADD);
void ApplyUpdates(Board* board, const Move move, const int captured, const int view) {
acc_t* output = board->accumulators->values[view];
acc_t* prev = (board->accumulators - 1)->values[view];

if (IsCap(move)) {
int movingSide = Moving(move) & 1;
int capturedSq = IsEP(move) ? To(move) - PawnDir(movingSide) : To(move);
f = FeatureIdx(captured, capturedSq, king, view);
const int king = LSB(PieceBB(KING, view));
const int movingSide = Moving(move) & 1;

ApplyFeature(output, output, f, SUB);
}
int from = FeatureIdx(Moving(move), From(move), king, view);
int to = FeatureIdx(Promo(move) ?: Moving(move), To(move), king, view);

if (IsCas(move)) {
int movingSide = Moving(move) & 1;
int rook = Piece(ROOK, movingSide);
int rookFrom = FeatureIdx(Piece(ROOK, movingSide), board->cr[CASTLING_ROOK[To(move)]], king, view);
int rookTo = FeatureIdx(Piece(ROOK, movingSide), CASTLE_ROOK_DEST[To(move)], king, view);

int rookFrom = board->cr[CASTLING_ROOK[To(move)]];
int rookTo = CASTLE_ROOK_DEST[To(move)];
ApplySubSubAddAdd(output, prev, from, rookFrom, to, rookTo);
} else if (IsCap(move)) {
int capSq = IsEP(move) ? To(move) - PawnDir(movingSide) : To(move);
int capturedTo = FeatureIdx(captured, capSq, king, view);

f = FeatureIdx(rook, rookFrom, king, view);
ApplyFeature(output, output, f, SUB);
f = FeatureIdx(rook, rookTo, king, view);
ApplyFeature(output, output, f, ADD);
ApplySubSubAdd(output, prev, from, capturedTo, to);
} else {
ApplySubAdd(output, prev, from, to);
}
}

Expand Down
8 changes: 8 additions & 0 deletions src/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

#include <immintrin.h>

#include "board.h"
#include "types.h"
#include "util.h"
Expand All @@ -30,3 +32,9 @@ void ApplyUpdates(Board* board, Move move, int captured, const int view);

void LoadDefaultNN();
int LoadNetwork(char* path);

typedef struct {
uint8_t r, a;
int rem[32];
int add[32];
} Delta;