From aff547122d68100384eb78f63eedda89bbe937df Mon Sep 17 00:00:00 2001 From: hesic73 Date: Thu, 29 Feb 2024 13:27:03 -0800 Subject: [PATCH] test sphinx action --- gomoku_rl/core.py | 24 ++++++++++++------------ gomoku_rl/env.py | 13 +++++++++++++ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/gomoku_rl/core.py b/gomoku_rl/core.py index 5c87038..3416ae7 100644 --- a/gomoku_rl/core.py +++ b/gomoku_rl/core.py @@ -163,22 +163,22 @@ def reset(self, env_indices: torch.Tensor | None = None): self.last_move[env_indices] = -1 def step( - self, action: torch.Tensor, env_indices: torch.Tensor | None = None + self, action: torch.Tensor, env_mask: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor]: - """Performs actions in specified environments and updates their states. + """Performs actions in specified environments and updates their states based on the provided action tensor. If an environment mask is provided, only the environments corresponding to `True` values in the mask are updated; otherwise, all environments are updated. Args: - action (torch.Tensor): Actions to be performed, linearly indexed. - env_indices (torch.Tensor | None, optional): Indices of environments to update. Updates all if None. + action (torch.Tensor): 1D positions to place a stone, one per environment. Shape: (E,) + env_indices (torch.Tensor | None, optional): Boolean mask to select environments for updating. If `None`, updates all. Shape should match environments. Returns: - tuple[torch.Tensor, torch.Tensor]: (done_statuses, invalid_actions) where: - done_statuses: Boolean tensor indicating if the game ended in each environment. - invalid_actions: Boolean tensor indicating if the action was invalid in each environment. + tuple[torch.Tensor, torch.Tensor]: A tuple containing two tensors: + - done_statuses: Boolean tensor with `True` where games ended. + - invalid_actions: Boolean tensor with `True` for invalid actions in environments. """ - if env_indices is None: - env_indices = torch.ones_like(action, dtype=torch.bool) + if env_mask is None: + env_mask = torch.ones_like(action, dtype=torch.bool) board_1d_view = self.board.view(self.num_envs, -1) @@ -187,7 +187,7 @@ def step( action, ] # (E,) - nop = (values_on_board != 0) | (~env_indices) # (E,) + nop = (values_on_board != 0) | (~env_mask) # (E,) inc = torch.logical_not(nop).long() # (E,) piece = torch.where(self.turn == 0, 1, -1) board_1d_view[ @@ -208,11 +208,11 @@ def step( self.turn = (self.turn + inc) % 2 self.last_move = torch.where(nop, self.last_move, action) - return self.done & env_indices, nop & env_indices + return self.done & env_mask, nop & env_mask def get_encoded_board(self) -> torch.Tensor: """Encodes the current board state into a tensor format suitable for neural network input. - + Returns: torch.Tensor: Encoded board state, shaped (E, 3, B, B), with separate channels for the current player's stones, the opponent's stones, and the last move. """ diff --git a/gomoku_rl/env.py b/gomoku_rl/env.py index a79a5fa..5556288 100644 --- a/gomoku_rl/env.py +++ b/gomoku_rl/env.py @@ -25,6 +25,19 @@ def make_transition( tensordict_t: TensorDict, tensordict_t_plus_1: TensorDict, ) -> TensorDict: + """ + Constructs a transition tensor dictionary for a two-player game by integrating the game state and actions from three consecutive time steps (t-1, t, and t+1). + + Args: + tensordict_t_minus_1 (TensorDict): A tensor dictionary containing the game state and associated information at time t-1. + tensordict_t (TensorDict): A tensor dictionary containing the game state and associated information at time t. + tensordict_t_plus_1 (TensorDict): A tensor dictionary containing the game state and associated information at time t+1. + + Returns: + TensorDict: A new tensor dictionary representing the transition from time t-1 to t+1. + + The function calculates rewards based on the win status at times t and t+1, and flags the transition as done if the game ends at either time t or t+1. The resulting tensor dictionary is structured to facilitate learning from this transition in reinforcement learning algorithms. + """ # if a player wins at time t, its opponent cannot win immediately after reset reward: torch.Tensor = ( tensordict_t.get("win").float() -