Skip to content

Commit

Permalink
1.15.4: update to use Chessnut package (#291)
Browse files Browse the repository at this point in the history
1.15.4: update to use Chessnut package
  • Loading branch information
bovard authored Oct 30, 2024
1 parent 2f2ebd7 commit 00422e8
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 50 deletions.
2 changes: 1 addition & 1 deletion kaggle_environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .main import http_request
from . import errors

__version__ = "1.15.3"
__version__ = "1.15.4"

__all__ = ["Agent", "environments", "errors", "evaluate", "http_request",
"make", "register", "utils", "__version__",
Expand Down
165 changes: 118 additions & 47 deletions kaggle_environments/envs/chess/chess.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,107 @@
import chess
import random
import json
from os import path
from collections import defaultdict

from Chessnut import Game

ERROR = "ERROR"
DONE = "DONE"
INACTIVE = "INACTIVE"
ACTIVE = "ACTIVE"


def random_agent(obs):
"""
Selects a random legal move from the board.
Returns:
"""
Selects a random legal move from the board.
Returns:
A string representing the chosen move in UCI notation (e.g., "e2e4").
"""
board = obs.board
board_obj = chess.Board(board)
moves = list(board_obj.legal_moves)
return random.choice(moves).uci()
"""
game = Game(obs.board)
moves = list(game.get_moves())
return random.choice(moves)


def king_shuffle_agent(obs):
"""Moves the king pawn and then shuffles the king."""
game = Game(obs.board)
moves = list(game.get_moves())

to_move = ["e7e5", "e2e4", "e8e7", "e7e8", "e1e2", "e2e1"]
for move in to_move:
if move in moves:
return move

# If no other moves are possible, pick a random legal move (or return None)
return random.choice(moves) if moves else None


agents = {"random": random_agent, "king_shuffle": king_shuffle_agent}


def sufficient_material(pieces):
"""Checks if the given pieces are sufficient for checkmate."""
if pieces['q'] > 0 or pieces['r'] > 0 or pieces['p'] > 0:
return True
if pieces['n'] + pieces['b'] >= 3:
return True
# TODO: they have to be opposite color bishops
# b/b or n/b can checkmate, n/n cannot
if pieces['n'] < 2:
return True

return False


def is_insufficient_material(board):
white_pieces = defaultdict(int)
black_pieces = defaultdict(int)

for square in range(64):
piece = board.get_piece(square)
if piece:
if piece.isupper():
white_pieces[piece.lower()] += 1
else:
black_pieces[piece.lower()] += 1

if not sufficient_material(
white_pieces) and not sufficient_material(black_pieces):
return True

return False


def square_str_to_int(square_str):
"""Converts a square string (e.g., "a2") to an integer index (0-63)."""
if len(square_str) != 2 or not (
'a' <= square_str[0] <= 'h' and '1' <= square_str[1] <= '8'):
raise ValueError("Invalid square string")

col = ord(square_str[0]) - ord('a') # Get column index (0-7)
row = int(square_str[1]) - 1 # Get row index (0-7)
return row * 8 + col


def is_pawn_move_or_capture(board, move):
move = move.lower()
if board.get_piece(square_str_to_int(move[2:4])).lower() == "p":
return True
if board.get_piece(square_str_to_int(move[:2])) != " ":
return True
return False


seen_positions = defaultdict(int)
pawn_or_capture_move_count = 0

agents = {"random": random_agent}

def interpreter(state, env):
global seen_positions
global pawn_or_capture_move_count
if env.done:
seen_positions = defaultdict(int)
pawn_or_capture_move_count = 0
return state

# Isolate the active and inactive agents.
Expand All @@ -36,74 +115,66 @@ def interpreter(state, env):
# The board is shared, only update the first state.
board = state[0].observation.board

# Create a chess board object from the FEN string
board_obj = chess.Board(board)
# Create a chessnut game object from the FEN string
game = Game(board)

# Get the action (move) from the agent
action = active.action

if action and is_pawn_move_or_capture(game.board, action):
pawn_or_capture_move_count = 0
else:
pawn_or_capture_move_count += 1

# Check if the move is legal
try:
move = chess.Move.from_uci(action)
if not board_obj.is_legal(move):
raise ValueError("Illegal move")
except:
game.apply_move(action)
except BaseException:
active.status = ERROR
active.reward = -1
inactive.status = DONE
return state

# Make the move
board_obj.push(move)
board_str = game.get_fen().split(" ", maxsplit=1)[0]
seen_positions[board_str] += 1

# Update the board in the observation
state[0].observation.board = board_obj.fen()
state[1].observation.board = board_obj.fen()
state[0].observation.board = game.get_fen()
state[1].observation.board = game.get_fen()

# Check for game end conditions
if board_obj.is_checkmate():
active.reward = 1
if pawn_or_capture_move_count == 100 or is_insufficient_material(
game.board):
active.status = DONE
inactive.reward = -1
inactive.status = DONE
elif board_obj.is_stalemate() or board_obj.is_insufficient_material() or board_obj.is_game_over():
elif seen_positions[board_str] >= 3 or game.status == Game.STALEMATE:
active.status = DONE
inactive.status = DONE
elif game.status == Game.CHECKMATE:
active.reward = 1
active.status = DONE
inactive.reward = -1
inactive.status = DONE

else:
# Switch turns
active.status = INACTIVE
inactive.status = ACTIVE

return state


def renderer(state, env):
board_str = state[0].observation.board
board_obj = chess.Board(board_str)

# Unicode characters for chess pieces
piece_symbols = {
'P': '♙', 'R': '♖', 'N': '♘', 'B': '♗', 'Q': '♕', 'K': '♔',
'p': '♟', 'r': '♜', 'n': '♞', 'b': '♝', 'q': '♛', 'k': '♚',
'.': ' ' # Empty square
}

board_repr = ""
for square in chess.SQUARES:
piece = board_obj.piece_at(square)
if piece:
board_repr += piece_symbols[piece.symbol()]
else:
board_repr += piece_symbols['.']
if chess.square_file(square) == 7: # End of a rank
board_repr += "\n"
board_fen = state[0].observation.board
game = Game(board_fen)
return game.board

return board_repr

jsonpath = path.abspath(path.join(path.dirname(__file__), "chess.json"))
with open(jsonpath) as f:
specification = json.load(f)


def html_renderer():
jspath = path.abspath(path.join(path.dirname(__file__), "chess.js"))
with open(jspath) as f:
return f.read()
with open(jspath) as g:
return g.read()
10 changes: 9 additions & 1 deletion kaggle_environments/envs/chess/test_chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,12 @@ def test_chess_inits():
env.run(["random", "random"])
json = env.toJSON()
assert json["name"] == "chess"
assert json["statuses"] == ["ERROR", "DONE"]
assert json["statuses"] == ["DONE", "DONE"]

def test_chess_three_fold():
env = make("chess", debug=True)
env.run(["king_shuffle", "king_shuffle"])
json = env.toJSON()
assert json["name"] == "chess"
assert json["statuses"] == ["DONE", "DONE"]
assert json["rewards"] == [0, 0]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_version(rel_path):
"transformers >= 4.33.1",
"scipy >= 1.11.2",
"shimmy >= 1.2.1",
"chess >= 1.11.0",
"Chessnut >= 0.3.1",
],
packages=find_packages(),
include_package_data=True,
Expand Down

0 comments on commit 00422e8

Please sign in to comment.