diff --git a/kaggle_environments/envs/chess/chess.py b/kaggle_environments/envs/chess/chess.py index d62c1f2a..91fc109a 100644 --- a/kaggle_environments/envs/chess/chess.py +++ b/kaggle_environments/envs/chess/chess.py @@ -1,3 +1,4 @@ +import math import random import json from os import path @@ -11,6 +12,29 @@ ACTIVE = "ACTIVE" WHITE = "white" +OPENINGS = [ + "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", + "rnbqkbnr/1p2pp1p/p2p2p1/8/2PNP3/8/PP3PPP/RNBQKB1R w KQkq - 0 6", + "r1b1kb1r/ppppq1pp/2n2n2/1B2p3/4N3/5N2/PPPPQPPP/R1B1K2R w KQkq - 3 7", + "rnbqkb1r/p2ppppp/5n2/2pP4/2p5/2N5/PP2PPPP/R1BQKBNR w KQkq - 0 5", + "rnbqk1nr/p1p1bppp/1p2p3/3pP3/3P4/2N5/PPP2PPP/R1BQKBNR w KQkq - 0 5", + "r2qk1nr/ppp2pp1/2np3p/2b1p3/2B1P1b1/2PP1N2/PP3PPP/RNBQ1RK1 w kq - 0 7", + "rn1qk1nr/pp2ppbp/3p2p1/2p5/2PP2b1/2N1PN2/PP3PPP/R1BQKB1R w KQkq c6 0 6", + "rnbqkbnr/1p2pp1p/p2p2p1/8/2PNP3/8/PP3PPP/RNBQKB1R w KQkq - 0 6", +] + + +MOVES = [ + "", + "e2e4 c7c5 g1f3 d7d6 d2d4 c5d4 f3d4 a7a6 c2c4 g7g6", + "e2e4 e7e5 g1f3 b8c6 f1b5 f7f5 b1c3 f5e4 c3e4 g8f6 d1e2 d8e7", + "d2d4 g8f6 c2c4 c7c5 d4d5 b7b5 b1c3 b5c4", + "e2e4 e7e6 d2d4 d7d5 b1c3 f8e7 e4e5 b7b6", + "e2e4 e7e5 g1f3 b8c6 f1c4 f8c5 e1g1 d7d6 c2c3 c8g4 d2d3 h7h6", + "d2d4 g7g6 c2c4 f8g7 b1c3 d7d6 g1f3 c8g4 e2e3 c7c5", + "e2e4 c7c5 g1f3 d7d6 d2d4 c5d4 f3d4 a7a6 c2c4 g7g6" +] + def random_agent(obs): """ @@ -20,7 +44,7 @@ def random_agent(obs): """ game = Game(obs.board) moves = list(game.get_moves()) - return random.choice(moves) + return random.choice(moves) if moves else None def king_shuffle_agent(obs): @@ -142,20 +166,25 @@ def square_str_to_int(square_str): seen_positions = defaultdict(int) game_one_complete = False +game_start_position = math.floor(random.random() * len(OPENINGS)) def interpreter(state, env): global seen_positions global game_one_complete + global game_start_position if env.done: game_one_complete = False seen_positions = defaultdict(int) + game_start_position = math.floor(random.random() * len(OPENINGS)) + state[0].observation.board = OPENINGS[game_start_position] + state[1].observation.board = OPENINGS[game_start_position] return state if state[0].status == ACTIVE and state[1].status == ACTIVE: # set up new game state[0].observation.mark, state[1].observation.mark = state[1].observation.mark, state[0].observation.mark - state[0].observation.board = Game().get_fen() - state[1].observation.board = Game().get_fen() + state[0].observation.board = OPENINGS[game_start_position] + state[1].observation.board = OPENINGS[game_start_position] state[0].status = ACTIVE if state[0].observation.mark == WHITE else INACTIVE state[0].status = ACTIVE if state[0].observation.mark == WHITE else INACTIVE return state @@ -203,7 +232,7 @@ def interpreter(state, env): inactive.status = terminal_state game_one_complete = True elif seen_positions[board_str] >= 3 or game.status == Game.STALEMATE: - active.status = terminal_state + active.status = terminal_state inactive.status = terminal_state game_one_complete = True elif game.status == Game.CHECKMATE: