diff --git a/ml-agents-envs/mlagents_envs/base_env.py b/ml-agents-envs/mlagents_envs/base_env.py index e4df43c182..fe2c159b4a 100644 --- a/ml-agents-envs/mlagents_envs/base_env.py +++ b/ml-agents-envs/mlagents_envs/base_env.py @@ -244,6 +244,43 @@ def empty(spec: "BehaviorSpec") -> "TerminalSteps": ) +class ActionTuple: + """ + An object whose fields correspond to actions of different types. + Continuous and discrete actions are numpy arrays of type float32 and + int32, respectively and are type checked on construction. + Dimensions are of (n_agents, continuous_size) and (n_agents, discrete_size), + respectively. + """ + + def __init__(self, continuous: np.ndarray, discrete: np.ndarray): + if continuous.dtype != np.float32: + continuous = continuous.astype(np.float32, copy=False) + self._continuous = continuous + + if discrete.dtype != np.int32: + discrete = discrete.astype(np.int32, copy=False) + self._discrete = discrete + + @property + def continuous(self) -> np.ndarray: + return self._continuous + + @property + def discrete(self) -> np.ndarray: + return self._discrete + + @staticmethod + def create_continuous(continuous: np.ndarray) -> "ActionTuple": + discrete = np.zeros((continuous.shape[0], 0), dtype=np.int32) + return ActionTuple(continuous, discrete) + + @staticmethod + def create_discrete(discrete: np.ndarray) -> "ActionTuple": + continuous = np.zeros((discrete.shape[0], 0), dtype=np.float32) + return ActionTuple(continuous, discrete) + + class ActionSpec(NamedTuple): """ A NamedTuple containing utility functions and information about the action spaces @@ -287,62 +324,61 @@ def discrete_size(self) -> int: """ return len(self.discrete_branches) - def empty_action(self, n_agents: int) -> np.ndarray: + def empty_action(self, n_agents: int) -> ActionTuple: """ - Generates a numpy array corresponding to an empty action (all zeros) + Generates ActionTuple corresponding to an empty action (all zeros) for a number of agents. :param n_agents: The number of agents that will have actions generated """ - if self.is_continuous(): - return np.zeros((n_agents, self.continuous_size), dtype=np.float32) - return np.zeros((n_agents, self.discrete_size), dtype=np.int32) + continuous = np.zeros((n_agents, self.continuous_size), dtype=np.float32) + discrete = np.zeros((n_agents, self.discrete_size), dtype=np.int32) + return ActionTuple(continuous, discrete) - def random_action(self, n_agents: int) -> np.ndarray: + def random_action(self, n_agents: int) -> ActionTuple: """ - Generates a numpy array corresponding to a random action (either discrete + Generates ActionTuple corresponding to a random action (either discrete or continuous) for a number of agents. :param n_agents: The number of agents that will have actions generated """ - if self.is_continuous(): - action = np.random.uniform( - low=-1.0, high=1.0, size=(n_agents, self.continuous_size) - ).astype(np.float32) - else: - branch_size = self.discrete_branches - action = np.column_stack( + continuous = np.random.uniform( + low=-1.0, high=1.0, size=(n_agents, self.continuous_size) + ) + discrete = np.zeros((n_agents, self.discrete_size), dtype=np.int32) + if self.discrete_size > 0: + discrete = np.column_stack( [ np.random.randint( 0, - branch_size[i], # type: ignore + self.discrete_branches[i], # type: ignore size=(n_agents), dtype=np.int32, ) for i in range(self.discrete_size) ] ) - return action + return ActionTuple(continuous, discrete) def _validate_action( - self, actions: np.ndarray, n_agents: int, name: str - ) -> np.ndarray: + self, actions: ActionTuple, n_agents: int, name: str + ) -> ActionTuple: """ Validates that action has the correct action dim for the correct number of agents and ensures the type. """ - if self.continuous_size > 0: - _size = self.continuous_size - else: - _size = self.discrete_size - _expected_shape = (n_agents, _size) - if actions.shape != _expected_shape: + _expected_shape = (n_agents, self.continuous_size) + if self.continuous_size > 0 and actions.continuous.shape != _expected_shape: + raise UnityActionException( + f"The behavior {name} needs a continuous input of dimension " + f"{_expected_shape} for (, ) but " + f"received input of dimension {actions.continuous.shape}" + ) + _expected_shape = (n_agents, self.discrete_size) + if self.discrete_size > 0 and actions.discrete.shape != _expected_shape: raise UnityActionException( - f"The behavior {name} needs an input of dimension " + f"The behavior {name} needs a discrete input of dimension " f"{_expected_shape} for (, ) but " - f"received input of dimension {actions.shape}" + f"received input of dimension {actions.discrete.shape}" ) - _expected_type = np.float32 if self.is_continuous() else np.int32 - if actions.dtype != _expected_type: - actions = actions.astype(_expected_type) return actions @staticmethod @@ -420,27 +456,30 @@ def behavior_specs(self) -> MappingType[str, BehaviorSpec]: """ @abstractmethod - def set_actions(self, behavior_name: BehaviorName, action: np.ndarray) -> None: + def set_actions(self, behavior_name: BehaviorName, action: ActionTuple) -> None: """ Sets the action for all of the agents in the simulation for the next step. The Actions must be in the same order as the order received in the DecisionSteps. :param behavior_name: The name of the behavior the agents are part of - :param action: A two dimensional np.ndarray corresponding to the action - (either int or float) + :param action: ActionTuple tuple of continuous and/or discrete action. + Actions are np.arrays with dimensions (n_agents, continuous_size) and + (n_agents, discrete_size), respectively. """ @abstractmethod def set_action_for_agent( - self, behavior_name: BehaviorName, agent_id: AgentId, action: np.ndarray + self, behavior_name: BehaviorName, agent_id: AgentId, action: ActionTuple ) -> None: """ Sets the action for one of the agents in the simulation for the next step. :param behavior_name: The name of the behavior the agent is part of :param agent_id: The id of the agent the action is set for - :param action: A one dimensional np.ndarray corresponding to the action - (either int or float) + :param action: ActionTuple tuple of continuous and/or discrete action + Actions are np.arrays with dimensions (1, continuous_size) and + (1, discrete_size), respectively. Note, this initial dimensions of 1 is because + this action is meant for a single agent. """ @abstractmethod diff --git a/ml-agents-envs/mlagents_envs/environment.py b/ml-agents-envs/mlagents_envs/environment.py index 63f21c3dc8..8ef6217449 100644 --- a/ml-agents-envs/mlagents_envs/environment.py +++ b/ml-agents-envs/mlagents_envs/environment.py @@ -18,6 +18,7 @@ DecisionSteps, TerminalSteps, BehaviorSpec, + ActionTuple, BehaviorName, AgentId, BehaviorMapping, @@ -236,7 +237,7 @@ def __init__( self._env_state: Dict[str, Tuple[DecisionSteps, TerminalSteps]] = {} self._env_specs: Dict[str, BehaviorSpec] = {} - self._env_actions: Dict[str, np.ndarray] = {} + self._env_actions: Dict[str, ActionTuple] = {} self._is_first_message = True self._update_behavior_specs(aca_output) @@ -336,7 +337,7 @@ def _assert_behavior_exists(self, behavior_name: str) -> None: f"agent group in the environment" ) - def set_actions(self, behavior_name: BehaviorName, action: np.ndarray) -> None: + def set_actions(self, behavior_name: BehaviorName, action: ActionTuple) -> None: self._assert_behavior_exists(behavior_name) if behavior_name not in self._env_state: return @@ -346,7 +347,7 @@ def set_actions(self, behavior_name: BehaviorName, action: np.ndarray) -> None: self._env_actions[behavior_name] = action def set_action_for_agent( - self, behavior_name: BehaviorName, agent_id: AgentId, action: np.ndarray + self, behavior_name: BehaviorName, agent_id: AgentId, action: ActionTuple ) -> None: self._assert_behavior_exists(behavior_name) if behavior_name not in self._env_state: @@ -366,7 +367,10 @@ def set_action_for_agent( agent_id ) ) from ie - self._env_actions[behavior_name][index] = action + if action_spec.continuous_size > 0: + self._env_actions[behavior_name].continuous[index] = action.continuous[0, :] + if action_spec.discrete_size > 0: + self._env_actions[behavior_name].discrete[index] = action.discrete[0, :] def get_steps( self, behavior_name: BehaviorName @@ -410,7 +414,7 @@ def _close(self, timeout: Optional[int] = None) -> None: @timed def _generate_step_input( - self, vector_action: Dict[str, np.ndarray] + self, vector_action: Dict[str, ActionTuple] ) -> UnityInputProto: rl_in = UnityRLInputProto() for b in vector_action: @@ -418,7 +422,12 @@ def _generate_step_input( if n_agents == 0: continue for i in range(n_agents): - action = AgentActionProto(vector_actions=vector_action[b][i]) + # TODO: This check will be removed when the oroto supports hybrid actions + if vector_action[b].continuous.shape[1] > 0: + _act = vector_action[b].continuous[i] + else: + _act = vector_action[b].discrete[i] + action = AgentActionProto(vector_actions=_act) rl_in.agent_actions[b].value.extend([action]) rl_in.command = STEP rl_in.side_channel = bytes( diff --git a/ml-agents-envs/mlagents_envs/tests/test_envs.py b/ml-agents-envs/mlagents_envs/tests/test_envs.py index 68fb34ea14..071c052d71 100755 --- a/ml-agents-envs/mlagents_envs/tests/test_envs.py +++ b/ml-agents-envs/mlagents_envs/tests/test_envs.py @@ -2,7 +2,7 @@ import pytest from mlagents_envs.environment import UnityEnvironment -from mlagents_envs.base_env import DecisionSteps, TerminalSteps +from mlagents_envs.base_env import DecisionSteps, TerminalSteps, ActionTuple from mlagents_envs.exception import UnityEnvironmentException, UnityActionException from mlagents_envs.mock_communicator import MockCommunicator @@ -99,7 +99,9 @@ def test_step(mock_communicator, mock_launcher): env.set_actions("RealFakeBrain", spec.action_spec.empty_action(n_agents - 1)) decision_steps, terminal_steps = env.get_steps("RealFakeBrain") n_agents = len(decision_steps) - env.set_actions("RealFakeBrain", spec.action_spec.empty_action(n_agents) - 1) + _empty_act = spec.action_spec.empty_action(n_agents) + next_action = ActionTuple(_empty_act.continuous - 1, _empty_act.discrete - 1) + env.set_actions("RealFakeBrain", next_action) env.step() env.close() diff --git a/ml-agents-envs/mlagents_envs/tests/test_steps.py b/ml-agents-envs/mlagents_envs/tests/test_steps.py index 83ce2dd668..5eb1066bcd 100644 --- a/ml-agents-envs/mlagents_envs/tests/test_steps.py +++ b/ml-agents-envs/mlagents_envs/tests/test_steps.py @@ -81,24 +81,35 @@ def test_specs(): assert specs.discrete_branches == () assert specs.discrete_size == 0 assert specs.continuous_size == 3 - assert specs.empty_action(5).shape == (5, 3) - assert specs.empty_action(5).dtype == np.float32 + assert specs.empty_action(5).continuous.shape == (5, 3) + assert specs.empty_action(5).continuous.dtype == np.float32 specs = ActionSpec.create_discrete((3,)) assert specs.discrete_branches == (3,) assert specs.discrete_size == 1 assert specs.continuous_size == 0 - assert specs.empty_action(5).shape == (5, 1) - assert specs.empty_action(5).dtype == np.int32 + assert specs.empty_action(5).discrete.shape == (5, 1) + assert specs.empty_action(5).discrete.dtype == np.int32 + + specs = ActionSpec(3, (3,)) + assert specs.continuous_size == 3 + assert specs.discrete_branches == (3,) + assert specs.discrete_size == 1 + assert specs.empty_action(5).continuous.shape == (5, 3) + assert specs.empty_action(5).continuous.dtype == np.float32 + assert specs.empty_action(5).discrete.shape == (5, 1) + assert specs.empty_action(5).discrete.dtype == np.int32 def test_action_generator(): # Continuous action_len = 30 specs = ActionSpec.create_continuous(action_len) - zero_action = specs.empty_action(4) + zero_action = specs.empty_action(4).continuous assert np.array_equal(zero_action, np.zeros((4, action_len), dtype=np.float32)) - random_action = specs.random_action(4) + print(specs.random_action(4)) + random_action = specs.random_action(4).continuous + print(random_action) assert random_action.dtype == np.float32 assert random_action.shape == (4, action_len) assert np.min(random_action) >= -1 @@ -107,10 +118,10 @@ def test_action_generator(): # Discrete action_shape = (10, 20, 30) specs = ActionSpec.create_discrete(action_shape) - zero_action = specs.empty_action(4) + zero_action = specs.empty_action(4).discrete assert np.array_equal(zero_action, np.zeros((4, len(action_shape)), dtype=np.int32)) - random_action = specs.random_action(4) + random_action = specs.random_action(4).discrete assert random_action.dtype == np.int32 assert random_action.shape == (4, len(action_shape)) assert np.min(random_action) >= 0 diff --git a/ml-agents/mlagents/trainers/agent_processor.py b/ml-agents/mlagents/trainers/agent_processor.py index ec8d7bb0d0..08efc991fa 100644 --- a/ml-agents/mlagents/trainers/agent_processor.py +++ b/ml-agents/mlagents/trainers/agent_processor.py @@ -2,6 +2,7 @@ from typing import List, Dict, TypeVar, Generic, Tuple, Any, Union from collections import defaultdict, Counter import queue +import numpy as np from mlagents_envs.base_env import ( DecisionSteps, @@ -129,12 +130,19 @@ def _process_step( done = terminated # Since this is an ongoing step interrupted = step.interrupted if terminated else False # Add the outputs of the last eval - action = stored_take_action_outputs["action"][idx] + action_dict = stored_take_action_outputs["action"] + action: Dict[str, np.ndarray] = {} + for act_type, act_array in action_dict.items(): + action[act_type] = act_array[idx] if self.policy.use_continuous_act: action_pre = stored_take_action_outputs["pre_action"][idx] else: action_pre = None - action_probs = stored_take_action_outputs["log_probs"][idx] + action_probs_dict = stored_take_action_outputs["log_probs"] + action_probs: Dict[str, np.ndarray] = {} + for prob_type, prob_array in action_probs_dict.items(): + action_probs[prob_type] = prob_array[idx] + action_mask = stored_decision_step.action_mask prev_action = self.policy.retrieve_previous_action([global_id])[0, :] experience = AgentExperience( diff --git a/ml-agents/mlagents/trainers/buffer.py b/ml-agents/mlagents/trainers/buffer.py index 673da15921..48a8ab1914 100644 --- a/ml-agents/mlagents/trainers/buffer.py +++ b/ml-agents/mlagents/trainers/buffer.py @@ -22,7 +22,7 @@ class AgentBuffer(dict): class AgentBufferField(list): """ - AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to his + AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to its AgentBufferField with the append method. """ diff --git a/ml-agents/mlagents/trainers/demo_loader.py b/ml-agents/mlagents/trainers/demo_loader.py index da24d7a6bf..5c518d47b8 100644 --- a/ml-agents/mlagents/trainers/demo_loader.py +++ b/ml-agents/mlagents/trainers/demo_loader.py @@ -66,7 +66,14 @@ def make_demo_buffer( for i, obs in enumerate(split_obs.visual_observations): demo_raw_buffer["visual_obs%d" % i].append(obs) demo_raw_buffer["vector_obs"].append(split_obs.vector_observations) - demo_raw_buffer["actions"].append(current_pair_info.action_info.vector_actions) + if behavior_spec.action_spec.is_continuous(): + demo_raw_buffer["continuous_action"].append( + current_pair_info.action_info.vector_actions + ) + else: + demo_raw_buffer["discrete_action"].append( + current_pair_info.action_info.vector_actions + ) demo_raw_buffer["prev_action"].append(previous_action) if next_done: demo_raw_buffer.resequence_and_append( diff --git a/ml-agents/mlagents/trainers/env_manager.py b/ml-agents/mlagents/trainers/env_manager.py index 613147b663..ec520ac522 100644 --- a/ml-agents/mlagents/trainers/env_manager.py +++ b/ml-agents/mlagents/trainers/env_manager.py @@ -1,10 +1,13 @@ from abc import ABC, abstractmethod +import numpy as np + from typing import List, Dict, NamedTuple, Iterable, Tuple from mlagents_envs.base_env import ( DecisionSteps, TerminalSteps, BehaviorSpec, BehaviorName, + ActionTuple, ) from mlagents_envs.side_channel.stats_side_channel import EnvironmentStats @@ -12,6 +15,7 @@ from mlagents.trainers.agent_processor import AgentManager, AgentManagerQueue from mlagents.trainers.action_info import ActionInfo from mlagents_envs.logging_util import get_logger +from mlagents_envs.exception import UnityActionException AllStepResult = Dict[BehaviorName, Tuple[DecisionSteps, TerminalSteps]] AllGroupSpec = Dict[BehaviorName, BehaviorSpec] @@ -143,3 +147,21 @@ def _process_step_infos(self, step_infos: List[EnvironmentStep]) -> int: step_info.environment_stats, step_info.worker_id ) return len(step_infos) + + @staticmethod + def action_tuple_from_numpy_dict(action_dict: Dict[str, np.ndarray]) -> ActionTuple: + if "continuous_action" in action_dict: + continuous = action_dict["continuous_action"] + if "discrete_action" in action_dict: + discrete = action_dict["discrete_action"] + action_tuple = ActionTuple(continuous, discrete) + else: + action_tuple = ActionTuple.create_continuous(continuous) + elif "discrete_action" in action_dict: + discrete = action_dict["discrete_action"] + action_tuple = ActionTuple.create_discrete(discrete) + else: + raise UnityActionException( + "The action dict must contain entries for either continuous_action or discrete_action." + ) + return action_tuple diff --git a/ml-agents/mlagents/trainers/optimizer/tf_optimizer.py b/ml-agents/mlagents/trainers/optimizer/tf_optimizer.py index eb0533ba19..f4e432366a 100644 --- a/ml-agents/mlagents/trainers/optimizer/tf_optimizer.py +++ b/ml-agents/mlagents/trainers/optimizer/tf_optimizer.py @@ -62,7 +62,9 @@ def get_trajectory_value_estimates( [self.value_heads, self.policy.memory_out, self.memory_out], feed_dict ) prev_action = ( - batch["actions"][-1] if not self.policy.use_continuous_act else None + batch["discrete_action"][-1] + if not self.policy.use_continuous_act + else None ) else: value_estimates = self.sess.run(self.value_heads, feed_dict) diff --git a/ml-agents/mlagents/trainers/policy/policy.py b/ml-agents/mlagents/trainers/policy/policy.py index 6de0e4794d..aca91299d3 100644 --- a/ml-agents/mlagents/trainers/policy/policy.py +++ b/ml-agents/mlagents/trainers/policy/policy.py @@ -49,12 +49,7 @@ def __init__( 1 for shape in behavior_spec.observation_shapes if len(shape) == 3 ) self.use_continuous_act = self.behavior_spec.action_spec.is_continuous() - # This line will be removed in the ActionBuffer change - self.num_branches = ( - self.behavior_spec.action_spec.continuous_size - + self.behavior_spec.action_spec.discrete_size - ) - self.previous_action_dict: Dict[str, np.array] = {} + self.previous_action_dict: Dict[str, np.ndarray] = {} self.memory_dict: Dict[str, np.ndarray] = {} self.normalize = trainer_settings.network_settings.normalize self.use_recurrent = self.network_settings.memory is not None @@ -108,24 +103,28 @@ def remove_memories(self, agent_ids): if agent_id in self.memory_dict: self.memory_dict.pop(agent_id) - def make_empty_previous_action(self, num_agents): + def make_empty_previous_action(self, num_agents: int) -> np.ndarray: """ Creates empty previous action for use with RNNs and discrete control :param num_agents: Number of agents. :return: Numpy array of zeros. """ - return np.zeros((num_agents, self.num_branches), dtype=np.int) + return np.zeros( + (num_agents, self.behavior_spec.action_spec.discrete_size), dtype=np.int32 + ) def save_previous_action( - self, agent_ids: List[str], action_matrix: Optional[np.ndarray] + self, agent_ids: List[str], action_dict: Dict[str, np.ndarray] ) -> None: - if action_matrix is None: + if action_dict is None or "discrete_action" not in action_dict: return for index, agent_id in enumerate(agent_ids): - self.previous_action_dict[agent_id] = action_matrix[index, :] + self.previous_action_dict[agent_id] = action_dict["discrete_action"][ + index, : + ] def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray: - action_matrix = np.zeros((len(agent_ids), self.num_branches), dtype=np.int) + action_matrix = self.make_empty_previous_action(len(agent_ids)) for index, agent_id in enumerate(agent_ids): if agent_id in self.previous_action_dict: action_matrix[index, :] = self.previous_action_dict[agent_id] diff --git a/ml-agents/mlagents/trainers/policy/tf_policy.py b/ml-agents/mlagents/trainers/policy/tf_policy.py index 3650565710..6e7db076ff 100644 --- a/ml-agents/mlagents/trainers/policy/tf_policy.py +++ b/ml-agents/mlagents/trainers/policy/tf_policy.py @@ -241,6 +241,7 @@ def evaluate( feed_dict[self.prev_action] = self.retrieve_previous_action( global_agent_ids ) + feed_dict[self.memory_in] = self.retrieve_memories(global_agent_ids) feed_dict = self.fill_eval_dict(feed_dict, decision_requests) run_out = self._execute_model(feed_dict, self.inference_dict) @@ -270,6 +271,14 @@ def get_action( ) self.save_memories(global_agent_ids, run_out.get("memory_out")) + # For Compatibility with buffer changes for hybrid action support + if "log_probs" in run_out: + run_out["log_probs"] = {"action_probs": run_out["log_probs"]} + if "action" in run_out: + if self.behavior_spec.action_spec.is_continuous(): + run_out["action"] = {"continuous_action": run_out["action"]} + else: + run_out["action"] = {"discrete_action": run_out["action"]} return ActionInfo( action=run_out.get("action"), value=run_out.get("value"), diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index 35760964e9..c51c015fa9 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -16,7 +16,8 @@ SeparateActorCritic, GlobalSteps, ) -from mlagents.trainers.torch.utils import ModelUtils + +from mlagents.trainers.torch.utils import ModelUtils, AgentAction, ActionLogProbs EPSILON = 1e-7 # Small value to avoid divide by zero @@ -122,15 +123,13 @@ def sample_actions( masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, seq_len: int = 1, - all_log_probs: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]: """ :param vec_obs: List of vector observations. :param vis_obs: List of visual observations. :param masks: Loss masks for RNN, else None. :param memories: Input memories when using RNN, else None. :param seq_len: Sequence length when using RNN. - :param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action. :return: Tuple of actions, log probabilities (dependent on all_log_probs), entropies, and output memories, all as Torch Tensors. """ @@ -144,37 +143,36 @@ def sample_actions( vec_obs, vis_obs, masks, memories, seq_len ) action_list = self.actor_critic.sample_action(dists) - log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy( + log_probs_list, entropies, all_logs_list = ModelUtils.get_probs_and_entropy( action_list, dists ) - actions = torch.stack(action_list, dim=-1) - if self.use_continuous_act: - actions = actions[:, :, 0] - else: - actions = actions[:, 0, :] + actions = AgentAction.create(action_list, self.behavior_spec.action_spec) + log_probs = ActionLogProbs.create( + log_probs_list, self.behavior_spec.action_spec, all_logs_list + ) # Use the sum of entropy across actions, not the mean entropy_sum = torch.sum(entropies, dim=1) - return ( - actions, - all_logs if all_log_probs else log_probs, - entropy_sum, - memories, - ) + return (actions, log_probs, entropy_sum, memories) def evaluate_actions( self, vec_obs: torch.Tensor, vis_obs: torch.Tensor, - actions: torch.Tensor, + actions: AgentAction, masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, seq_len: int = 1, - ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: + ) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]: dists, value_heads, _ = self.actor_critic.get_dist_and_value( vec_obs, vis_obs, masks, memories, seq_len ) - action_list = [actions[..., i] for i in range(actions.shape[-1])] - log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_list, dists) + action_list = actions.to_tensor_list() + log_probs_list, entropies, _ = ModelUtils.get_probs_and_entropy( + action_list, dists + ) + log_probs = ActionLogProbs.create( + log_probs_list, self.behavior_spec.action_spec + ) # Use the sum of entropy across actions, not the mean entropy_sum = torch.sum(entropies, dim=1) return log_probs, entropy_sum, value_heads @@ -203,10 +201,12 @@ def evaluate( action, log_probs, entropy, memories = self.sample_actions( vec_obs, vis_obs, masks=masks, memories=memories ) - run_out["action"] = ModelUtils.to_numpy(action) - run_out["pre_action"] = ModelUtils.to_numpy(action) - # Todo - make pre_action difference - run_out["log_probs"] = ModelUtils.to_numpy(log_probs) + action_dict = action.to_numpy_dict() + run_out["action"] = action_dict + run_out["pre_action"] = ( + action_dict["continuous_action"] if self.use_continuous_act else None + ) + run_out["log_probs"] = log_probs.to_numpy_dict() run_out["entropy"] = ModelUtils.to_numpy(entropy) run_out["learning_rate"] = 0.0 if self.use_recurrent: diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_tf.py b/ml-agents/mlagents/trainers/ppo/optimizer_tf.py index 05ce4503c8..7c4c86cd40 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_tf.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_tf.py @@ -338,9 +338,12 @@ def _construct_feed_dict( if self.policy.output_pre is not None and "actions_pre" in mini_batch: feed_dict[self.policy.output_pre] = mini_batch["actions_pre"] else: - feed_dict[self.policy.output] = mini_batch["actions"] - if self.policy.use_recurrent: - feed_dict[self.policy.prev_action] = mini_batch["prev_action"] + if self.policy.use_continuous_act: # For hybrid action buffer support + feed_dict[self.policy.output] = mini_batch["continuous_action"] + else: + feed_dict[self.policy.output] = mini_batch["discrete_action"] + if self.policy.use_recurrent: + feed_dict[self.policy.prev_action] = mini_batch["prev_action"] feed_dict[self.policy.action_masks] = mini_batch["action_mask"] if "vector_obs" in mini_batch: feed_dict[self.policy.vector_in] = mini_batch["vector_obs"] diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index 8660d820aa..8dc3b6efdf 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -7,7 +7,7 @@ from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer from mlagents.trainers.settings import TrainerSettings, PPOSettings -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch.utils import ModelUtils, AgentAction, ActionLogProbs class TorchPPOOptimizer(TorchOptimizer): @@ -102,7 +102,6 @@ def ppo_policy_loss( advantage = advantages.unsqueeze(-1) decay_epsilon = self.hyperparameters.epsilon - r_theta = torch.exp(log_probs - old_log_probs) p_opt_a = r_theta * advantage p_opt_b = ( @@ -135,10 +134,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) - if self.policy.use_continuous_act: - actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) - else: - actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) + actions = AgentAction.from_dict(batch) memories = [ ModelUtils.list_to_tensor(batch["memory"][i]) @@ -164,6 +160,8 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: memories=memories, seq_len=self.policy.sequence_length, ) + old_log_probs = ActionLogProbs.from_dict(batch).flatten() + log_probs = log_probs.flatten() loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) value_loss = self.ppo_value_loss( values, old_values, returns, decay_eps, loss_masks @@ -171,7 +169,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: policy_loss = self.ppo_policy_loss( ModelUtils.list_to_tensor(batch["advantages"]), log_probs, - ModelUtils.list_to_tensor(batch["action_probs"]), + old_log_probs, loss_masks, ) loss = ( diff --git a/ml-agents/mlagents/trainers/sac/optimizer_tf.py b/ml-agents/mlagents/trainers/sac/optimizer_tf.py index 95860219e2..e9d341193e 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_tf.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_tf.py @@ -608,9 +608,11 @@ def _construct_feed_dict( feed_dict[self.rewards_holders[name]] = batch[f"{name}_rewards"] if self.policy.use_continuous_act: - feed_dict[self.policy_network.external_action_in] = batch["actions"] + feed_dict[self.policy_network.external_action_in] = batch[ + "continuous_action" + ] else: - feed_dict[policy.output] = batch["actions"] + feed_dict[policy.output] = batch["discrete_action"] if self.policy.use_recurrent: feed_dict[policy.prev_action] = batch["prev_action"] feed_dict[policy.action_masks] = batch["action_mask"] diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index 21d34d8a02..df5bbaeff9 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -8,7 +8,7 @@ from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.settings import NetworkSettings from mlagents.trainers.torch.networks import ValueNetwork -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch.utils import ModelUtils, AgentAction, ActionLogProbs from mlagents.trainers.buffer import AgentBuffer from mlagents_envs.timers import timed from mlagents.trainers.exception import UnityTrainerException @@ -231,7 +231,7 @@ def sac_q_loss( def sac_value_loss( self, - log_probs: torch.Tensor, + log_probs: ActionLogProbs, values: Dict[str, torch.Tensor], q1p_out: Dict[str, torch.Tensor], q2p_out: Dict[str, torch.Tensor], @@ -245,7 +245,7 @@ def sac_value_loss( if not discrete: min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) else: - action_probs = log_probs.exp() + action_probs = log_probs.all_discrete_tensor.exp() _branched_q1p = ModelUtils.break_into_branches( q1p_out[name] * action_probs, self.act_size ) @@ -278,7 +278,7 @@ def sac_value_loss( for name in values.keys(): with torch.no_grad(): v_backup = min_policy_qs[name] - torch.sum( - _ent_coef * log_probs, dim=1 + _ent_coef * log_probs.continuous_tensor, dim=1 ) value_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(values[name], v_backup), loss_masks @@ -286,7 +286,8 @@ def sac_value_loss( value_losses.append(value_loss) else: branched_per_action_ent = ModelUtils.break_into_branches( - log_probs * log_probs.exp(), self.act_size + log_probs.all_discrete_tensor * log_probs.all_discrete_tensor.exp(), + self.act_size, ) # We have to do entropy bonus per action branch branched_ent_bonus = torch.stack( @@ -312,7 +313,7 @@ def sac_value_loss( def sac_policy_loss( self, - log_probs: torch.Tensor, + log_probs: ActionLogProbs, q1p_outs: Dict[str, torch.Tensor], loss_masks: torch.Tensor, discrete: bool, @@ -321,12 +322,14 @@ def sac_policy_loss( mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0) if not discrete: mean_q1 = mean_q1.unsqueeze(1) - batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1) + batch_policy_loss = torch.mean( + _ent_coef * log_probs.continuous_tensor - mean_q1, dim=1 + ) policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) else: - action_probs = log_probs.exp() + action_probs = log_probs.all_discrete_tensor.exp() branched_per_action_ent = ModelUtils.break_into_branches( - log_probs * action_probs, self.act_size + log_probs.all_discrete_tensor * action_probs, self.act_size ) branched_q_term = ModelUtils.break_into_branches( mean_q1 * action_probs, self.act_size @@ -345,18 +348,21 @@ def sac_policy_loss( return policy_loss def sac_entropy_loss( - self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool + self, log_probs: ActionLogProbs, loss_masks: torch.Tensor, discrete: bool ) -> torch.Tensor: if not discrete: with torch.no_grad(): - target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1) + target_current_diff = torch.sum( + log_probs.continuous_tensor + self.target_entropy, dim=1 + ) entropy_loss = -1 * ModelUtils.masked_mean( self._log_ent_coef * target_current_diff, loss_masks ) else: with torch.no_grad(): branched_per_action_ent = ModelUtils.break_into_branches( - log_probs * log_probs.exp(), self.act_size + log_probs.all_discrete_tensor * log_probs.all_discrete_tensor.exp(), + self.act_size, ) target_current_diff_branched = torch.stack( [ @@ -411,10 +417,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] next_vec_obs = [ModelUtils.list_to_tensor(batch["next_vector_in"])] act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) - if self.policy.use_continuous_act: - actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) - else: - actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) + actions = AgentAction.from_dict(batch) memories_list = [ ModelUtils.list_to_tensor(batch["memory"][i]) @@ -470,18 +473,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: masks=act_masks, memories=memories, seq_len=self.policy.sequence_length, - all_log_probs=not self.policy.use_continuous_act, ) value_estimates, _ = self.policy.actor_critic.critic_pass( vec_obs, vis_obs, memories, sequence_length=self.policy.sequence_length ) if self.policy.use_continuous_act: - squeezed_actions = actions.squeeze(-1) + squeezed_actions = actions.continuous_tensor # Only need grad for q1, as that is used for policy. q1p_out, q2p_out = self.value_network( vec_obs, vis_obs, - sampled_actions, + sampled_actions.continuous_tensor, memories=q_memories, sequence_length=self.policy.sequence_length, q2_grad=False, @@ -510,8 +512,8 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: memories=q_memories, sequence_length=self.policy.sequence_length, ) - q1_stream = self._condense_q_streams(q1_out, actions) - q2_stream = self._condense_q_streams(q2_out, actions) + q1_stream = self._condense_q_streams(q1_out, actions.discrete_tensor) + q2_stream = self._condense_q_streams(q2_out, actions.discrete_tensor) with torch.no_grad(): target_values, _ = self.target_network( diff --git a/ml-agents/mlagents/trainers/simple_env_manager.py b/ml-agents/mlagents/trainers/simple_env_manager.py index 5055b96d22..5e60571dd2 100644 --- a/ml-agents/mlagents/trainers/simple_env_manager.py +++ b/ml-agents/mlagents/trainers/simple_env_manager.py @@ -28,7 +28,8 @@ def _step(self) -> List[EnvironmentStep]: self.previous_all_action_info = all_action_info for brain_name, action_info in all_action_info.items(): - self.env.set_actions(brain_name, action_info.action) + _action = EnvManager.action_tuple_from_numpy_dict(action_info.action) + self.env.set_actions(brain_name, _action) self.env.step() all_step_result = self._generate_all_results() diff --git a/ml-agents/mlagents/trainers/subprocess_env_manager.py b/ml-agents/mlagents/trainers/subprocess_env_manager.py index c688beaa57..c205d2cb8b 100644 --- a/ml-agents/mlagents/trainers/subprocess_env_manager.py +++ b/ml-agents/mlagents/trainers/subprocess_env_manager.py @@ -144,7 +144,10 @@ def _generate_all_results() -> AllStepResult: all_action_info = req.payload for brain_name, action_info in all_action_info.items(): if len(action_info.action) != 0: - env.set_actions(brain_name, action_info.action) + _action = EnvManager.action_tuple_from_numpy_dict( + action_info.action + ) + env.set_actions(brain_name, _action) env.step() all_step_result = _generate_all_results() # The timers in this process are independent from all the processes and the "main" process diff --git a/ml-agents/mlagents/trainers/tests/mock_brain.py b/ml-agents/mlagents/trainers/tests/mock_brain.py index 2f5241a7a1..078193b8de 100644 --- a/ml-agents/mlagents/trainers/tests/mock_brain.py +++ b/ml-agents/mlagents/trainers/tests/mock_brain.py @@ -77,17 +77,22 @@ def make_fake_trajectory( steps_list = [] action_size = action_spec.discrete_size + action_spec.continuous_size - action_probs = np.ones( - int(np.sum(action_spec.discrete_branches) + action_spec.continuous_size), - dtype=np.float32, - ) + action_probs = { + "action_probs": np.ones( + int(np.sum(action_spec.discrete_branches) + action_spec.continuous_size), + dtype=np.float32, + ) + } for _i in range(length - 1): obs = [] for _shape in observation_shapes: obs.append(np.ones(_shape, dtype=np.float32)) reward = 1.0 done = False - action = np.zeros(action_size, dtype=np.float32) + if action_spec.is_continuous(): + action = {"continuous_action": np.zeros(action_size, dtype=np.float32)} + else: + action = {"discrete_action": np.zeros(action_size, dtype=np.float32)} action_pre = np.zeros(action_size, dtype=np.float32) action_mask = ( [ @@ -97,7 +102,11 @@ def make_fake_trajectory( if action_spec.is_discrete() else None ) - prev_action = np.ones(action_size, dtype=np.float32) + if action_spec.is_discrete(): + prev_action = np.ones(action_size, dtype=np.int32) + else: + prev_action = np.ones(action_size, dtype=np.float32) + max_step = False memory = np.ones(memory_size, dtype=np.float32) agent_id = "test_agent" diff --git a/ml-agents/mlagents/trainers/tests/simple_test_envs.py b/ml-agents/mlagents/trainers/tests/simple_test_envs.py index c53e42c95a..5f1c076d92 100644 --- a/ml-agents/mlagents/trainers/tests/simple_test_envs.py +++ b/ml-agents/mlagents/trainers/tests/simple_test_envs.py @@ -4,6 +4,7 @@ from mlagents_envs.base_env import ( ActionSpec, + ActionTuple, BaseEnv, BehaviorSpec, DecisionSteps, @@ -58,6 +59,7 @@ def __init__( else: action_spec = ActionSpec.create_continuous(action_size) self.behavior_spec = BehaviorSpec(self._make_obs_spec(), action_spec) + self.action_spec = action_spec self.action_size = action_size self.names = brain_names self.positions: Dict[str, List[float]] = {} @@ -114,11 +116,13 @@ def get_steps(self, behavior_name): def _take_action(self, name: str) -> bool: deltas = [] - for _act in self.action[name][0]: - if self.discrete: - deltas.append(1 if _act else -1) - else: - deltas.append(_act) + _act = self.action[name] + if self.action_spec.discrete_size > 0: + for _disc in _act.discrete[0]: + deltas.append(1 if _disc else -1) + if self.action_spec.continuous_size > 0: + for _cont in _act.continuous[0]: + deltas.append(_cont) for i, _delta in enumerate(deltas): _delta = clamp(_delta, -self.step_size, self.step_size) self.positions[name][i] += _delta @@ -281,8 +285,12 @@ def __init__( def step(self) -> None: super().step() for name in self.names: + if self.discrete: + action = self.action[name].discrete + else: + action = self.action[name].continuous self.demonstration_protos[name] += proto_from_steps_and_action( - self.step_result[name][0], self.step_result[name][1], self.action[name] + self.step_result[name][0], self.step_result[name][1], action ) self.demonstration_protos[name] = self.demonstration_protos[name][ -self.n_demos : @@ -293,7 +301,15 @@ def solve(self) -> None: for _ in range(self.n_demos): for name in self.names: if self.discrete: - self.action[name] = [[1]] if self.goal[name] > 0 else [[0]] + self.action[name] = ActionTuple( + np.array([], dtype=np.float32), + np.array( + [[1]] if self.goal[name] > 0 else [[0]], dtype=np.int32 + ), + ) else: - self.action[name] = [[float(self.goal[name])]] + self.action[name] = ActionTuple( + np.array([[float(self.goal[name])]], dtype=np.float32), + np.array([], dtype=np.int32), + ) self.step() diff --git a/ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py b/ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py index 3d45c8c645..4804c46eeb 100644 --- a/ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py +++ b/ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py @@ -48,7 +48,7 @@ def test_take_action_returns_action_info_when_available(): behavior_spec = basic_behavior_spec() policy = FakePolicy(test_seed, behavior_spec, TrainerSettings(), "output") policy_eval_out = { - "action": np.array([1.0], dtype=np.float32), + "action": {"continuous_action": np.array([1.0], dtype=np.float32)}, "memory_out": np.array([[2.5]], dtype=np.float32), "value": np.array([1.1], dtype=np.float32), } diff --git a/ml-agents/mlagents/trainers/tests/test_agent_processor.py b/ml-agents/mlagents/trainers/tests/test_agent_processor.py index 399f89fd4f..a452f3ae06 100644 --- a/ml-agents/mlagents/trainers/tests/test_agent_processor.py +++ b/ml-agents/mlagents/trainers/tests/test_agent_processor.py @@ -20,9 +20,7 @@ def create_mock_policy(): mock_policy = mock.Mock() mock_policy.reward_signals = {} mock_policy.retrieve_memories.return_value = np.zeros((1, 1), dtype=np.float32) - mock_policy.retrieve_previous_action.return_value = np.zeros( - (1, 1), dtype=np.float32 - ) + mock_policy.retrieve_previous_action.return_value = np.zeros((1, 1), dtype=np.int32) return mock_policy @@ -39,11 +37,11 @@ def test_agentprocessor(num_vis_obs): ) fake_action_outputs = { - "action": [0.1, 0.1], + "action": {"continuous_action": [0.1, 0.1]}, "entropy": np.array([1.0], dtype=np.float32), "learning_rate": 1.0, "pre_action": [0.1, 0.1], - "log_probs": [0.1, 0.1], + "log_probs": {"continuous_log_probs": [0.1, 0.1]}, } mock_decision_steps, mock_terminal_steps = mb.create_mock_steps( num_agents=2, @@ -51,7 +49,7 @@ def test_agentprocessor(num_vis_obs): action_spec=ActionSpec.create_continuous(2), ) fake_action_info = ActionInfo( - action=[0.1, 0.1], + action={"continuous_action": [0.1, 0.1]}, value=[0.1, 0.1], outputs=fake_action_outputs, agent_ids=mock_decision_steps.agent_id, @@ -101,11 +99,11 @@ def test_agent_deletion(): ) fake_action_outputs = { - "action": [0.1], + "action": {"continuous_action": [0.1]}, "entropy": np.array([1.0], dtype=np.float32), "learning_rate": 1.0, "pre_action": [0.1], - "log_probs": [0.1], + "log_probs": {"continuous_log_probs": [0.1]}, } mock_decision_step, mock_terminal_step = mb.create_mock_steps( num_agents=1, @@ -119,7 +117,7 @@ def test_agent_deletion(): done=True, ) fake_action_info = ActionInfo( - action=[0.1], + action={"continuous_action": [0.1]}, value=[0.1], outputs=fake_action_outputs, agent_ids=mock_decision_step.agent_id, @@ -139,7 +137,9 @@ def test_agent_deletion(): processor.add_experiences( mock_decision_step, mock_terminal_step, _ep, fake_action_info ) - add_calls.append(mock.call([get_global_agent_id(_ep, 0)], [0.1])) + add_calls.append( + mock.call([get_global_agent_id(_ep, 0)], {"continuous_action": [0.1]}) + ) processor.add_experiences( mock_done_decision_step, mock_done_terminal_step, _ep, fake_action_info ) @@ -178,11 +178,11 @@ def test_end_episode(): ) fake_action_outputs = { - "action": [0.1], + "action": {"continuous_action": [0.1]}, "entropy": np.array([1.0], dtype=np.float32), "learning_rate": 1.0, "pre_action": [0.1], - "log_probs": [0.1], + "log_probs": {"continuous_log_probs": [0.1]}, } mock_decision_step, mock_terminal_step = mb.create_mock_steps( num_agents=1, @@ -190,7 +190,7 @@ def test_end_episode(): action_spec=ActionSpec.create_continuous(2), ) fake_action_info = ActionInfo( - action=[0.1], + action={"continuous_action": [0.1]}, value=[0.1], outputs=fake_action_outputs, agent_ids=mock_decision_step.agent_id, diff --git a/ml-agents/mlagents/trainers/tests/test_demo_loader.py b/ml-agents/mlagents/trainers/tests/test_demo_loader.py index d0a94d12ec..21c2750d88 100644 --- a/ml-agents/mlagents/trainers/tests/test_demo_loader.py +++ b/ml-agents/mlagents/trainers/tests/test_demo_loader.py @@ -32,7 +32,10 @@ def test_load_demo(): assert len(pair_infos) == total_expected _, demo_buffer = demo_to_buffer(path_prefix + "/test.demo", 1, BEHAVIOR_SPEC) - assert len(demo_buffer["actions"]) == total_expected - 1 + assert ( + len(demo_buffer["continuous_action"]) == total_expected - 1 + or len(demo_buffer["discrete_action"]) == total_expected - 1 + ) def test_load_demo_dir(): @@ -44,7 +47,10 @@ def test_load_demo_dir(): assert len(pair_infos) == total_expected _, demo_buffer = demo_to_buffer(path_prefix + "/test_demo_dir", 1, BEHAVIOR_SPEC) - assert len(demo_buffer["actions"]) == total_expected - 1 + assert ( + len(demo_buffer["continuous_action"]) == total_expected - 1 + or len(demo_buffer["discrete_action"]) == total_expected - 1 + ) def test_demo_mismatch(): diff --git a/ml-agents/mlagents/trainers/tests/test_trajectory.py b/ml-agents/mlagents/trainers/tests/test_trajectory.py index 5d6bf7e683..4c6c6cb50a 100644 --- a/ml-agents/mlagents/trainers/tests/test_trajectory.py +++ b/ml-agents/mlagents/trainers/tests/test_trajectory.py @@ -40,7 +40,7 @@ def test_trajectory_to_agentbuffer(): "masks", "done", "actions_pre", - "actions", + "continuous_action", "action_probs", "action_mask", "prev_action", diff --git a/ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py b/ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py index b5bb7b0289..bee095f92a 100644 --- a/ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py +++ b/ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py @@ -80,13 +80,14 @@ def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None: with torch.no_grad(): _, log_probs1, _, _ = policy1.sample_actions( - vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True + vec_obs, vis_obs, masks=masks, memories=memories ) _, log_probs2, _, _ = policy2.sample_actions( - vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True + vec_obs, vis_obs, masks=masks, memories=memories ) - - np.testing.assert_array_equal(log_probs1, log_probs2) + np.testing.assert_array_equal( + log_probs1.all_discrete_tensor, log_probs2.all_discrete_tensor + ) @pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_policy.py b/ml-agents/mlagents/trainers/tests/torch/test_policy.py index bea74c71ae..e4c319bd70 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_policy.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_policy.py @@ -4,7 +4,7 @@ from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.tests import mock_brain as mb from mlagents.trainers.settings import TrainerSettings, NetworkSettings -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch.utils import ModelUtils, AgentAction VECTOR_ACTION_SPACE = 2 VECTOR_OBS_SPACE = 8 @@ -53,9 +53,15 @@ def test_policy_evaluate(rnn, visual, discrete): run_out = policy.evaluate(decision_step, list(decision_step.agent_id)) if discrete: - run_out["action"].shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE)) + run_out["action"]["discrete_action"].shape == ( + NUM_AGENTS, + len(DISCRETE_ACTION_SPACE), + ) else: - assert run_out["action"].shape == (NUM_AGENTS, VECTOR_ACTION_SPACE) + assert run_out["action"]["continuous_action"].shape == ( + NUM_AGENTS, + VECTOR_ACTION_SPACE, + ) @pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) @@ -68,10 +74,7 @@ def test_evaluate_actions(rnn, visual, discrete): buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size) vec_obs = [ModelUtils.list_to_tensor(buffer["vector_obs"])] act_masks = ModelUtils.list_to_tensor(buffer["action_mask"]) - if policy.use_continuous_act: - actions = ModelUtils.list_to_tensor(buffer["actions"]).unsqueeze(-1) - else: - actions = ModelUtils.list_to_tensor(buffer["actions"], dtype=torch.long) + agent_action = AgentAction.from_dict(buffer) vis_obs = [] for idx, _ in enumerate(policy.actor_critic.network_body.visual_processors): vis_ob = ModelUtils.list_to_tensor(buffer["visual_obs%d" % idx]) @@ -88,7 +91,7 @@ def test_evaluate_actions(rnn, visual, discrete): vec_obs, vis_obs, masks=act_masks, - actions=actions, + actions=agent_action, memories=memories, seq_len=policy.sequence_length, ) @@ -97,7 +100,7 @@ def test_evaluate_actions(rnn, visual, discrete): else: _size = policy.behavior_spec.action_spec.continuous_size - assert log_probs.shape == (64, _size) + assert log_probs.flatten().shape == (64, _size) assert entropy.shape == (64,) for val in values.values(): assert val.shape == (64,) @@ -132,15 +135,17 @@ def test_sample_actions(rnn, visual, discrete): masks=act_masks, memories=memories, seq_len=policy.sequence_length, - all_log_probs=not policy.use_continuous_act, ) if discrete: - assert log_probs.shape == ( + assert log_probs.all_discrete_tensor.shape == ( 64, sum(policy.behavior_spec.action_spec.discrete_branches), ) else: - assert log_probs.shape == (64, policy.behavior_spec.action_spec.continuous_size) + assert log_probs.continuous_tensor.shape == ( + 64, + policy.behavior_spec.action_spec.continuous_size, + ) assert entropies.shape == (64,) if rnn: diff --git a/ml-agents/mlagents/trainers/tests/torch/test_ppo.py b/ml-agents/mlagents/trainers/tests/torch/test_ppo.py index 3e2314016e..f7c9d885bc 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_ppo.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_ppo.py @@ -77,7 +77,14 @@ def test_ppo_optimizer_update(dummy_config, rnn, visual, discrete): # NOTE: In TensorFlow, the log_probs are saved as one for every discrete action, whereas # in PyTorch it is saved as the total probability per branch. So we need to modify the # log prob in the fake buffer here. - update_buffer["action_probs"] = np.ones_like(update_buffer["actions"]) + if discrete: + update_buffer["discrete_log_probs"] = np.ones_like( + update_buffer["discrete_action"] + ) + else: + update_buffer["continuous_log_probs"] = np.ones_like( + update_buffer["continuous_action"] + ) return_stats = optimizer.update( update_buffer, num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length, @@ -122,7 +129,14 @@ def test_ppo_optimizer_update_curiosity( # NOTE: In TensorFlow, the log_probs are saved as one for every discrete action, whereas # in PyTorch it is saved as the total probability per branch. So we need to modify the # log prob in the fake buffer here. - update_buffer["action_probs"] = np.ones_like(update_buffer["actions"]) + if discrete: + update_buffer["discrete_log_probs"] = np.ones_like( + update_buffer["discrete_action"] + ) + else: + update_buffer["continuous_log_probs"] = np.ones_like( + update_buffer["continuous_action"] + ) optimizer.update( update_buffer, num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length, @@ -147,6 +161,9 @@ def test_ppo_optimizer_update_gail(gail_dummy_config, dummy_config): # noqa: F8 update_buffer["extrinsic_value_estimates"] = update_buffer["environment_rewards"] update_buffer["gail_returns"] = update_buffer["environment_rewards"] update_buffer["gail_value_estimates"] = update_buffer["environment_rewards"] + update_buffer["continuous_log_probs"] = np.ones_like( + update_buffer["continuous_action"] + ) optimizer.update( update_buffer, num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length, @@ -163,7 +180,9 @@ def test_ppo_optimizer_update_gail(gail_dummy_config, dummy_config): # noqa: F8 # NOTE: In TensorFlow, the log_probs are saved as one for every discrete action, whereas # in PyTorch it is saved as the total probability per branch. So we need to modify the # log prob in the fake buffer here. - update_buffer["action_probs"] = np.ones_like(update_buffer["actions"]) + update_buffer["continuous_log_probs"] = np.ones_like( + update_buffer["continuous_action"] + ) optimizer.update( update_buffer, num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length, diff --git a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py b/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py index 16b9ced0c6..ca82392be9 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py @@ -87,7 +87,7 @@ def test_continuous_action_prediction(behavior_spec: BehaviorSpec, seed: int) -> for _ in range(200): curiosity_rp.update(buffer) prediction = curiosity_rp._network.predict_action(buffer)[0] - target = torch.tensor(buffer["actions"][0]) + target = torch.tensor(buffer["continuous_action"][0]) error = torch.mean((prediction - target) ** 2).item() assert error < 0.001 diff --git a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py b/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py index 8b01fbd467..a7cceadee9 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py @@ -16,7 +16,13 @@ def create_agent_buffer( np.random.normal(size=shape).astype(np.float32) for shape in behavior_spec.observation_shapes ] - action = behavior_spec.action_spec.random_action(1)[0, :] + action_buffer = behavior_spec.action_spec.random_action(1) + action = {} + if behavior_spec.action_spec.continuous_size > 0: + action["continuous_action"] = action_buffer.continuous + if behavior_spec.action_spec.discrete_size > 0: + action["discrete_action"] = action_buffer.discrete + for _ in range(number): curr_split_obs = SplitObservations.from_observations(curr_observations) next_split_obs = SplitObservations.from_observations(next_observations) @@ -27,7 +33,8 @@ def create_agent_buffer( ) buffer["vector_obs"].append(curr_split_obs.vector_observations) buffer["next_vector_in"].append(next_split_obs.vector_observations) - buffer["actions"].append(action) + for _act_type, _act in action.items(): + buffer[_act_type].append(_act[0, :]) buffer["reward"].append(np.ones(1, dtype=np.float32) * reward) buffer["masks"].append(np.ones(1, dtype=np.float32)) buffer["done"] = np.zeros(number, dtype=np.float32) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_utils.py b/ml-agents/mlagents/trainers/tests/torch/test_utils.py index c52bfd1802..1e39afae04 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_utils.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_utils.py @@ -159,13 +159,15 @@ def test_get_probs_and_entropy(): log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy( action_list, dist_list ) - assert log_probs.shape == (1, 2, 2) + for lp in log_probs: + assert lp.shape == (1, 2) assert entropies.shape == (1, 2, 2) - assert all_probs is None + assert all_probs == [] - for log_prob in log_probs.flatten(): + for log_prob in log_probs: # Log prob of standard normal at 0 - assert log_prob == pytest.approx(-0.919, abs=0.01) + for lp in log_prob.flatten(): + assert lp == pytest.approx(-0.919, abs=0.01) for ent in entropies.flatten(): # entropy of standard normal at 0 @@ -182,10 +184,11 @@ def test_get_probs_and_entropy(): log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy( action_list, dist_list ) - assert all_probs.shape == (1, len(dist_list * act_size)) + for all_prob in all_probs: + assert all_prob.shape == (1, act_size) assert entropies.shape == (1, len(dist_list)) # Make sure the first action has high probability than the others. - assert log_probs.flatten()[0] > log_probs.flatten()[1] + assert log_probs[0] > log_probs[1] def test_masked_mean(): diff --git a/ml-agents/mlagents/trainers/tf/components/bc/module.py b/ml-agents/mlagents/trainers/tf/components/bc/module.py index 1cace3294c..ef829c73b1 100644 --- a/ml-agents/mlagents/trainers/tf/components/bc/module.py +++ b/ml-agents/mlagents/trainers/tf/components/bc/module.py @@ -106,8 +106,8 @@ def _update_batch( self.policy.batch_size_ph: n_sequences, self.policy.sequence_length_ph: self.policy.sequence_length, } - feed_dict[self.model.action_in_expert] = mini_batch_demo["actions"] if self.policy.behavior_spec.action_spec.is_discrete(): + feed_dict[self.model.action_in_expert] = mini_batch_demo["discrete_action"] feed_dict[self.policy.action_masks] = np.ones( ( self.n_sequences * self.policy.sequence_length, @@ -115,6 +115,10 @@ def _update_batch( ), dtype=np.float32, ) + else: + feed_dict[self.model.action_in_expert] = mini_batch_demo[ + "continuous_action" + ] if self.policy.vec_obs_size > 0: feed_dict[self.policy.vector_in] = mini_batch_demo["vector_obs"] for i, _ in enumerate(self.policy.visual_in): diff --git a/ml-agents/mlagents/trainers/tf/components/reward_signals/curiosity/signal.py b/ml-agents/mlagents/trainers/tf/components/reward_signals/curiosity/signal.py index 63b5453ba1..48a01d0f34 100644 --- a/ml-agents/mlagents/trainers/tf/components/reward_signals/curiosity/signal.py +++ b/ml-agents/mlagents/trainers/tf/components/reward_signals/curiosity/signal.py @@ -42,7 +42,7 @@ def __init__(self, policy: TFPolicy, settings: CuriositySettings): def evaluate_batch(self, mini_batch: AgentBuffer) -> RewardSignalResult: feed_dict: Dict[tf.Tensor, Any] = { - self.policy.batch_size_ph: len(mini_batch["actions"]), + self.policy.batch_size_ph: len(mini_batch["vector_obs"]), self.policy.sequence_length_ph: self.policy.sequence_length, } if self.policy.use_vec_obs: @@ -56,9 +56,9 @@ def evaluate_batch(self, mini_batch: AgentBuffer) -> RewardSignalResult: feed_dict[self.model.next_visual_in[i]] = _next_obs if self.policy.use_continuous_act: - feed_dict[self.policy.selected_actions] = mini_batch["actions"] + feed_dict[self.policy.selected_actions] = mini_batch["continuous_action"] else: - feed_dict[self.policy.output] = mini_batch["actions"] + feed_dict[self.policy.output] = mini_batch["discrete_action"] unscaled_reward = self.policy.sess.run( self.model.intrinsic_reward, feed_dict=feed_dict ) @@ -82,9 +82,9 @@ def prepare_update( policy.mask_input: mini_batch["masks"], } if self.policy.use_continuous_act: - feed_dict[policy.selected_actions] = mini_batch["actions"] + feed_dict[policy.selected_actions] = mini_batch["continuous_action"] else: - feed_dict[policy.output] = mini_batch["actions"] + feed_dict[policy.output] = mini_batch["discrete_action"] if self.policy.use_vec_obs: feed_dict[policy.vector_in] = mini_batch["vector_obs"] feed_dict[self.model.next_vector_in] = mini_batch["next_vector_in"] diff --git a/ml-agents/mlagents/trainers/tf/components/reward_signals/gail/signal.py b/ml-agents/mlagents/trainers/tf/components/reward_signals/gail/signal.py index baeec49ac6..d2b585314e 100644 --- a/ml-agents/mlagents/trainers/tf/components/reward_signals/gail/signal.py +++ b/ml-agents/mlagents/trainers/tf/components/reward_signals/gail/signal.py @@ -57,7 +57,7 @@ def __init__(self, policy: TFPolicy, settings: GAILSettings): def evaluate_batch(self, mini_batch: AgentBuffer) -> RewardSignalResult: feed_dict: Dict[tf.Tensor, Any] = { - self.policy.batch_size_ph: len(mini_batch["actions"]), + self.policy.batch_size_ph: len(mini_batch["vector_obs"]), self.policy.sequence_length_ph: self.policy.sequence_length, } if self.model.use_vail: @@ -71,9 +71,9 @@ def evaluate_batch(self, mini_batch: AgentBuffer) -> RewardSignalResult: feed_dict[self.policy.visual_in[i]] = _obs if self.policy.use_continuous_act: - feed_dict[self.policy.selected_actions] = mini_batch["actions"] + feed_dict[self.policy.selected_actions] = mini_batch["continuous_action"] else: - feed_dict[self.policy.output] = mini_batch["actions"] + feed_dict[self.policy.output] = mini_batch["discrete_action"] feed_dict[self.model.done_policy_holder] = np.array( mini_batch["done"] ).flatten() @@ -106,11 +106,16 @@ def prepare_update( if self.model.use_vail: feed_dict[self.model.use_noise] = [1] - feed_dict[self.model.action_in_expert] = np.array(mini_batch_demo["actions"]) if self.policy.use_continuous_act: - feed_dict[policy.selected_actions] = mini_batch["actions"] + feed_dict[policy.selected_actions] = mini_batch["continuous_action"] + feed_dict[self.model.action_in_expert] = np.array( + mini_batch_demo["continuous_action"] + ) else: - feed_dict[policy.output] = mini_batch["actions"] + feed_dict[policy.output] = mini_batch["discrete_action"] + feed_dict[self.model.action_in_expert] = np.array( + mini_batch_demo["discrete_action"] + ) if self.policy.use_vis_obs > 0: for i in range(len(policy.visual_in)): diff --git a/ml-agents/mlagents/trainers/torch/components/bc/module.py b/ml-agents/mlagents/trainers/torch/components/bc/module.py index 7622dc7722..e8511c009d 100644 --- a/ml-agents/mlagents/trainers/torch/components/bc/module.py +++ b/ml-agents/mlagents/trainers/torch/components/bc/module.py @@ -5,7 +5,7 @@ from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.demo_loader import demo_to_buffer from mlagents.trainers.settings import BehavioralCloningSettings, ScheduleType -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch.utils import ModelUtils, AgentAction, ActionLogProbs class BCModule: @@ -98,12 +98,20 @@ def update(self) -> Dict[str, np.ndarray]: update_stats = {"Losses/Pretraining Loss": np.mean(batch_losses)} return update_stats - def _behavioral_cloning_loss(self, selected_actions, log_probs, expert_actions): + def _behavioral_cloning_loss( + self, + selected_actions: AgentAction, + log_probs: ActionLogProbs, + expert_actions: torch.Tensor, + ) -> torch.Tensor: if self.policy.use_continuous_act: - bc_loss = torch.nn.functional.mse_loss(selected_actions, expert_actions) + bc_loss = torch.nn.functional.mse_loss( + selected_actions.continuous_tensor, expert_actions + ) else: log_prob_branches = ModelUtils.break_into_branches( - log_probs, self.policy.act_size + log_probs.all_discrete_tensor, + self.policy.behavior_spec.action_spec.discrete_branches, ) bc_loss = torch.mean( torch.stack( @@ -130,10 +138,12 @@ def _update_batch( vec_obs = [ModelUtils.list_to_tensor(mini_batch_demo["vector_obs"])] act_masks = None if self.policy.use_continuous_act: - expert_actions = ModelUtils.list_to_tensor(mini_batch_demo["actions"]) + expert_actions = ModelUtils.list_to_tensor( + mini_batch_demo["continuous_action"] + ) else: raw_expert_actions = ModelUtils.list_to_tensor( - mini_batch_demo["actions"], dtype=torch.long + mini_batch_demo["discrete_action"], dtype=torch.long ) expert_actions = ModelUtils.actions_to_onehot( raw_expert_actions, self.policy.act_size @@ -164,16 +174,15 @@ def _update_batch( else: vis_obs = [] - selected_actions, all_log_probs, _, _ = self.policy.sample_actions( + selected_actions, log_probs, _, _ = self.policy.sample_actions( vec_obs, vis_obs, masks=act_masks, memories=memories, seq_len=self.policy.sequence_length, - all_log_probs=True, ) bc_loss = self._behavioral_cloning_loss( - selected_actions, all_log_probs, expert_actions + selected_actions, log_probs, expert_actions ) self.optimizer.zero_grad() bc_loss.backward() diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py b/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py index c6608c84f9..dd2c028605 100644 --- a/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py +++ b/ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py @@ -9,7 +9,7 @@ from mlagents.trainers.settings import CuriositySettings from mlagents_envs.base_env import BehaviorSpec -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch.utils import ModelUtils, AgentAction from mlagents.trainers.torch.networks import NetworkBody from mlagents.trainers.torch.layers import LinearEncoder, linear_layer from mlagents.trainers.settings import NetworkSettings, EncoderType @@ -148,13 +148,13 @@ def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor: Uses the current state embedding and the action of the mini_batch to predict the next state embedding. """ + actions = AgentAction.from_dict(mini_batch) if self._action_spec.is_continuous(): - action = ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float) + action = actions.continuous_tensor else: action = torch.cat( ModelUtils.actions_to_onehot( - ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long), - self._action_spec.discrete_branches, + actions.discrete_tensor, self._action_spec.discrete_branches ), dim=1, ) @@ -170,11 +170,9 @@ def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor: action prediction (given the current and next state). """ predicted_action = self.predict_action(mini_batch) + actions = AgentAction.from_dict(mini_batch) if self._action_spec.is_continuous(): - sq_difference = ( - ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float) - - predicted_action - ) ** 2 + sq_difference = (actions.continuous_tensor - predicted_action) ** 2 sq_difference = torch.sum(sq_difference, dim=1) return torch.mean( ModelUtils.dynamic_partition( @@ -186,8 +184,7 @@ def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor: else: true_action = torch.cat( ModelUtils.actions_to_onehot( - ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long), - self._action_spec.discrete_branches, + actions.discrete_tensor, self._action_spec.discrete_branches ), dim=1, ) diff --git a/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py b/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py index 496e3cdfc0..6d4daed5fd 100644 --- a/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py +++ b/ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py @@ -8,7 +8,7 @@ ) from mlagents.trainers.settings import GAILSettings from mlagents_envs.base_env import BehaviorSpec -from mlagents.trainers.torch.utils import ModelUtils +from mlagents.trainers.torch.utils import ModelUtils, AgentAction from mlagents.trainers.torch.networks import NetworkBody from mlagents.trainers.torch.layers import linear_layer, Initialization from mlagents.trainers.settings import NetworkSettings, EncoderType @@ -109,9 +109,7 @@ def get_action_input(self, mini_batch: AgentBuffer) -> torch.Tensor: Creates the action Tensor. In continuous case, corresponds to the action. In the discrete case, corresponds to the concatenation of one hot action Tensors. """ - return self._action_flattener.forward( - torch.as_tensor(mini_batch["actions"], dtype=torch.float) - ) + return self._action_flattener.forward(AgentAction.from_dict(mini_batch)) def get_state_inputs( self, mini_batch: AgentBuffer diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py index fad8636363..0108dfbf75 100644 --- a/ml-agents/mlagents/trainers/torch/utils.py +++ b/ml-agents/mlagents/trainers/torch/utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, NamedTuple, Dict from mlagents.torch_utils import torch, nn import numpy as np @@ -15,6 +15,201 @@ from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance +class AgentAction(NamedTuple): + """ + A NamedTuple containing the tensor for continuous actions and list of tensors for + discrete actions. Utility functions provide numpy <=> tensor conversions to be + sent as actions to the environment manager as well as used by the optimizers. + :param continuous_tensor: Torch tensor corresponding to continuous actions + :param discrete_list: List of Torch tensors each corresponding to discrete actions + """ + + continuous_tensor: torch.Tensor + discrete_list: List[torch.Tensor] + + @property + def discrete_tensor(self): + """ + Returns the discrete action list as a stacked tensor + """ + return torch.stack(self.discrete_list, dim=-1) + + def to_numpy_dict(self) -> Dict[str, np.ndarray]: + """ + Returns a Dict of np arrays with an entry correspinding to the continuous action + and an entry corresponding to the discrete action. "continuous_action" and + "discrete_action" are added to the agents buffer individually to maintain a flat buffer. + """ + array_dict: Dict[str, np.ndarray] = {} + if self.continuous_tensor is not None: + array_dict["continuous_action"] = ModelUtils.to_numpy( + self.continuous_tensor + ) + if self.discrete_list is not None: + array_dict["discrete_action"] = ModelUtils.to_numpy( + self.discrete_tensor[:, 0, :] + ) + return array_dict + + def to_tensor_list(self) -> List[torch.Tensor]: + """ + Returns the tensors in the AgentAction as a flat List of torch Tensors. This will be removed + when the ActionModel is merged. + """ + tensor_list: List[torch.Tensor] = [] + if self.continuous_tensor is not None: + tensor_list.append(self.continuous_tensor) + if self.discrete_list is not None: + tensor_list += ( + self.discrete_list + ) # Note this is different for ActionLogProbs + return tensor_list + + @staticmethod + def create( + tensor_list: List[torch.Tensor], action_spec: ActionSpec + ) -> "AgentAction": + """ + A static method that converts a list of torch Tensors into an AgentAction using the ActionSpec. + This will change (and may be removed) in the ActionModel. + """ + continuous: torch.Tensor = None + discrete: List[torch.Tensor] = None # type: ignore + _offset = 0 + if action_spec.continuous_size > 0: + continuous = tensor_list[0] + _offset = 1 + if action_spec.discrete_size > 0: + discrete = tensor_list[_offset:] + return AgentAction(continuous, discrete) + + @staticmethod + def from_dict(buff: Dict[str, np.ndarray]) -> "AgentAction": + """ + A static method that accesses continuous and discrete action fields in an AgentBuffer + and constructs the corresponding AgentAction from the retrieved np arrays. + """ + continuous: torch.Tensor = None + discrete: List[torch.Tensor] = None # type: ignore + if "continuous_action" in buff: + continuous = ModelUtils.list_to_tensor(buff["continuous_action"]) + if "discrete_action" in buff: + discrete_tensor = ModelUtils.list_to_tensor( + buff["discrete_action"], dtype=torch.long + ) + discrete = [ + discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1]) + ] + return AgentAction(continuous, discrete) + + +class ActionLogProbs(NamedTuple): + """ + A NamedTuple containing the tensor for continuous log probs and list of tensors for + discrete log probs of individual actions as well as all the log probs for an entire branch. + Utility functions provide numpy <=> tensor conversions to be used by the optimizers. + :param continuous_tensor: Torch tensor corresponding to log probs of continuous actions + :param discrete_list: List of Torch tensors each corresponding to log probs of the discrete actions that were + sampled. + :param all_discrete_list: List of Torch tensors each corresponding to all log probs of + a discrete action branch, even the discrete actions that were not sampled. all_discrete_list is a list of Tensors, + each Tensor corresponds to one discrete branch log probabilities. + """ + + continuous_tensor: torch.Tensor + discrete_list: List[torch.Tensor] + all_discrete_list: Optional[List[torch.Tensor]] + + @property + def discrete_tensor(self): + """ + Returns the discrete log probs list as a stacked tensor + """ + return torch.stack(self.discrete_list, dim=-1) + + @property + def all_discrete_tensor(self): + """ + Returns the discrete log probs of each branch as a tensor + """ + return torch.cat(self.all_discrete_list, dim=1) + + def to_numpy_dict(self) -> Dict[str, np.ndarray]: + """ + Returns a Dict of np arrays with an entry correspinding to the continuous log probs + and an entry corresponding to the discrete log probs. "continuous_log_probs" and + "discrete_log_probs" are added to the agents buffer individually to maintain a flat buffer. + """ + array_dict: Dict[str, np.ndarray] = {} + if self.continuous_tensor is not None: + array_dict["continuous_log_probs"] = ModelUtils.to_numpy( + self.continuous_tensor + ) + if self.discrete_list is not None: + + array_dict["discrete_log_probs"] = ModelUtils.to_numpy(self.discrete_tensor) + return array_dict + + def _to_tensor_list(self) -> List[torch.Tensor]: + """ + Returns the tensors in the ActionLogProbs as a flat List of torch Tensors. This + is private and serves as a utility for self.flatten() + """ + tensor_list: List[torch.Tensor] = [] + if self.continuous_tensor is not None: + tensor_list.append(self.continuous_tensor) + if self.discrete_list is not None: + tensor_list.append( + self.discrete_tensor + ) # Note this is different for AgentActions + return tensor_list + + def flatten(self) -> torch.Tensor: + """ + A utility method that returns all log probs in ActionLogProbs as a flattened tensor. + This is useful for algorithms like PPO which can treat all log probs in the same way. + """ + return torch.cat(self._to_tensor_list(), dim=1) + + @staticmethod + def create( + log_prob_list: List[torch.Tensor], + action_spec: ActionSpec, + all_log_prob_list: List[torch.Tensor] = None, + ) -> "ActionLogProbs": + """ + A static method that converts a list of torch Tensors into an ActionLogProbs using the ActionSpec. + This will change (and may be removed) in the ActionModel. + """ + continuous: torch.Tensor = None + discrete: List[torch.Tensor] = None # type: ignore + _offset = 0 + if action_spec.continuous_size > 0: + continuous = log_prob_list[0] + _offset = 1 + if action_spec.discrete_size > 0: + discrete = log_prob_list[_offset:] + return ActionLogProbs(continuous, discrete, all_log_prob_list) + + @staticmethod + def from_dict(buff: Dict[str, np.ndarray]) -> "ActionLogProbs": + """ + A static method that accesses continuous and discrete log probs fields in an AgentBuffer + and constructs the corresponding ActionLogProbs from the retrieved np arrays. + """ + continuous: torch.Tensor = None + discrete: List[torch.Tensor] = None # type: ignore + + if "continuous_log_probs" in buff: + continuous = ModelUtils.list_to_tensor(buff["continuous_log_probs"]) + if "discrete_log_probs" in buff: + discrete_tensor = ModelUtils.list_to_tensor(buff["discrete_log_probs"]) + discrete = [ + discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1]) + ] + return ActionLogProbs(continuous, discrete, None) + + class ModelUtils: # Minimum supported side for each encoder type. If refactoring an encoder, please # adjust these also. @@ -36,13 +231,13 @@ def flattened_size(self) -> int: else: return sum(self._specs.discrete_branches) - def forward(self, action: torch.Tensor) -> torch.Tensor: + def forward(self, action: AgentAction) -> torch.Tensor: if self._specs.is_continuous(): - return action + return action.continuous_tensor else: return torch.cat( ModelUtils.actions_to_onehot( - torch.as_tensor(action, dtype=torch.long), + torch.as_tensor(action.discrete_tensor, dtype=torch.long), self._specs.discrete_branches, ), dim=1, @@ -270,7 +465,7 @@ def dynamic_partition( @staticmethod def get_probs_and_entropy( action_list: List[torch.Tensor], dists: List[DistInstance] - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[List[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: log_probs_list = [] all_probs_list = [] entropies_list = [] @@ -280,15 +475,10 @@ def get_probs_and_entropy( entropies_list.append(action_dist.entropy()) if isinstance(action_dist, DiscreteDistInstance): all_probs_list.append(action_dist.all_log_prob()) - log_probs = torch.stack(log_probs_list, dim=-1) entropies = torch.stack(entropies_list, dim=-1) if not all_probs_list: - log_probs = log_probs.squeeze(-1) entropies = entropies.squeeze(-1) - all_probs = None - else: - all_probs = torch.cat(all_probs_list, dim=-1) - return log_probs, entropies, all_probs + return log_probs_list, entropies, all_probs_list @staticmethod def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: diff --git a/ml-agents/mlagents/trainers/trajectory.py b/ml-agents/mlagents/trainers/trajectory.py index 1eb1e55993..dd1130fe6b 100644 --- a/ml-agents/mlagents/trainers/trajectory.py +++ b/ml-agents/mlagents/trainers/trajectory.py @@ -1,4 +1,4 @@ -from typing import List, NamedTuple +from typing import List, NamedTuple, Dict import numpy as np from mlagents.trainers.buffer import AgentBuffer @@ -8,8 +8,8 @@ class AgentExperience(NamedTuple): obs: List[np.ndarray] reward: float done: bool - action: np.ndarray - action_probs: np.ndarray + action: Dict[str, np.ndarray] + action_probs: Dict[str, np.ndarray] action_pre: np.ndarray # TODO: Remove this action_mask: np.ndarray prev_action: np.ndarray @@ -107,12 +107,13 @@ def to_agentbuffer(self) -> AgentBuffer: agent_buffer_trajectory["done"].append(exp.done) # Add the outputs of the last eval if exp.action_pre is not None: - actions_pre = exp.action_pre - agent_buffer_trajectory["actions_pre"].append(actions_pre) + agent_buffer_trajectory["actions_pre"].append(exp.action_pre) - # value is a dictionary from name of reward to value estimate of the value head - agent_buffer_trajectory["actions"].append(exp.action) - agent_buffer_trajectory["action_probs"].append(exp.action_probs) + # Adds the log prob and action of continuous/discrete separately + for act_type, act_array in exp.action.items(): + agent_buffer_trajectory[act_type].append(act_array) + for log_type, log_array in exp.action_probs.items(): + agent_buffer_trajectory[log_type].append(log_array) # Store action masks if necessary. Note that 1 means active, while # in AgentExperience False means active. @@ -122,10 +123,14 @@ def to_agentbuffer(self) -> AgentBuffer: else: # This should never be needed unless the environment somehow doesn't supply the # action mask in a discrete space. + + if "discrete_action" in exp.action: + action_shape = exp.action["discrete_action"].shape + else: + action_shape = exp.action["continuous_action"].shape agent_buffer_trajectory["action_mask"].append( - np.ones(exp.action_probs.shape, dtype=np.float32), padding_value=1 + np.ones(action_shape, dtype=np.float32), padding_value=1 ) - agent_buffer_trajectory["prev_action"].append(exp.prev_action) agent_buffer_trajectory["environment_rewards"].append(exp.reward)