diff --git a/src/ale_interface.hpp b/src/ale_interface.hpp index 126f4e394..832adf592 100644 --- a/src/ale_interface.hpp +++ b/src/ale_interface.hpp @@ -87,7 +87,7 @@ class ALEInterface { // when necessary - this method will keep pressing buttons on the // game over screen. reward_t act(Action action); - + // Applies a continuous action to the game and returns the reward. It is the // user's responsibility to check if the game has ended and reset // when necessary - this method will keep pressing buttons on the diff --git a/src/python/env.py b/src/python/env.py index c8c42743e..a1d3cba6b 100644 --- a/src/python/env.py +++ b/src/python/env.py @@ -160,21 +160,21 @@ def __init__( self.continuous = continuous self.continuous_action_threshold = continuous_action_threshold if continuous: - # We don't need action_set for continuous actions. - self._action_set = None - # Actions are radius, theta, and fire, where first two are the - # parameters of polar coordinates. - self._action_space = spaces.Box( - np.array([0, -1, 0]).astype(np.float32), - np.array([+1, +1, +1]).astype(np.float32), - ) # radius, theta, fire. First two are polar coordinates. + # We don't need action_set for continuous actions. + self._action_set = None + # Actions are radius, theta, and fire, where first two are the + # parameters of polar coordinates. + self._action_space = spaces.Box( + np.array([0, -1, 0]).astype(np.float32), + np.array([+1, +1, +1]).astype(np.float32), + ) # radius, theta, fire. First two are polar coordinates. else: - self._action_set = ( - self.ale.getLegalActionSet() - if full_action_space - else self.ale.getMinimalActionSet() - ) - self._action_space = spaces.Discrete(len(self._action_set)) + self._action_set = ( + self.ale.getLegalActionSet() + if full_action_space + else self.ale.getMinimalActionSet() + ) + self._action_space = spaces.Discrete(len(self._action_set)) # initialize observation space if self._obs_type == "ram": @@ -270,7 +270,7 @@ def step( # pyright: ignore[reportIncompatibleMethodOverride] if self.continuous: action = tuple(action) if len(action) != 3: - raise ValueError('Actions must have 3-dimensions.') + raise ValueError("Actions must have 3-dimensions.") r, theta, fire = action reward += self.ale.actContinuous(r, theta, fire)