Skip to content

Commit

Permalink
Force different nodes for new and deprecated action output (#4705)
Browse files Browse the repository at this point in the history
* fix export node missing bug by forcing different nodes for new and deprecated action output

* fix dynamic axis
  • Loading branch information
dongruoping authored Dec 4, 2020
1 parent 0d3e10d commit fa3e093
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
8 changes: 4 additions & 4 deletions ml-agents/mlagents/trainers/torch/action_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten
continuous_out, discrete_out, action_out_deprecated = None, None, None
if self.action_spec.continuous_size > 0 and dists.continuous is not None:
continuous_out = dists.continuous.exported_model_output()
action_out_deprecated = continuous_out
action_out_deprecated = dists.continuous.exported_model_output()
if self.action_spec.discrete_size > 0 and dists.discrete is not None:
discrete_out = [
discrete_out_list = [
discrete_dist.exported_model_output()
for discrete_dist in dists.discrete
]
discrete_out = torch.cat(discrete_out, dim=1)
action_out_deprecated = discrete_out
discrete_out = torch.cat(discrete_out_list, dim=1)
action_out_deprecated = torch.cat(discrete_out_list, dim=1)
# deprecated action field does not support hybrid action
if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0:
action_out_deprecated = None
Expand Down
7 changes: 4 additions & 3 deletions ml-agents/mlagents/trainers/torch/model_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,18 @@ def __init__(self, policy):
+ [f"visual_observation_{i}" for i in range(self.policy.vis_obs_size)]
+ ["action_masks", "memories"]
)
self.dynamic_axes = {name: {0: "batch"} for name in self.input_names}

self.output_names = ["version_number", "memory_size"]
if self.policy.behavior_spec.action_spec.continuous_size > 0:
self.output_names += [
"continuous_actions",
"continuous_action_output_shape",
]
self.dynamic_axes.update({"continuous_actions": {0: "batch"}})
if self.policy.behavior_spec.action_spec.discrete_size > 0:
self.output_names += ["discrete_actions", "discrete_action_output_shape"]
self.dynamic_axes.update({"discrete_actions": {0: "batch"}})
if (
self.policy.behavior_spec.action_spec.continuous_size == 0
or self.policy.behavior_spec.action_spec.discrete_size == 0
Expand All @@ -89,9 +92,7 @@ def __init__(self, policy):
"is_continuous_control",
"action_output_shape",
]

self.dynamic_axes = {name: {0: "batch"} for name in self.input_names}
self.dynamic_axes.update({"action": {0: "batch"}})
self.dynamic_axes.update({"action": {0: "batch"}})

def export_policy_model(self, output_filepath: str) -> None:
"""
Expand Down

0 comments on commit fa3e093

Please sign in to comment.