Skip to content

Commit

Permalink
progress towards ai rl q-learning bot; lots more to do
Browse files Browse the repository at this point in the history
  • Loading branch information
r3w0p committed Aug 19, 2024
1 parent 02fe9af commit 68c434d
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 2 deletions.
29 changes: 27 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,15 @@ target_link_libraries(model core)

add_library(user
"include/caravan/user/user.h"
"include/caravan/user/bot/ai.h"
"include/caravan/user/bot/factory.h"
"include/caravan/user/bot/normal.h"
"include/caravan/user/bot/friendly.h"
"include/caravan/user/bot/normal.h"

"src/caravan/user/bot/ai.cpp"
"src/caravan/user/bot/factory.cpp"
"src/caravan/user/bot/normal.cpp"
"src/caravan/user/bot/friendly.cpp"
"src/caravan/user/bot/normal.cpp"
)

add_library(view
Expand Down Expand Up @@ -125,6 +127,29 @@ target_link_libraries(caravan
# ---


# --- train.exe
add_executable(train
"src/caravan/train.cpp"
)

target_compile_definitions(train
PRIVATE CARAVAN_NAME="${PROJECT_NAME}"
PRIVATE CARAVAN_VERSION="${PROJECT_VERSION}"
PRIVATE CARAVAN_DESCRIPTION="${PROJECT_DESCRIPTION}"
PRIVATE CARAVAN_COPYRIGHT="${PROJECT_COPYRIGHT}"
PRIVATE CARAVAN_URL="${PROJECT_URL}"
)

target_link_libraries(train
PRIVATE core
PRIVATE model
PRIVATE user
PRIVATE view
PRIVATE cxxopts
)
# ---


# --- test.exe
enable_testing()

Expand Down
8 changes: 8 additions & 0 deletions include/caravan/core/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ typedef struct GameCommand {
Card board{};
} GameCommand;

typedef struct TrainConfig {
float discount{0.0};
float explore{0.0};
float learning{0.0};
uint32_t episode_max{0};
uint32_t episode{0};
} TrainConfig;

/*
* FUNCTIONS
*/
Expand Down
20 changes: 20 additions & 0 deletions include/caravan/user/bot/ai.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) 2022-2024 r3w0p
// The following code can be redistributed and/or
// modified under the terms of the GPL-3.0 License.

#ifndef CARAVAN_USER_BOT_AI_H
#define CARAVAN_USER_BOT_AI_H

#include "caravan/user/user.h"

class UserBotAI : public UserBot {
protected:
bool train;
public:
explicit UserBotAI(PlayerName pn, bool train);

std::string request_move(Game *game) override;
std::string request_move_train(Game *game, TrainConfig *tc);
};

#endif //CARAVAN_USER_BOT_AI_H
117 changes: 117 additions & 0 deletions src/caravan/train.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) 2022-2024 r3w0p
// The following code can be redistributed and/or
// modified under the terms of the GPL-3.0 License.

#include <iostream>
#include <chrono>
#include <random>
#include <vector>
#include <algorithm>
#include "cxxopts.hpp"
#include "caravan/user/bot/ai.h"

const std::string OPTS_HELP = "h,help";

const std::string KEY_HELP = "help";

const uint8_t FIRST_ABC = 1;
const uint8_t FIRST_DEF = 2;

int main(int argc, char *argv[]) {
UserBotAI *user_abc;
UserBotAI *user_def;
UserBotAI *user_turn;
Game *game;
GameConfig gc;
TrainConfig tc;

std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> distr_first(FIRST_ABC, FIRST_DEF);

uint8_t rand_first;

try {
cxxopts::Options options(CARAVAN_NAME);

options.add_options()
(OPTS_HELP, "Print help instructions.")
;

auto result = options.parse(argc, argv);

float discount = 0.95;
float explore = 0.9;
float learning = 0.75;
uint32_t episode_max = 10;

gc = {
.player_abc_cards = DECK_CARAVAN_MAX,
.player_abc_samples = SAMPLE_DECKS_MAX,
.player_abc_balanced = true,
.player_def_cards = DECK_CARAVAN_MAX,
.player_def_samples = SAMPLE_DECKS_MAX,
.player_def_balanced = true
};

tc = {
.episode_max = episode_max,
.episode = 1
};

user_abc = new UserBotAI(PLAYER_ABC, true);
user_def = new UserBotAI(PLAYER_DEF, true);

for(; tc.episode <= tc.episode_max; tc.episode++) {
// Random first player
rand_first = distr_first(gen);
gc.player_first = rand_first == FIRST_ABC ? PLAYER_ABC : PLAYER_DEF;
user_turn = rand_first == FIRST_ABC ? user_abc : user_def;

// Set training parameters
tc.discount = discount;
tc.explore = (float) (tc.episode_max - (tc.episode - 1)) / (float) tc.episode_max;
tc.learning = learning;

// Start a new game
game = new Game(&gc);

// Take turns until a winner is declared
while(game->get_winner() != NO_PLAYER) {
// TODO borrow logic from other bot to determine if move is
// valid or not; use this to narrow down possible moves that
// can be made per game state; any issues with move when
// passing to game should result in fatal exception
user_turn->request_move_train(game, &tc);

// TODO convert string move to command: take functions to do
// this out of view tui; have single function to make this
// conversion and then pass it to the game

if(user_turn->get_name() == PLAYER_ABC) {
user_turn = user_def;
} else {
user_turn = user_abc;
}
}

// Finish game
game->close();
delete game;
}

} catch (CaravanException &e) {
printf("%s\n", e.what().c_str());
exit(EXIT_FAILURE);

} catch (std::exception &e) {
printf("%s\n", e.what());
exit(EXIT_FAILURE);
}

user_abc->close();
user_def->close();

delete user_abc;
delete user_def;
}
23 changes: 23 additions & 0 deletions src/caravan/user/bot/ai.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) 2022-2024 r3w0p
// The following code can be redistributed and/or
// modified under the terms of the GPL-3.0 License.

#include "caravan/user/bot/ai.h"

UserBotAI::UserBotAI(PlayerName pn, bool train) : UserBot(pn), train(train) {
// TODO setup for train vs not train
}

std::string UserBotAI::request_move(Game *game) {
if (closed) { throw CaravanFatalException("Bot is closed."); }
if (train) { throw CaravanFatalException("Bot is in training mode."); }

return "D1"; // TODO
}

std::string UserBotAI::request_move_train(Game *game, TrainConfig *tc) {
if (closed) { throw CaravanFatalException("Bot is closed."); }
if (!train) { throw CaravanFatalException("Bot is not in training mode."); }

return "D1"; // TODO
}
3 changes: 3 additions & 0 deletions src/caravan/user/bot/factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
#include "caravan/user/bot/factory.h"
#include "caravan/user/bot/normal.h"
#include "caravan/user/bot/friendly.h"
#include "caravan/user/bot/ai.h"

const std::string NAME_NORMAL = "normal";
const std::string NAME_FRIENDLY = "friendly";
const std::string NAME_AI = "ai";

UserBot* BotFactory::get(std::string name, PlayerName player_name) {
// Set name to lowercase
Expand All @@ -22,6 +24,7 @@ UserBot* BotFactory::get(std::string name, PlayerName player_name) {
// Return bot that matches name, or fail
if(name == NAME_NORMAL) { return new UserBotNormal(player_name); }
if(name == NAME_FRIENDLY) { return new UserBotFriendly(player_name); }
if(name == NAME_AI) { return new UserBotAI(player_name, false); }
else {
throw CaravanFatalException("Unknown bot name '" + name + "'.");
}
Expand Down

0 comments on commit 68c434d

Please sign in to comment.