Skip to content

Commit

Permalink
Merge pull request #27 from Unity-Technologies/fix-internal-placeholder
Browse files Browse the repository at this point in the history
made a better error if a placeholder is missing or if a placeholder is …
  • Loading branch information
vincentpierre authored Sep 22, 2017
2 parents 53bf659 + caa5ba9 commit 4b7d0c9
Showing 1 changed file with 38 additions and 9 deletions.
47 changes: 38 additions & 9 deletions unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public enum tensorType
public string[] ObservationPlaceholderName;
/// Modify only in inspector : Name of the action node
public string ActionPlaceholderName = "action";
#if ENABLE_TENSORFLOW
#if ENABLE_TENSORFLOW
TFGraph graph;
TFSession session;
bool hasRecurrent;
Expand All @@ -62,7 +62,7 @@ public enum tensorType
float[,] inputState;
List<float[,,,]> observationMatrixList;
float[,] inputOldMemories;
#endif
#endif

/// Reference to the brain that uses this CoreBrainInternal
public Brain brain;
Expand Down Expand Up @@ -190,13 +190,22 @@ public void DecideAction()

foreach (TensorFlowAgentPlaceholder placeholder in graphPlaceholders)
{
if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint)
try
{
runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) });
if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint)
{
runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) });
}
else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer)
{
runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) });
}
}
else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer)
catch
{
runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) });
throw new UnityAgentsException(string.Format(@"One of the Tensorflow placeholder cound nout be found.
In brain {0}, there are no {1} placeholder named {2}.",
brain.gameObject.name, placeholder.valueType.ToString(), graphScope + placeholder.name));
}
}

Expand All @@ -212,6 +221,26 @@ public void DecideAction()
runner.AddInput(graph[graphScope + ObservationPlaceholderName[obs_number]][0], observationMatrixList[obs_number]);
}

TFTensor[] networkOutput;
try
{
networkOutput = runner.Run();
}
catch (TFException e)
{
string errorMessage = e.Message;
try
{
errorMessage = string.Format(@"The tensorflow graph needs an input for {0} of type {1}",
e.Message.Split(new string[]{ "Node: " }, 0)[1].Split('=')[0],
e.Message.Split(new string[]{ "dtype=" }, 0)[1].Split(',')[0]);
}
finally
{
throw new UnityAgentsException(errorMessage);
}

}

// Create the recurrent tensor
if (hasRecurrent)
Expand All @@ -220,7 +249,7 @@ public void DecideAction()

runner.AddInput(graph[graphScope + RecurrentInPlaceholderName][0], inputOldMemories);
runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]);
float[,] recurrent_tensor = runner.Run()[1].GetValue() as float[,];
float[,] recurrent_tensor = networkOutput[1].GetValue() as float[,];

int i = 0;
foreach (int k in agentKeys)
Expand All @@ -241,7 +270,7 @@ public void DecideAction()

if (brain.brainParameters.actionSpaceType == StateType.continuous)
{
float[,] output = runner.Run()[0].GetValue() as float[,];
float[,] output = networkOutput[0].GetValue() as float[,];
int i = 0;
foreach (int k in agentKeys)
{
Expand All @@ -256,7 +285,7 @@ public void DecideAction()
}
else if (brain.brainParameters.actionSpaceType == StateType.discrete)
{
long[,] output = runner.Run()[0].GetValue() as long[,];
long[,] output = networkOutput[0].GetValue() as long[,];
int i = 0;
foreach (int k in agentKeys)
{
Expand Down

0 comments on commit 4b7d0c9

Please sign in to comment.