-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tests for core PyTorch files #4292
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
a475151
Add test for some utils and fix bug
7f5b4a9
Fix PPO TF test
86a8a8e
Added some more utils tests
3729d0f
Fix typo in docstring
1f06280
Fix saving mean and std with normalizer
d93680c
Tests for encoders
59ba0b6
Visual encoder test
ed5c169
Add decoder test
0adca61
Merge branch 'develop-add-fire' into develop-add-fire-tests
f6135dc
Add check to create_encoders
d0bed89
Add typing to distributions
d228d4b
Tests for distributions
df2a29c
Address comments
503dd46
Add test for get_probs_and_entropy
594cb86
Rename test_distribution to test_distributions
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import pytest | ||
import torch | ||
|
||
from mlagents.trainers.torch.decoders import ValueHeads | ||
|
||
|
||
def test_valueheads(): | ||
stream_names = [f"reward_signal_{num}" for num in range(5)] | ||
input_size = 5 | ||
batch_size = 4 | ||
|
||
# Test default 1 value per head | ||
value_heads = ValueHeads(stream_names, input_size) | ||
input_data = torch.ones((batch_size, input_size)) | ||
value_out, _ = value_heads(input_data) # Note: mean value will be removed shortly | ||
|
||
for stream_name in stream_names: | ||
assert value_out[stream_name].shape == (batch_size,) | ||
|
||
# Test that inputting the wrong size input will throw an error | ||
with pytest.raises(Exception): | ||
value_out = value_heads(torch.ones((batch_size, input_size + 2))) | ||
|
||
# Test multiple values per head (e.g. discrete Q function) | ||
output_size = 4 | ||
value_heads = ValueHeads(stream_names, input_size, output_size) | ||
input_data = torch.ones((batch_size, input_size)) | ||
value_out, _ = value_heads(input_data) | ||
|
||
for stream_name in stream_names: | ||
assert value_out[stream_name].shape == (batch_size, output_size) |
141 changes: 141 additions & 0 deletions
141
ml-agents/mlagents/trainers/tests/torch/test_distributions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import pytest | ||
import torch | ||
|
||
from mlagents.trainers.torch.distributions import ( | ||
GaussianDistribution, | ||
MultiCategoricalDistribution, | ||
GaussianDistInstance, | ||
TanhGaussianDistInstance, | ||
CategoricalDistInstance, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("tanh_squash", [True, False]) | ||
@pytest.mark.parametrize("conditional_sigma", [True, False]) | ||
def test_gaussian_distribution(conditional_sigma, tanh_squash): | ||
torch.manual_seed(0) | ||
hidden_size = 16 | ||
act_size = 4 | ||
sample_embedding = torch.ones((1, 16)) | ||
gauss_dist = GaussianDistribution( | ||
hidden_size, | ||
act_size, | ||
conditional_sigma=conditional_sigma, | ||
tanh_squash=tanh_squash, | ||
) | ||
|
||
# Make sure backprop works | ||
force_action = torch.zeros((1, act_size)) | ||
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3) | ||
|
||
for _ in range(50): | ||
dist_inst = gauss_dist(sample_embedding)[0] | ||
if tanh_squash: | ||
assert isinstance(dist_inst, TanhGaussianDistInstance) | ||
else: | ||
assert isinstance(dist_inst, GaussianDistInstance) | ||
log_prob = dist_inst.log_prob(force_action) | ||
loss = torch.nn.functional.mse_loss(log_prob, -2 * torch.ones(log_prob.shape)) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
for prob in log_prob.flatten(): | ||
assert prob == pytest.approx(-2, abs=0.1) | ||
|
||
|
||
def test_multi_categorical_distribution(): | ||
torch.manual_seed(0) | ||
hidden_size = 16 | ||
act_size = [3, 3, 4] | ||
sample_embedding = torch.ones((1, 16)) | ||
gauss_dist = MultiCategoricalDistribution(hidden_size, act_size) | ||
|
||
# Make sure backprop works | ||
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3) | ||
|
||
def create_test_prob(size: int) -> torch.Tensor: | ||
test_prob = torch.tensor( | ||
[[1.0 - 0.01 * (size - 1)] + [0.01] * (size - 1)] | ||
) # High prob for first action | ||
return test_prob.log() | ||
|
||
for _ in range(100): | ||
dist_insts = gauss_dist(sample_embedding, masks=torch.ones((1, sum(act_size)))) | ||
loss = 0 | ||
for i, dist_inst in enumerate(dist_insts): | ||
assert isinstance(dist_inst, CategoricalDistInstance) | ||
log_prob = dist_inst.all_log_prob() | ||
test_log_prob = create_test_prob(act_size[i]) | ||
# Force log_probs to match the high probability for the first action generated by | ||
# create_test_prob | ||
loss += torch.nn.functional.mse_loss(log_prob, test_log_prob) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
for dist_inst, size in zip(dist_insts, act_size): | ||
# Check that the log probs are close to the fake ones that we generated. | ||
test_log_probs = create_test_prob(size) | ||
for _prob, _test_prob in zip( | ||
dist_inst.all_log_prob().flatten().tolist(), | ||
test_log_probs.flatten().tolist(), | ||
): | ||
assert _prob == pytest.approx(_test_prob, abs=0.1) | ||
|
||
# Test masks | ||
masks = [] | ||
for branch in act_size: | ||
masks += [0] * (branch - 1) + [1] | ||
masks = torch.tensor([masks]) | ||
dist_insts = gauss_dist(sample_embedding, masks=masks) | ||
for dist_inst in dist_insts: | ||
log_prob = dist_inst.all_log_prob() | ||
assert log_prob.flatten()[-1] == pytest.approx(0, abs=0.001) | ||
|
||
|
||
def test_gaussian_dist_instance(): | ||
torch.manual_seed(0) | ||
act_size = 4 | ||
dist_instance = GaussianDistInstance( | ||
torch.zeros(1, act_size), torch.ones(1, act_size) | ||
) | ||
action = dist_instance.sample() | ||
assert action.shape == (1, act_size) | ||
for log_prob in dist_instance.log_prob(torch.zeros((1, act_size))).flatten(): | ||
# Log prob of standard normal at 0 | ||
assert log_prob == pytest.approx(-0.919, abs=0.01) | ||
|
||
for ent in dist_instance.entropy().flatten(): | ||
# entropy of standard normal at 0 | ||
assert ent == pytest.approx(2.83, abs=0.01) | ||
|
||
|
||
def test_tanh_gaussian_dist_instance(): | ||
torch.manual_seed(0) | ||
act_size = 4 | ||
dist_instance = GaussianDistInstance( | ||
torch.zeros(1, act_size), torch.ones(1, act_size) | ||
) | ||
for _ in range(10): | ||
action = dist_instance.sample() | ||
assert action.shape == (1, act_size) | ||
assert torch.max(action) < 1.0 and torch.min(action) > -1.0 | ||
|
||
|
||
def test_categorical_dist_instance(): | ||
torch.manual_seed(0) | ||
act_size = 4 | ||
test_prob = torch.tensor( | ||
[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1) | ||
) # High prob for first action | ||
dist_instance = CategoricalDistInstance(test_prob) | ||
|
||
for _ in range(10): | ||
action = dist_instance.sample() | ||
assert action.shape == (1,) | ||
assert action < act_size | ||
|
||
# Make sure the first action as higher probability than the others. | ||
prob_first_action = dist_instance.log_prob(torch.tensor([0])) | ||
|
||
for i in range(1, act_size): | ||
assert dist_instance.log_prob(torch.tensor([i])) < prob_first_action |
110 changes: 110 additions & 0 deletions
110
ml-agents/mlagents/trainers/tests/torch/test_encoders.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import torch | ||
from unittest import mock | ||
import pytest | ||
|
||
from mlagents.trainers.torch.encoders import ( | ||
VectorEncoder, | ||
VectorAndUnnormalizedInputEncoder, | ||
Normalizer, | ||
SimpleVisualEncoder, | ||
ResNetVisualEncoder, | ||
NatureVisualEncoder, | ||
) | ||
|
||
|
||
# This test will also reveal issues with states not being saved in the state_dict. | ||
def compare_models(module_1, module_2): | ||
is_same = True | ||
for key_item_1, key_item_2 in zip( | ||
module_1.state_dict().items(), module_2.state_dict().items() | ||
): | ||
# Compare tensors in state_dict and not the keys. | ||
is_same = torch.equal(key_item_1[1], key_item_2[1]) and is_same | ||
return is_same | ||
|
||
|
||
def test_normalizer(): | ||
input_size = 2 | ||
norm = Normalizer(input_size) | ||
|
||
# These three inputs should mean to 0.5, and variance 2 | ||
# with the steps starting at 1 | ||
vec_input1 = torch.tensor([[1, 1]]) | ||
vec_input2 = torch.tensor([[1, 1]]) | ||
vec_input3 = torch.tensor([[0, 0]]) | ||
norm.update(vec_input1) | ||
norm.update(vec_input2) | ||
norm.update(vec_input3) | ||
|
||
# Test normalization | ||
for val in norm(vec_input1)[0]: | ||
assert val == pytest.approx(0.707, abs=0.001) | ||
|
||
# Test copy normalization | ||
norm2 = Normalizer(input_size) | ||
assert not compare_models(norm, norm2) | ||
norm2.copy_from(norm) | ||
assert compare_models(norm, norm2) | ||
for val in norm2(vec_input1)[0]: | ||
assert val == pytest.approx(0.707, abs=0.001) | ||
|
||
|
||
@mock.patch("mlagents.trainers.torch.encoders.Normalizer") | ||
def test_vector_encoder(mock_normalizer): | ||
mock_normalizer_inst = mock.Mock() | ||
mock_normalizer.return_value = mock_normalizer_inst | ||
input_size = 64 | ||
hidden_size = 128 | ||
num_layers = 3 | ||
normalize = False | ||
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize) | ||
output = vector_encoder(torch.ones((1, input_size))) | ||
assert output.shape == (1, hidden_size) | ||
|
||
normalize = True | ||
vector_encoder = VectorEncoder(input_size, hidden_size, num_layers, normalize) | ||
new_vec = torch.ones((1, input_size)) | ||
vector_encoder.update_normalization(new_vec) | ||
|
||
mock_normalizer.assert_called_with(input_size) | ||
mock_normalizer_inst.update.assert_called_with(new_vec) | ||
|
||
vector_encoder2 = VectorEncoder(input_size, hidden_size, num_layers, normalize) | ||
vector_encoder.copy_normalization(vector_encoder2) | ||
mock_normalizer_inst.copy_from.assert_called_with(mock_normalizer_inst) | ||
|
||
|
||
@mock.patch("mlagents.trainers.torch.encoders.Normalizer") | ||
def test_vector_and_unnormalized_encoder(mock_normalizer): | ||
mock_normalizer_inst = mock.Mock() | ||
mock_normalizer.return_value = mock_normalizer_inst | ||
input_size = 64 | ||
unnormalized_size = 32 | ||
hidden_size = 128 | ||
num_layers = 3 | ||
normalize = True | ||
mock_normalizer_inst.return_value = torch.ones((1, input_size)) | ||
vector_encoder = VectorAndUnnormalizedInputEncoder( | ||
input_size, hidden_size, unnormalized_size, num_layers, normalize | ||
) | ||
# Make sure normalizer is only called on input_size | ||
mock_normalizer.assert_called_with(input_size) | ||
normal_input = torch.ones((1, input_size)) | ||
|
||
unnormalized_input = torch.ones((1, 32)) | ||
output = vector_encoder(normal_input, unnormalized_input) | ||
mock_normalizer_inst.assert_called_with(normal_input) | ||
assert output.shape == (1, hidden_size) | ||
|
||
|
||
@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)]) | ||
@pytest.mark.parametrize( | ||
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder] | ||
) | ||
def test_visual_encoder(vis_class, image_size): | ||
num_outputs = 128 | ||
enc = vis_class(image_size[0], image_size[1], image_size[2], num_outputs) | ||
# Note: NCHW not NHWC | ||
sample_input = torch.ones((1, image_size[2], image_size[0], image_size[1])) | ||
encoding = enc(sample_input) | ||
assert encoding.shape == (1, num_outputs) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why compare only the first item?
key_item_1[*1*]
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the ZIP,
key_item_X
is a Tuple of (key, value). So we're just comparing the values.