Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Package changes to support deterministic inference #5599

Merged
9 changes: 6 additions & 3 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ and this project adheres to
2. env_params.restarts_rate_limit_n (--restarts-rate-limit-n) [default=1]
3. env_params.restarts_rate_limit_period_s (--restarts-rate-limit-period-s) [default=60]

- Added a new `--deterministic` cli flag to deterministically select the most probable actions in policy. The same thing can
be achieved by adding `deterministic: true` under `network_settings` of the run options configuration.
- Extra tensors are now serialized to support deterministic action selection in onnx. (#5597)

- Deterministic action selection is now supported during training and inference
- Added a new `--deterministic` cli flag to deterministically select the most probable actions in policy. The same thing can
be achieved by adding `deterministic: true` under `network_settings` of the run options configuration.(#5619)
- Extra tensors are now serialized to support deterministic action selection in onnx. (#5593)
- Support inference with deterministic action selection in editor (#5599)
### Bug Fixes
- Fixed the bug where curriculum learning would crash because of the incorrect run_options parsing. (#5586)

Expand Down
4 changes: 3 additions & 1 deletion com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ internal class BehaviorParametersEditor : UnityEditor.Editor
const string k_BrainParametersName = "m_BrainParameters";
const string k_ModelName = "m_Model";
const string k_InferenceDeviceName = "m_InferenceDevice";
const string k_DeterministicInference = "m_DeterministicInference";
const string k_BehaviorTypeName = "m_BehaviorType";
const string k_TeamIdName = "TeamId";
const string k_UseChildSensorsName = "m_UseChildSensors";
Expand Down Expand Up @@ -68,6 +69,7 @@ public override void OnInspectorGUI()
EditorGUILayout.PropertyField(so.FindProperty(k_ModelName), true);
EditorGUI.indentLevel++;
EditorGUILayout.PropertyField(so.FindProperty(k_InferenceDeviceName), true);
EditorGUILayout.PropertyField(so.FindProperty(k_DeterministicInference), true);
EditorGUI.indentLevel--;
}
needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck();
Expand Down Expand Up @@ -156,7 +158,7 @@ void DisplayFailedModelChecks()
{
var failedChecks = Inference.BarracudaModelParamLoader.CheckModel(
barracudaModel, brainParameters, sensors, actuatorComponents,
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType, behaviorParameters.DeterministicInference
);
foreach (var check in failedChecks)
{
Expand Down
6 changes: 4 additions & 2 deletions com.unity.ml-agents/Runtime/Academy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -616,14 +616,16 @@ void EnvironmentReset()
/// <param name="inferenceDevice">
/// The inference device (CPU or GPU) the ModelRunner will use.
/// </param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// Deterministic. </param>
/// <returns> The ModelRunner compatible with the input settings.</returns>
internal ModelRunner GetOrCreateModelRunner(
NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice)
NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice, bool deterministicInference = false)
{
var modelRunner = m_ModelRunners.Find(x => x.HasModel(model, inferenceDevice));
if (modelRunner == null)
{
modelRunner = new ModelRunner(model, actionSpec, inferenceDevice, m_InferenceSeed);
modelRunner = new ModelRunner(model, actionSpec, inferenceDevice, m_InferenceSeed, deterministicInference);
m_ModelRunners.Add(modelRunner);
m_InferenceSeed++;
}
Expand Down
107 changes: 80 additions & 27 deletions com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ public static int GetNumVisualInputs(this Model model)
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>Array of the output tensor names of the model</returns>
public static string[] GetOutputNames(this Model model)
public static string[] GetOutputNames(this Model model, bool deterministicInference = false)
{
var names = new List<string>();

Expand All @@ -122,13 +124,13 @@ public static string[] GetOutputNames(this Model model)
return names.ToArray();
}

if (model.HasContinuousOutputs())
if (model.HasContinuousOutputs(deterministicInference))
{
names.Add(model.ContinuousOutputName());
names.Add(model.ContinuousOutputName(deterministicInference));
}
if (model.HasDiscreteOutputs())
if (model.HasDiscreteOutputs(deterministicInference))
{
names.Add(model.DiscreteOutputName());
names.Add(model.DiscreteOutputName(deterministicInference));
}

var modelVersion = model.GetVersion();
Expand All @@ -149,8 +151,10 @@ public static string[] GetOutputNames(this Model model)
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>True if the model has continuous action outputs.</returns>
public static bool HasContinuousOutputs(this Model model)
public static bool HasContinuousOutputs(this Model model, bool deterministicInference = false)
{
if (model == null)
return false;
Expand All @@ -160,8 +164,13 @@ public static bool HasContinuousOutputs(this Model model)
}
else
{
return model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
bool hasStochasticOutput = !deterministicInference &&
model.outputs.Contains(TensorNames.ContinuousActionOutput);
bool hasDeterministicOutput = deterministicInference &&
model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput);

return (hasStochasticOutput || hasDeterministicOutput) &&
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
}
}

Expand Down Expand Up @@ -194,8 +203,10 @@ public static int ContinuousOutputSize(this Model model)
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>Tensor name of continuous action output.</returns>
public static string ContinuousOutputName(this Model model)
public static string ContinuousOutputName(this Model model, bool deterministicInference = false)
{
if (model == null)
return null;
Expand All @@ -205,7 +216,7 @@ public static string ContinuousOutputName(this Model model)
}
else
{
return TensorNames.ContinuousActionOutput;
return deterministicInference ? TensorNames.DeterministicContinuousActionOutput : TensorNames.ContinuousActionOutput;
}
}

Expand All @@ -215,8 +226,10 @@ public static string ContinuousOutputName(this Model model)
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>True if the model has discrete action outputs.</returns>
public static bool HasDiscreteOutputs(this Model model)
public static bool HasDiscreteOutputs(this Model model, bool deterministicInference = false)
{
if (model == null)
return false;
Expand All @@ -226,7 +239,12 @@ public static bool HasDiscreteOutputs(this Model model)
}
else
{
return model.outputs.Contains(TensorNames.DiscreteActionOutput) && model.DiscreteOutputSize() > 0;
bool hasStochasticOutput = !deterministicInference &&
model.outputs.Contains(TensorNames.DiscreteActionOutput);
bool hasDeterministicOutput = deterministicInference &&
model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput);
return (hasStochasticOutput || hasDeterministicOutput) &&
model.DiscreteOutputSize() > 0;
}
}

Expand Down Expand Up @@ -279,8 +297,10 @@ public static int DiscreteOutputSize(this Model model)
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>Tensor name of discrete action output.</returns>
public static string DiscreteOutputName(this Model model)
public static string DiscreteOutputName(this Model model, bool deterministicInference = false)
{
if (model == null)
return null;
Expand All @@ -290,7 +310,7 @@ public static string DiscreteOutputName(this Model model)
}
else
{
return TensorNames.DiscreteActionOutput;
return deterministicInference ? TensorNames.DeterministicDiscreteActionOutput : TensorNames.DiscreteActionOutput;
}
}

Expand All @@ -316,9 +336,11 @@ public static bool SupportsContinuousAndDiscrete(this Model model)
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <param name="failedModelChecks">Output list of failure messages</param>
///
///<param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>True if the model contains all the expected tensors.</returns>
public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks)
/// TODO: add checks for deterministic actions
public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks, bool deterministicInference = false)
{
// Check the presence of model version
var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);
Expand All @@ -343,7 +365,9 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
// Check the presence of action output tensor
if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) &&
!model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
!model.outputs.Contains(TensorNames.DiscreteActionOutput))
!model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
!model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput) &&
!model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput))
{
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain any Action Output Node.")
Expand Down Expand Up @@ -373,22 +397,51 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
}
else
{
if (model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
if (model.outputs.Contains(TensorNames.ContinuousActionOutput))
{
failedModelChecks.Add(
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
if (model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
{
failedModelChecks.Add(
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
);
return false;
return false;
}

else if (!model.HasContinuousOutputs(deterministicInference))
{
var actionType = deterministicInference ? "deterministic" : "stochastic";
var actionName = deterministicInference ? "Deterministic" : "";
failedModelChecks.Add(
FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..")
);
return false;
}
}
if (model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)

if (model.outputs.Contains(TensorNames.DiscreteActionOutput))
{
failedModelChecks.Add(
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
{
failedModelChecks.Add(
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
);
return false;
return false;
}
else if (!model.HasDiscreteOutputs(deterministicInference))
{
var actionType = deterministicInference ? "deterministic" : "stochastic";
var actionName = deterministicInference ? "Deterministic" : "";
failedModelChecks.Add(
FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.")
);
return false;
}

}




}
return true;
}
Expand Down
24 changes: 16 additions & 8 deletions com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,17 @@ public static FailedCheck CheckModelVersion(Model model)
/// <param name="actuatorComponents">Attached actuator components</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <param name="behaviorType">BehaviorType or the Agent to check.</param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>A IEnumerable of the checks that failed</returns>
public static IEnumerable<FailedCheck> CheckModel(
Model model,
BrainParameters brainParameters,
ISensor[] sensors,
ActuatorComponent[] actuatorComponents,
int observableAttributeTotalSize = 0,
BehaviorType behaviorType = BehaviorType.Default
BehaviorType behaviorType = BehaviorType.Default,
bool deterministicInference = false
)
{
List<FailedCheck> failedModelChecks = new List<FailedCheck>();
Expand All @@ -148,7 +151,7 @@ public static IEnumerable<FailedCheck> CheckModel(
return failedModelChecks;
}

var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks);
var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks, deterministicInference);
if (!hasExpectedTensors)
{
return failedModelChecks;
Expand Down Expand Up @@ -181,7 +184,7 @@ public static IEnumerable<FailedCheck> CheckModel(
else if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0)
{
failedModelChecks.AddRange(
CheckInputTensorPresence(model, brainParameters, memorySize, sensors)
CheckInputTensorPresence(model, brainParameters, memorySize, sensors, deterministicInference)
);
failedModelChecks.AddRange(
CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize)
Expand All @@ -195,7 +198,7 @@ public static IEnumerable<FailedCheck> CheckModel(
);

failedModelChecks.AddRange(
CheckOutputTensorPresence(model, memorySize)
CheckOutputTensorPresence(model, memorySize, deterministicInference)
);
return failedModelChecks;
}
Expand Down Expand Up @@ -318,14 +321,17 @@ ISensor[] sensors
/// The memory size that the model is expecting.
/// </param>
/// <param name="sensors">Array of attached sensor components</param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// Deterministic. </param>
/// <returns>
/// A IEnumerable of the checks that failed
/// </returns>
static IEnumerable<FailedCheck> CheckInputTensorPresence(
Model model,
BrainParameters brainParameters,
int memory,
ISensor[] sensors
ISensor[] sensors,
bool deterministicInference = false
)
{
var failedModelChecks = new List<FailedCheck>();
Expand Down Expand Up @@ -356,7 +362,7 @@ ISensor[] sensors
}

// If the model uses discrete control but does not have an input for action masks
if (model.HasDiscreteOutputs())
if (model.HasDiscreteOutputs(deterministicInference))
{
if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder))
{
Expand All @@ -376,17 +382,19 @@ ISensor[] sensors
/// The Barracuda engine model for loading static parameters
/// </param>
/// <param name="memory">The memory size that the model is expecting/</param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>
/// A IEnumerable of the checks that failed
/// </returns>
static IEnumerable<FailedCheck> CheckOutputTensorPresence(Model model, int memory)
static IEnumerable<FailedCheck> CheckOutputTensorPresence(Model model, int memory, bool deterministicInference = false)
{
var failedModelChecks = new List<FailedCheck>();

// If there is no Recurrent Output but the model is Recurrent.
if (memory > 0)
{
var allOutputs = model.GetOutputNames().ToList();
var allOutputs = model.GetOutputNames(deterministicInference).ToList();
if (!allOutputs.Any(x => x == TensorNames.RecurrentOutput))
{
failedModelChecks.Add(
Expand Down
Loading