-
Notifications
You must be signed in to change notification settings - Fork 641
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
wip: still needs batch logic for act and tdmp
- Loading branch information
1 parent
8c56770
commit ba91976
Showing
11 changed files
with
240 additions
and
100 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from abc import abstractmethod | ||
from collections import deque | ||
|
||
import torch | ||
from torch import Tensor, nn | ||
|
||
|
||
class AbstractPolicy(nn.Module): | ||
@abstractmethod | ||
def update(self, replay_buffer, step): | ||
"""One step of the policy's learning algorithm.""" | ||
pass | ||
|
||
def save(self, fp): | ||
torch.save(self.state_dict(), fp) | ||
|
||
def load(self, fp): | ||
d = torch.load(fp) | ||
self.load_state_dict(d) | ||
|
||
@abstractmethod | ||
def select_action(self, observation) -> Tensor: | ||
"""Select an action (or trajectory of actions) based on an observation during rollout. | ||
Should return a (batch_size, n_action_steps, *) tensor of actions. | ||
""" | ||
pass | ||
|
||
def forward(self, *args, **kwargs): | ||
"""Inference step that makes multi-step policies compatible with their single-step environments. | ||
WARNING: In general, this should not be overriden. | ||
Consider a "policy" that observes the environment then charts a course of N actions to take. To make this fit | ||
into the formalism of a TorchRL environment, we view it as being effectively a policy that (1) makes an | ||
observation and prepares a queue of actions, (2) consumes that queue when queried, regardless of the environment | ||
observation, (3) repopulates the action queue when empty. This method handles the aforementioned logic so that | ||
the subclass doesn't have to. | ||
This method effectively wraps the `select_action` method of the subclass. The following assumptions are made: | ||
1. The `select_action` method returns a Tensor of actions with shape (B, H, *) where B is the batch size, H is | ||
the action trajectory horizon and * is the action dimensions. | ||
2. Prior to the `select_action` method being called, theres is an `n_action_steps` instance attribute defined. | ||
""" | ||
n_action_steps_attr = "n_action_steps" | ||
if not hasattr(self, n_action_steps_attr): | ||
raise RuntimeError(f"Underlying policy must have an `{n_action_steps_attr}` attribute") | ||
if not hasattr(self, "_action_queue"): | ||
self._action_queue = deque([], maxlen=getattr(self, n_action_steps_attr)) | ||
if len(self._action_queue) == 0: | ||
# Each element in the queue has shape (B, *). | ||
self._action_queue.extend(self.select_action(*args, **kwargs).transpose(0, 1)) | ||
|
||
return self._action_queue.popleft() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.