Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick] Integrate Group Manager to soccer/retrain with POCA (#5115) #5121

Merged
merged 3 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 113 additions & 26 deletions Project/Assets/ML-Agents/Examples/Soccer/Prefabs/SoccerFieldTwos.prefab

Large diffs are not rendered by default.

Large diffs are not rendered by default.

57 changes: 12 additions & 45 deletions Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;

public enum Team
{
Blue = 0,
Purple = 1
}

public class AgentSoccer : Agent
{
// Note that that the detectable tags are different for the blue and purple teams. The order is
Expand All @@ -12,11 +18,6 @@ public class AgentSoccer : Agent
// * wall
// * own teammate
// * opposing player
public enum Team
{
Blue = 0,
Purple = 1
}

public enum Position
{
Expand All @@ -28,8 +29,6 @@ public enum Position
[HideInInspector]
public Team team;
float m_KickPower;
int m_PlayerIndex;
public SoccerFieldArea area;
// The coefficient for the reward for colliding with a ball. Set using curriculum.
float m_BallTouch;
public Position position;
Expand All @@ -39,14 +38,13 @@ public enum Position
float m_LateralSpeed;
float m_ForwardSpeed;

[HideInInspector]
public float timePenalty;

[HideInInspector]
public Rigidbody agentRb;
SoccerSettings m_SoccerSettings;
BehaviorParameters m_BehaviorParameters;
Vector3 m_Transform;
public Vector3 initialPos;
public float rotSign;

EnvironmentParameters m_ResetParams;

Expand All @@ -57,12 +55,14 @@ public override void Initialize()
if (m_BehaviorParameters.TeamId == (int)Team.Blue)
{
team = Team.Blue;
m_Transform = new Vector3(transform.position.x - 4f, .5f, transform.position.z);
initialPos = new Vector3(transform.position.x - 5f, .5f, transform.position.z);
rotSign = 1f;
}
else
{
team = Team.Purple;
m_Transform = new Vector3(transform.position.x + 4f, .5f, transform.position.z);
initialPos = new Vector3(transform.position.x + 5f, .5f, transform.position.z);
rotSign = -1f;
}
if (position == Position.Goalie)
{
Expand All @@ -83,16 +83,6 @@ public override void Initialize()
agentRb = GetComponent<Rigidbody>();
agentRb.maxAngularVelocity = 500;

var playerState = new PlayerState
{
agentRb = agentRb,
startingPos = transform.position,
agentScript = this,
};
area.playerStates.Add(playerState);
m_PlayerIndex = area.playerStates.IndexOf(playerState);
playerState.playerIndex = m_PlayerIndex;

m_ResetParams = Academy.Instance.EnvironmentParameters;
}

Expand Down Expand Up @@ -157,11 +147,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
// Existential penalty for Strikers
AddReward(-m_Existential);
}
else
{
// Existential penalty cumulant for Generic
timePenalty -= m_Existential;
}
MoveAgent(actionBuffers.DiscreteActions);
}

Expand Down Expand Up @@ -218,25 +203,7 @@ void OnCollisionEnter(Collision c)

public override void OnEpisodeBegin()
{

timePenalty = 0;
m_BallTouch = m_ResetParams.GetWithDefault("ball_touch", 0);
if (team == Team.Purple)
{
transform.rotation = Quaternion.Euler(0f, -90f, 0f);
}
else
{
transform.rotation = Quaternion.Euler(0f, 90f, 0f);
}
transform.position = m_Transform;
agentRb.velocity = Vector3.zero;
agentRb.angularVelocity = Vector3.zero;
SetResetParameters();
}

public void SetResetParameters()
{
area.ResetBall();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,26 @@

public class SoccerBallController : MonoBehaviour
{
public GameObject area;
[HideInInspector]
public SoccerFieldArea area;
public SoccerEnvController envController;
public string purpleGoalTag; //will be used to check if collided with purple goal
public string blueGoalTag; //will be used to check if collided with blue goal

void Start()
{
envController = area.GetComponent<SoccerEnvController>();
}

void OnCollisionEnter(Collision col)
{
if (col.gameObject.CompareTag(purpleGoalTag)) //ball touched purple goal
{
area.GoalTouched(AgentSoccer.Team.Blue);
envController.GoalTouched(Team.Blue);
}
if (col.gameObject.CompareTag(blueGoalTag)) //ball touched blue goal
{
area.GoalTouched(AgentSoccer.Team.Purple);
envController.GoalTouched(Team.Purple);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using UnityEngine;

public class SoccerEnvController : MonoBehaviour
{
[System.Serializable]
public class PlayerInfo
{
public AgentSoccer Agent;
[HideInInspector]
public Vector3 StartingPos;
[HideInInspector]
public Quaternion StartingRot;
[HideInInspector]
public Rigidbody Rb;
}


/// <summary>
/// Max Academy steps before this platform resets
/// </summary>
/// <returns></returns>
[Header("Max Environment Steps")] public int MaxEnvironmentSteps = 25000;

/// <summary>
/// The area bounds.
/// </summary>

/// <summary>
/// We will be changing the ground material based on success/failue
/// </summary>

public GameObject ball;
[HideInInspector]
public Rigidbody ballRb;
Vector3 m_BallStartingPos;

//List of Agents On Platform
public List<PlayerInfo> AgentsList = new List<PlayerInfo>();

private SoccerSettings m_SoccerSettings;


private SimpleMultiAgentGroup m_BlueAgentGroup;
private SimpleMultiAgentGroup m_PurpleAgentGroup;

private int m_ResetTimer;

void Start()
{

m_SoccerSettings = FindObjectOfType<SoccerSettings>();
// Initialize TeamManager
m_BlueAgentGroup = new SimpleMultiAgentGroup();
m_PurpleAgentGroup = new SimpleMultiAgentGroup();
ballRb = ball.GetComponent<Rigidbody>();
m_BallStartingPos = new Vector3(ball.transform.position.x, ball.transform.position.y, ball.transform.position.z);
foreach (var item in AgentsList)
{
item.StartingPos = item.Agent.transform.position;
item.StartingRot = item.Agent.transform.rotation;
item.Rb = item.Agent.GetComponent<Rigidbody>();
if (item.Agent.team == Team.Blue)
{
m_BlueAgentGroup.RegisterAgent(item.Agent);
}
else
{
m_PurpleAgentGroup.RegisterAgent(item.Agent);
}
}
ResetScene();
}

void FixedUpdate()
{
m_ResetTimer += 1;
if (m_ResetTimer >= MaxEnvironmentSteps && MaxEnvironmentSteps > 0)
{
m_BlueAgentGroup.GroupEpisodeInterrupted();
m_PurpleAgentGroup.GroupEpisodeInterrupted();
ResetScene();
}
}


public void ResetBall()
{
var randomPosX = Random.Range(-2.5f, 2.5f);
var randomPosZ = Random.Range(-2.5f, 2.5f);

ball.transform.position = m_BallStartingPos + new Vector3(randomPosX, 0f, randomPosZ); ;
ballRb.velocity = Vector3.zero;
ballRb.angularVelocity = Vector3.zero;

}

public void GoalTouched(Team scoredTeam)
{
if (scoredTeam == Team.Blue)
{
m_BlueAgentGroup.AddGroupReward(1 - m_ResetTimer / MaxEnvironmentSteps);
m_PurpleAgentGroup.AddGroupReward(-1);
}
else
{
m_PurpleAgentGroup.AddGroupReward(1 - m_ResetTimer / MaxEnvironmentSteps);
m_BlueAgentGroup.AddGroupReward(-1);
}
m_PurpleAgentGroup.EndGroupEpisode();
m_BlueAgentGroup.EndGroupEpisode();
ResetScene();

}


public void ResetScene()
{
m_ResetTimer = 0;

//Reset Agents
foreach (var item in AgentsList)
{
var randomPosX = Random.Range(-5f, 5f);
var newStartPos = item.Agent.initialPos + new Vector3(randomPosX, 0f, 0f);
var rot = item.Agent.rotSign * Random.Range(80.0f, 100.0f);
var newRot = Quaternion.Euler(0, rot, 0);
item.Agent.transform.SetPositionAndRotation(newStartPos, newRot);

item.Rb.velocity = Vector3.zero;
item.Rb.angularVelocity = Vector3.zero;
}

//Reset Ball
ResetBall();
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

This file was deleted.

Binary file not shown.

This file was deleted.

Binary file not shown.
Loading