Skip to content

Commit

Permalink
[Feature] flexible batch_locked for jumanji
Browse files Browse the repository at this point in the history
ghstack-source-id: 383470ab68a0ff84009d7152e0d39f29083bb10d
Pull Request resolved: #2382
  • Loading branch information
vmoens committed Aug 8, 2024
1 parent c71c44c commit fed07f5
Showing 1 changed file with 39 additions and 17 deletions.
56 changes: 39 additions & 17 deletions torchrl/envs/libs/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,9 @@ def __init__(
self.categorical_action_encoding = categorical_action_encoding
if env is not None:
kwargs["env"] = env
batch_locked = kwargs.pop("batch_locked", kwargs.get("batch_size") is not None)
super().__init__(**kwargs)
self._batch_locked = batch_locked

def _build_env(
self,
Expand Down Expand Up @@ -486,17 +488,17 @@ def _set_seed(self, seed):
raise Exception("Jumanji requires an integer seed.")
self.key = jax.random.PRNGKey(seed)

def read_state(self, state):
state_dict = _object_to_tensordict(state, self.device, self.batch_size)
def read_state(self, state, batch_size=None):
state_dict = _object_to_tensordict(state, self.device, self.batch_size if batch_size is None else batch_size)
return self.state_spec["state"].encode(state_dict)

def read_obs(self, obs):
def read_obs(self, obs, batch_size=None):
from jax import numpy as jnp

if isinstance(obs, (list, jnp.ndarray, np.ndarray)):
obs_dict = _ndarray_to_tensor(obs).to(self.device)
else:
obs_dict = _object_to_tensordict(obs, self.device, self.batch_size)
obs_dict = _object_to_tensordict(obs, self.device, self.batch_size if batch_size is None else batch_size)
return super().read_obs(obs_dict)

def render(
Expand Down Expand Up @@ -579,25 +581,35 @@ def render(

def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
import jax
if self.batch_locked:
batch_size = self.batch_size
else:
batch_size = tensordict.batch_size

# prepare inputs
_state_example = self._state_example
if not self.batch_locked and _state_example.batch_size != tensordict.batch_size:
_state_example = _state_example.expand(tensordict.batch_size)
else:
_state_example = self._state_example

state = _tensordict_to_object(tensordict.get("state"), self._state_example)
action = self.read_action(tensordict.get("action"))

# flatten batch size into vector
state = _tree_flatten(state, self.batch_size)
action = _tree_flatten(action, self.batch_size)
state = _tree_flatten(state, batch_size)
action = _tree_flatten(action, batch_size)

# jax vectorizing map on env.step
state, timestep = jax.vmap(self._env.step)(state, action)

# reshape batch size from vector
state = _tree_reshape(state, self.batch_size)
timestep = _tree_reshape(timestep, self.batch_size)
state = _tree_reshape(state, batch_size)
timestep = _tree_reshape(timestep, batch_size)

# collect outputs
state_dict = self.read_state(state)
obs_dict = self.read_obs(timestep.observation)
state_dict = self.read_state(state, batch_size=batch_size)
obs_dict = self.read_obs(timestep.observation, batch_size=batch_size)
reward = self.read_reward(np.asarray(timestep.reward))
done = timestep.step_type == self.lib.types.StepType.LAST
done = _ndarray_to_tensor(done).view(torch.bool).to(self.device)
Expand All @@ -622,25 +634,35 @@ def _reset(
import jax
from jax import numpy as jnp

if self.batch_locked:
numel = self.numel()
batch_size = self.batch_size
else:
numel = tensordict.numel()
batch_size = tensordict.batch_size

# generate random keys
self.key, *keys = jax.random.split(self.key, self.numel() + 1)
self.key, *keys = jax.random.split(self.key, numel + 1)

# jax vectorizing map on env.reset
state, timestep = jax.vmap(self._env.reset)(jnp.stack(keys))

# reshape batch size from vector
state = _tree_reshape(state, self.batch_size)
timestep = _tree_reshape(timestep, self.batch_size)
state = _tree_reshape(state, batch_size)
timestep = _tree_reshape(timestep, batch_size)

# collect outputs
state_dict = self.read_state(state)
obs_dict = self.read_obs(timestep.observation)
done_td = self.full_done_spec.zero()
state_dict = self.read_state(state, batch_size=batch_size)
obs_dict = self.read_obs(timestep.observation, batch_size=batch_size)
if not self.batch_locked:
done_td = self.full_done_spec.zero(batch_size)
else:
done_td = self.full_done_spec.zero()

# build results
tensordict_out = TensorDict(
source=obs_dict,
batch_size=self.batch_size,
batch_size=batch_size,
device=self.device,
)
tensordict_out.update(done_td)
Expand Down

0 comments on commit fed07f5

Please sign in to comment.