-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
Copy pathBrainParameters.cs
204 lines (187 loc) · 7.43 KB
/
BrainParameters.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
using System;
using UnityEngine;
using UnityEngine.Serialization;
using Unity.MLAgents.Actuators;
namespace Unity.MLAgents.Policies
{
/// <summary>
/// This is deprecated. Agents can now use both continuous and discrete actions together.
/// </summary>
[Obsolete("Continuous and discrete actions on the same Agent are now supported; see ActionSpec.")]
public enum SpaceType
{
/// <summary>
/// Discrete action space: a fixed number of options are available.
/// </summary>
Discrete,
/// <summary>
/// Continuous action space: each action can take on a float value.
/// </summary>
Continuous
}
/// <summary>
/// Holds information about the brain. It defines what are the inputs and outputs of the
/// decision process.
/// </summary>
/// <remarks>
/// Set brain parameters for an <see cref="Agent"/> instance using the
/// <seealso cref="BehaviorParameters"/> component attached to the agent's [GameObject].
///
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
/// </remarks>
[Serializable]
public class BrainParameters : ISerializationCallbackReceiver
{
/// <summary>
/// The number of the observations that are added in
/// <see cref="Agent.CollectObservations(Sensors.VectorSensor)"/>
/// </summary>
/// <value>
/// The length of the vector containing observation values.
/// </value>
[FormerlySerializedAs("vectorObservationSize")]
public int VectorObservationSize = 1;
/// <summary>
/// Stacking refers to concatenating the observations across multiple frames. This field
/// indicates the number of frames to concatenate across.
/// </summary>
[FormerlySerializedAs("numStackedVectorObservations")]
[Range(1, 50)] public int NumStackedVectorObservations = 1;
[SerializeField]
internal ActionSpec m_ActionSpec = new ActionSpec(0, null);
/// <summary>
/// The specification of the Actions for the BrainParameters.
/// </summary>
public ActionSpec ActionSpec
{
get { return m_ActionSpec; }
set
{
m_ActionSpec.NumContinuousActions = value.NumContinuousActions;
m_ActionSpec.BranchSizes = value.BranchSizes;
SyncDeprecatedActionFields();
}
}
/// <summary>
/// (Deprecated) The number of possible actions.
/// </summary>
/// <remarks>The size specified is interpreted differently depending on whether
/// the agent uses the continuous or the discrete actions.</remarks>
/// <value>
/// For the continuous actions: the length of the float vector that represents
/// the action.
/// For the discrete actions: the number of branches.
/// </value>
[Obsolete("VectorActionSize has been deprecated, please use ActionSpec instead.")]
[FormerlySerializedAs("vectorActionSize")]
public int[] VectorActionSize = new[] { 1 };
/// <summary>
/// The list of strings describing what the actions correspond to.
/// </summary>
[FormerlySerializedAs("vectorActionDescriptions")]
public string[] VectorActionDescriptions;
/// <summary>
/// (Deprecated) Defines if the action is discrete or continuous.
/// </summary>
[Obsolete("VectorActionSpaceType has been deprecated, please use ActionSpec instead.")]
[FormerlySerializedAs("vectorActionSpaceType")]
public SpaceType VectorActionSpaceType = SpaceType.Discrete;
[SerializeField]
[HideInInspector]
internal bool hasUpgradedBrainParametersWithActionSpec;
/// <summary>
/// (Deprecated) The number of actions specified by this Brain.
/// </summary>
[Obsolete("NumActions has been deprecated, please use ActionSpec instead.")]
public int NumActions
{
get
{
return ActionSpec.NumContinuousActions > 0 ? ActionSpec.NumContinuousActions : ActionSpec.NumDiscreteActions;
}
}
/// <summary>
/// Deep clones the BrainParameter object.
/// </summary>
/// <returns> A new BrainParameter object with the same values as the original.</returns>
public BrainParameters Clone()
{
// Disable deprecation warnings so we can read/write the old fields.
#pragma warning disable CS0618
return new BrainParameters
{
VectorObservationSize = VectorObservationSize,
NumStackedVectorObservations = NumStackedVectorObservations,
VectorActionDescriptions = (string[])VectorActionDescriptions.Clone(),
ActionSpec = new ActionSpec(ActionSpec.NumContinuousActions, ActionSpec.BranchSizes),
VectorActionSize = (int[])VectorActionSize.Clone(),
VectorActionSpaceType = VectorActionSpaceType,
};
#pragma warning restore CS0618
}
/// <summary>
/// Propogate ActionSpec fields from deprecated fields
/// </summary>
private void UpdateToActionSpec()
{
// Disable deprecation warnings so we can read the old fields.
#pragma warning disable CS0618
if (!hasUpgradedBrainParametersWithActionSpec
&& m_ActionSpec.NumContinuousActions == 0
&& m_ActionSpec.BranchSizes == null)
{
if (VectorActionSpaceType == SpaceType.Continuous)
{
m_ActionSpec.NumContinuousActions = VectorActionSize[0];
m_ActionSpec.BranchSizes = null;
}
if (VectorActionSpaceType == SpaceType.Discrete)
{
m_ActionSpec.NumContinuousActions = 0;
m_ActionSpec.BranchSizes = (int[])VectorActionSize.Clone();
}
}
hasUpgradedBrainParametersWithActionSpec = true;
#pragma warning restore CS0618
}
/// <summary>
/// Sync values in ActionSpec fields to deprecated fields
/// </summary>
private void SyncDeprecatedActionFields()
{
// Disable deprecation warnings so we can read the old fields.
#pragma warning disable CS0618
if (m_ActionSpec.NumContinuousActions == 0)
{
VectorActionSize = (int[])ActionSpec.BranchSizes.Clone();
VectorActionSpaceType = SpaceType.Discrete;
}
else if (m_ActionSpec.NumDiscreteActions == 0)
{
VectorActionSize = new[] { m_ActionSpec.NumContinuousActions };
VectorActionSpaceType = SpaceType.Continuous;
}
else
{
VectorActionSize = null;
}
#pragma warning restore CS0618
}
/// <summary>
/// Called by Unity immediately before serializing this object.
/// </summary>
public void OnBeforeSerialize()
{
UpdateToActionSpec();
SyncDeprecatedActionFields();
}
/// <summary>
/// Called by Unity immediately after deserializing this object.
/// </summary>
public void OnAfterDeserialize()
{
UpdateToActionSpec();
SyncDeprecatedActionFields();
}
}
}