Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change BrainParametersProto to support ActionSpec #4579

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ public static BrainParametersProto ToProto(this BrainParameters bp, string name,
{
var brainParametersProto = new BrainParametersProto
{
VectorActionSize = { bp.VectorActionSize },
VectorActionSpaceType = (SpaceTypeProto)bp.VectorActionSpaceType,
VectorActionSizeDeprecated = { bp.VectorActionSize },
VectorActionSpaceTypeDeprecated = (SpaceTypeProto)bp.VectorActionSpaceType,
BrainName = name,
IsTraining = isTraining
};
if (bp.VectorActionDescriptions != null)
{
brainParametersProto.VectorActionDescriptions.AddRange(bp.VectorActionDescriptions);
brainParametersProto.VectorActionDescriptionsDeprecated.AddRange(bp.VectorActionDescriptions);
}
return brainParametersProto;
}
Expand All @@ -126,13 +126,13 @@ public static BrainParametersProto ToBrainParametersProto(this ActionSpec action
};
if (actionSpec.NumContinuousActions > 0)
{
brainParametersProto.VectorActionSize.Add(actionSpec.NumContinuousActions);
brainParametersProto.VectorActionSpaceType = SpaceTypeProto.Continuous;
brainParametersProto.VectorActionSizeDeprecated.Add(actionSpec.NumContinuousActions);
brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Continuous;
}
else if (actionSpec.NumDiscreteActions > 0)
{
brainParametersProto.VectorActionSize.AddRange(actionSpec.BranchSizes);
brainParametersProto.VectorActionSpaceType = SpaceTypeProto.Discrete;
brainParametersProto.VectorActionSizeDeprecated.AddRange(actionSpec.BranchSizes);
brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Discrete;
}

// TODO handle ActionDescriptions?
Expand All @@ -148,9 +148,9 @@ public static BrainParameters ToBrainParameters(this BrainParametersProto bpp)
{
var bp = new BrainParameters
{
VectorActionSize = bpp.VectorActionSize.ToArray(),
VectorActionDescriptions = bpp.VectorActionDescriptions.ToArray(),
VectorActionSpaceType = (SpaceType)bpp.VectorActionSpaceType
VectorActionSize = bpp.VectorActionSizeDeprecated.ToArray(),
VectorActionDescriptions = bpp.VectorActionDescriptionsDeprecated.ToArray(),
VectorActionSpaceType = (SpaceType)bpp.VectorActionSpaceTypeDeprecated
};
return bp;
}
Expand Down
348 changes: 298 additions & 50 deletions com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/BrainParameters.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public bool CompressedChannelMapping {
public const int HybridActionsFieldNumber = 4;
private bool hybridActions_;
/// <summary>
/// support for mixed (discrete + continuous) actions
/// support for hybrid action spaces (discrete + continuous)
/// </summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool HybridActions {
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,56 @@ builtin___float = float
builtin___int = int


class ActionSpecProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
num_continuous_actions = ... # type: builtin___int
num_discrete_actions = ... # type: builtin___int
discrete_branch_sizes = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
action_descriptions = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text]

def __init__(self,
*,
num_continuous_actions : typing___Optional[builtin___int] = None,
num_discrete_actions : typing___Optional[builtin___int] = None,
discrete_branch_sizes : typing___Optional[typing___Iterable[builtin___int]] = None,
action_descriptions : typing___Optional[typing___Iterable[typing___Text]] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> ActionSpecProto: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def ClearField(self, field_name: typing_extensions___Literal[u"action_descriptions",u"discrete_branch_sizes",u"num_continuous_actions",u"num_discrete_actions"]) -> None: ...
else:
def ClearField(self, field_name: typing_extensions___Literal[u"action_descriptions",b"action_descriptions",u"discrete_branch_sizes",b"discrete_branch_sizes",u"num_continuous_actions",b"num_continuous_actions",u"num_discrete_actions",b"num_discrete_actions"]) -> None: ...

class BrainParametersProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
vector_action_size = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
vector_action_descriptions = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text]
vector_action_space_type = ... # type: mlagents_envs___communicator_objects___space_type_pb2___SpaceTypeProto
vector_action_size_deprecated = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[builtin___int]
vector_action_descriptions_deprecated = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text]
vector_action_space_type_deprecated = ... # type: mlagents_envs___communicator_objects___space_type_pb2___SpaceTypeProto
brain_name = ... # type: typing___Text
is_training = ... # type: builtin___bool

@property
def action_spec(self) -> ActionSpecProto: ...

def __init__(self,
*,
vector_action_size : typing___Optional[typing___Iterable[builtin___int]] = None,
vector_action_descriptions : typing___Optional[typing___Iterable[typing___Text]] = None,
vector_action_space_type : typing___Optional[mlagents_envs___communicator_objects___space_type_pb2___SpaceTypeProto] = None,
vector_action_size_deprecated : typing___Optional[typing___Iterable[builtin___int]] = None,
vector_action_descriptions_deprecated : typing___Optional[typing___Iterable[typing___Text]] = None,
vector_action_space_type_deprecated : typing___Optional[mlagents_envs___communicator_objects___space_type_pb2___SpaceTypeProto] = None,
brain_name : typing___Optional[typing___Text] = None,
is_training : typing___Optional[builtin___bool] = None,
action_spec : typing___Optional[ActionSpecProto] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> BrainParametersProto: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def ClearField(self, field_name: typing_extensions___Literal[u"brain_name",u"is_training",u"vector_action_descriptions",u"vector_action_size",u"vector_action_space_type"]) -> None: ...
def HasField(self, field_name: typing_extensions___Literal[u"action_spec"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_spec",u"brain_name",u"is_training",u"vector_action_descriptions_deprecated",u"vector_action_size_deprecated",u"vector_action_space_type_deprecated"]) -> None: ...
else:
def ClearField(self, field_name: typing_extensions___Literal[u"brain_name",b"brain_name",u"is_training",b"is_training",u"vector_action_descriptions",b"vector_action_descriptions",u"vector_action_size",b"vector_action_size",u"vector_action_space_type",b"vector_action_space_type"]) -> None: ...
def HasField(self, field_name: typing_extensions___Literal[u"action_spec",b"action_spec"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"action_spec",b"action_spec",u"brain_name",b"brain_name",u"is_training",b"is_training",u"vector_action_descriptions_deprecated",b"vector_action_descriptions_deprecated",u"vector_action_size_deprecated",b"vector_action_size_deprecated",u"vector_action_space_type_deprecated",b"vector_action_space_type_deprecated"]) -> None: ...
8 changes: 5 additions & 3 deletions ml-agents-envs/mlagents_envs/mock_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ def __init__(

def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
bp = BrainParametersProto(
vector_action_size=[2],
vector_action_descriptions=["", ""],
vector_action_space_type=discrete if self.is_discrete else continuous,
vector_action_size_deprecated=[2],
vector_action_descriptions_deprecated=["", ""],
vector_action_space_type_deprecated=discrete
if self.is_discrete
else continuous,
brain_name=self.brain_name,
is_training=True,
)
Expand Down
6 changes: 3 additions & 3 deletions ml-agents-envs/mlagents_envs/rpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def behavior_spec_from_proto(
observation_shape = [tuple(obs.shape) for obs in agent_info.observations]
action_type = (
ActionType.DISCRETE
if brain_param_proto.vector_action_space_type == 0
if brain_param_proto.vector_action_space_type_deprecated == 0
else ActionType.CONTINUOUS
)
if action_type == ActionType.CONTINUOUS:
action_shape: Union[
int, Tuple[int, ...]
] = brain_param_proto.vector_action_size[0]
] = brain_param_proto.vector_action_size_deprecated[0]
else:
action_shape = tuple(brain_param_proto.vector_action_size)
action_shape = tuple(brain_param_proto.vector_action_size_deprecated)
return BehaviorSpec(observation_shape, action_type, action_shape)


Expand Down
8 changes: 4 additions & 4 deletions ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,17 +408,17 @@ def test_action_masking_continuous():
def test_agent_behavior_spec_from_proto():
agent_proto = generate_list_agent_proto(1, [(3,), (4,)])[0]
bp = BrainParametersProto()
bp.vector_action_size.extend([5, 4])
bp.vector_action_space_type = 0
bp.vector_action_size_deprecated.extend([5, 4])
bp.vector_action_space_type_deprecated = 0
behavior_spec = behavior_spec_from_proto(bp, agent_proto)
assert behavior_spec.is_action_discrete()
assert not behavior_spec.is_action_continuous()
assert behavior_spec.observation_shapes == [(3,), (4,)]
assert behavior_spec.discrete_action_branches == (5, 4)
assert behavior_spec.action_size == 2
bp = BrainParametersProto()
bp.vector_action_size.extend([6])
bp.vector_action_space_type = 1
bp.vector_action_size_deprecated.extend([6])
bp.vector_action_space_type_deprecated = 1
behavior_spec = behavior_spec_from_proto(bp, agent_proto)
assert not behavior_spec.is_action_discrete()
assert behavior_spec.is_action_continuous()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,11 @@ def record_demo(use_discrete, num_visual=0, num_vector=1):
agent_info_protos = env.demonstration_protos[BRAIN_NAME]
meta_data_proto = DemonstrationMetaProto()
brain_param_proto = BrainParametersProto(
vector_action_size=[2] if use_discrete else [1],
vector_action_descriptions=[""],
vector_action_space_type=discrete if use_discrete else continuous,
vector_action_size_deprecated=[2] if use_discrete else [1],
vector_action_descriptions_deprecated=[""],
vector_action_space_type_deprecated=discrete
if use_discrete
else continuous,
brain_name=BRAIN_NAME,
is_training=True,
)
Expand Down
8 changes: 5 additions & 3 deletions ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,11 @@ def record_demo(use_discrete, num_visual=0, num_vector=1):
agent_info_protos = env.demonstration_protos[BRAIN_NAME]
meta_data_proto = DemonstrationMetaProto()
brain_param_proto = BrainParametersProto(
vector_action_size=[2] if use_discrete else [1],
vector_action_descriptions=[""],
vector_action_space_type=discrete if use_discrete else continuous,
vector_action_size_deprecated=[2] if use_discrete else [1],
vector_action_descriptions_deprecated=[""],
vector_action_space_type_deprecated=discrete
if use_discrete
else continuous,
brain_name=BRAIN_NAME,
is_training=True,
)
Expand Down
Loading