Skip to content

Commit

Permalink
[change] Remove concatenate in discrete action probabilities to impro…
Browse files Browse the repository at this point in the history
…ve inference performance (#3598)
  • Loading branch information
Ervin T authored Mar 11, 2020
1 parent e91a8cc commit 6dbba73
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 38 deletions.
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- `DecisionRequester` has been made internal (you can still use the DecisionRequesterComponent from the inspector). `RepeatAction` was renamed `TakeActionsBetweenDecisions` for clarity. (#3555)
- The `IFloatProperties` interface has been removed.
- Fix #3579.
- Improved inference performance for models with multiple action branches. (#3598)
- Fixed an issue when using GAIL with less than `batch_size` number of demonstrations. (#3591)
- The interfaces to the `SideChannel` classes (on C# and python) have changed to use new `IncomingMessage` and `OutgoingMessage` classes. These should make reading and writing data to the channel easier. (#3596)

Expand Down
5 changes: 2 additions & 3 deletions ml-agents/mlagents/trainers/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,11 @@ def _create_policy_branches(
kernel_initializer=ModelUtils.scaled_init(0.01),
)
)
unmasked_log_probs = tf.concat(policy_branches, axis=1)
return unmasked_log_probs
return policy_branches

def _get_masked_actions_probs(
self,
unmasked_log_probs: tf.Tensor,
unmasked_log_probs: List[tf.Tensor],
act_size: List[int],
action_masks: tf.Tensor,
) -> Tuple[tf.Tensor, tf.Tensor, np.ndarray]:
Expand Down
36 changes: 25 additions & 11 deletions ml-agents/mlagents/trainers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,25 +456,39 @@ def get_encoder_for_type(encoder_type: EncoderType) -> EncoderFunction:
)

@staticmethod
def create_discrete_action_masking_layer(all_logits, action_masks, action_size):
def break_into_branches(
concatenated_logits: tf.Tensor, action_size: List[int]
) -> List[tf.Tensor]:
"""
Takes a concatenated set of logits that represent multiple discrete action branches
and breaks it up into one Tensor per branch.
:param concatenated_logits: Tensor that represents the concatenated action branches
:param action_size: List of ints containing the number of possible actions for each branch.
:return: A List of Tensors containing one tensor per branch.
"""
action_idx = [0] + list(np.cumsum(action_size))
branched_logits = [
concatenated_logits[:, action_idx[i] : action_idx[i + 1]]
for i in range(len(action_size))
]
return branched_logits

@staticmethod
def create_discrete_action_masking_layer(
branches_logits: List[tf.Tensor],
action_masks: tf.Tensor,
action_size: List[int],
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""
Creates a masking layer for the discrete actions
:param all_logits: The concatenated unnormalized action probabilities for all branches
:param branches_logits: A List of the unnormalized action probabilities for each branch
:param action_masks: The mask for the logits. Must be of dimension [None x total_number_of_action]
:param action_size: A list containing the number of possible actions for each branch
:return: The action output dimension [batch_size, num_branches], the concatenated
normalized probs (after softmax)
and the concatenated normalized log probs
"""
action_idx = [0] + list(np.cumsum(action_size))
branches_logits = [
all_logits[:, action_idx[i] : action_idx[i + 1]]
for i in range(len(action_size))
]
branch_masks = [
action_masks[:, action_idx[i] : action_idx[i + 1]]
for i in range(len(action_size))
]
branch_masks = ModelUtils.break_into_branches(action_masks, action_size)
raw_probs = [
tf.multiply(tf.nn.softmax(branches_logits[k]) + EPSILON, branch_masks[k])
for k in range(len(action_size))
Expand Down
8 changes: 7 additions & 1 deletion ml-agents/mlagents/trainers/ppo/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,14 @@ def _create_dc_critic(
dtype=tf.float32,
name="old_probabilities",
)

# Break old log probs into separate branches
old_log_prob_branches = ModelUtils.break_into_branches(
self.all_old_log_probs, self.policy.act_size
)

_, _, old_normalized_logits = ModelUtils.create_discrete_action_masking_layer(
self.all_old_log_probs, self.policy.action_masks, self.policy.act_size
old_log_prob_branches, self.policy.action_masks, self.policy.act_size
)

action_idx = [0] + list(np.cumsum(self.policy.act_size))
Expand Down
38 changes: 15 additions & 23 deletions ml-agents/mlagents/trainers/sac/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,9 @@ def _create_losses(

for name in stream_names:
if discrete:
_branched_mpq1 = self._apply_as_branches(
self.policy_network.q1_pheads[name] * discrete_action_probs
_branched_mpq1 = ModelUtils.break_into_branches(
self.policy_network.q1_pheads[name] * discrete_action_probs,
self.act_size,
)
branched_mpq1 = tf.stack(
[
Expand All @@ -243,8 +244,9 @@ def _create_losses(
)
_q1_p_mean = tf.reduce_mean(branched_mpq1, axis=0)

_branched_mpq2 = self._apply_as_branches(
self.policy_network.q2_pheads[name] * discrete_action_probs
_branched_mpq2 = ModelUtils.break_into_branches(
self.policy_network.q2_pheads[name] * discrete_action_probs,
self.act_size,
)
branched_mpq2 = tf.stack(
[
Expand Down Expand Up @@ -282,11 +284,11 @@ def _create_losses(

if discrete:
# We need to break up the Q functions by branch, and update them individually.
branched_q1_stream = self._apply_as_branches(
self.policy.selected_actions * q1_streams[name]
branched_q1_stream = ModelUtils.break_into_branches(
self.policy.selected_actions * q1_streams[name], self.act_size
)
branched_q2_stream = self._apply_as_branches(
self.policy.selected_actions * q2_streams[name]
branched_q2_stream = ModelUtils.break_into_branches(
self.policy.selected_actions * q2_streams[name], self.act_size
)

# Reduce each branch into scalar
Expand Down Expand Up @@ -344,7 +346,9 @@ def _create_losses(
self.ent_coef = tf.exp(self.log_ent_coef)
if discrete:
# We also have to do a different entropy and target_entropy per branch.
branched_per_action_ent = self._apply_as_branches(per_action_entropy)
branched_per_action_ent = ModelUtils.break_into_branches(
per_action_entropy, self.act_size
)
branched_ent_sums = tf.stack(
[
tf.reduce_sum(_lp, axis=1, keep_dims=True) + _te
Expand All @@ -364,8 +368,8 @@ def _create_losses(
# Same with policy loss, we have to do the loss per branch and average them,
# so that larger branches don't get more weight.
# The equivalent KL divergence from Eq 10 of Haarnoja et al. is also pi*log(pi) - Q
branched_q_term = self._apply_as_branches(
discrete_action_probs * self.policy_network.q1_p
branched_q_term = ModelUtils.break_into_branches(
discrete_action_probs * self.policy_network.q1_p, self.act_size
)

branched_policy_loss = tf.stack(
Expand Down Expand Up @@ -444,18 +448,6 @@ def _create_losses(

self.entropy = self.policy_network.entropy

def _apply_as_branches(self, concat_logits: tf.Tensor) -> List[tf.Tensor]:
"""
Takes in a concatenated set of logits and breaks it up into a list of non-concatenated logits, one per
action branch
"""
action_idx = [0] + list(np.cumsum(self.act_size))
branches_logits = [
concat_logits[:, action_idx[i] : action_idx[i + 1]]
for i in range(len(self.act_size))
]
return branches_logits

def _create_sac_optimizer_ops(self) -> None:
"""
Creates the Adam optimizers and update ops for SAC, including
Expand Down

0 comments on commit 6dbba73

Please sign in to comment.