Skip to content

Commit

Permalink
[WIP] Side Channel Design Changes (#3807)
Browse files Browse the repository at this point in the history
* Make EnvironmentParameters a first-class citizen in the API

Missing: Python conterparts and testing.

* Minor comment fix to Engine Parameters

* A second minor fix.

* Make EngineConfigChannel Internal and add a singleton/sealed accessor

* Make StatsSideChannel Internal and add a singleton/sealed accessor

* Changes to SideChannelUtils

- Disallow two sidechannels of the same type to be added
- Remove GetSideChannels that return a list as that is now unnecessary
- Make most methods except (register/unregister) internal to limit users impacting the “system-level” side channels
- Add an improved comment to SideChannel.cs

* Added Dispose methods to system-level sidechannel wrappers

- Specifically to StatsRecorder, EnvironmentParameters and EngineParameters.
- Updated Academy.Dispose to take advantage of these.
- Updated Editor tests to cover all three “system-level” side channels.

Kudos to Unit Tests (TestAcademy / TestAcademyDispose) for catching these.

* Removed debub log.

* Back-up commit.

* Revert "Back-up commit."

This reverts commit f81e835.

* key changes to wrapper classes

made the wrapper classes non-singleton (but internal constructors)
made EngineParameters internal

* Re-enabled the option to add multiple side channels of the same type

* Fixed example env

* Add an enum flag to the EnvParamsChannel

* Adding .cs.meta files

* Update engine config side channel

Removed unnecessary accessors
Made capture frame rate a parameter

* Rename SideChannelUtils —> SideChannelsManager

* PR feedback

* Minor PR feedback.

* Python side changes to the SideChannel redesign (#3826)

* Modified the EngineConfig to send one message per field

* Created the Python Environment Parameters Channel and hooked it in

* Make OnMessageReceived protected

* addressing comments

* [Side Channels] Edited the documenation and renamed a few things (#3833)

* Edited the documetation and renamed a few things

* addressing comments

* Update docs/Python-API.md

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Update com.unity.ml-agents/CHANGELOG.md

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Removing unecessary migrating line

Co-authored-by: Chris Elion <chris.elion@unity3d.com>

* Addressing renaming comments

* Removing the EngineParameters class

Co-authored-by: Vincent-Pierre BERGES <vincentpierre@unity3d.com>
Co-authored-by: Chris Elion <chris.elion@unity3d.com>
  • Loading branch information
3 people authored Apr 23, 2020
1 parent cc74f81 commit e4826b6
Show file tree
Hide file tree
Showing 47 changed files with 724 additions and 305 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ public class Ball3DAgent : Agent
[Header("Specific to Ball3D")]
public GameObject ball;
Rigidbody m_BallRb;
FloatPropertiesChannel m_ResetParams;
EnvironmentParameters m_ResetParams;

public override void Initialize()
{
m_BallRb = ball.GetComponent<Rigidbody>();
m_ResetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters();
}

Expand Down Expand Up @@ -75,8 +75,8 @@ public override void Heuristic(float[] actionsOut)
public void SetBall()
{
//Set the attributes of the ball by fetching the information from the academy
m_BallRb.mass = m_ResetParams.GetPropertyWithDefault("mass", 1.0f);
var scale = m_ResetParams.GetPropertyWithDefault("scale", 1.0f);
m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
var scale = m_ResetParams.GetWithDefault("scale", 1.0f);
ball.transform.localScale = new Vector3(scale, scale, scale);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ public class Ball3DHardAgent : Agent
[Header("Specific to Ball3DHard")]
public GameObject ball;
Rigidbody m_BallRb;
FloatPropertiesChannel m_ResetParams;
EnvironmentParameters m_ResetParams;

public override void Initialize()
{
m_BallRb = ball.GetComponent<Rigidbody>();
m_ResetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters();
}

Expand Down Expand Up @@ -66,8 +66,8 @@ public override void OnEpisodeBegin()
public void SetBall()
{
//Set the attributes of the ball by fetching the information from the academy
m_BallRb.mass = m_ResetParams.GetPropertyWithDefault("mass", 1.0f);
var scale = m_ResetParams.GetPropertyWithDefault("scale", 1.0f);
m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
var scale = m_ResetParams.GetWithDefault("scale", 1.0f);
ball.transform.localScale = new Vector3(scale, scale, scale);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ public class BouncerAgent : Agent
int m_NumberJumps = 20;
int m_JumpLeft = 20;

FloatPropertiesChannel m_ResetParams;
EnvironmentParameters m_ResetParams;

public override void Initialize()
{
m_Rb = gameObject.GetComponent<Rigidbody>();
m_LookDir = Vector3.zero;

m_ResetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
m_ResetParams = Academy.Instance.EnvironmentParameters;

SetResetParameters();
}
Expand Down Expand Up @@ -121,7 +121,7 @@ void Update()

public void SetTargetScale()
{
var targetScale = m_ResetParams.GetPropertyWithDefault("target_scale", 1.0f);
var targetScale = m_ResetParams.GetWithDefault("target_scale", 1.0f);
target.transform.localScale = new Vector3(targetScale, targetScale, targetScale);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ public class FoodCollectorAgent : Agent
public bool contribute;
public bool useVectorObs;

EnvironmentParameters m_ResetParams;

public override void Initialize()
{
m_AgentRb = GetComponent<Rigidbody>();
m_MyArea = area.GetComponent<FoodCollectorArea>();
m_FoodCollecterSettings = FindObjectOfType<FoodCollectorSettings>();

m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters();
}

Expand Down Expand Up @@ -271,12 +272,12 @@ void OnCollisionEnter(Collision collision)

public void SetLaserLengths()
{
m_LaserLength = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().GetPropertyWithDefault("laser_length", 1.0f);
m_LaserLength = m_ResetParams.GetWithDefault("laser_length", 1.0f);
}

public void SetAgentScale()
{
float agentScale = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().GetPropertyWithDefault("agent_scale", 1.0f);
float agentScale = m_ResetParams.GetWithDefault("agent_scale", 1.0f);
gameObject.transform.localScale = new Vector3(agentScale, agentScale, agentScale);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using System;
using UnityEngine;
using UnityEngine.UI;
using MLAgents;
using MLAgents.SideChannels;

public class FoodCollectorSettings : MonoBehaviour
{
Expand All @@ -14,15 +12,15 @@ public class FoodCollectorSettings : MonoBehaviour
public int totalScore;
public Text scoreText;

StatsSideChannel m_statsSideChannel;
StatsRecorder m_Recorder;

public void Awake()
{
Academy.Instance.OnEnvironmentReset += EnvironmentReset;
m_statsSideChannel = SideChannelUtils.GetSideChannel<StatsSideChannel>();
m_Recorder = Academy.Instance.StatsRecorder;
}

public void EnvironmentReset()
private void EnvironmentReset()
{
ClearObjects(GameObject.FindGameObjectsWithTag("food"));
ClearObjects(GameObject.FindGameObjectsWithTag("badFood"));
Expand Down Expand Up @@ -54,7 +52,7 @@ public void Update()
// need to send every Update() call.
if ((Time.frameCount % 100)== 0)
{
m_statsSideChannel?.AddStat("TotalScore", totalScore);
m_Recorder.Add("TotalScore", totalScore);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ public class GridAgent : Agent
const int k_Left = 3;
const int k_Right = 4;

EnvironmentParameters m_ResetParams;

public override void Initialize()
{
m_ResetParams = Academy.Instance.EnvironmentParameters;
}

public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
{
// Mask the necessary actions if selected by the user.
Expand All @@ -37,7 +44,7 @@ public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMaske
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
var positionZ = (int)transform.position.z;
var maxPosition = (int)SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().GetPropertyWithDefault("gridSize", 5f) - 1;
var maxPosition = (int)m_ResetParams.GetWithDefault("gridSize", 5f) - 1;

if (positionX == 0)
{
Expand Down
20 changes: 10 additions & 10 deletions Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ public class GridArea : MonoBehaviour

public GameObject trueAgent;

FloatPropertiesChannel m_ResetParameters;

Camera m_AgentCam;

public GameObject goalPref;
Expand All @@ -30,9 +28,11 @@ public class GridArea : MonoBehaviour

Vector3 m_InitialPosition;

EnvironmentParameters m_ResetParams;

public void Start()
{
m_ResetParameters = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
m_ResetParams = Academy.Instance.EnvironmentParameters;

m_Objects = new[] { goalPref, pitPref };

Expand All @@ -50,23 +50,23 @@ public void Start()
m_InitialPosition = transform.position;
}

public void SetEnvironment()
private void SetEnvironment()
{
transform.position = m_InitialPosition * (m_ResetParameters.GetPropertyWithDefault("gridSize", 5f) + 1);
transform.position = m_InitialPosition * (m_ResetParams.GetWithDefault("gridSize", 5f) + 1);
var playersList = new List<int>();

for (var i = 0; i < (int)m_ResetParameters.GetPropertyWithDefault("numObstacles", 1); i++)
for (var i = 0; i < (int)m_ResetParams.GetWithDefault("numObstacles", 1); i++)
{
playersList.Add(1);
}

for (var i = 0; i < (int)m_ResetParameters.GetPropertyWithDefault("numGoals", 1f); i++)
for (var i = 0; i < (int)m_ResetParams.GetWithDefault("numGoals", 1f); i++)
{
playersList.Add(0);
}
players = playersList.ToArray();

var gridSize = (int)m_ResetParameters.GetPropertyWithDefault("gridSize", 5f);
var gridSize = (int)m_ResetParams.GetWithDefault("gridSize", 5f);
m_Plane.transform.localScale = new Vector3(gridSize / 10.0f, 1f, gridSize / 10.0f);
m_Plane.transform.localPosition = new Vector3((gridSize - 1) / 2f, -0.5f, (gridSize - 1) / 2f);
m_Sn.transform.localScale = new Vector3(1, 1, gridSize + 2);
Expand All @@ -84,7 +84,7 @@ public void SetEnvironment()

public void AreaReset()
{
var gridSize = (int)m_ResetParameters.GetPropertyWithDefault("gridSize", 5f);
var gridSize = (int)m_ResetParams.GetWithDefault("gridSize", 5f);
foreach (var actor in actorObjs)
{
DestroyImmediate(actor);
Expand All @@ -98,7 +98,7 @@ public void AreaReset()
{
numbers.Add(Random.Range(0, gridSize * gridSize));
}
var numbersA = Enumerable.ToArray(numbers);
var numbersA = numbers.ToArray();

for (var i = 0; i < players.Length; i++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public class GridSettings : MonoBehaviour

public void Awake()
{
SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().RegisterCallback("gridSize", f =>
Academy.Instance.EnvironmentParameters.RegisterCallback("gridSize", f =>
{
MainCamera.transform.position = new Vector3(-(f - 1) / 2f, f * 1.25f, -(f - 1) / 2f);
MainCamera.orthographicSize = (f + 5f) / 2f;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public class PushAgentBasic : Agent
/// </summary>
Renderer m_GroundRenderer;

private EnvironmentParameters m_ResetParams;

void Awake()
{
m_PushBlockSettings = FindObjectOfType<PushBlockSettings>();
Expand All @@ -70,6 +72,8 @@ public override void Initialize()
// Starting material
m_GroundMaterial = m_GroundRenderer.material;

m_ResetParams = Academy.Instance.EnvironmentParameters;

SetResetParameters();
}

Expand Down Expand Up @@ -226,27 +230,23 @@ public override void OnEpisodeBegin()

public void SetGroundMaterialFriction()
{
var resetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();

var groundCollider = ground.GetComponent<Collider>();

groundCollider.material.dynamicFriction = resetParams.GetPropertyWithDefault("dynamic_friction", 0);
groundCollider.material.staticFriction = resetParams.GetPropertyWithDefault("static_friction", 0);
groundCollider.material.dynamicFriction = m_ResetParams.GetWithDefault("dynamic_friction", 0);
groundCollider.material.staticFriction = m_ResetParams.GetWithDefault("static_friction", 0);
}

public void SetBlockProperties()
{
var resetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();

var scale = resetParams.GetPropertyWithDefault("block_scale", 2);
var scale = m_ResetParams.GetWithDefault("block_scale", 2);
//Set the scale of the block
m_BlockRb.transform.localScale = new Vector3(scale, 0.75f, scale);

// Set the drag of the block
m_BlockRb.drag = resetParams.GetPropertyWithDefault("block_drag", 0.5f);
m_BlockRb.drag = m_ResetParams.GetWithDefault("block_drag", 0.5f);
}

public void SetResetParameters()
private void SetResetParameters()
{
SetGroundMaterialFriction();
SetBlockProperties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public class ReacherAgent : Agent
// Frequency of the cosine deviation of the goal along the vertical dimension
float m_DeviationFreq;

private EnvironmentParameters m_ResetParams;

/// <summary>
/// Collect the rigidbodies of the reacher in order to resue them for
/// observations and actions.
Expand All @@ -30,6 +32,8 @@ public override void Initialize()
m_RbA = pendulumA.GetComponent<Rigidbody>();
m_RbB = pendulumB.GetComponent<Rigidbody>();

m_ResetParams = Academy.Instance.EnvironmentParameters;

SetResetParameters();
}

Expand Down Expand Up @@ -110,10 +114,9 @@ public override void OnEpisodeBegin()

public void SetResetParameters()
{
var fp = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
m_GoalSize = fp.GetPropertyWithDefault("goal_size", 5);
m_GoalSpeed = Random.Range(-1f, 1f) * fp.GetPropertyWithDefault("goal_speed", 1);
m_Deviation = fp.GetPropertyWithDefault("deviation", 0);
m_DeviationFreq = fp.GetPropertyWithDefault("deviation_freq", 0);
m_GoalSize = m_ResetParams.GetWithDefault("goal_size", 5);
m_GoalSpeed = Random.Range(-1f, 1f) * m_ResetParams.GetWithDefault("goal_speed", 1);
m_Deviation = m_ResetParams.GetWithDefault("deviation", 0);
m_DeviationFreq = m_ResetParams.GetWithDefault("deviation_freq", 0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ public void Awake()
Physics.defaultSolverVelocityIterations = solverVelocityIterations;

// Make sure the Academy singleton is initialized first, since it will create the SideChannels.
var academy = Academy.Instance;
SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().RegisterCallback("gravity", f => { Physics.gravity = new Vector3(0, -f, 0); });
Academy.Instance.EnvironmentParameters.RegisterCallback("gravity", f => { Physics.gravity = new Vector3(0, -f, 0); });
}

public void OnDestroy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public enum Position
BehaviorParameters m_BehaviorParameters;
Vector3 m_Transform;

private EnvironmentParameters m_ResetParams;

public override void Initialize()
{
m_Existential = 1f / MaxStep;
Expand All @@ -73,7 +75,7 @@ public override void Initialize()
m_LateralSpeed = 0.3f;
m_ForwardSpeed = 1.3f;
}
else
else
{
m_LateralSpeed = 0.3f;
m_ForwardSpeed = 1.0f;
Expand All @@ -91,6 +93,8 @@ public override void Initialize()
area.playerStates.Add(playerState);
m_PlayerIndex = area.playerStates.IndexOf(playerState);
playerState.playerIndex = m_PlayerIndex;

m_ResetParams = Academy.Instance.EnvironmentParameters;
}

public void MoveAgent(float[] act)
Expand Down Expand Up @@ -214,7 +218,7 @@ public override void OnEpisodeBegin()
{

timePenalty = 0;
m_BallTouch = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().GetPropertyWithDefault("ball_touch", 0);
m_BallTouch = m_ResetParams.GetWithDefault("ball_touch", 0);
if (team == Team.Purple)
{
transform.rotation = Quaternion.Euler(0f, -90f, 0f);
Expand Down
Loading

0 comments on commit e4826b6

Please sign in to comment.