Skip to content

Commit

Permalink
Added more stable test.
Browse files Browse the repository at this point in the history
  • Loading branch information
cmard committed Nov 15, 2021
1 parent 98da4b1 commit 604d7c1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
29 changes: 21 additions & 8 deletions ml-agents/mlagents/trainers/tests/torch/test_action_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def create_action_model(inp_size, act_size, deterministic=False):
mask = torch.ones([1, act_size * 2])
mask = torch.ones([1, act_size ** 2])
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
action_model = ActionModel(inp_size, action_spec, deterministic=deterministic)
return action_model, mask
Expand Down Expand Up @@ -45,13 +45,14 @@ def test_sample_action():

def test_deterministic_sample_action():
inp_size = 4
act_size = 2
act_size = 8
action_model, masks = create_action_model(inp_size, act_size, deterministic=True)
sample_inp = torch.ones((1, inp_size))
dists = action_model._get_dists(sample_inp, masks=masks)
agent_action1 = action_model._sample_action(dists)
agent_action2 = action_model._sample_action(dists)
agent_action3 = action_model._sample_action(dists)

assert torch.equal(agent_action1.continuous_tensor, agent_action2.continuous_tensor)
assert torch.equal(agent_action1.continuous_tensor, agent_action3.continuous_tensor)
assert torch.equal(agent_action1.discrete_tensor, agent_action2.discrete_tensor)
Expand All @@ -63,14 +64,26 @@ def test_deterministic_sample_action():
agent_action1 = action_model._sample_action(dists)
agent_action2 = action_model._sample_action(dists)
agent_action3 = action_model._sample_action(dists)
assert not torch.equal(

chance_counter = 0

if not torch.equal(
agent_action1.continuous_tensor, agent_action2.continuous_tensor
)
assert not torch.equal(
):
chance_counter += 1

if not torch.equal(
agent_action1.continuous_tensor, agent_action3.continuous_tensor
)
assert not torch.equal(agent_action1.discrete_tensor, agent_action2.discrete_tensor)
assert not torch.equal(agent_action1.discrete_tensor, agent_action3.discrete_tensor)
):
chance_counter += 1

assert chance_counter > 1
chance_counter = 0
if not torch.equal(agent_action1.discrete_tensor, agent_action2.discrete_tensor):
chance_counter += 1
if not torch.equal(agent_action1.discrete_tensor, agent_action3.discrete_tensor):
chance_counter += 1
assert chance_counter > 1


def test_get_probs_and_entropy():
Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/torch/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def _mask_branch(
# We do -1 * tensor + constant instead of constant - tensor because it seems
# Barracuda might swap the inputs of a "Sub" operation
logits = logits * allow_mask - 1e8 * block_mask

return logits

def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]:
Expand Down

0 comments on commit 604d7c1

Please sign in to comment.