-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[add-fire] Memory class abstraction #4375
Changes from all commits
d523c8c
f2873b2
b97b1e5
cd509dd
c66ecba
07bb4c0
0a3c795
2337d15
1f69102
c0a77f7
f404834
beab310
6fece65
eac1dc9
d2e31aa
37feeee
bf485a2
bd90e29
7f4ea51
317454a
4845ed0
848a875
295094b
72bca86
683001f
77a47f4
35de032
7dba7bf
1d6a08c
9bb065c
6eb09d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
import torch | ||
import abc | ||
from typing import Tuple | ||
from enum import Enum | ||
|
||
|
||
|
@@ -82,3 +84,66 @@ def lstm_layer( | |
forget_bias | ||
) | ||
return lstm | ||
|
||
|
||
class MemoryModule(torch.nn.Module): | ||
@abc.abstractproperty | ||
def memory_size(self) -> int: | ||
""" | ||
Size of memory that is required at the start of a sequence. | ||
""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def forward( | ||
self, input_tensor: torch.Tensor, memories: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Pass a sequence to the memory module. | ||
:input_tensor: Tensor of shape (batch_size, seq_length, size) that represents the input. | ||
:memories: Tensor of initial memories. | ||
:return: Tuple of output, final memories. | ||
""" | ||
pass | ||
|
||
|
||
class LSTM(MemoryModule): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a test specific for this Module. (unless it is already tested enough in the network_tests) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's tested in both the network tests and the lstm_layer is tested in the layers test, but I can add another test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
""" | ||
Memory module that implements LSTM. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add more details in this comment about the division/multiplication by 2 and why we do it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
""" | ||
|
||
def __init__( | ||
self, | ||
input_size: int, | ||
memory_size: int, | ||
num_layers: int = 1, | ||
forget_bias: float = 1.0, | ||
kernel_init: Initialization = Initialization.XavierGlorotUniform, | ||
bias_init: Initialization = Initialization.Zero, | ||
): | ||
super().__init__() | ||
# We set hidden size to half of memory_size since the initial memory | ||
# will be divided between the hidden state and initial cell state. | ||
self.hidden_size = memory_size // 2 | ||
self.lstm = lstm_layer( | ||
input_size, | ||
self.hidden_size, | ||
num_layers, | ||
True, | ||
forget_bias, | ||
kernel_init, | ||
bias_init, | ||
) | ||
|
||
@property | ||
def memory_size(self) -> int: | ||
return 2 * self.hidden_size | ||
|
||
def forward( | ||
self, input_tensor: torch.Tensor, memories: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
h0, c0 = torch.split(memories, self.hidden_size, dim=-1) | ||
hidden = (h0, c0) | ||
lstm_out, hidden_out = self.lstm(input_tensor, hidden) | ||
output_mem = torch.cat(hidden_out, dim=-1) | ||
return lstm_out, output_mem |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
from typing import Callable, List, Dict, Tuple, Optional | ||
import attr | ||
import abc | ||
|
||
import torch | ||
|
@@ -14,7 +13,7 @@ | |
from mlagents.trainers.settings import NetworkSettings | ||
from mlagents.trainers.torch.utils import ModelUtils | ||
from mlagents.trainers.torch.decoders import ValueHeads | ||
from mlagents.trainers.torch.layers import lstm_layer | ||
from mlagents.trainers.torch.layers import LSTM | ||
|
||
ActivationFunction = Callable[[torch.Tensor], torch.Tensor] | ||
EncoderFunction = Callable[ | ||
|
@@ -51,9 +50,9 @@ def __init__( | |
) | ||
|
||
if self.use_lstm: | ||
self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True) | ||
self.lstm = LSTM(self.h_size, self.m_size) | ||
else: | ||
self.lstm = None | ||
self.lstm = None # type: ignore | ||
|
||
def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None: | ||
for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders): | ||
|
@@ -64,6 +63,10 @@ def copy_normalization(self, other_network: "NetworkBody") -> None: | |
for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders): | ||
n1.copy_normalization(n2) | ||
|
||
@property | ||
def memory_size(self) -> int: | ||
return self.lstm.memory_size if self.use_lstm else 0 | ||
|
||
def forward( | ||
self, | ||
vec_inputs: List[torch.Tensor], | ||
|
@@ -99,10 +102,8 @@ def forward( | |
if self.use_lstm: | ||
# Resize to (batch, sequence length, encoding size) | ||
encoding = encoding.reshape([-1, sequence_length, self.h_size]) | ||
memories = torch.split(memories, self.m_size // 2, dim=-1) | ||
encoding, memories = self.lstm(encoding, memories) | ||
encoding = encoding.reshape([-1, self.m_size // 2]) | ||
memories = torch.cat(memories, dim=-1) | ||
return encoding, memories | ||
|
||
|
||
|
@@ -127,6 +128,10 @@ def __init__( | |
encoding_size = network_settings.hidden_units | ||
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) | ||
|
||
@property | ||
def memory_size(self) -> int: | ||
return self.network_body.memory_size | ||
|
||
def forward( | ||
self, | ||
vec_inputs: List[torch.Tensor], | ||
|
@@ -237,6 +242,14 @@ def get_dist_and_value( | |
""" | ||
pass | ||
|
||
@abc.abstractproperty | ||
def memory_size(self): | ||
""" | ||
Returns the size of the memory (same size used as input and output in the other | ||
methods) used by this Actor. | ||
""" | ||
pass | ||
|
||
|
||
class SimpleActor(nn.Module, Actor): | ||
def __init__( | ||
|
@@ -252,7 +265,6 @@ def __init__( | |
self.act_type = act_type | ||
self.act_size = act_size | ||
self.version_number = torch.nn.Parameter(torch.Tensor([2.0])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have we decided on using Parameter with require_grad or register_buffer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd go with register_buffer since it's the "PyTorch way". With that said, this one is there purely for ONNX export - would that cause an issue? Maybe we should add a requires_grad = False for this one. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's possible. Make it requires_grad=False is good enough then. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think these need to be saved at all actually, since they're not affected by the training process - maybe we can just set them to bare Tensors? |
||
self.memory_size = torch.nn.Parameter(torch.Tensor([0])) | ||
self.is_continuous_int = torch.nn.Parameter( | ||
torch.Tensor([int(act_type == ActionType.CONTINUOUS)]) | ||
) | ||
|
@@ -262,6 +274,7 @@ def __init__( | |
self.encoding_size = network_settings.memory.memory_size // 2 | ||
else: | ||
self.encoding_size = network_settings.hidden_units | ||
|
||
if self.act_type == ActionType.CONTINUOUS: | ||
self.distribution = GaussianDistribution( | ||
self.encoding_size, | ||
|
@@ -274,6 +287,10 @@ def __init__( | |
self.encoding_size, act_size | ||
) | ||
|
||
@property | ||
def memory_size(self) -> int: | ||
return self.network_body.memory_size | ||
|
||
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: | ||
self.network_body.update_normalization(vector_obs) | ||
|
||
|
@@ -326,7 +343,7 @@ def forward( | |
sampled_actions, | ||
log_probs, | ||
self.version_number, | ||
self.memory_size, | ||
torch.Tensor([self.network_body.memory_size]), | ||
self.is_continuous_int, | ||
self.act_size_vector, | ||
) | ||
|
@@ -400,29 +417,20 @@ def __init__( | |
# Give the Actor only half the memories. Note we previously validate | ||
# that memory_size must be a multiple of 4. | ||
self.use_lstm = network_settings.memory is not None | ||
if network_settings.memory is not None: | ||
self.half_mem_size = network_settings.memory.memory_size // 2 | ||
new_memory_settings = attr.evolve( | ||
network_settings.memory, memory_size=self.half_mem_size | ||
) | ||
use_network_settings = attr.evolve( | ||
network_settings, memory=new_memory_settings | ||
) | ||
else: | ||
use_network_settings = network_settings | ||
self.half_mem_size = 0 | ||
super().__init__( | ||
observation_shapes, | ||
use_network_settings, | ||
network_settings, | ||
act_type, | ||
act_size, | ||
conditional_sigma, | ||
tanh_squash, | ||
) | ||
self.stream_names = stream_names | ||
self.critic = ValueNetwork( | ||
stream_names, observation_shapes, use_network_settings | ||
) | ||
self.critic = ValueNetwork(stream_names, observation_shapes, network_settings) | ||
|
||
@property | ||
def memory_size(self) -> int: | ||
return self.network_body.memory_size + self.critic.memory_size | ||
|
||
def critic_pass( | ||
self, | ||
|
@@ -434,7 +442,7 @@ def critic_pass( | |
actor_mem, critic_mem = None, None | ||
if self.use_lstm: | ||
# Use only the back half of memories for critic | ||
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, -1) | ||
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) | ||
value_outputs, critic_mem_out = self.critic( | ||
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length | ||
) | ||
|
@@ -455,7 +463,7 @@ def get_dist_and_value( | |
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: | ||
if self.use_lstm: | ||
# Use only the back half of memories for critic and actor | ||
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, dim=-1) | ||
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) | ||
else: | ||
critic_mem = None | ||
actor_mem = None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if the MemoryModule interface is needed if LSTM is the only implementation of it. Do you think there will be more MemoryModules in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes 😛 OFC still experimental
Also things like Transformers, GRU, etc. Not sure how many of these things will make it to
master
, but I assume we're going to want this for experimentation.