diff --git a/torchrl/envs/libs/jumanji.py b/torchrl/envs/libs/jumanji.py index dbbc980e8cc..4ceaa35f56d 100644 --- a/torchrl/envs/libs/jumanji.py +++ b/torchrl/envs/libs/jumanji.py @@ -343,7 +343,9 @@ def __init__( self.categorical_action_encoding = categorical_action_encoding if env is not None: kwargs["env"] = env + batch_locked = kwargs.pop("batch_locked", kwargs.get("batch_size") is not None) super().__init__(**kwargs) + self._batch_locked = batch_locked def _build_env( self, @@ -486,17 +488,17 @@ def _set_seed(self, seed): raise Exception("Jumanji requires an integer seed.") self.key = jax.random.PRNGKey(seed) - def read_state(self, state): - state_dict = _object_to_tensordict(state, self.device, self.batch_size) + def read_state(self, state, batch_size=None): + state_dict = _object_to_tensordict(state, self.device, self.batch_size if batch_size is None else batch_size) return self.state_spec["state"].encode(state_dict) - def read_obs(self, obs): + def read_obs(self, obs, batch_size=None): from jax import numpy as jnp if isinstance(obs, (list, jnp.ndarray, np.ndarray)): obs_dict = _ndarray_to_tensor(obs).to(self.device) else: - obs_dict = _object_to_tensordict(obs, self.device, self.batch_size) + obs_dict = _object_to_tensordict(obs, self.device, self.batch_size if batch_size is None else batch_size) return super().read_obs(obs_dict) def render( @@ -579,25 +581,35 @@ def render( def _step(self, tensordict: TensorDictBase) -> TensorDictBase: import jax + if self.batch_locked: + batch_size = self.batch_size + else: + batch_size = tensordict.batch_size # prepare inputs + _state_example = self._state_example + if not self.batch_locked and _state_example.batch_size != tensordict.batch_size: + _state_example = _state_example.expand(tensordict.batch_size) + else: + _state_example = self._state_example + state = _tensordict_to_object(tensordict.get("state"), self._state_example) action = self.read_action(tensordict.get("action")) # flatten batch size into vector - state = _tree_flatten(state, self.batch_size) - action = _tree_flatten(action, self.batch_size) + state = _tree_flatten(state, batch_size) + action = _tree_flatten(action, batch_size) # jax vectorizing map on env.step state, timestep = jax.vmap(self._env.step)(state, action) # reshape batch size from vector - state = _tree_reshape(state, self.batch_size) - timestep = _tree_reshape(timestep, self.batch_size) + state = _tree_reshape(state, batch_size) + timestep = _tree_reshape(timestep, batch_size) # collect outputs - state_dict = self.read_state(state) - obs_dict = self.read_obs(timestep.observation) + state_dict = self.read_state(state, batch_size=batch_size) + obs_dict = self.read_obs(timestep.observation, batch_size=batch_size) reward = self.read_reward(np.asarray(timestep.reward)) done = timestep.step_type == self.lib.types.StepType.LAST done = _ndarray_to_tensor(done).view(torch.bool).to(self.device) @@ -622,25 +634,35 @@ def _reset( import jax from jax import numpy as jnp + if self.batch_locked: + numel = self.numel() + batch_size = self.batch_size + else: + numel = tensordict.numel() + batch_size = tensordict.batch_size + # generate random keys - self.key, *keys = jax.random.split(self.key, self.numel() + 1) + self.key, *keys = jax.random.split(self.key, numel + 1) # jax vectorizing map on env.reset state, timestep = jax.vmap(self._env.reset)(jnp.stack(keys)) # reshape batch size from vector - state = _tree_reshape(state, self.batch_size) - timestep = _tree_reshape(timestep, self.batch_size) + state = _tree_reshape(state, batch_size) + timestep = _tree_reshape(timestep, batch_size) # collect outputs - state_dict = self.read_state(state) - obs_dict = self.read_obs(timestep.observation) - done_td = self.full_done_spec.zero() + state_dict = self.read_state(state, batch_size=batch_size) + obs_dict = self.read_obs(timestep.observation, batch_size=batch_size) + if not self.batch_locked: + done_td = self.full_done_spec.zero(batch_size) + else: + done_td = self.full_done_spec.zero() # build results tensordict_out = TensorDict( source=obs_dict, - batch_size=self.batch_size, + batch_size=batch_size, device=self.device, ) tensordict_out.update(done_td)