Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop modify stepping logic #3448

Merged
merged 10 commits into from
Feb 19, 2020
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Several classes were changed from public to internal visibility. (#3390)
- Academy.RegisterSideChannel and UnregisterSideChannel methods were added. (#3391)
- A tutorial on adding custom SideChannels was added (#3391)
- The stepping logic for the Agent and the Academy has been simplified (#3448)
- Update Barracuda to 0.6.0-preview

### Bugfixes
Expand Down
10 changes: 7 additions & 3 deletions com.unity.ml-agents/Runtime/Academy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ public bool IsCommunicatorOn
// Signals to all the listeners that the academy is being destroyed
internal event Action DestroyAction;

// Signals the Agent that a new step is about to start.
// This will mark the Agent as Done if it has reached its maxSteps.
internal event Action AgentIncrementStep;

// Signals to all the agents at each environment step along with the
// Academy's maxStepReached, done and stepCount values. The agents rely
// on this event to update their own values of max step reached and done
Expand Down Expand Up @@ -418,6 +422,9 @@ public void EnvironmentStep()

AgentSetStatus?.Invoke(m_StepCount);

m_StepCount += 1;
m_TotalStepCount += 1;
AgentIncrementStep?.Invoke();

using (TimerStack.Instance.Scoped("AgentSendState"))
{
Expand All @@ -433,9 +440,6 @@ public void EnvironmentStep()
{
AgentAct?.Invoke();
}

m_StepCount += 1;
m_TotalStepCount += 1;
}

/// <summary>
Expand Down
23 changes: 13 additions & 10 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ public void LazyInitialize()
m_Action = new AgentAction();
sensors = new List<ISensor>();

Academy.Instance.AgentIncrementStep += AgentIncrementStep;
Academy.Instance.AgentSendState += SendInfo;
Academy.Instance.DecideAction += DecideAction;
Academy.Instance.AgentAct += AgentStep;
Expand All @@ -256,6 +257,7 @@ void OnDisable()
// We don't want to even try, because this will lazily create a new Academy!
if (Academy.IsInitialized)
{
Academy.Instance.AgentIncrementStep -= AgentIncrementStep;
Academy.Instance.AgentSendState -= SendInfo;
Academy.Instance.DecideAction -= DecideAction;
Academy.Instance.AgentAct -= AgentStep;
Expand Down Expand Up @@ -685,24 +687,25 @@ void SendInfo()
}
}

void AgentIncrementStep()
{
m_StepCount += 1;
vincentpierre marked this conversation as resolved.
Show resolved Hide resolved
}

/// Used by the brain to make the agent perform a step.
void AgentStep()
{
if ((m_StepCount >= maxStep) && (maxStep > 0))
{
NotifyAgentDone(true);
_AgentReset();
}
else
{
m_StepCount += 1;
}

if ((m_RequestAction) && (m_Brain != null))
{
m_RequestAction = false;
AgentAction(m_Action.vectorActions);
}

if ((m_StepCount >= maxStep) && (maxStep > 0))
{
NotifyAgentDone(true);
_AgentReset();
}
}

void DecideAction()
Expand Down
60 changes: 39 additions & 21 deletions com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public override void AgentAction(float[] vectorAction)

public override void AgentReset()
{

agentResetCalls += 1;
collectObservationsCallsSinceLastReset = 0;
agentActionCallsSinceLastReset = 0;
Expand Down Expand Up @@ -485,19 +486,16 @@ public void TestCumulativeReward()
agent1.LazyInitialize();
agent2.SetPolicy(new TestPolicy());

var j = 0;
for (var i = 0; i < 500; i++)
var expectedAgent1ActionSinceReset = 0;

for (var i = 0; i < 50; i++)
{
if (i % 21 == 0)
{
j = 0;
}
else
{
j++;
expectedAgent1ActionSinceReset += 1;
if (expectedAgent1ActionSinceReset == agent1.maxStep || i == 0){
expectedAgent1ActionSinceReset = 0;
}
agent2.RequestAction();
Assert.LessOrEqual(Mathf.Abs(j * 10.1f - agent1.GetCumulativeReward()), 0.05f);
Assert.LessOrEqual(Mathf.Abs(expectedAgent1ActionSinceReset * 10.1f - agent1.GetCumulativeReward()), 0.05f);
Assert.LessOrEqual(Mathf.Abs(i * 0.1f - agent2.GetCumulativeReward()), 0.05f);

agent1.AddReward(10f);
Expand All @@ -517,26 +515,46 @@ public void TestMaxStepsReset()
decisionRequester.DecisionPeriod = 1;
decisionRequester.Awake();

var maxStep = 6;
const int maxStep = 6;
agent1.maxStep = maxStep;
agent1.LazyInitialize();

var expectedAgentStepCount = 0;
var expectedResets= 0;
var expectedAgentAction = 0;
var expectedAgentActionSinceReset = 0;
var expectedCollectObsCalls = 0;
var expectedCollectObsCallsSinceReset = 0;

for (var i = 0; i < 15; i++)
{
// We expect resets to occur when there are maxSteps actions since the last reset (and on the first step)
var expectReset = agent1.agentActionCallsSinceLastReset == maxStep || (i == 0);
var previousNumResets = agent1.agentResetCalls;

aca.EnvironmentStep();

if (expectReset)
// Agent should observe and act on each Academy step
expectedAgentAction += 1;
expectedAgentActionSinceReset += 1;
expectedCollectObsCalls += 1;
expectedCollectObsCallsSinceReset += 1;
expectedAgentStepCount += 1;

// If the next step will put the agent at maxSteps, we expect it to reset
if (agent1.GetStepCount() == maxStep - 1 || (i == 0))
{
Assert.AreEqual(previousNumResets + 1, agent1.agentResetCalls);
expectedResets +=1;
}
else

if (agent1.GetStepCount() == maxStep - 1)
{
Assert.AreEqual(previousNumResets, agent1.agentResetCalls);
expectedAgentActionSinceReset = 0;
expectedCollectObsCallsSinceReset = 0;
expectedAgentStepCount = 0;
}
aca.EnvironmentStep();

Assert.AreEqual(expectedAgentStepCount, agent1.GetStepCount());
Assert.AreEqual(expectedResets, agent1.agentResetCalls);
Assert.AreEqual(expectedAgentAction, agent1.agentActionCalls);
Assert.AreEqual(expectedAgentActionSinceReset, agent1.agentActionCallsSinceLastReset);
Assert.AreEqual(expectedCollectObsCalls, agent1.collectObservationsCalls);
Assert.AreEqual(expectedCollectObsCallsSinceReset, agent1.collectObservationsCallsSinceLastReset);
}
}
}
Expand Down