Skip to content

Commit

Permalink
[add-fire] Add tests and fix issues with Policy (#4372)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ervin T authored Aug 18, 2020
1 parent c3fae3a commit 71e7b17
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 18 deletions.
46 changes: 28 additions & 18 deletions ml-agents/mlagents/trainers/policy/torch_policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Tuple, Optional
import numpy as np
import torch

Expand Down Expand Up @@ -82,7 +82,9 @@ def __init__(

self.actor_critic.to("cpu")

def split_decision_step(self, decision_requests):
def _split_decision_step(
self, decision_requests: DecisionSteps
) -> Tuple[SplitObservations, np.ndarray]:
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)
mask = None
if not self.use_continuous_act:
Expand All @@ -91,7 +93,7 @@ def split_decision_step(self, decision_requests):
mask = torch.as_tensor(
1 - np.concatenate(decision_requests.action_mask, axis=1)
)
return vec_vis_obs.vector_observations, vec_vis_obs.visual_observations, mask
return vec_vis_obs, mask

def update_normalization(self, vector_obs: np.ndarray) -> None:
"""
Expand All @@ -105,13 +107,15 @@ def update_normalization(self, vector_obs: np.ndarray) -> None:
@timed
def sample_actions(
self,
vec_obs,
vis_obs,
masks=None,
memories=None,
seq_len=1,
all_log_probs=False,
):
vec_obs: List[torch.Tensor],
vis_obs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
all_log_probs: bool = False,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
"""
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
"""
Expand All @@ -137,14 +141,18 @@ def sample_actions(
)

def evaluate_actions(
self, vec_obs, vis_obs, actions, masks=None, memories=None, seq_len=1
):
self,
vec_obs: torch.Tensor,
vis_obs: torch.Tensor,
actions: torch.Tensor,
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
dists, value_heads, _ = self.actor_critic.get_dist_and_value(
vec_obs, vis_obs, masks, memories, seq_len
)
if len(actions.shape) <= 2:
actions = actions.unsqueeze(-1)
action_list = [actions[..., i] for i in range(actions.shape[2])]
action_list = [actions[..., i] for i in range(actions.shape[-1])]
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_list, dists)

return log_probs, entropies, value_heads
Expand All @@ -159,9 +167,11 @@ def evaluate(
:param decision_requests: DecisionStep object containing inputs.
:return: Outputs from network as defined by self.inference_dict.
"""
vec_obs, vis_obs, masks = self.split_decision_step(decision_requests)
vec_obs = [torch.as_tensor(vec_obs)]
vis_obs = [torch.as_tensor(vis_ob) for vis_ob in vis_obs]
vec_vis_obs, masks = self._split_decision_step(decision_requests)
vec_obs = [torch.as_tensor(vec_vis_obs.vector_observations)]
vis_obs = [
torch.as_tensor(vis_ob) for vis_ob in vec_vis_obs.visual_observations
]
memories = torch.as_tensor(self.retrieve_memories(global_agent_ids)).unsqueeze(
0
)
Expand Down
150 changes: 150 additions & 0 deletions ml-agents/mlagents/trainers/tests/torch/test_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import pytest

import torch
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.settings import TrainerSettings, NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils

VECTOR_ACTION_SPACE = 2
VECTOR_OBS_SPACE = 8
DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 32
NUM_AGENTS = 12
EPSILON = 1e-7


def create_policy_mock(
dummy_config: TrainerSettings,
use_rnn: bool = False,
use_discrete: bool = True,
use_visual: bool = False,
seed: int = 0,
) -> TorchPolicy:
mock_spec = mb.setup_test_behavior_specs(
use_discrete,
use_visual,
vector_action_space=DISCRETE_ACTION_SPACE
if use_discrete
else VECTOR_ACTION_SPACE,
vector_obs_space=VECTOR_OBS_SPACE,
)

trainer_settings = dummy_config
trainer_settings.keep_checkpoints = 3
trainer_settings.network_settings.memory = (
NetworkSettings.MemorySettings() if use_rnn else None
)
policy = TorchPolicy(seed, mock_spec, trainer_settings)
return policy


@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_policy_evaluate(rnn, visual, discrete):
# Test evaluate
policy = create_policy_mock(
TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
decision_step, terminal_step = mb.create_steps_from_behavior_spec(
policy.behavior_spec, num_agents=NUM_AGENTS
)

run_out = policy.evaluate(decision_step, list(decision_step.agent_id))
if discrete:
run_out["action"].shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE))
else:
assert run_out["action"].shape == (NUM_AGENTS, VECTOR_ACTION_SPACE)


@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_evaluate_actions(rnn, visual, discrete):
policy = create_policy_mock(
TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size)
vec_obs = [ModelUtils.list_to_tensor(buffer["vector_obs"])]
act_masks = ModelUtils.list_to_tensor(buffer["action_mask"])
if policy.use_continuous_act:
actions = ModelUtils.list_to_tensor(buffer["actions"]).unsqueeze(-1)
else:
actions = ModelUtils.list_to_tensor(buffer["actions"], dtype=torch.long)
vis_obs = []
for idx, _ in enumerate(policy.actor_critic.network_body.visual_encoders):
vis_ob = ModelUtils.list_to_tensor(buffer["visual_obs%d" % idx])
vis_obs.append(vis_ob)

memories = [
ModelUtils.list_to_tensor(buffer["memory"][i])
for i in range(0, len(buffer["memory"]), policy.sequence_length)
]
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)

log_probs, entropy, values = policy.evaluate_actions(
vec_obs,
vis_obs,
masks=act_masks,
actions=actions,
memories=memories,
seq_len=policy.sequence_length,
)
assert log_probs.shape == (64, policy.behavior_spec.action_size)
assert entropy.shape == (64, policy.behavior_spec.action_size)
for val in values.values():
assert val.shape == (64,)


@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_sample_actions(rnn, visual, discrete):
policy = create_policy_mock(
TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size)
vec_obs = [ModelUtils.list_to_tensor(buffer["vector_obs"])]
act_masks = ModelUtils.list_to_tensor(buffer["action_mask"])

vis_obs = []
for idx, _ in enumerate(policy.actor_critic.network_body.visual_encoders):
vis_ob = ModelUtils.list_to_tensor(buffer["visual_obs%d" % idx])
vis_obs.append(vis_ob)

memories = [
ModelUtils.list_to_tensor(buffer["memory"][i])
for i in range(0, len(buffer["memory"]), policy.sequence_length)
]
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)

(
sampled_actions,
log_probs,
entropies,
sampled_values,
memories,
) = policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,
memories=memories,
seq_len=policy.sequence_length,
all_log_probs=not policy.use_continuous_act,
)
if discrete:
assert log_probs.shape == (
64,
sum(policy.behavior_spec.discrete_action_branches),
)
else:
assert log_probs.shape == (64, policy.behavior_spec.action_shape)
assert entropies.shape == (64, policy.behavior_spec.action_size)
for val in sampled_values.values():
assert val.shape == (64,)

if rnn:
assert memories.shape == (1, 1, policy.m_size)

0 comments on commit 71e7b17

Please sign in to comment.