Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Side Channel Design Changes #3807

Merged
merged 24 commits into from
Apr 23, 2020
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
963a60d
Make EnvironmentParameters a first-class citizen in the API
Apr 19, 2020
ce1525e
Minor comment fix to Engine Parameters
Apr 19, 2020
62c868b
A second minor fix.
Apr 19, 2020
53f9850
Make EngineConfigChannel Internal and add a singleton/sealed accessor
Apr 19, 2020
fc28824
Make StatsSideChannel Internal and add a singleton/sealed accessor
Apr 19, 2020
12e88f9
Changes to SideChannelUtils
Apr 19, 2020
cbe21ad
Added Dispose methods to system-level sidechannel wrappers
Apr 19, 2020
194cea4
Removed debub log.
Apr 19, 2020
f81e835
Back-up commit.
Apr 21, 2020
26bad99
Revert "Back-up commit."
Apr 22, 2020
c1b7bc5
key changes to wrapper classes
Apr 22, 2020
8c33551
Re-enabled the option to add multiple side channels of the same type
Apr 22, 2020
1dce636
Fixed example env
Apr 22, 2020
bbea585
Add an enum flag to the EnvParamsChannel
Apr 22, 2020
797cf12
Adding .cs.meta files
Apr 22, 2020
c9b5cbc
Update engine config side channel
Apr 22, 2020
5a8351d
Rename SideChannelUtils —> SideChannelsManager
Apr 22, 2020
8ed980a
PR feedback
Apr 22, 2020
2db13f0
Minor PR feedback.
Apr 22, 2020
ab9dc93
Python side changes to the SideChannel redesign (#3826)
vincentpierre Apr 22, 2020
be42950
[Side Channels] Edited the documenation and renamed a few things (#3833)
vincentpierre Apr 23, 2020
1479351
Merge branch 'master' into develop-mm-env-params-unity
vincentpierre Apr 23, 2020
7bb1cf1
Addressing renaming comments
vincentpierre Apr 23, 2020
f326961
Removing the EngineParameters class
vincentpierre Apr 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ public class Ball3DAgent : Agent
[Header("Specific to Ball3D")]
public GameObject ball;
Rigidbody m_BallRb;
FloatPropertiesChannel m_ResetParams;
EnvironmentParameters m_ResetParams;

public override void Initialize()
{
m_BallRb = ball.GetComponent<Rigidbody>();
m_ResetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters();
}

Expand Down Expand Up @@ -75,8 +75,8 @@ public override void Heuristic(float[] actionsOut)
public void SetBall()
{
//Set the attributes of the ball by fetching the information from the academy
m_BallRb.mass = m_ResetParams.GetPropertyWithDefault("mass", 1.0f);
var scale = m_ResetParams.GetPropertyWithDefault("scale", 1.0f);
m_BallRb.mass = m_ResetParams.GetParameterWithDefault("mass", 1.0f);
var scale = m_ResetParams.GetParameterWithDefault("scale", 1.0f);
ball.transform.localScale = new Vector3(scale, scale, scale);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ public class Ball3DHardAgent : Agent
[Header("Specific to Ball3DHard")]
public GameObject ball;
Rigidbody m_BallRb;
FloatPropertiesChannel m_ResetParams;
EnvironmentParameters m_ResetParams;

public override void Initialize()
{
m_BallRb = ball.GetComponent<Rigidbody>();
m_ResetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters();
}

Expand Down Expand Up @@ -66,8 +66,8 @@ public override void OnEpisodeBegin()
public void SetBall()
{
//Set the attributes of the ball by fetching the information from the academy
m_BallRb.mass = m_ResetParams.GetPropertyWithDefault("mass", 1.0f);
var scale = m_ResetParams.GetPropertyWithDefault("scale", 1.0f);
m_BallRb.mass = m_ResetParams.GetParameterWithDefault("mass", 1.0f);
var scale = m_ResetParams.GetParameterWithDefault("scale", 1.0f);
ball.transform.localScale = new Vector3(scale, scale, scale);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ public class BouncerAgent : Agent
int m_NumberJumps = 20;
int m_JumpLeft = 20;

FloatPropertiesChannel m_ResetParams;
EnvironmentParameters m_ResetParams;

public override void Initialize()
{
m_Rb = gameObject.GetComponent<Rigidbody>();
m_LookDir = Vector3.zero;

m_ResetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
m_ResetParams = Academy.Instance.EnvironmentParameters;

SetResetParameters();
}
Expand Down Expand Up @@ -121,7 +121,7 @@ void Update()

public void SetTargetScale()
{
var targetScale = m_ResetParams.GetPropertyWithDefault("target_scale", 1.0f);
var targetScale = m_ResetParams.GetParameterWithDefault("target_scale", 1.0f);
target.transform.localScale = new Vector3(targetScale, targetScale, targetScale);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ public class FoodCollectorAgent : Agent
public bool contribute;
public bool useVectorObs;

EnvironmentParameters m_ResetParams;

public override void Initialize()
{
m_AgentRb = GetComponent<Rigidbody>();
m_MyArea = area.GetComponent<FoodCollectorArea>();
m_FoodCollecterSettings = FindObjectOfType<FoodCollectorSettings>();

m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters();
}

Expand Down Expand Up @@ -271,12 +272,12 @@ void OnCollisionEnter(Collision collision)

public void SetLaserLengths()
{
m_LaserLength = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().GetPropertyWithDefault("laser_length", 1.0f);
m_LaserLength = m_ResetParams.GetParameterWithDefault("laser_length", 1.0f);
}

public void SetAgentScale()
{
float agentScale = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().GetPropertyWithDefault("agent_scale", 1.0f);
float agentScale = m_ResetParams.GetParameterWithDefault("agent_scale", 1.0f);
gameObject.transform.localScale = new Vector3(agentScale, agentScale, agentScale);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using System;
using UnityEngine;
using UnityEngine.UI;
using MLAgents;
using MLAgents.SideChannels;

public class FoodCollectorSettings : MonoBehaviour
{
Expand All @@ -14,15 +12,15 @@ public class FoodCollectorSettings : MonoBehaviour
public int totalScore;
public Text scoreText;

StatsSideChannel m_statsSideChannel;
StatsRecorder m_Recorder;

public void Awake()
{
Academy.Instance.OnEnvironmentReset += EnvironmentReset;
m_statsSideChannel = SideChannelUtils.GetSideChannel<StatsSideChannel>();
m_Recorder = Academy.Instance.StatsRecorder;
}

public void EnvironmentReset()
private void EnvironmentReset()
{
ClearObjects(GameObject.FindGameObjectsWithTag("food"));
ClearObjects(GameObject.FindGameObjectsWithTag("badFood"));
Expand Down Expand Up @@ -54,7 +52,7 @@ public void Update()
// need to send every Update() call.
if ((Time.frameCount % 100)== 0)
{
m_statsSideChannel?.AddStat("TotalScore", totalScore);
m_Recorder.Add("TotalScore", totalScore);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ public class GridAgent : Agent
const int k_Left = 3;
const int k_Right = 4;

EnvironmentParameters m_ResetParams;

public override void Initialize()
{
m_ResetParams = Academy.Instance.EnvironmentParameters;
}

public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
{
// Mask the necessary actions if selected by the user.
Expand All @@ -37,7 +44,7 @@ public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMaske
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
var positionZ = (int)transform.position.z;
var maxPosition = (int)SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().GetPropertyWithDefault("gridSize", 5f) - 1;
var maxPosition = (int)m_ResetParams.GetParameterWithDefault("gridSize", 5f) - 1;

if (positionX == 0)
{
Expand Down
20 changes: 10 additions & 10 deletions Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ public class GridArea : MonoBehaviour

public GameObject trueAgent;

FloatPropertiesChannel m_ResetParameters;

Camera m_AgentCam;

public GameObject goalPref;
Expand All @@ -30,9 +28,11 @@ public class GridArea : MonoBehaviour

Vector3 m_InitialPosition;

EnvironmentParameters m_ResetParams;

public void Start()
{
m_ResetParameters = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
m_ResetParams = Academy.Instance.EnvironmentParameters;

m_Objects = new[] { goalPref, pitPref };

Expand All @@ -50,23 +50,23 @@ public void Start()
m_InitialPosition = transform.position;
}

public void SetEnvironment()
private void SetEnvironment()
{
transform.position = m_InitialPosition * (m_ResetParameters.GetPropertyWithDefault("gridSize", 5f) + 1);
transform.position = m_InitialPosition * (m_ResetParams.GetParameterWithDefault("gridSize", 5f) + 1);
var playersList = new List<int>();

for (var i = 0; i < (int)m_ResetParameters.GetPropertyWithDefault("numObstacles", 1); i++)
for (var i = 0; i < (int)m_ResetParams.GetParameterWithDefault("numObstacles", 1); i++)
{
playersList.Add(1);
}

for (var i = 0; i < (int)m_ResetParameters.GetPropertyWithDefault("numGoals", 1f); i++)
for (var i = 0; i < (int)m_ResetParams.GetParameterWithDefault("numGoals", 1f); i++)
{
playersList.Add(0);
}
players = playersList.ToArray();

var gridSize = (int)m_ResetParameters.GetPropertyWithDefault("gridSize", 5f);
var gridSize = (int)m_ResetParams.GetParameterWithDefault("gridSize", 5f);
m_Plane.transform.localScale = new Vector3(gridSize / 10.0f, 1f, gridSize / 10.0f);
m_Plane.transform.localPosition = new Vector3((gridSize - 1) / 2f, -0.5f, (gridSize - 1) / 2f);
m_Sn.transform.localScale = new Vector3(1, 1, gridSize + 2);
Expand All @@ -84,7 +84,7 @@ public void SetEnvironment()

public void AreaReset()
{
var gridSize = (int)m_ResetParameters.GetPropertyWithDefault("gridSize", 5f);
var gridSize = (int)m_ResetParams.GetParameterWithDefault("gridSize", 5f);
foreach (var actor in actorObjs)
{
DestroyImmediate(actor);
Expand All @@ -98,7 +98,7 @@ public void AreaReset()
{
numbers.Add(Random.Range(0, gridSize * gridSize));
}
var numbersA = Enumerable.ToArray(numbers);
var numbersA = numbers.ToArray();

for (var i = 0; i < players.Length; i++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public class GridSettings : MonoBehaviour

public void Awake()
{
SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().RegisterCallback("gridSize", f =>
Academy.Instance.EnvironmentParameters.RegisterCallback("gridSize", f =>
{
MainCamera.transform.position = new Vector3(-(f - 1) / 2f, f * 1.25f, -(f - 1) / 2f);
MainCamera.orthographicSize = (f + 5f) / 2f;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public class PushAgentBasic : Agent
/// </summary>
Renderer m_GroundRenderer;

private EnvironmentParameters m_ResetParams;

void Awake()
{
m_PushBlockSettings = FindObjectOfType<PushBlockSettings>();
Expand All @@ -70,6 +72,8 @@ public override void Initialize()
// Starting material
m_GroundMaterial = m_GroundRenderer.material;

m_ResetParams = Academy.Instance.EnvironmentParameters;

SetResetParameters();
}

Expand Down Expand Up @@ -226,27 +230,23 @@ public override void OnEpisodeBegin()

public void SetGroundMaterialFriction()
{
var resetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();

var groundCollider = ground.GetComponent<Collider>();

groundCollider.material.dynamicFriction = resetParams.GetPropertyWithDefault("dynamic_friction", 0);
groundCollider.material.staticFriction = resetParams.GetPropertyWithDefault("static_friction", 0);
groundCollider.material.dynamicFriction = m_ResetParams.GetParameterWithDefault("dynamic_friction", 0);
groundCollider.material.staticFriction = m_ResetParams.GetParameterWithDefault("static_friction", 0);
}

public void SetBlockProperties()
{
var resetParams = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();

var scale = resetParams.GetPropertyWithDefault("block_scale", 2);
var scale = m_ResetParams.GetParameterWithDefault("block_scale", 2);
//Set the scale of the block
m_BlockRb.transform.localScale = new Vector3(scale, 0.75f, scale);

// Set the drag of the block
m_BlockRb.drag = resetParams.GetPropertyWithDefault("block_drag", 0.5f);
m_BlockRb.drag = m_ResetParams.GetParameterWithDefault("block_drag", 0.5f);
}

public void SetResetParameters()
private void SetResetParameters()
{
SetGroundMaterialFriction();
SetBlockProperties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ public class ReacherAgent : Agent
// Frequency of the cosine deviation of the goal along the vertical dimension
float m_DeviationFreq;

private EnvironmentParameters m_ResetParams;

/// <summary>
/// Collect the rigidbodies of the reacher in order to resue them for
/// observations and actions.
Expand All @@ -30,6 +32,8 @@ public override void Initialize()
m_RbA = pendulumA.GetComponent<Rigidbody>();
m_RbB = pendulumB.GetComponent<Rigidbody>();

m_ResetParams = Academy.Instance.EnvironmentParameters;

SetResetParameters();
}

Expand Down Expand Up @@ -110,10 +114,9 @@ public override void OnEpisodeBegin()

public void SetResetParameters()
{
var fp = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>();
m_GoalSize = fp.GetPropertyWithDefault("goal_size", 5);
m_GoalSpeed = Random.Range(-1f, 1f) * fp.GetPropertyWithDefault("goal_speed", 1);
m_Deviation = fp.GetPropertyWithDefault("deviation", 0);
m_DeviationFreq = fp.GetPropertyWithDefault("deviation_freq", 0);
m_GoalSize = m_ResetParams.GetParameterWithDefault("goal_size", 5);
m_GoalSpeed = Random.Range(-1f, 1f) * m_ResetParams.GetParameterWithDefault("goal_speed", 1);
m_Deviation = m_ResetParams.GetParameterWithDefault("deviation", 0);
m_DeviationFreq = m_ResetParams.GetParameterWithDefault("deviation_freq", 0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public void Awake()

// Make sure the Academy singleton is initialized first, since it will create the SideChannels.
mmattar marked this conversation as resolved.
Show resolved Hide resolved
var academy = Academy.Instance;
SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().RegisterCallback("gravity", f => { Physics.gravity = new Vector3(0, -f, 0); });
Academy.Instance.EnvironmentParameters.RegisterCallback("gravity", f => { Physics.gravity = new Vector3(0, -f, 0); });
}

public void OnDestroy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public enum Position
BehaviorParameters m_BehaviorParameters;
Vector3 m_Transform;

private EnvironmentParameters m_ResetParams;

public override void Initialize()
{
m_Existential = 1f / maxStep;
Expand All @@ -73,7 +75,7 @@ public override void Initialize()
m_LateralSpeed = 0.3f;
m_ForwardSpeed = 1.3f;
}
else
else
{
m_LateralSpeed = 0.3f;
m_ForwardSpeed = 1.0f;
Expand All @@ -91,6 +93,8 @@ public override void Initialize()
area.playerStates.Add(playerState);
m_PlayerIndex = area.playerStates.IndexOf(playerState);
playerState.playerIndex = m_PlayerIndex;

m_ResetParams = Academy.Instance.EnvironmentParameters;
}

public void MoveAgent(float[] act)
Expand Down Expand Up @@ -214,7 +218,7 @@ public override void OnEpisodeBegin()
{

timePenalty = 0;
m_BallTouch = SideChannelUtils.GetSideChannel<FloatPropertiesChannel>().GetPropertyWithDefault("ball_touch", 0);
m_BallTouch = m_ResetParams.GetParameterWithDefault("ball_touch", 0);
if (team == Team.Purple)
{
transform.rotation = Quaternion.Euler(0f, -90f, 0f);
Expand Down
Loading