Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[add-fire] Memory class abstraction #4375

Merged
merged 31 commits into from
Aug 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d523c8c
Running LSTM for SAC
Aug 7, 2020
f2873b2
Use correct half of memories
Aug 10, 2020
b97b1e5
Fix policy memory storinig
Aug 11, 2020
cd509dd
Fix SeparateActorCritic and add test
Aug 11, 2020
c66ecba
Merge branch 'develop-add-fire' into develop-add-fire-sac-lst
Aug 12, 2020
07bb4c0
Use loss masks in PPO.
Aug 12, 2020
0a3c795
Proper shape of masks
Aug 12, 2020
2337d15
Proper mask mean for PPO
Aug 12, 2020
1f69102
Fix dtype for actions
Aug 12, 2020
c0a77f7
Proper initialization and SAC masking
Aug 12, 2020
f404834
Experimental amrl layer
Aug 12, 2020
beab310
Add extra FF layer
Aug 12, 2020
6fece65
Faster implementation
Aug 12, 2020
eac1dc9
Add comment
Aug 13, 2020
d2e31aa
Passthrough max
Aug 13, 2020
37feeee
Merge branch 'develop-add-fire' into develop-add-fire-amrl
Aug 14, 2020
bf485a2
Memory size abstraction and fixes
Aug 14, 2020
bd90e29
Fix SeparateActorCritic
Aug 14, 2020
7f4ea51
Fix SeparateActorCritic
Aug 14, 2020
317454a
LSTM class
Aug 18, 2020
4845ed0
Merge branch 'develop-add-fire' into develop-add-fire-memoryclass
Aug 18, 2020
848a875
Fix SeparateActorCritic export
Aug 18, 2020
295094b
Add abstract method to Actor
Aug 18, 2020
72bca86
Fix BC module
Aug 18, 2020
683001f
Remove some comments
Aug 18, 2020
77a47f4
Fix network tests
Aug 18, 2020
35de032
Merge branch 'develop-add-fire' into develop-add-fire-memoryclass
Aug 18, 2020
7dba7bf
Clean up memory_size logic
Aug 18, 2020
1d6a08c
Cleanup, add test
Aug 18, 2020
9bb065c
Properly export memory size
Aug 18, 2020
6eb09d3
Fix exporting again
Aug 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions ml-agents/mlagents/trainers/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,21 @@ def __init__(
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
# Save the m_size needed for export
self._export_m_size = self.m_size
# m_size needed for training is determined by network, not trainer settings
self.m_size = self.actor_critic.memory_size

self.actor_critic.to("cpu")

@property
def export_memory_size(self) -> int:
"""
Returns the memory size of the exported ONNX policy. This only includes the memory
of the Actor and not any auxillary networks.
"""
return self._export_m_size

def _split_decision_step(
self, decision_requests: DecisionSteps
) -> Tuple[SplitObservations, np.ndarray]:
Expand Down
19 changes: 19 additions & 0 deletions ml-agents/mlagents/trainers/tests/torch/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
linear_layer,
lstm_layer,
Initialization,
LSTM,
)


Expand Down Expand Up @@ -38,3 +39,21 @@ def test_lstm_layer():
assert torch.all(
torch.eq(param.data[4:8], torch.ones_like(param.data[4:8]))
)


def test_lstm_class():
torch.manual_seed(0)
input_size = 12
memory_size = 64
batch_size = 8
seq_len = 16
lstm = LSTM(input_size, memory_size)

assert lstm.memory_size == memory_size

sample_input = torch.ones((batch_size, seq_len, input_size))
sample_memories = torch.ones((1, batch_size, memory_size))
out, mem = lstm(sample_input, sample_memories)
# Hidden size should be half of memory_size
assert out.shape == (batch_size, seq_len, memory_size // 2)
assert mem.shape == (1, batch_size, memory_size)
6 changes: 1 addition & 5 deletions ml-agents/mlagents/trainers/tests/torch/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,7 @@ def test_actor_critic(ac_type, lstm):
if lstm:
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))
memories = torch.ones(
(
1,
network_settings.memory.sequence_length,
network_settings.memory.memory_size,
)
(1, network_settings.memory.sequence_length, actor.memory_size)
)
else:
sample_obs = torch.ones((1, obs_size))
Expand Down
4 changes: 1 addition & 3 deletions ml-agents/mlagents/trainers/torch/components/bc/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,7 @@ def _update_batch(

memories = []
if self.policy.use_recurrent:
memories = torch.zeros(
1, self.n_sequences, self.policy.actor_critic.half_mem_size * 2
)
memories = torch.zeros(1, self.n_sequences, self.policy.m_size)

if self.policy.use_vis_obs:
vis_obs = []
Expand Down
65 changes: 65 additions & 0 deletions ml-agents/mlagents/trainers/torch/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import torch
import abc
from typing import Tuple
from enum import Enum


Expand Down Expand Up @@ -82,3 +84,66 @@ def lstm_layer(
forget_bias
)
return lstm


class MemoryModule(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if the MemoryModule interface is needed if LSTM is the only implementation of it. Do you think there will be more MemoryModules in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes 😛 OFC still experimental

Also things like Transformers, GRU, etc. Not sure how many of these things will make it to master, but I assume we're going to want this for experimentation.

@abc.abstractproperty
def memory_size(self) -> int:
"""
Size of memory that is required at the start of a sequence.
"""
pass

@abc.abstractmethod
def forward(
self, input_tensor: torch.Tensor, memories: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Pass a sequence to the memory module.
:input_tensor: Tensor of shape (batch_size, seq_length, size) that represents the input.
:memories: Tensor of initial memories.
:return: Tuple of output, final memories.
"""
pass


class LSTM(MemoryModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a test specific for this Module. (unless it is already tested enough in the network_tests)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's tested in both the network tests and the lstm_layer is tested in the layers test, but I can add another test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

"""
Memory module that implements LSTM.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more details in this comment about the division/multiplication by 2 and why we do it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

"""

def __init__(
self,
input_size: int,
memory_size: int,
num_layers: int = 1,
forget_bias: float = 1.0,
kernel_init: Initialization = Initialization.XavierGlorotUniform,
bias_init: Initialization = Initialization.Zero,
):
super().__init__()
# We set hidden size to half of memory_size since the initial memory
# will be divided between the hidden state and initial cell state.
self.hidden_size = memory_size // 2
self.lstm = lstm_layer(
input_size,
self.hidden_size,
num_layers,
True,
forget_bias,
kernel_init,
bias_init,
)

@property
def memory_size(self) -> int:
return 2 * self.hidden_size

def forward(
self, input_tensor: torch.Tensor, memories: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
h0, c0 = torch.split(memories, self.hidden_size, dim=-1)
hidden = (h0, c0)
lstm_out, hidden_out = self.lstm(input_tensor, hidden)
output_mem = torch.cat(hidden_out, dim=-1)
return lstm_out, output_mem
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/torch/model_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, policy):
else []
)
dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)])
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size])
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.export_memory_size])

# Need to pass all posslible inputs since currently keyword arguments is not
# supported by torch.nn.export()
Expand Down
58 changes: 33 additions & 25 deletions ml-agents/mlagents/trainers/torch/networks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Callable, List, Dict, Tuple, Optional
import attr
import abc

import torch
Expand All @@ -14,7 +13,7 @@
from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import lstm_layer
from mlagents.trainers.torch.layers import LSTM

ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[
Expand Down Expand Up @@ -51,9 +50,9 @@ def __init__(
)

if self.use_lstm:
self.lstm = lstm_layer(self.h_size, self.m_size // 2, batch_first=True)
self.lstm = LSTM(self.h_size, self.m_size)
else:
self.lstm = None
self.lstm = None # type: ignore

def update_normalization(self, vec_inputs: List[torch.Tensor]) -> None:
for vec_input, vec_enc in zip(vec_inputs, self.vector_encoders):
Expand All @@ -64,6 +63,10 @@ def copy_normalization(self, other_network: "NetworkBody") -> None:
for n1, n2 in zip(self.vector_encoders, other_network.vector_encoders):
n1.copy_normalization(n2)

@property
def memory_size(self) -> int:
return self.lstm.memory_size if self.use_lstm else 0

def forward(
self,
vec_inputs: List[torch.Tensor],
Expand Down Expand Up @@ -99,10 +102,8 @@ def forward(
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)
encoding = encoding.reshape([-1, sequence_length, self.h_size])
memories = torch.split(memories, self.m_size // 2, dim=-1)
encoding, memories = self.lstm(encoding, memories)
encoding = encoding.reshape([-1, self.m_size // 2])
memories = torch.cat(memories, dim=-1)
return encoding, memories


Expand All @@ -127,6 +128,10 @@ def __init__(
encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)

@property
def memory_size(self) -> int:
return self.network_body.memory_size

def forward(
self,
vec_inputs: List[torch.Tensor],
Expand Down Expand Up @@ -237,6 +242,14 @@ def get_dist_and_value(
"""
pass

@abc.abstractproperty
def memory_size(self):
"""
Returns the size of the memory (same size used as input and output in the other
methods) used by this Actor.
"""
pass


class SimpleActor(nn.Module, Actor):
def __init__(
Expand All @@ -252,7 +265,6 @@ def __init__(
self.act_type = act_type
self.act_size = act_size
self.version_number = torch.nn.Parameter(torch.Tensor([2.0]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we decided on using Parameter with require_grad or register_buffer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd go with register_buffer since it's the "PyTorch way". With that said, this one is there purely for ONNX export - would that cause an issue? Maybe we should add a requires_grad = False for this one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's possible. Make it requires_grad=False is good enough then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these need to be saved at all actually, since they're not affected by the training process - maybe we can just set them to bare Tensors?

self.memory_size = torch.nn.Parameter(torch.Tensor([0]))
self.is_continuous_int = torch.nn.Parameter(
torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
)
Expand All @@ -262,6 +274,7 @@ def __init__(
self.encoding_size = network_settings.memory.memory_size // 2
else:
self.encoding_size = network_settings.hidden_units

if self.act_type == ActionType.CONTINUOUS:
self.distribution = GaussianDistribution(
self.encoding_size,
Expand All @@ -274,6 +287,10 @@ def __init__(
self.encoding_size, act_size
)

@property
def memory_size(self) -> int:
return self.network_body.memory_size

def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
self.network_body.update_normalization(vector_obs)

Expand Down Expand Up @@ -326,7 +343,7 @@ def forward(
sampled_actions,
log_probs,
self.version_number,
self.memory_size,
torch.Tensor([self.network_body.memory_size]),
self.is_continuous_int,
self.act_size_vector,
)
Expand Down Expand Up @@ -400,29 +417,20 @@ def __init__(
# Give the Actor only half the memories. Note we previously validate
# that memory_size must be a multiple of 4.
self.use_lstm = network_settings.memory is not None
if network_settings.memory is not None:
self.half_mem_size = network_settings.memory.memory_size // 2
new_memory_settings = attr.evolve(
network_settings.memory, memory_size=self.half_mem_size
)
use_network_settings = attr.evolve(
network_settings, memory=new_memory_settings
)
else:
use_network_settings = network_settings
self.half_mem_size = 0
super().__init__(
observation_shapes,
use_network_settings,
network_settings,
act_type,
act_size,
conditional_sigma,
tanh_squash,
)
self.stream_names = stream_names
self.critic = ValueNetwork(
stream_names, observation_shapes, use_network_settings
)
self.critic = ValueNetwork(stream_names, observation_shapes, network_settings)

@property
def memory_size(self) -> int:
return self.network_body.memory_size + self.critic.memory_size

def critic_pass(
self,
Expand All @@ -434,7 +442,7 @@ def critic_pass(
actor_mem, critic_mem = None, None
if self.use_lstm:
# Use only the back half of memories for critic
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, -1)
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1)
value_outputs, critic_mem_out = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)
Expand All @@ -455,7 +463,7 @@ def get_dist_and_value(
) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
if self.use_lstm:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.half_mem_size, dim=-1)
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
else:
critic_mem = None
actor_mem = None
Expand Down