Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Between search optimizations #418

Merged
merged 3 commits into from
Nov 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/bench.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ char* benchmarks[] = {
void Bench() {
Board board;
SearchParams params = {.depth = 13, .multiPV = 1, .hitrate = 1000, .max = INT_MAX};
ThreadData* threads = CreatePool(1);

CreatePool(1);

Move bestMoves[NUM_BENCH_POSITIONS];
int scores[NUM_BENCH_POSITIONS];
Expand All @@ -50,11 +51,11 @@ void Bench() {
ParseFen(benchmarks[i], &board);

TTClear();
ResetThreadPool(threads);
InitPool(&board, &params, threads);
ResetThreadPool();
InitPool(&board, &params);

params.start = GetTimeMS();
BestMove(&board, &params, threads);
BestMove(&board, &params);
times[i] = GetTimeMS() - params.start;

SearchResults* results = &threads->results;
Expand Down
4 changes: 2 additions & 2 deletions src/nn.c
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ int Predict(Board* board) {
return OutputLayer(stm, xstm);
}

void ResetRefreshTable(Board* board) {
void ResetRefreshTable(AccumulatorKingState* refreshTable[2]) {
for (int c = WHITE; c <= BLACK; c++) {
for (int b = 0; b < 2 * N_KING_BUCKETS; b++) {
AccumulatorKingState* state = &board->refreshTable[c][b];
AccumulatorKingState* state = &refreshTable[c][b];
memcpy(state->values, INPUT_BIASES, sizeof(int16_t) * N_HIDDEN);
memset(state->pcs, 0, sizeof(BitBoard) * 12);
}
Expand Down
2 changes: 1 addition & 1 deletion src/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

int Predict(Board* board);
int OutputLayer(Accumulator stm, Accumulator xstm);
void ResetRefreshTable(Board* board);
void ResetRefreshTable(AccumulatorKingState* refreshTable[2]);
void RefreshAccumulator(Accumulator accumulator, Board* board, const int perspective);
void ResetAccumulator(Accumulator output, Board* board, const int perspective);

Expand Down
15 changes: 6 additions & 9 deletions src/search.c
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,22 @@ void* UCISearch(void* arg) {

Board* board = args->board;
SearchParams* params = args->params;
ThreadData* threads = args->threads;

BestMove(board, params, threads);
BestMove(board, params);

free(args);
return NULL;
}

void BestMove(Board* board, SearchParams* params, ThreadData* threads) {
void BestMove(Board* board, SearchParams* params) {
Move bestMove;
if ((bestMove = TBRootProbe(board))) {
while (PONDERING)
;

printf("bestmove %s\n", MoveToStr(bestMove, board));
} else {
pthread_t pthreads[threads->count];
InitPool(board, params, threads);
InitPool(board, params);

params->stopped = 0;
TTUpdate();
Expand Down Expand Up @@ -133,9 +131,8 @@ void* Search(void* arg) {
int beta = CHECKMATE;

board->acc = 0;
ResetRefreshTable(board);
RefreshAccumulator(board->accumulators[WHITE][board->acc], board, WHITE);
RefreshAccumulator(board->accumulators[BLACK][board->acc], board, BLACK);
RefreshAccumulator(board->accumulators[WHITE][0], board, WHITE);
RefreshAccumulator(board->accumulators[BLACK][0], board, BLACK);

data->contempt[WHITE] = data->contempt[BLACK] = 0;
SetContempt(data->contempt, board->stm);
Expand Down Expand Up @@ -482,7 +479,7 @@ int Negamax(int alpha, int beta, int depth, int cutnode, ThreadData* thread, PV*

while ((move = NextMove(&moves, board, skipQuiets))) {
int64_t startingNodeCount = data->nodes;

if (isRoot && MoveSearchedByMultiPV(thread, move)) continue;
if (isRoot && !MoveSearchable(params, move)) continue;

Expand Down
2 changes: 1 addition & 1 deletion src/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
void InitPruningAndReductionTables();

void* UCISearch(void* arg);
void BestMove(Board* board, SearchParams* params, ThreadData* threads);
void BestMove(Board* board, SearchParams* params);
void* Search(void* arg);
int Negamax(int alpha, int beta, int depth, int cutnode, ThreadData* thread, PV* pv);
int Quiesce(int alpha, int beta, ThreadData* thread);
Expand Down
54 changes: 32 additions & 22 deletions src/thread.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@

#include "thread.h"

#include <pthread.h>
#include <stdlib.h>
#include <string.h>

#include "eval.h"
#include "nn.h"
#include "search.h"
#include "types.h"
#include "util.h"

ThreadData* threads = NULL;
pthread_t* pthreads = NULL;

void* AlignedMalloc(int size) {
void* mem = malloc(size + ALIGN_ON + sizeof(void*));
void** ptr = (void**) ((uintptr_t) (mem + ALIGN_ON + sizeof(void*)) & ~(ALIGN_ON - 1));
Expand All @@ -36,24 +41,32 @@ void AlignedFree(void* ptr) {
}

// initialize a pool of threads
ThreadData* CreatePool(int count) {
ThreadData* threads = malloc(count * sizeof(ThreadData));
void CreatePool(int count) {
if (threads) FreeThreads();
if (pthreads) free(pthreads);

threads = calloc(count, sizeof(ThreadData));
pthreads = calloc(count, sizeof(pthread_t));

for (int i = 0; i < count; i++) {
// allow reference to one another
threads[i].idx = i;
threads[i].threads = threads;
threads[i].count = count;
threads[i].results.depth = 0;
threads[i].accumulators[WHITE] = (Accumulator*) AlignedMalloc(sizeof(Accumulator) * (MAX_SEARCH_PLY + 1));
threads[i].accumulators[BLACK] = (Accumulator*) AlignedMalloc(sizeof(Accumulator) * (MAX_SEARCH_PLY + 1));
}
threads[i].refreshTable[WHITE] =
(AccumulatorKingState*) AlignedMalloc(sizeof(AccumulatorKingState) * 2 * N_KING_BUCKETS);
threads[i].refreshTable[BLACK] =
(AccumulatorKingState*) AlignedMalloc(sizeof(AccumulatorKingState) * 2 * N_KING_BUCKETS);

return threads;
ResetRefreshTable(threads[i].refreshTable);

threads[i].data.moves = &threads[i].data.searchMoves[2];
}
}

// initialize a pool prepping to start a search
void InitPool(Board* board, SearchParams* params, ThreadData* threads) {
void InitPool(Board* board, SearchParams* params) {
for (int i = 0; i < threads->count; i++) {
threads[i].params = params;

Expand All @@ -66,27 +79,18 @@ void InitPool(Board* board, SearchParams* params, ThreadData* threads) {
threads[i].data.ply = 0;
threads[i].data.tbhits = 0;

// empty unneeded data
memset(&threads[i].data.skipMove, 0, sizeof(threads[i].data.skipMove));
memset(&threads[i].data.evals, 0, sizeof(threads[i].data.evals));
memset(&threads[i].data.tm, 0, sizeof(threads[i].data.tm));
memset(&threads[i].data.de, 0, sizeof(threads[i].data.de));

// set the moves arr as an offset of 2
threads[i].data.moves = &threads[i].data.searchMoves[2];

memset(&threads[i].scores, 0, sizeof(threads[i].scores));
memset(&threads[i].bestMoves, 0, sizeof(threads[i].bestMoves));
memset(&threads[i].pvs, 0, sizeof(threads[i].pvs));

// need full copies of the board
memcpy(&threads[i].board, board, sizeof(Board));
threads[i].board.accumulators[WHITE] = threads[i].accumulators[WHITE];
threads[i].board.accumulators[BLACK] = threads[i].accumulators[BLACK];
threads[i].board.refreshTable[WHITE] = threads[i].refreshTable[WHITE];
threads[i].board.refreshTable[BLACK] = threads[i].refreshTable[BLACK];
}
}

void ResetThreadPool(ThreadData* threads) {
void ResetThreadPool() {
for (int i = 0; i < threads->count; i++) {
threads[i].results.depth = 0;

Expand All @@ -104,27 +108,33 @@ void ResetThreadPool(ThreadData* threads) {
memset(&threads[i].data.hh, 0, sizeof(threads[i].data.hh));
memset(&threads[i].data.ch, 0, sizeof(threads[i].data.ch));
memset(&threads[i].data.th, 0, sizeof(threads[i].data.th));

memset(&threads[i].scores, 0, sizeof(threads[i].scores));
memset(&threads[i].bestMoves, 0, sizeof(threads[i].bestMoves));
memset(&threads[i].pvs, 0, sizeof(threads[i].pvs));
}
}

void FreeThreads(ThreadData* threads) {
void FreeThreads() {
for (int i = 0; i < threads->count; i++) {
AlignedFree(threads[i].accumulators[WHITE]);
AlignedFree(threads[i].accumulators[BLACK]);
AlignedFree(threads[i].refreshTable[WHITE]);
AlignedFree(threads[i].refreshTable[BLACK]);
}

free(threads);
}

// sum node counts
uint64_t NodesSearched(ThreadData* threads) {
uint64_t NodesSearched() {
uint64_t nodes = 0;
for (int i = 0; i < threads->count; i++) nodes += threads[i].data.nodes;

return nodes;
}

uint64_t TBHits(ThreadData* threads) {
uint64_t TBHits() {
uint64_t tbhits = 0;
for (int i = 0; i < threads->count; i++) tbhits += threads[i].data.tbhits;

Expand Down
17 changes: 11 additions & 6 deletions src/thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@

#include "types.h"

ThreadData* CreatePool(int count);
void InitPool(Board* board, SearchParams* params, ThreadData* threads);
void ResetThreadPool(ThreadData* threads);
void FreeThreads(ThreadData* threads);
uint64_t NodesSearched(ThreadData* threads);
uint64_t TBHits(ThreadData* threads);
#include <pthread.h>

extern ThreadData* threads;
extern pthread_t* pthreads;

void CreatePool(int count);
void InitPool(Board* board, SearchParams* params);
void ResetThreadPool();
void FreeThreads();
uint64_t NodesSearched();
uint64_t TBHits();

#endif
6 changes: 2 additions & 4 deletions src/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ typedef struct {
Accumulator values;
} ALIGN AccumulatorKingState;

typedef AccumulatorKingState AccumulatorRefreshTable[2][2 * N_KING_BUCKETS] ALIGN;

typedef struct {
BitBoard pcs, sqs;
} Threat;
Expand All @@ -69,7 +67,7 @@ typedef struct {
uint64_t zobrist; // zobrist hash of the position

Accumulator* accumulators[2];
AccumulatorRefreshTable refreshTable;
AccumulatorKingState* refreshTable[2];

int squares[64]; // piece per square
BitBoard occupancies[3]; // 0 - white pieces, 1 - black pieces, 2 - both
Expand Down Expand Up @@ -163,6 +161,7 @@ struct ThreadData {
int count, idx, multiPV, depth;

Accumulator* accumulators[2];
AccumulatorKingState* refreshTable[2];

ThreadData* threads;
jmp_buf exit;
Expand All @@ -181,7 +180,6 @@ struct ThreadData {
typedef struct {
Board* board;
SearchParams* params;
ThreadData* threads;
} SearchArgs;

// Move generation storage
Expand Down
22 changes: 10 additions & 12 deletions src/uci.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void RootMoves(SimpleMoveList* moves, Board* board) {
}

// uci "go" command
void ParseGo(char* in, SearchParams* params, Board* board, ThreadData* threads) {
void ParseGo(char* in, SearchParams* params, Board* board) {
in += 3;

params->depth = MAX_SEARCH_PLY;
Expand Down Expand Up @@ -170,7 +170,6 @@ void ParseGo(char* in, SearchParams* params, Board* board, ThreadData* threads)
SearchArgs* args = malloc(sizeof(SearchArgs));
args->board = board;
args->params = params;
args->threads = threads;

// start the search!
pthread_t searchThread;
Expand All @@ -179,7 +178,7 @@ void ParseGo(char* in, SearchParams* params, Board* board, ThreadData* threads)
}

// uci "position" command
void ParsePosition(char* in, Board* board, ThreadData* threads) {
void ParsePosition(char* in, Board* board) {
in += 9;
char* ptrChar = in;

Expand Down Expand Up @@ -248,7 +247,7 @@ void UCILoop() {
Board board;
ParseFen(START_FEN, &board);

ThreadData* threads = CreatePool(1);
CreatePool(1);
SearchParams searchParameters = {.quit = 0};

setbuf(stdin, NULL);
Expand All @@ -260,13 +259,13 @@ void UCILoop() {
if (!strncmp(in, "isready", 7)) {
printf("readyok\n");
} else if (!strncmp(in, "position", 8)) {
ParsePosition(in, &board, threads);
ParsePosition(in, &board);
} else if (!strncmp(in, "ucinewgame", 10)) {
ParsePosition("position startpos\n", &board, threads);
ParsePosition("position startpos\n", &board);
TTClear();
ResetThreadPool(threads);
ResetThreadPool();
} else if (!strncmp(in, "go", 2)) {
ParseGo(in, &searchParameters, &board, threads);
ParseGo(in, &searchParameters, &board);
} else if (!strncmp(in, "stop", 4)) {
PONDERING = 0;
searchParameters.stopped = 1;
Expand Down Expand Up @@ -354,8 +353,7 @@ void UCILoop() {
printf("info string set Hash to value %d (%" PRId64 " bytes)\n", mb, bytesAllocated);
} else if (!strncmp(in, "setoption name Threads value ", 29)) {
int n = GetOptionIntValue(in);
FreeThreads(threads);
threads = CreatePool(max(1, min(256, n)));
CreatePool(max(1, min(256, n)));
printf("info string set Threads to value %d\n", n);
} else if (!strncmp(in, "setoption name SyzygyPath value ", 32)) {
int success = tb_init(in + 32);
Expand All @@ -382,9 +380,9 @@ void UCILoop() {
printf("info string set UCI_Chess960 to value %s\n", CHESS_960 ? "true" : "false");
printf("info string Resetting board...\n");

ParsePosition("position startpos\n", &board, threads);
ParsePosition("position startpos\n", &board);
TTClear();
ResetThreadPool(threads);
ResetThreadPool();
} else if (!strncmp(in, "setoption name MoveOverhead value ", 34)) {
MOVE_OVERHEAD = min(10000, max(100, GetOptionIntValue(in)));
} else if (!strncmp(in, "setoption name Contempt value ", 30)) {
Expand Down
4 changes: 2 additions & 2 deletions src/uci.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ extern int CONTEMPT;

void RootMoves(SimpleMoveList* moves, Board* board);

void ParseGo(char* in, SearchParams* params, Board* board, ThreadData* threads);
void ParsePosition(char* in, Board* board, ThreadData* threads);
void ParseGo(char* in, SearchParams* params, Board* board);
void ParsePosition(char* in, Board* board);
void PrintUCIOptions();

int ReadLine(char* in);
Expand Down