diff --git a/intermediate_source/mario_rl_tutorial.py b/intermediate_source/mario_rl_tutorial.py index 8d02f3daf3..eb46feb2ad 100755 --- a/intermediate_source/mario_rl_tutorial.py +++ b/intermediate_source/mario_rl_tutorial.py @@ -53,6 +53,8 @@ # Super Mario environment for OpenAI Gym import gym_super_mario_bros +from tensordict import TensorDict +from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage ###################################################################### # RL Definitions @@ -348,7 +350,7 @@ def act(self, state): class Mario(Mario): # subclassing for continuity def __init__(self, state_dim, action_dim, save_dir): super().__init__(state_dim, action_dim, save_dir) - self.memory = deque(maxlen=100000) + self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000)) self.batch_size = 32 def cache(self, state, next_state, action, reward, done): @@ -373,14 +375,15 @@ def first_if_tuple(x): reward = torch.tensor([reward], device=self.device) done = torch.tensor([done], device=self.device) - self.memory.append((state, next_state, action, reward, done,)) + # self.memory.append((state, next_state, action, reward, done,)) + self.memory.add(TensorDict({"state": state, "next_state": next_state, "action": action, "reward": reward, "done": done}, batch_size=[])) def recall(self): """ Retrieve a batch of experiences from memory """ - batch = random.sample(self.memory, self.batch_size) - state, next_state, action, reward, done = map(torch.stack, zip(*batch)) + batch = self.memory.sample(self.batch_size) + state, next_state, action, reward, done = (batch.get(key) for key in ("state", "next_state", "action", "reward", "done")) return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()