Skip to content

Commit

Permalink
test sphinx action
Browse files Browse the repository at this point in the history
  • Loading branch information
hesic73 committed Feb 29, 2024
1 parent ea80823 commit aff5471
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
24 changes: 12 additions & 12 deletions gomoku_rl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,22 +163,22 @@ def reset(self, env_indices: torch.Tensor | None = None):
self.last_move[env_indices] = -1

def step(
self, action: torch.Tensor, env_indices: torch.Tensor | None = None
self, action: torch.Tensor, env_mask: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Performs actions in specified environments and updates their states.
"""Performs actions in specified environments and updates their states based on the provided action tensor. If an environment mask is provided, only the environments corresponding to `True` values in the mask are updated; otherwise, all environments are updated.
Args:
action (torch.Tensor): Actions to be performed, linearly indexed.
env_indices (torch.Tensor | None, optional): Indices of environments to update. Updates all if None.
action (torch.Tensor): 1D positions to place a stone, one per environment. Shape: (E,)
env_indices (torch.Tensor | None, optional): Boolean mask to select environments for updating. If `None`, updates all. Shape should match environments.
Returns:
tuple[torch.Tensor, torch.Tensor]: (done_statuses, invalid_actions) where:
done_statuses: Boolean tensor indicating if the game ended in each environment.
invalid_actions: Boolean tensor indicating if the action was invalid in each environment.
tuple[torch.Tensor, torch.Tensor]: A tuple containing two tensors:
- done_statuses: Boolean tensor with `True` where games ended.
- invalid_actions: Boolean tensor with `True` for invalid actions in environments.
"""

if env_indices is None:
env_indices = torch.ones_like(action, dtype=torch.bool)
if env_mask is None:
env_mask = torch.ones_like(action, dtype=torch.bool)

board_1d_view = self.board.view(self.num_envs, -1)

Expand All @@ -187,7 +187,7 @@ def step(
action,
] # (E,)

nop = (values_on_board != 0) | (~env_indices) # (E,)
nop = (values_on_board != 0) | (~env_mask) # (E,)
inc = torch.logical_not(nop).long() # (E,)
piece = torch.where(self.turn == 0, 1, -1)
board_1d_view[
Expand All @@ -208,11 +208,11 @@ def step(
self.turn = (self.turn + inc) % 2
self.last_move = torch.where(nop, self.last_move, action)

return self.done & env_indices, nop & env_indices
return self.done & env_mask, nop & env_mask

def get_encoded_board(self) -> torch.Tensor:
"""Encodes the current board state into a tensor format suitable for neural network input.
Returns:
torch.Tensor: Encoded board state, shaped (E, 3, B, B), with separate channels for the current player's stones, the opponent's stones, and the last move.
"""
Expand Down
13 changes: 13 additions & 0 deletions gomoku_rl/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ def make_transition(
tensordict_t: TensorDict,
tensordict_t_plus_1: TensorDict,
) -> TensorDict:
"""
Constructs a transition tensor dictionary for a two-player game by integrating the game state and actions from three consecutive time steps (t-1, t, and t+1).
Args:
tensordict_t_minus_1 (TensorDict): A tensor dictionary containing the game state and associated information at time t-1.
tensordict_t (TensorDict): A tensor dictionary containing the game state and associated information at time t.
tensordict_t_plus_1 (TensorDict): A tensor dictionary containing the game state and associated information at time t+1.
Returns:
TensorDict: A new tensor dictionary representing the transition from time t-1 to t+1.
The function calculates rewards based on the win status at times t and t+1, and flags the transition as done if the game ends at either time t or t+1. The resulting tensor dictionary is structured to facilitate learning from this transition in reinforcement learning algorithms.
"""
# if a player wins at time t, its opponent cannot win immediately after reset
reward: torch.Tensor = (
tensordict_t.get("win").float() -
Expand Down

0 comments on commit aff5471

Please sign in to comment.