Skip to content

Commit

Permalink
Update mario_rl_tutorial.py (#2381)
Browse files Browse the repository at this point in the history
* Update mario_rl_tutorial.py
Fixes #1620
---------
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
  • Loading branch information
neuralninja27 authored Jun 6, 2023
1 parent 2284ab2 commit 6e0fd0a
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions intermediate_source/mario_rl_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
# Super Mario environment for OpenAI Gym
import gym_super_mario_bros

from tensordict import TensorDict
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage

######################################################################
# RL Definitions
Expand Down Expand Up @@ -348,7 +350,7 @@ def act(self, state):
class Mario(Mario): # subclassing for continuity
def __init__(self, state_dim, action_dim, save_dir):
super().__init__(state_dim, action_dim, save_dir)
self.memory = deque(maxlen=100000)
self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000))
self.batch_size = 32

def cache(self, state, next_state, action, reward, done):
Expand All @@ -373,14 +375,15 @@ def first_if_tuple(x):
reward = torch.tensor([reward], device=self.device)
done = torch.tensor([done], device=self.device)

self.memory.append((state, next_state, action, reward, done,))
# self.memory.append((state, next_state, action, reward, done,))
self.memory.add(TensorDict({"state": state, "next_state": next_state, "action": action, "reward": reward, "done": done}, batch_size=[]))

def recall(self):
"""
Retrieve a batch of experiences from memory
"""
batch = random.sample(self.memory, self.batch_size)
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
batch = self.memory.sample(self.batch_size)
state, next_state, action, reward, done = (batch.get(key) for key in ("state", "next_state", "action", "reward", "done"))
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()


Expand Down

0 comments on commit 6e0fd0a

Please sign in to comment.