Skip to content

Commit

Permalink
Add catch to SARSA and Q-learning tests and fix a bug by which the al…
Browse files Browse the repository at this point in the history
…gorithms didn't handle games with chance start nodes.

PiperOrigin-RevId: 359502727
Change-Id: Iaf8f9ce9d640d33d2765c7e9b41a13c9aef23fb7
  • Loading branch information
Satyaki Upadhyay authored and open_spiel@google.com committed Feb 25, 2021
1 parent 4de00c5 commit 2f2e0e7
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 15 deletions.
17 changes: 11 additions & 6 deletions open_spiel/algorithms/tabular_q_learning.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ Action TabularQLearningSolver::SampleActionFromEpsilonGreedyPolicy(
return GetBestAction(state, min_utility);
}

void TabularQLearningSolver::SampleUntilNextStateOrTerminal(State* state) {
// Repeatedly sample while chance node, so that we end up at a decision node
while (state->IsChanceNode() && !state->IsTerminal()) {
vector<Action> legal_actions = state->LegalActions();
state->ApplyAction(
legal_actions[absl::Uniform<int>(rng_, 0, legal_actions.size())]);
}
}

TabularQLearningSolver::TabularQLearningSolver(std::shared_ptr<const Game> game)
: game_(game),
depth_limit_(kDefaultDepthLimit),
Expand Down Expand Up @@ -95,6 +104,7 @@ void TabularQLearningSolver::RunIteration() {
const double min_utility = game_->MinUtility();
// Choose start state
std::unique_ptr<State> curr_state = game_->NewInitialState();
SampleUntilNextStateOrTerminal(curr_state.get());

while (!curr_state->IsTerminal()) {
const Player player = curr_state->CurrentPlayer();
Expand All @@ -104,12 +114,7 @@ void TabularQLearningSolver::RunIteration() {
SampleActionFromEpsilonGreedyPolicy(*(curr_state.get()), min_utility);

std::unique_ptr<State> next_state = curr_state->Child(curr_action);
// Repeatedly sample while chance node, so that we end up at a decision node
while (next_state->IsChanceNode() && !next_state->IsTerminal()) {
vector<Action> legal_actions = next_state->LegalActions();
next_state->ApplyAction(
legal_actions[absl::Uniform<int>(rng_, 0, legal_actions.size())]);
}
SampleUntilNextStateOrTerminal(curr_state.get());

const double reward = next_state->Rewards()[player];
// Next q-value in perspective of player to play at curr_state (important
Expand Down
4 changes: 4 additions & 0 deletions open_spiel/algorithms/tabular_q_learning.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class TabularQLearningSolver {
Action SampleActionFromEpsilonGreedyPolicy(const State& state,
double min_utility);

// Moves a chance node to the next decision/terminal node by sampling from
// the legal actions repeatedly
void SampleUntilNextStateOrTerminal(State* state);

std::shared_ptr<const Game> game_;
int depth_limit_;
double epsilon_;
Expand Down
17 changes: 11 additions & 6 deletions open_spiel/algorithms/tabular_sarsa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ Action TabularSarsaSolver::SampleActionFromEpsilonGreedyPolicy(
return GetBestAction(state, min_utility);
}

void TabularSarsaSolver::SampleUntilNextStateOrTerminal(State* state) {
// Repeatedly sample while chance node, so that we end up at a decision node
while (state->IsChanceNode() && !state->IsTerminal()) {
vector<Action> legal_actions = state->LegalActions();
state->ApplyAction(
legal_actions[absl::Uniform<int>(rng_, 0, legal_actions.size())]);
}
}

TabularSarsaSolver::TabularSarsaSolver(std::shared_ptr<const Game> game)
: game_(game),
depth_limit_(kDefaultDepthLimit),
Expand Down Expand Up @@ -89,6 +98,7 @@ void TabularSarsaSolver::RunIteration() {
double min_utility = game_->MinUtility();
// Choose start state
std::unique_ptr<State> curr_state = game_->NewInitialState();
SampleUntilNextStateOrTerminal(curr_state.get());

Player player = curr_state->CurrentPlayer();
// Sample action from the state using an epsilon-greedy policy
Expand All @@ -97,12 +107,7 @@ void TabularSarsaSolver::RunIteration() {

while (!curr_state->IsTerminal()) {
std::unique_ptr<State> next_state = curr_state->Child(curr_action);
// Repeatedly sample while chance node, so that we end up at a decision node
while (next_state->IsChanceNode() && !next_state->IsTerminal()) {
vector<Action> legal_actions = next_state->LegalActions();
next_state->ApplyAction(
legal_actions[absl::Uniform<int>(rng_, 0, legal_actions.size())]);
}
SampleUntilNextStateOrTerminal(curr_state.get());
const double reward = next_state->Rewards()[player];

const Action next_action =
Expand Down
4 changes: 4 additions & 0 deletions open_spiel/algorithms/tabular_sarsa.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class TabularSarsaSolver {
Action SampleActionFromEpsilonGreedyPolicy(const State& state,
double min_utility);

// Moves a chance node to the next decision/terminal node by sampling from
// the legal actions repeatedly
void SampleUntilNextStateOrTerminal(State* state);

std::shared_ptr<const Game> game_;
int depth_limit_;
double epsilon_;
Expand Down
27 changes: 27 additions & 0 deletions open_spiel/examples/tabular_q_learning_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,34 @@ void SolveTicTacToe() {
SPIEL_CHECK_EQ(state->Rewards()[1], 0);
}

void SolveCatch() {
std::shared_ptr<const Game> game = open_spiel::LoadGame("catch");
open_spiel::algorithms::TabularQLearningSolver tabular_q_learning_solver(
game);

int training_iter = 100000;
while (training_iter-- > 0) {
tabular_q_learning_solver.RunIteration();
}
const absl::flat_hash_map<std::pair<std::string, Action>, double>& q_values =
tabular_q_learning_solver.GetQValueTable();

int eval_iter = 1000;
int total_reward = 0;
while (eval_iter-- > 0) {
std::unique_ptr<State> state = game->NewInitialState();
while (!state->IsTerminal()) {
Action optimal_action = GetOptimalAction(q_values, state);
state->ApplyAction(optimal_action);
total_reward += state->Rewards()[0];
}
}

SPIEL_CHECK_GT(total_reward, 0);
}

int main(int argc, char** argv) {
SolveTicTacToe();
SolveCatch();
return 0;
}
32 changes: 29 additions & 3 deletions open_spiel/examples/tabular_sarsa_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ Action GetOptimalAction(

void SolveTicTacToe() {
std::shared_ptr<const Game> game = open_spiel::LoadGame("tic_tac_toe");
open_spiel::algorithms::TabularSarsaSolver sarsa_solver(game);
open_spiel::algorithms::TabularSarsaSolver tabular_sarsa_solver(game);

int iter = 100000;
while (iter-- > 0) {
sarsa_solver.RunIteration();
tabular_sarsa_solver.RunIteration();
}

const absl::flat_hash_map<std::pair<std::string, Action>, double>& q_values =
sarsa_solver.GetQValueTable();
tabular_sarsa_solver.GetQValueTable();
std::unique_ptr<State> state = game->NewInitialState();
while (!state->IsTerminal()) {
Action optimal_action = GetOptimalAction(q_values, state);
Expand All @@ -67,7 +67,33 @@ void SolveTicTacToe() {
SPIEL_CHECK_EQ(state->Rewards()[1], 0);
}

void SolveCatch() {
std::shared_ptr<const Game> game = open_spiel::LoadGame("catch");
open_spiel::algorithms::TabularSarsaSolver tabular_sarsa_solver(game);

int training_iter = 100000;
while (training_iter-- > 0) {
tabular_sarsa_solver.RunIteration();
}
const absl::flat_hash_map<std::pair<std::string, Action>, double>& q_values =
tabular_sarsa_solver.GetQValueTable();

int eval_iter = 1000;
int total_reward = 0;
while (eval_iter-- > 0) {
std::unique_ptr<State> state = game->NewInitialState();
while (!state->IsTerminal()) {
Action optimal_action = GetOptimalAction(q_values, state);
state->ApplyAction(optimal_action);
total_reward += state->Rewards()[0];
}
}

SPIEL_CHECK_GT(total_reward, 0);
}

int main(int argc, char** argv) {
SolveTicTacToe();
SolveCatch();
return 0;
}

0 comments on commit 2f2e0e7

Please sign in to comment.