diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 94bc24bf07..a76c3b393d 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -44,7 +44,12 @@ repos:
hooks:
- id: pyupgrade
args: [--py3-plus, --py36-plus]
- exclude: .*barracuda.py
+ exclude: >
+ (?x)^(
+ .*barracuda.py|
+ .*_pb2.py|
+ .*_pb2_grpc.py
+ )$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.5.0
diff --git a/.yamato/com.unity.ml-agents-performance.yml b/.yamato/com.unity.ml-agents-performance.yml
index 62351afae9..f79e85d56a 100644
--- a/.yamato/com.unity.ml-agents-performance.yml
+++ b/.yamato/com.unity.ml-agents-performance.yml
@@ -12,7 +12,7 @@ Run_Mac_Perfomance_Tests{{ editor.version }}:
variables:
UNITY_VERSION: {{ editor.version }}
commands:
- - python -m pip install unity-downloader-cli --extra-index-url https://artifactory.eu-cph-1.unityops.net/api/pypi/common-python/simple
+ - python -m pip install unity-downloader-cli --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
- unity-downloader-cli -u {{ editor.version }} -c editor --wait --fast
- curl -s https://artifactory.internal.unity3d.com/core-automation/tools/utr-standalone/utr --output utr
- chmod +x ./utr
diff --git a/.yamato/com.unity.ml-agents-test.yml b/.yamato/com.unity.ml-agents-test.yml
index d0dc31b21a..71498d850f 100644
--- a/.yamato/com.unity.ml-agents-test.yml
+++ b/.yamato/com.unity.ml-agents-test.yml
@@ -113,7 +113,7 @@ test_{{ package.name }}_{{ platform.name }}_trunk:
image: {{ platform.image }}
flavor: {{ platform.flavor}}
commands:
- - python -m pip install unity-downloader-cli --extra-index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/unity-pypi-local/simple --upgrade
+ - python -m pip install unity-downloader-cli --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple --upgrade
- unity-downloader-cli -u trunk -c editor --wait --fast
- npm install upm-ci-utils@stable -g --registry https://artifactory.prd.cds.internal.unity3d.com/artifactory/api/npm/upm-npm
- upm-ci project test -u {{ editor.version }} --project-path Project --package-filter {{ package.name }} {{ coverageOptions }}
diff --git a/.yamato/gym-interface-test.yml b/.yamato/gym-interface-test.yml
index a5fdcfdfc9..fd2aa8a09c 100644
--- a/.yamato/gym-interface-test.yml
+++ b/.yamato/gym-interface-test.yml
@@ -11,7 +11,7 @@ test_gym_interface_{{ editor.version }}:
variables:
UNITY_VERSION: {{ editor.version }}
commands:
- - pip install pyyaml
+ - pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
- python -u -m ml-agents.tests.yamato.setup_venv
- ./venv/bin/python ml-agents/tests/yamato/scripts/run_gym.py --env=artifacts/testPlayer-Basic
dependencies:
@@ -21,12 +21,12 @@ test_gym_interface_{{ editor.version }}:
expression: |
(pull_request.target eq "master" OR
pull_request.target match "release.+") AND
- NOT pull_request.draft AND
- (pull_request.changes.any match "com.unity.ml-agents/**" OR
- pull_request.changes.any match "Project/**" OR
- pull_request.changes.any match "ml-agents/**" OR
- pull_request.changes.any match "ml-agents-envs/**" OR
- pull_request.changes.any match "gym-unity/**" OR
+ NOT pull_request.draft AND
+ (pull_request.changes.any match "com.unity.ml-agents/**" OR
+ pull_request.changes.any match "Project/**" OR
+ pull_request.changes.any match "ml-agents/**" OR
+ pull_request.changes.any match "ml-agents-envs/**" OR
+ pull_request.changes.any match "gym-unity/**" OR
pull_request.changes.any match ".yamato/gym-interface-test.yml") AND
NOT pull_request.changes.all match "**/*.md"
{% endfor %}
diff --git a/.yamato/protobuf-generation-test.yml b/.yamato/protobuf-generation-test.yml
index 96f190775d..e91a2ddf89 100644
--- a/.yamato/protobuf-generation-test.yml
+++ b/.yamato/protobuf-generation-test.yml
@@ -13,9 +13,8 @@ test_mac_protobuf_generation:
nuget install Grpc.Tools -Version $GRPC_VERSION -OutputDirectory protobuf-definitions/
python3 -m venv venv
. venv/bin/activate
- pip install --upgrade pip
- pip install grpcio-tools==1.13.0 --progress-bar=off
- pip install mypy-protobuf==1.16.0 --progress-bar=off
+ pip install --upgrade pip --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
+ pip install grpcio==1.28.1 grpcio-tools==1.13.0 protobuf==3.11.3 six==1.14.0 mypy-protobuf==1.16.0 --progress-bar=off --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
cd protobuf-definitions
chmod +x Grpc.Tools.$GRPC_VERSION/tools/macosx_x64/protoc
chmod +x Grpc.Tools.$GRPC_VERSION/tools/macosx_x64/grpc_csharp_plugin
diff --git a/.yamato/python-ll-api-test.yml b/.yamato/python-ll-api-test.yml
index aa816ec68a..983597313e 100644
--- a/.yamato/python-ll-api-test.yml
+++ b/.yamato/python-ll-api-test.yml
@@ -11,9 +11,9 @@ test_mac_ll_api_{{ editor.version }}:
variables:
UNITY_VERSION: {{ editor.version }}
commands:
- - pip install pyyaml
+ - pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
- python -u -m ml-agents.tests.yamato.setup_venv
- - ./venv/bin/python ml-agents/tests/yamato/scripts/run_llapi.py
+ - ./venv/bin/python ml-agents/tests/yamato/scripts/run_llapi.py
- ./venv/bin/python ml-agents/tests/yamato/scripts/run_llapi.py --env=artifacts/testPlayer-Basic
- ./venv/bin/python ml-agents/tests/yamato/scripts/run_llapi.py --env=artifacts/testPlayer-WallJump
- ./venv/bin/python ml-agents/tests/yamato/scripts/run_llapi.py --env=artifacts/testPlayer-Bouncer
@@ -24,11 +24,11 @@ test_mac_ll_api_{{ editor.version }}:
expression: |
(pull_request.target eq "master" OR
pull_request.target match "release.+") AND
- NOT pull_request.draft AND
- (pull_request.changes.any match "com.unity.ml-agents/**" OR
- pull_request.changes.any match "Project/**" OR
- pull_request.changes.any match "ml-agents/**" OR
- pull_request.changes.any match "ml-agents-envs/**" OR
+ NOT pull_request.draft AND
+ (pull_request.changes.any match "com.unity.ml-agents/**" OR
+ pull_request.changes.any match "Project/**" OR
+ pull_request.changes.any match "ml-agents/**" OR
+ pull_request.changes.any match "ml-agents-envs/**" OR
pull_request.changes.any match ".yamato/python-ll-api-test.yml") AND
NOT pull_request.changes.all match "**/*.md"
{% endfor %}
diff --git a/.yamato/standalone-build-test.yml b/.yamato/standalone-build-test.yml
index a769ff9686..3077bff8bf 100644
--- a/.yamato/standalone-build-test.yml
+++ b/.yamato/standalone-build-test.yml
@@ -12,7 +12,7 @@ test_mac_standalone_{{ editor.version }}:
variables:
UNITY_VERSION: {{ editor.version }}
commands:
- - pip install pyyaml
+ - pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
- python -u -m ml-agents.tests.yamato.standalone_build_tests
- python -u -m ml-agents.tests.yamato.standalone_build_tests --scene=Assets/ML-Agents/Examples/Basic/Scenes/Basic.unity
- python -u -m ml-agents.tests.yamato.standalone_build_tests --scene=Assets/ML-Agents/Examples/Bouncer/Scenes/Bouncer.unity
diff --git a/.yamato/training-int-tests.yml b/.yamato/training-int-tests.yml
index b7839bb0f6..05f79e5cb2 100644
--- a/.yamato/training-int-tests.yml
+++ b/.yamato/training-int-tests.yml
@@ -12,7 +12,7 @@ test_mac_training_int_{{ editor.version }}:
variables:
UNITY_VERSION: {{ editor.version }}
commands:
- - pip install pyyaml
+ - pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
- python -u -m ml-agents.tests.yamato.training_int_tests
# Backwards-compatibility tests.
# If we make a breaking change to the communication protocol, these will need
diff --git a/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab b/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab
index ce6583ecd3..22474a8612 100644
--- a/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab
+++ b/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab
@@ -2742,6 +2742,7 @@ GameObject:
- component: {fileID: 4845971001715176662}
- component: {fileID: 4845971001715176663}
- component: {fileID: 4845971001715176660}
+ - component: {fileID: 4622120667686875944}
m_Layer: 0
m_Name: Crawler
m_TagString: Untagged
@@ -2779,7 +2780,7 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- VectorObservationSize: 138
+ VectorObservationSize: 21
NumStackedVectorObservations: 1
VectorActionSize: 14000000
VectorActionDescriptions: []
@@ -2872,6 +2873,30 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
debugCommandLineOverride:
+--- !u!114 &4622120667686875944
+MonoBehaviour:
+ m_ObjectHideFlags: 0
+ m_CorrespondingSourceObject: {fileID: 0}
+ m_PrefabInstance: {fileID: 0}
+ m_PrefabAsset: {fileID: 0}
+ m_GameObject: {fileID: 4845971001715176661}
+ m_Enabled: 1
+ m_EditorHideFlags: 0
+ m_Script: {fileID: 11500000, guid: df0f8be9a37d6486498061e2cbc4cd94, type: 3}
+ m_Name:
+ m_EditorClassIdentifier:
+ RootBody: {fileID: 4845971001588102145}
+ VirtualRoot: {fileID: 2270141184585723037}
+ Settings:
+ UseModelSpaceTranslations: 1
+ UseModelSpaceRotations: 1
+ UseLocalSpaceTranslations: 0
+ UseLocalSpaceRotations: 1
+ UseModelSpaceLinearVelocity: 1
+ UseLocalSpaceLinearVelocity: 0
+ UseJointPositionsAndAngles: 0
+ UseJointForces: 0
+ sensorName:
--- !u!1 &4845971001730692034
GameObject:
m_ObjectHideFlags: 0
@@ -3018,6 +3043,12 @@ PrefabInstance:
objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: 72f745913c5a34df5aaadd5c1f0024cb, type: 3}
+--- !u!1 &2270141184585723037 stripped
+GameObject:
+ m_CorrespondingSourceObject: {fileID: 2591864627249999519, guid: 72f745913c5a34df5aaadd5c1f0024cb,
+ type: 3}
+ m_PrefabInstance: {fileID: 4357529801223143938}
+ m_PrefabAsset: {fileID: 0}
--- !u!4 &2270141184585723026 stripped
Transform:
m_CorrespondingSourceObject: {fileID: 2591864627249999504, guid: 72f745913c5a34df5aaadd5c1f0024cb,
@@ -3030,7 +3061,7 @@ MonoBehaviour:
type: 3}
m_PrefabInstance: {fileID: 4357529801223143938}
m_PrefabAsset: {fileID: 0}
- m_GameObject: {fileID: 0}
+ m_GameObject: {fileID: 2270141184585723037}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 771e78c5e980e440e8cd19716b55075f, type: 3}
diff --git a/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab b/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab
index fb182341b2..aa1378135e 100644
--- a/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab
+++ b/Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab
@@ -392,6 +392,11 @@ PrefabInstance:
propertyPath: targetToLookAt
value:
objectReference: {fileID: 2673081981996998229}
+ - target: {fileID: 4622120667686875944, guid: 0456c89e8c9c243d595b039fe7aa0bf9,
+ type: 3}
+ propertyPath: Settings.UseLocalSpaceLinearVelocity
+ value: 1
+ objectReference: {fileID: 0}
- target: {fileID: 4845971000000621469, guid: 0456c89e8c9c243d595b039fe7aa0bf9,
type: 3}
propertyPath: m_ConnectedAnchor.x
diff --git a/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs b/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
index e0b7951833..fda546a13c 100644
--- a/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
+++ b/Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
@@ -91,17 +91,8 @@ public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
//GROUND CHECK
sensor.AddObservation(bp.groundContact.touchingGround); // Is this bp touching the ground
- //Get velocities in the context of our orientation cube's space
- //Note: You can get these velocities in world space as well but it may not train as well.
- sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.velocity));
- sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.angularVelocity));
-
- //Get position relative to hips in the context of our orientation cube's space
- sensor.AddObservation(orientationCube.transform.InverseTransformDirection(bp.rb.position - body.position));
-
if (bp.rb.transform != body)
{
- sensor.AddObservation(bp.rb.transform.localRotation);
sensor.AddObservation(bp.currentStrength / m_JdController.maxJointForceLimit);
}
}
@@ -111,9 +102,6 @@ public void CollectObservationBodyPart(BodyPart bp, VectorSensor sensor)
///
public override void CollectObservations(VectorSensor sensor)
{
- //Add body rotation delta relative to orientation cube
- sensor.AddObservation(Quaternion.FromToRotation(body.forward, orientationCube.transform.forward));
-
//Add pos of target relative to orientation cube
sensor.AddObservation(orientationCube.transform.InverseTransformPoint(target.transform.position));
diff --git a/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn b/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn
index 9413f25653..9902063433 100644
Binary files a/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn and b/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn differ
diff --git a/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn b/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn
index 70fac39c0c..85e15c089d 100644
Binary files a/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn and b/Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn differ
diff --git a/Project/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockVisualArea.prefab b/Project/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockVisualArea.prefab
index a7e3d67275..6bca403623 100644
--- a/Project/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockVisualArea.prefab
+++ b/Project/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockVisualArea.prefab
@@ -859,17 +859,18 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
- vectorObservationSize: 0
- numStackedVectorObservations: 1
- vectorActionSize: 07000000
- vectorActionDescriptions: []
- vectorActionSpaceType: 0
+ VectorObservationSize: 0
+ NumStackedVectorObservations: 1
+ VectorActionSize: 07000000
+ VectorActionDescriptions: []
+ VectorActionSpaceType: 0
m_Model: {fileID: 0}
m_InferenceDevice: 0
m_BehaviorType: 0
- m_BehaviorName: VisualHallway
- m_TeamID: 0
- m_useChildSensors: 1
+ m_BehaviorName: VisualPushBlock
+ TeamId: 0
+ m_UseChildSensors: 1
+ m_ObservableAttributeHandling: 0
--- !u!114 &114812843792483960
MonoBehaviour:
m_ObjectHideFlags: 0
@@ -882,7 +883,10 @@ MonoBehaviour:
m_Script: {fileID: 11500000, guid: dea8c4f2604b947e6b7b97750dde87ca, type: 3}
m_Name:
m_EditorClassIdentifier:
- maxStep: 5000
+ agentParameters:
+ maxStep: 0
+ hasUpgradedFromAgentParameters: 1
+ MaxStep: 5000
ground: {fileID: 1913379827958244}
area: {fileID: 1632733799967290}
areaBounds:
@@ -917,12 +921,12 @@ MonoBehaviour:
m_Script: {fileID: 11500000, guid: 282f342c2ab144bf38be65d4d0c4e07d, type: 3}
m_Name:
m_EditorClassIdentifier:
- camera: {fileID: 20961401228419460}
- sensorName: CameraSensor
- width: 84
- height: 84
- grayscale: 0
- compression: 1
+ m_Camera: {fileID: 20961401228419460}
+ m_SensorName: CameraSensor
+ m_Width: 84
+ m_Height: 84
+ m_Grayscale: 0
+ m_Compression: 1
--- !u!114 &9049837659352187721
MonoBehaviour:
m_ObjectHideFlags: 0
@@ -936,8 +940,7 @@ MonoBehaviour:
m_Name:
m_EditorClassIdentifier:
DecisionPeriod: 5
- RepeatAction: 1
- offsetStep: 0
+ TakeActionsBetweenDecisions: 1
--- !u!1 &1626651094211584
GameObject:
m_ObjectHideFlags: 0
diff --git a/Project/ProjectSettings/ProjectVersion.txt b/Project/ProjectSettings/ProjectVersion.txt
index ef4a753e2f..b71c05700f 100644
--- a/Project/ProjectSettings/ProjectVersion.txt
+++ b/Project/ProjectSettings/ProjectVersion.txt
@@ -1 +1 @@
-m_EditorVersion: 2018.4.20f1
+m_EditorVersion: 2018.4.24f1
diff --git a/README.md b/README.md
index 1a7d940249..6fc6a3af4b 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
# Unity ML-Agents Toolkit
-[![docs badge](https://img.shields.io/badge/docs-reference-blue.svg)](https://github.com/Unity-Technologies/ml-agents/tree/release_4_docs/docs/)
+[![docs badge](https://img.shields.io/badge/docs-reference-blue.svg)](https://github.com/Unity-Technologies/ml-agents/tree/release_5_docs/docs/)
[![license badge](https://img.shields.io/badge/license-Apache--2.0-green.svg)](LICENSE)
@@ -48,8 +48,8 @@ descriptions of all these features.
## Releases & Documentation
-**Our latest, stable release is `Release 4`. Click
-[here](https://github.com/Unity-Technologies/ml-agents/tree/release_4_docs/docs/Readme.md)
+**Our latest, stable release is `Release 5`. Click
+[here](https://github.com/Unity-Technologies/ml-agents/tree/release_5_docs/docs/Readme.md)
to get started with the latest release of ML-Agents.**
The table below lists all our releases, including our `master` branch which is
@@ -67,13 +67,13 @@ under active development and may be unstable. A few helpful guidelines:
| **Version** | **Release Date** | **Source** | **Documentation** | **Download** |
|:-------:|:------:|:-------------:|:-------:|:------------:|
| **master (unstable)** | -- | [source](https://github.com/Unity-Technologies/ml-agents/tree/master) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/master/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/master.zip) |
-| **Release 4** | **July 15, 2020** | **[source](https://github.com/Unity-Technologies/ml-agents/tree/release_4)** | **[docs](https://github.com/Unity-Technologies/ml-agents/tree/release_4_docs/docs/Readme.md)** | **[download](https://github.com/Unity-Technologies/ml-agents/archive/release_4.zip)** |
+| **Release 5** | **July 31, 2020** | **[source](https://github.com/Unity-Technologies/ml-agents/tree/release_5)** | **[docs](https://github.com/Unity-Technologies/ml-agents/tree/release_5_docs/docs/Readme.md)** | **[download](https://github.com/Unity-Technologies/ml-agents/archive/release_5.zip)** |
+| **Release 4** | July 15, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_4) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_4_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_4.zip) |
| **Release 3** | June 10, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_3) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_3_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_3.zip) |
| **Release 2** | May 20, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_2) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_2_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_2.zip) |
| **Release 1** | April 30, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_1) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_1_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_1.zip) |
| **0.15.1** | March 30, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/0.15.1) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/0.15.1/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/0.15.1.zip) |
| **0.15.0** | March 18, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/0.15.0) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/0.15.0/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/0.15.0.zip) |
-| **0.14.1** | February 26, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/0.14.1) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/0.14.1/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/0.14.1.zip) |
## Citation
diff --git a/com.unity.ml-agents.extensions/README.md b/com.unity.ml-agents.extensions/README.md
index 651f450e09..5cba2759c9 100644
--- a/com.unity.ml-agents.extensions/README.md
+++ b/com.unity.ml-agents.extensions/README.md
@@ -1,3 +1,5 @@
# ML-Agents Extensions
This is a source-only package for new features based on ML-Agents.
+
+More details coming soon.
diff --git a/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs b/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs
new file mode 100644
index 0000000000..0cd831e21f
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs
@@ -0,0 +1,3 @@
+using System.Runtime.CompilerServices;
+
+[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.EditorTests")]
diff --git a/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs.meta b/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs.meta
new file mode 100644
index 0000000000..21cec76829
--- /dev/null
+++ b/com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs.meta
@@ -0,0 +1,11 @@
+fileFormatVersion: 2
+guid: 48c8790647c3345e19c57d6c21065112
+MonoImporter:
+ externalObjects: {}
+ serializedVersion: 2
+ defaultReferences: []
+ executionOrder: 0
+ icon: {instanceID: 0}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
index f354a614b7..49aef67b74 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
@@ -54,17 +54,17 @@ public ArticulationBodyPoseExtractor(ArticulationBody rootBody)
parentIndices[i] = bodyToIndex[parentArticBody];
}
- SetParentIndices(parentIndices);
+ Setup(parentIndices);
}
///
- protected override Vector3 GetLinearVelocityAt(int index)
+ protected internal override Vector3 GetLinearVelocityAt(int index)
{
return m_Bodies[index].velocity;
}
///
- protected override Pose GetPoseAt(int index)
+ protected internal override Pose GetPoseAt(int index)
{
var body = m_Bodies[index];
var go = body.gameObject;
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
index de9d3866f6..ec9eddfae1 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
@@ -18,12 +18,20 @@ public class PhysicsBodySensor : ISensor
///
/// Construct a new PhysicsBodySensor
///
- ///
+ /// The root Rigidbody. This has no Joints on it (but other Joints may connect to it).
+ /// Optional GameObject used to find Rigidbodies in the hierarchy.
+ /// Optional GameObject used to determine the root of the poses,
///
///
- public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null)
+ public PhysicsBodySensor(
+ Rigidbody rootBody,
+ GameObject rootGameObject,
+ GameObject virtualRoot,
+ PhysicsSensorSettings settings,
+ string sensorName=null
+ )
{
- var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
+ var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject, virtualRoot);
m_PoseExtractor = poseExtractor;
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
index 9109d9592e..5488be8666 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
@@ -1,5 +1,4 @@
using System;
-
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
@@ -95,25 +94,26 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting
var offset = baseOffset;
if (settings.UseModelSpace)
{
- var poses = poseExtractor.ModelSpacePoses;
- var vels = poseExtractor.ModelSpaceVelocities;
-
- for(var i=0; i
- /// Read access to the model space transforms.
+ /// Read iterator for the enabled model space transforms.
///
- public IList ModelSpacePoses
+ public IEnumerable GetEnabledModelSpacePoses()
{
- get { return m_ModelSpacePoses; }
+ if (m_ModelSpacePoses == null)
+ {
+ yield break;
+ }
+
+ for (var i = 0; i < m_ModelSpacePoses.Length; i++)
+ {
+ if (m_PoseEnabled[i])
+ {
+ yield return m_ModelSpacePoses[i];
+ }
+ }
}
///
- /// Read access to the local space transforms.
+ /// Read iterator for the enabled local space transforms.
///
- public IList LocalSpacePoses
+ public IEnumerable GetEnabledLocalSpacePoses()
{
- get { return m_LocalSpacePoses; }
+ if (m_LocalSpacePoses == null)
+ {
+ yield break;
+ }
+
+ for (var i = 0; i < m_LocalSpacePoses.Length; i++)
+ {
+ if (m_PoseEnabled[i])
+ {
+ yield return m_LocalSpacePoses[i];
+ }
+ }
}
///
- /// Read access to the model space linear velocities.
+ /// Read iterator for the enabled model space linear velocities.
///
- public IList ModelSpaceVelocities
+ public IEnumerable GetEnabledModelSpaceVelocities()
{
- get { return m_ModelSpaceLinearVelocities; }
+ if (m_ModelSpaceLinearVelocities == null)
+ {
+ yield break;
+ }
+
+ for (var i = 0; i < m_ModelSpaceLinearVelocities.Length; i++)
+ {
+ if (m_PoseEnabled[i])
+ {
+ yield return m_ModelSpaceLinearVelocities[i];
+ }
+ }
}
///
- /// Read access to the local space linear velocities.
+ /// Read iterator for the enabled local space linear velocities.
///
- public IList LocalSpaceVelocities
+ public IEnumerable GetEnabledLocalSpaceVelocities()
{
- get { return m_LocalSpaceLinearVelocities; }
+ if (m_LocalSpaceLinearVelocities == null)
+ {
+ yield break;
+ }
+
+ for (var i = 0; i < m_LocalSpaceLinearVelocities.Length; i++)
+ {
+ if (m_PoseEnabled[i])
+ {
+ yield return m_LocalSpaceLinearVelocities[i];
+ }
+ }
}
///
- /// Number of poses in the hierarchy (read-only).
+ /// Number of enabled poses in the hierarchy (read-only).
+ ///
+ public int NumEnabledPoses
+ {
+ get
+ {
+ if (m_PoseEnabled == null)
+ {
+ return 0;
+ }
+
+ var numEnabled = 0;
+ for (var i = 0; i < m_PoseEnabled.Length; i++)
+ {
+ numEnabled += m_PoseEnabled[i] ? 1 : 0;
+ }
+
+ return numEnabled;
+ }
+ }
+
+ ///
+ /// Number of total poses in the hierarchy (read-only).
///
public int NumPoses
{
- get { return m_ModelSpacePoses?.Length ?? 0; }
+ get { return m_ModelSpacePoses?.Length ?? 0; }
}
///
@@ -77,20 +145,43 @@ public int GetParentIndex(int index)
return m_ParentIndices[index];
}
+ ///
+ /// Set whether the pose at the given index is enabled or disabled for observations.
+ ///
+ ///
+ ///
+ public void SetPoseEnabled(int index, bool val)
+ {
+ m_PoseEnabled[index] = val;
+ }
+
///
/// Initialize with the mapping of parent indices.
/// The 0th element is assumed to be -1, indicating that it's the root.
///
///
- protected void SetParentIndices(int[] parentIndices)
+ protected void Setup(int[] parentIndices)
{
+#if DEBUG
+ if (parentIndices[0] != -1)
+ {
+ throw new UnityAgentsException($"Expected parentIndices[0] to be -1, got {parentIndices[0]}");
+ }
+#endif
m_ParentIndices = parentIndices;
- var numTransforms = parentIndices.Length;
- m_ModelSpacePoses = new Pose[numTransforms];
- m_LocalSpacePoses = new Pose[numTransforms];
+ var numPoses = parentIndices.Length;
+ m_ModelSpacePoses = new Pose[numPoses];
+ m_LocalSpacePoses = new Pose[numPoses];
- m_ModelSpaceLinearVelocities = new Vector3[numTransforms];
- m_LocalSpaceLinearVelocities = new Vector3[numTransforms];
+ m_ModelSpaceLinearVelocities = new Vector3[numPoses];
+ m_LocalSpaceLinearVelocities = new Vector3[numPoses];
+
+ m_PoseEnabled = new bool[numPoses];
+ // All poses are enabled by default. Generally we'll want to disable the root though.
+ for (var i = 0; i < numPoses; i++)
+ {
+ m_PoseEnabled[i] = true;
+ }
}
///
@@ -98,14 +189,14 @@ protected void SetParentIndices(int[] parentIndices)
///
///
///
- protected abstract Pose GetPoseAt(int index);
+ protected internal abstract Pose GetPoseAt(int index);
///
/// Return the world space linear velocity of the i'th object.
///
///
///
- protected abstract Vector3 GetLinearVelocityAt(int index);
+ protected internal abstract Vector3 GetLinearVelocityAt(int index);
///
@@ -113,24 +204,27 @@ protected void SetParentIndices(int[] parentIndices)
///
public void UpdateModelSpacePoses()
{
- if (m_ModelSpacePoses == null)
+ using (TimerStack.Instance.Scoped("UpdateModelSpacePoses"))
{
- return;
- }
+ if (m_ModelSpacePoses == null)
+ {
+ return;
+ }
- var rootWorldTransform = GetPoseAt(0);
- var worldToModel = rootWorldTransform.Inverse();
- var rootLinearVel = GetLinearVelocityAt(0);
+ var rootWorldTransform = GetPoseAt(0);
+ var worldToModel = rootWorldTransform.Inverse();
+ var rootLinearVel = GetLinearVelocityAt(0);
- for (var i = 0; i < m_ModelSpacePoses.Length; i++)
- {
- var currentWorldSpacePose = GetPoseAt(i);
- var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose);
- m_ModelSpacePoses[i] = currentModelSpacePose;
+ for (var i = 0; i < m_ModelSpacePoses.Length; i++)
+ {
+ var currentWorldSpacePose = GetPoseAt(i);
+ var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose);
+ m_ModelSpacePoses[i] = currentModelSpacePose;
- var currentBodyLinearVel = GetLinearVelocityAt(i);
- var relativeVelocity = currentBodyLinearVel - rootLinearVel;
- m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity;
+ var currentBodyLinearVel = GetLinearVelocityAt(i);
+ var relativeVelocity = currentBodyLinearVel - rootLinearVel;
+ m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity;
+ }
}
}
@@ -139,30 +233,33 @@ public void UpdateModelSpacePoses()
///
public void UpdateLocalSpacePoses()
{
- if (m_LocalSpacePoses == null)
- {
- return;
- }
-
- for (var i = 0; i < m_LocalSpacePoses.Length; i++)
+ using (TimerStack.Instance.Scoped("UpdateLocalSpacePoses"))
{
- if (m_ParentIndices[i] != -1)
+ if (m_LocalSpacePoses == null)
{
- var parentTransform = GetPoseAt(m_ParentIndices[i]);
- // This is slightly inefficient, since for a body with multiple children, we'll end up inverting
- // the transform multiple times. Might be able to trade space for perf here.
- var invParent = parentTransform.Inverse();
- var currentTransform = GetPoseAt(i);
- m_LocalSpacePoses[i] = invParent.Multiply(currentTransform);
-
- var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]);
- var currentLinearVel = GetLinearVelocityAt(i);
- m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel);
+ return;
}
- else
+
+ for (var i = 0; i < m_LocalSpacePoses.Length; i++)
{
- m_LocalSpacePoses[i] = Pose.identity;
- m_LocalSpaceLinearVelocities[i] = Vector3.zero;
+ if (m_ParentIndices[i] != -1)
+ {
+ var parentTransform = GetPoseAt(m_ParentIndices[i]);
+ // This is slightly inefficient, since for a body with multiple children, we'll end up inverting
+ // the transform multiple times. Might be able to trade space for perf here.
+ var invParent = parentTransform.Inverse();
+ var currentTransform = GetPoseAt(i);
+ m_LocalSpacePoses[i] = invParent.Multiply(currentTransform);
+
+ var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]);
+ var currentLinearVel = GetLinearVelocityAt(i);
+ m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel);
+ }
+ else
+ {
+ m_LocalSpacePoses[i] = Pose.identity;
+ m_LocalSpaceLinearVelocities[i] = Vector3.zero;
+ }
}
}
}
@@ -183,7 +280,7 @@ public int GetNumPoseObservations(PhysicsSensorSettings settings)
obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0;
obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0;
- return NumPoses * obsPerPose;
+ return NumEnabledPoses * obsPerPose;
}
internal void DrawModelSpace(Vector3 offset)
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
index 05b55ef737..44ff9a7641 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
@@ -12,11 +12,21 @@ public class RigidBodyPoseExtractor : PoseExtractor
{
Rigidbody[] m_Bodies;
+ ///
+ /// Optional game object used to determine the root of the poses, separate from the actual Rigidbodies
+ /// in the hierarchy. For locomotion
+ ///
+ GameObject m_VirtualRoot;
+
///
/// Initialize given a root RigidBody.
///
- ///
- public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null)
+ /// The root Rigidbody. This has no Joints on it (but other Joints may connect to it).
+ /// Optional GameObject used to find Rigidbodies in the hierarchy.
+ /// Optional GameObject used to determine the root of the poses,
+ /// separate from the actual Rigidbodies in the hierarchy. For locomotion tasks, with ragdolls, this provides
+ /// a stabilized refernece frame, which can improve learning.
+ public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null, GameObject virtualRoot = null)
{
if (rootBody == null)
{
@@ -32,18 +42,42 @@ public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = nu
{
rbs = rootGameObject.GetComponentsInChildren();
}
- var bodyToIndex = new Dictionary(rbs.Length);
- var parentIndices = new int[rbs.Length];
- if (rbs[0] != rootBody)
+ if (rbs == null || rbs.Length == 0)
+ {
+ Debug.Log("No rigid bodies found!");
+ return;
+ }
+
+ if (rbs[0] != rootBody)
{
Debug.Log("Expected root body at index 0");
return;
}
+ // Adjust the array if we have a virtual root.
+ // This will be at index 0, and the "real" root will be parented to it.
+ if (virtualRoot != null)
+ {
+ var extendedRbs = new Rigidbody[rbs.Length + 1];
+ for (var i = 0; i < rbs.Length; i++)
+ {
+ extendedRbs[i + 1] = rbs[i];
+ }
+
+ rbs = extendedRbs;
+ }
+
+ var bodyToIndex = new Dictionary(rbs.Length);
+ var parentIndices = new int[rbs.Length];
+ parentIndices[0] = -1;
+
for (var i = 0; i < rbs.Length; i++)
{
- bodyToIndex[rbs[i]] = i;
+ if(rbs[i] != null)
+ {
+ bodyToIndex[rbs[i]] = i;
+ }
}
var joints = rootBody.GetComponentsInChildren ();
@@ -59,19 +93,44 @@ public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = nu
parentIndices[childIndex] = parentIndex;
}
+ if (virtualRoot != null)
+ {
+ // Make sure the original root treats the virtual root as its parent.
+ parentIndices[1] = 0;
+ m_VirtualRoot = virtualRoot;
+ }
+
m_Bodies = rbs;
- SetParentIndices(parentIndices);
+ Setup(parentIndices);
+
+ // By default, ignore the root
+ SetPoseEnabled(0, false);
}
///
- protected override Vector3 GetLinearVelocityAt(int index)
+ protected internal override Vector3 GetLinearVelocityAt(int index)
{
+ if (index == 0 && m_VirtualRoot != null)
+ {
+ // No velocity on the virtual root
+ return Vector3.zero;
+ }
return m_Bodies[index].velocity;
}
///
- protected override Pose GetPoseAt(int index)
+ protected internal override Pose GetPoseAt(int index)
{
+ if (index == 0 && m_VirtualRoot != null)
+ {
+ // Use the GameObject's world transform
+ return new Pose
+ {
+ rotation = m_VirtualRoot.transform.rotation,
+ position = m_VirtualRoot.transform.position
+ };
+ }
+
var body = m_Bodies[index];
return new Pose { rotation = body.rotation, position = body.position };
}
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
index ce6cf05379..9a077a6594 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
@@ -13,6 +13,11 @@ public class RigidBodySensorComponent : SensorComponent
///
public Rigidbody RootBody;
+ ///
+ /// Optional GameObject used to determine the root of the poses.
+ ///
+ public GameObject VirtualRoot;
+
///
/// Settings defining what types of observations will be generated.
///
@@ -30,7 +35,7 @@ public class RigidBodySensorComponent : SensorComponent
///
public override ISensor CreateSensor()
{
- return new PhysicsBodySensor(RootBody, gameObject, Settings, sensorName);
+ return new PhysicsBodySensor(RootBody, gameObject, VirtualRoot, Settings, sensorName);
}
///
@@ -43,7 +48,7 @@ public override int[] GetObservationShape()
// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
- var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject);
+ var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot);
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
index 94e708ed68..f642c32c5d 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
@@ -100,17 +100,12 @@ public void TestBodiesWithJoint()
// Local space
0f, 0f, 0f, // Root pos
-#if UNITY_2020_2_OR_NEWER
- 0f, 0f, 0f, // Root vel
-#endif
-
13.37f, 0f, 0f, // Attached pos
-#if UNITY_2020_2_OR_NEWER
- -1f, 1f, 0f, // Attached vel
-#endif
-
4.2f, 0f, 0f, // Leaf pos
+
#if UNITY_2020_2_OR_NEWER
+ 0f, 0f, 0f, // Root vel
+ -1f, 1f, 0f, // Attached vel
0f, -1f, 1f // Leaf vel
#endif
};
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
index 627b4d7b8b..5f862d613d 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
@@ -8,19 +8,19 @@ public class PoseExtractorTests
{
class UselessPoseExtractor : PoseExtractor
{
- protected override Pose GetPoseAt(int index)
+ protected internal override Pose GetPoseAt(int index)
{
return Pose.identity;
}
- protected override Vector3 GetLinearVelocityAt(int index)
+ protected internal override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}
public void Init(int[] parentIndices)
{
- SetParentIndices(parentIndices);
+ Setup(parentIndices);
}
}
@@ -60,10 +60,10 @@ public ChainPoseExtractor(int size)
{
parents[i] = i - 1;
}
- SetParentIndices(parents);
+ Setup(parents);
}
- protected override Pose GetPoseAt(int index)
+ protected internal override Pose GetPoseAt(int index)
{
var rotation = Quaternion.identity;
var translation = offset + new Vector3(index, index, index);
@@ -74,7 +74,7 @@ protected override Pose GetPoseAt(int index)
};
}
- protected override Vector3 GetLinearVelocityAt(int index)
+ protected internal override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}
@@ -91,23 +91,77 @@ public void TestChain()
chain.UpdateModelSpacePoses();
chain.UpdateLocalSpacePoses();
- // Root transforms are currently always the identity.
- Assert.IsTrue(chain.ModelSpacePoses[0] == Pose.identity);
- Assert.IsTrue(chain.LocalSpacePoses[0] == Pose.identity);
- // Check the non-root transforms
- for (var i = 1; i < size; i++)
+ var modelPoseIndex = 0;
+ foreach (var modelSpace in chain.GetEnabledModelSpacePoses())
{
- var modelSpace = chain.ModelSpacePoses[i];
- var expectedModelTranslation = new Vector3(i, i, i);
- Assert.IsTrue(expectedModelTranslation == modelSpace.position);
+ if (modelPoseIndex == 0)
+ {
+ // Root transforms are currently always the identity.
+ Assert.IsTrue(modelSpace == Pose.identity);
+ }
+ else
+ {
+ var expectedModelTranslation = new Vector3(modelPoseIndex, modelPoseIndex, modelPoseIndex);
+ Assert.IsTrue(expectedModelTranslation == modelSpace.position);
- var localSpace = chain.LocalSpacePoses[i];
- var expectedLocalTranslation = new Vector3(1, 1, 1);
- Assert.IsTrue(expectedLocalTranslation == localSpace.position);
+ }
+ modelPoseIndex++;
}
+ Assert.AreEqual(size, modelPoseIndex);
+
+ var localPoseIndex = 0;
+ foreach (var localSpace in chain.GetEnabledLocalSpacePoses())
+ {
+ if (localPoseIndex == 0)
+ {
+ // Root transforms are currently always the identity.
+ Assert.IsTrue(localSpace == Pose.identity);
+ }
+ else
+ {
+ var expectedLocalTranslation = new Vector3(1, 1, 1);
+ Assert.IsTrue(expectedLocalTranslation == localSpace.position, $"{expectedLocalTranslation} != {localSpace.position}");
+ }
+
+ localPoseIndex++;
+ }
+ Assert.AreEqual(size, localPoseIndex);
}
+ class BadPoseExtractor : PoseExtractor
+ {
+ public BadPoseExtractor()
+ {
+ var size = 2;
+ var parents = new int[size];
+ // Parents are intentionally invalid - expect -1 at root
+ for (var i = 0; i < size; i++)
+ {
+ parents[i] = i;
+ }
+ Setup(parents);
+ }
+
+ protected internal override Pose GetPoseAt(int index)
+ {
+ return Pose.identity;
+ }
+
+ protected internal override Vector3 GetLinearVelocityAt(int index)
+ {
+ return Vector3.zero;
+ }
+ }
+
+ [Test]
+ public void TestExpectedRoot()
+ {
+ Assert.Throws(() =>
+ {
+ var bad = new BadPoseExtractor();
+ });
+ }
}
public class PoseExtensionTests
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
index a5d8b5bcb5..2d157b88e0 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
@@ -1,6 +1,7 @@
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;
+using UnityEditor;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
@@ -56,6 +57,63 @@ public void TestTwoBodies()
var poseExtractor = new RigidBodyPoseExtractor(rb1);
Assert.AreEqual(2, poseExtractor.NumPoses);
+
+ rb1.position = new Vector3(1, 0, 0);
+ rb1.rotation = Quaternion.Euler(0, 13.37f, 0);
+ rb1.velocity = new Vector3(2, 0, 0);
+
+ Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(0).position);
+ Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(0).rotation);
+ Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(0));
+ }
+
+ [Test]
+ public void TestTwoBodiesVirtualRoot()
+ {
+ // * virtualRoot
+ // * rootObj
+ // - rb1
+ // * go2
+ // - rb2
+ // - joint
+ var virtualRoot = new GameObject("I am vroot");
+
+ var rootObj = new GameObject();
+ var rb1 = rootObj.AddComponent();
+
+ var go2 = new GameObject();
+ var rb2 = go2.AddComponent();
+ go2.transform.SetParent(rootObj.transform);
+
+ var joint = go2.AddComponent();
+ joint.connectedBody = rb1;
+
+ var poseExtractor = new RigidBodyPoseExtractor(rb1, null, virtualRoot);
+ Assert.AreEqual(3, poseExtractor.NumPoses);
+
+ // "body" 0 has no parent
+ Assert.AreEqual(-1, poseExtractor.GetParentIndex(0));
+
+ // body 1 has parent 0
+ Assert.AreEqual(0, poseExtractor.GetParentIndex(1));
+
+ var virtualRootPos = new Vector3(0,2,0);
+ var virtualRootRot = Quaternion.Euler(0, 42, 0);
+ virtualRoot.transform.position = virtualRootPos;
+ virtualRoot.transform.rotation = virtualRootRot;
+
+ Assert.AreEqual(virtualRootPos, poseExtractor.GetPoseAt(0).position);
+ Assert.IsTrue(virtualRootRot == poseExtractor.GetPoseAt(0).rotation);
+ Assert.AreEqual(Vector3.zero, poseExtractor.GetLinearVelocityAt(0));
+
+ // Same as above test, but using index 1
+ rb1.position = new Vector3(1, 0, 0);
+ rb1.rotation = Quaternion.Euler(0, 13.37f, 0);
+ rb1.velocity = new Vector3(2, 0, 0);
+
+ Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(1).position);
+ Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(1).rotation);
+ Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(1));
}
}
}
diff --git a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
index 279fc7007d..a6c8b9f366 100644
--- a/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
@@ -45,14 +45,12 @@ public void TestSingleRigidbody()
var sensor = sensorComponent.CreateSensor();
sensor.Update();
- var expected = new[]
- {
- 0f, 0f, 0f, // ModelSpaceLinearVelocity
- 0f, 0f, 0f, // LocalSpaceTranslations
- 0f, 0f, 0f, 1f // LocalSpaceRotations
- };
- SensorTestHelper.CompareObservation(sensor, expected);
+
+ // The root body is ignored since it always generates identity values
+ // and there are no other bodies to generate observations.
+ var expected = new float[0];
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
+ SensorTestHelper.CompareObservation(sensor, expected);
}
[Test]
@@ -78,6 +76,7 @@ public void TestBodiesWithJoint()
var joint2 = leafGameObj.AddComponent();
joint2.connectedBody = middleRb;
+ var virtualRoot = new GameObject();
var sensorComponent = rootObj.AddComponent();
sensorComponent.RootBody = rootRb;
@@ -87,9 +86,12 @@ public void TestBodiesWithJoint()
UseLocalSpaceTranslations = true,
UseLocalSpaceLinearVelocity = true
};
+ sensorComponent.VirtualRoot = virtualRoot;
var sensor = sensorComponent.CreateSensor();
sensor.Update();
+
+ // Note that the VirtualRoot is ignored from the observations
var expected = new[]
{
// Model space
@@ -99,16 +101,15 @@ public void TestBodiesWithJoint()
// Local space
0f, 0f, 0f, // Root pos
- 0f, 0f, 0f, // Root vel
-
13.37f, 0f, 0f, // Attached pos
- -1f, 1f, 0f, // Attached vel
-
4.2f, 0f, 0f, // Leaf pos
+
+ 1f, 0f, 0f, // Root vel (relative to virtual root)
+ -1f, 1f, 0f, // Attached vel
0f, -1f, 1f // Leaf vel
};
- SensorTestHelper.CompareObservation(sensor, expected);
Assert.AreEqual(expected.Length, sensorComponent.GetObservationShape()[0]);
+ SensorTestHelper.CompareObservation(sensor, expected);
// Update the settings to only process joint observations
sensorComponent.Settings = new PhysicsSensorSettings
diff --git a/com.unity.ml-agents.extensions/package.json b/com.unity.ml-agents.extensions/package.json
index 3ce85c652a..8b43fa05e4 100644
--- a/com.unity.ml-agents.extensions/package.json
+++ b/com.unity.ml-agents.extensions/package.json
@@ -5,6 +5,6 @@
"unity": "2018.4",
"description": "A source-only package for new features based on ML-Agents",
"dependencies": {
- "com.unity.ml-agents": "1.2.0-preview"
+ "com.unity.ml-agents": "1.3.0-preview"
}
}
diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index 55aa465de6..987cdf12e4 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -11,19 +11,36 @@ and this project adheres to
### Major Changes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
-- The minimum supported python version for ml-agents-envs was changed to 3.6.1. (#4244)
+
+### Minor Changes
+#### com.unity.ml-agents (C#)
+#### ml-agents / ml-agents-envs / gym-unity (Python)
+
+### Bug Fixes
+#### com.unity.ml-agents (C#)
+#### ml-agents / ml-agents-envs / gym-unity (Python)
+
+## [1.3.0-preview] 2020-08-12
+
+### Major Changes
+#### com.unity.ml-agents (C#)
+#### ml-agents / ml-agents-envs / gym-unity (Python)
+- The minimum supported Python version for ml-agents-envs was changed to 3.6.1. (#4244)
- The interaction between EnvManager and TrainerController was changed; EnvManager.advance() was split into to stages,
and TrainerController now uses the results from the first stage to handle new behavior names. This change speeds up
Python training by approximately 5-10%. (#4259)
### Minor Changes
#### com.unity.ml-agents (C#)
-#### ml-agents / ml-agents-envs / gym-unity (Python)
- StatsSideChannel now stores multiple values per key. This means that multiple
calls to `StatsRecorder.Add()` with the same key in the same step will no
longer overwrite each other. (#4236)
+#### ml-agents / ml-agents-envs / gym-unity (Python)
+- The versions of `numpy` supported by ml-agents-envs were changed to disallow 1.19.0 or later. This was done to reflect
+a similar change in TensorFlow's requirements. (#4274)
- Model checkpoints are now also saved as .nn files during training. (#4127)
- Model checkpoint info is saved in TrainingStatus.json after training is concluded (#4127)
+- CSV statistics writer was removed (#4300).
### Bug Fixes
#### com.unity.ml-agents (C#)
@@ -75,7 +92,7 @@ argument. (#4203)
- `max_step` in the `TerminalStep` and `TerminalSteps` objects was renamed `interrupted`.
- `beta` and `epsilon` in `PPO` are no longer decayed by default but follow the same schedule as learning rate. (#3940)
- `get_behavior_names()` and `get_behavior_spec()` on UnityEnvironment were replaced by the `behavior_specs` property. (#3946)
-- The first version of the Unity Environment Registry (Experimental) has been released. More information [here](https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Unity-Environment-Registry.md)(#3967)
+- The first version of the Unity Environment Registry (Experimental) has been released. More information [here](https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Unity-Environment-Registry.md)(#3967)
- `use_visual` and `allow_multiple_visual_obs` in the `UnityToGymWrapper` constructor
were replaced by `allow_multiple_obs` which allows one or more visual observations and
vector observations to be used simultaneously. (#3981) Thank you @shakenes !
@@ -83,7 +100,7 @@ vector observations to be used simultaneously. (#3981) Thank you @shakenes !
into the main training configuration file. Note that this means training
configuration files are now environment-specific. (#3791)
- The format for trainer configuration has changed, and the "default" behavior has been deprecated.
- See the [Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Migrating.md) for more details. (#3936)
+ See the [Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Migrating.md) for more details. (#3936)
- Training artifacts (trained models, summaries) are now found in the `results/`
directory. (#3829)
- When using Curriculum, the current lesson will resume if training is quit and resumed. As such,
diff --git a/com.unity.ml-agents/Documentation~/com.unity.ml-agents.md b/com.unity.ml-agents/Documentation~/com.unity.ml-agents.md
index 227a1eb2e7..4f79cde27c 100755
--- a/com.unity.ml-agents/Documentation~/com.unity.ml-agents.md
+++ b/com.unity.ml-agents/Documentation~/com.unity.ml-agents.md
@@ -114,7 +114,7 @@ a number of ways to [connect with us] including our [ML-Agents Forum].
[unity ML-Agents Toolkit]: https://github.com/Unity-Technologies/ml-agents
[unity inference engine]: https://docs.unity3d.com/Packages/com.unity.barracuda@latest/index.html
[package manager documentation]: https://docs.unity3d.com/Manual/upm-ui-install.html
-[installation instructions]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Installation.md
+[installation instructions]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Installation.md
[github repository]: https://github.com/Unity-Technologies/ml-agents
[python package]: https://github.com/Unity-Technologies/ml-agents
[execution order of event functions]: https://docs.unity3d.com/Manual/ExecutionOrder.html
diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs
index dc1d586923..8b0b1055f7 100644
--- a/com.unity.ml-agents/Runtime/Academy.cs
+++ b/com.unity.ml-agents/Runtime/Academy.cs
@@ -19,7 +19,7 @@
* API. For more information on each of these entities, in addition to how to
* set-up a learning environment and train the behavior of characters in a
* Unity scene, please browse our documentation pages on GitHub:
- * https://github.com/Unity-Technologies/ml-agents/tree/release_4_docs/docs/
+ * https://github.com/Unity-Technologies/ml-agents/tree/release_5_docs/docs/
*/
namespace Unity.MLAgents
@@ -51,7 +51,7 @@ void FixedUpdate()
/// fall back to inference or heuristic decisions. (You can also set agents to always use
/// inference or heuristics.)
///
- [HelpURL("https://github.com/Unity-Technologies/ml-agents/tree/release_4_docs/" +
+ [HelpURL("https://github.com/Unity-Technologies/ml-agents/tree/release_5_docs/" +
"docs/Learning-Environment-Design.md")]
public class Academy : IDisposable
{
@@ -68,7 +68,7 @@ public class Academy : IDisposable
/// Unity package version of com.unity.ml-agents.
/// This must match the version string in package.json and is checked in a unit test.
///
- internal const string k_PackageVersion = "1.2.0-preview";
+ internal const string k_PackageVersion = "1.3.0-preview";
const int k_EditorTrainingPort = 5004;
diff --git a/com.unity.ml-agents/Runtime/Actuators.meta b/com.unity.ml-agents/Runtime/Actuators.meta
new file mode 100644
index 0000000000..96bbfb99b3
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: 26733e59183b6479e8f0e892a8bf09a4
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
new file mode 100644
index 0000000000..feb06a708d
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
@@ -0,0 +1,181 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Diagnostics;
+
+namespace Unity.MLAgents.Actuators
+{
+ ///
+ /// ActionSegment{T} is a data structure that allows access to a segment of an underlying array
+ /// in order to avoid the copying and allocation of sub-arrays. The segment is defined by
+ /// the offset into the original array, and an length.
+ ///
+ /// The type of object stored in the underlying
+ internal readonly struct ActionSegment : IEnumerable, IEquatable>
+ where T : struct
+ {
+ ///
+ /// The zero-based offset into the original array at which this segment starts.
+ ///
+ public readonly int Offset;
+
+ ///
+ /// The number of items this segment can access in the underlying array.
+ ///
+ public readonly int Length;
+
+ ///
+ /// An Empty segment which has an offset of 0, a Length of 0, and it's underlying array
+ /// is also empty.
+ ///
+ public static ActionSegment Empty = new ActionSegment(System.Array.Empty(), 0, 0);
+
+ static void CheckParameters(T[] actionArray, int offset, int length)
+ {
+#if DEBUG
+ if (offset + length > actionArray.Length)
+ {
+ throw new ArgumentOutOfRangeException(nameof(offset),
+ $"Arguments offset: {offset} and length: {length} " +
+ $"are out of bounds of actionArray: {actionArray.Length}.");
+ }
+#endif
+ }
+
+ ///
+ /// Construct an with an underlying array
+ /// and offset, and a length.
+ ///
+ /// The underlying array which this segment has a view into
+ /// The zero-based offset into the underlying array.
+ /// The length of the segment.
+ public ActionSegment(T[] actionArray, int offset, int length)
+ {
+ CheckParameters(actionArray, offset, length);
+ Array = actionArray;
+ Offset = offset;
+ Length = length;
+ }
+
+ ///
+ /// Get the underlying of this segment.
+ ///
+ public T[] Array { get; }
+
+ ///
+ /// Allows access to the underlying array using array syntax.
+ ///
+ /// The zero-based index of the segment.
+ /// Thrown when the index is less than 0 or
+ /// greater than or equal to
+ public T this[int index]
+ {
+ get
+ {
+ if (index < 0 || index > Length)
+ {
+ throw new IndexOutOfRangeException($"Index out of bounds, expected a number between 0 and {Length}");
+ }
+ return Array[Offset + index];
+ }
+ }
+
+ ///
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return new Enumerator(this);
+ }
+
+ ///
+ public IEnumerator GetEnumerator()
+ {
+ return new Enumerator(this);
+ }
+
+ ///
+ public override bool Equals(object obj)
+ {
+ if (!(obj is ActionSegment))
+ {
+ return false;
+ }
+ return Equals((ActionSegment)obj);
+ }
+
+ ///
+ public bool Equals(ActionSegment other)
+ {
+ return Offset == other.Offset && Length == other.Length && Equals(Array, other.Array);
+ }
+
+ ///
+ public override int GetHashCode()
+ {
+ unchecked
+ {
+ var hashCode = Offset;
+ hashCode = (hashCode * 397) ^ Length;
+ hashCode = (hashCode * 397) ^ (Array != null ? Array.GetHashCode() : 0);
+ return hashCode;
+ }
+ }
+
+ ///
+ /// A private for the value type which follows its
+ /// rules of being a view into an underlying .
+ ///
+ struct Enumerator : IEnumerator
+ {
+ readonly T[] m_Array;
+ readonly int m_Start;
+ readonly int m_End; // cache Offset + Count, since it's a little slow
+ int m_Current;
+
+ internal Enumerator(ActionSegment arraySegment)
+ {
+ Debug.Assert(arraySegment.Array != null);
+ Debug.Assert(arraySegment.Offset >= 0);
+ Debug.Assert(arraySegment.Length >= 0);
+ Debug.Assert(arraySegment.Offset + arraySegment.Length <= arraySegment.Array.Length);
+
+ m_Array = arraySegment.Array;
+ m_Start = arraySegment.Offset;
+ m_End = arraySegment.Offset + arraySegment.Length;
+ m_Current = arraySegment.Offset - 1;
+ }
+
+ public bool MoveNext()
+ {
+ if (m_Current < m_End)
+ {
+ m_Current++;
+ return m_Current < m_End;
+ }
+ return false;
+ }
+
+ public T Current
+ {
+ get
+ {
+ if (m_Current < m_Start)
+ throw new InvalidOperationException("Enumerator not started.");
+ if (m_Current >= m_End)
+ throw new InvalidOperationException("Enumerator has reached the end already.");
+ return m_Array[m_Current];
+ }
+ }
+
+ object IEnumerator.Current => Current;
+
+ void IEnumerator.Reset()
+ {
+ m_Current = m_Start - 1;
+ }
+
+ public void Dispose()
+ {
+ }
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs.meta b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs.meta
new file mode 100644
index 0000000000..8e08ed0a4a
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 4fa1432c1ba3460caaa84303a9011ef2
+timeCreated: 1595869823
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs b/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
new file mode 100644
index 0000000000..fbee0c4476
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
@@ -0,0 +1,75 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Unity.MLAgents.Policies;
+
+namespace Unity.MLAgents.Actuators
+{
+ ///
+ /// Defines the structure of an Action Space to be used by the Actuator system.
+ ///
+ internal readonly struct ActionSpec
+ {
+
+ ///
+ /// An array of branch sizes for our action space.
+ ///
+ /// For an IActuator that uses a Discrete , the number of
+ /// branches is the Length of the Array and each index contains the branch size.
+ /// The cumulative sum of the total number of discrete actions can be retrieved
+ /// by the property.
+ ///
+ /// For an IActuator with a Continuous it will be null.
+ ///
+ public readonly int[] BranchSizes;
+
+ ///
+ /// The number of actions for a Continuous .
+ ///
+ public int NumContinuousActions { get; }
+
+ ///
+ /// The number of branches for a Discrete .
+ ///
+ public int NumDiscreteActions { get; }
+
+ ///
+ /// Get the total number of Discrete Actions that can be taken by calculating the Sum
+ /// of all of the Discrete Action branch sizes.
+ ///
+ public int SumOfDiscreteBranchSizes { get; }
+
+ ///
+ /// Creates a Continuous with the number of actions available.
+ ///
+ /// The number of actions available.
+ /// An Continuous ActionSpec initialized with the number of actions available.
+ public static ActionSpec MakeContinuous(int numActions)
+ {
+ var actuatorSpace = new ActionSpec(numActions, 0);
+ return actuatorSpace;
+ }
+
+ ///
+ /// Creates a Discrete with the array of branch sizes that
+ /// represents the action space.
+ ///
+ /// The array of branch sizes for the discrete action space. Each index
+ /// contains the number of actions available for that branch.
+ /// An Discrete ActionSpec initialized with the array of branch sizes.
+ public static ActionSpec MakeDiscrete(int[] branchSizes)
+ {
+ var numActions = branchSizes.Length;
+ var actuatorSpace = new ActionSpec(0, numActions, branchSizes);
+ return actuatorSpace;
+ }
+
+ ActionSpec(int numContinuousActions, int numDiscreteActions, int[] branchSizes = null)
+ {
+ NumContinuousActions = numContinuousActions;
+ NumDiscreteActions = numDiscreteActions;
+ BranchSizes = branchSizes;
+ SumOfDiscreteBranchSizes = branchSizes?.Sum() ?? 0;
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs.meta b/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs.meta
new file mode 100644
index 0000000000..a442a91a5e
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: ecdd6deefba1416ca149fe09d2a5afd8
+timeCreated: 1595892361
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs b/com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs
new file mode 100644
index 0000000000..48fbba501e
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs
@@ -0,0 +1,17 @@
+using UnityEngine;
+
+namespace Unity.MLAgents.Actuators
+{
+ ///
+ /// Editor components for creating Actuators. Generally an IActuator component should
+ /// have a corresponding ActuatorComponent.
+ ///
+ internal abstract class ActuatorComponent : MonoBehaviour
+ {
+ ///
+ /// Create the IActuator. This is called by the Agent when it is initialized.
+ ///
+ /// Created IActuator object.
+ public abstract IActuator CreateActuator();
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs.meta b/com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs.meta
new file mode 100644
index 0000000000..1b7a643ed1
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 77cefae5f6d841be9ff80b41293d271b
+timeCreated: 1593017318
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs b/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs
new file mode 100644
index 0000000000..4904aded7f
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs
@@ -0,0 +1,150 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace Unity.MLAgents.Actuators
+{
+ ///
+ /// Implementation of IDiscreteActionMask that allows writing to the action mask from an .
+ ///
+ internal class ActuatorDiscreteActionMask : IDiscreteActionMask
+ {
+ /// When using discrete control, is the starting indices of the actions
+ /// when all the branches are concatenated with each other.
+ int[] m_StartingActionIndices;
+
+ int[] m_BranchSizes;
+
+ bool[] m_CurrentMask;
+
+ IList m_Actuators;
+
+ readonly int m_SumOfDiscreteBranchSizes;
+ readonly int m_NumBranches;
+
+ ///
+ /// The offset into the branches array that is used when actuators are writing to the action mask.
+ ///
+ public int CurrentBranchOffset { get; set; }
+
+ internal ActuatorDiscreteActionMask(IList actuators, int sumOfDiscreteBranchSizes, int numBranches)
+ {
+ m_Actuators = actuators;
+ m_SumOfDiscreteBranchSizes = sumOfDiscreteBranchSizes;
+ m_NumBranches = numBranches;
+ }
+
+ ///
+ public void WriteMask(int branch, IEnumerable actionIndices)
+ {
+ LazyInitialize();
+
+ // Perform the masking
+ foreach (var actionIndex in actionIndices)
+ {
+#if DEBUG
+ if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
+ {
+ throw new UnityAgentsException(
+ "Invalid Action Masking: Action Mask is too large for specified branch.");
+ }
+#endif
+ m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = true;
+ }
+ }
+
+ void LazyInitialize()
+ {
+ if (m_BranchSizes == null)
+ {
+ m_BranchSizes = new int[m_NumBranches];
+ var start = 0;
+ for (var i = 0; i < m_Actuators.Count; i++)
+ {
+ var actuator = m_Actuators[i];
+ var branchSizes = actuator.ActionSpec.BranchSizes;
+ Array.Copy(branchSizes, 0, m_BranchSizes, start, branchSizes.Length);
+ start += branchSizes.Length;
+ }
+ }
+
+ // By default, the masks are null. If we want to specify a new mask, we initialize
+ // the actionMasks with trues.
+ if (m_CurrentMask == null)
+ {
+ m_CurrentMask = new bool[m_SumOfDiscreteBranchSizes];
+ }
+
+ // If this is the first time the masked actions are used, we generate the starting
+ // indices for each branch.
+ if (m_StartingActionIndices == null)
+ {
+ m_StartingActionIndices = Utilities.CumSum(m_BranchSizes);
+ }
+ }
+
+ ///
+ public bool[] GetMask()
+ {
+#if DEBUG
+ if (m_CurrentMask != null)
+ {
+ AssertMask();
+ }
+#endif
+ return m_CurrentMask;
+ }
+
+ ///
+ /// Makes sure that the current mask is usable.
+ ///
+ void AssertMask()
+ {
+#if DEBUG
+ for (var branchIndex = 0; branchIndex < m_NumBranches; branchIndex++)
+ {
+ if (AreAllActionsMasked(branchIndex))
+ {
+ throw new UnityAgentsException(
+ "Invalid Action Masking : All the actions of branch " + branchIndex +
+ " are masked.");
+ }
+ }
+#endif
+ }
+
+ ///
+ /// Resets the current mask for an agent.
+ ///
+ public void ResetMask()
+ {
+ if (m_CurrentMask != null)
+ {
+ Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length);
+ }
+ }
+
+ ///
+ /// Checks if all the actions in the input branch are masked.
+ ///
+ /// The index of the branch to check.
+ /// True if all the actions of the branch are masked.
+ bool AreAllActionsMasked(int branch)
+ {
+ if (m_CurrentMask == null)
+ {
+ return false;
+ }
+ var start = m_StartingActionIndices[branch];
+ var end = m_StartingActionIndices[branch + 1];
+ for (var i = start; i < end; i++)
+ {
+ if (!m_CurrentMask[i])
+ {
+ return false;
+ }
+ }
+ return true;
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs.meta b/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs.meta
new file mode 100644
index 0000000000..09aa4784b0
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: d2a19e2f43fd4637a38d42b2a5f989f3
+timeCreated: 1595459316
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
new file mode 100644
index 0000000000..a1b953118f
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
@@ -0,0 +1,415 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using UnityEngine;
+
+namespace Unity.MLAgents.Actuators
+{
+ ///
+ /// A class that manages the delegation of events, action buffers, and action mask for a list of IActuators.
+ ///
+ internal class ActuatorManager : IList
+ {
+ // IActuators managed by this object.
+ IList m_Actuators;
+
+ // An implementation of IDiscreteActionMask that allows for writing to it based on an offset.
+ ActuatorDiscreteActionMask m_DiscreteActionMask;
+
+ ///
+ /// Flag used to check if our IActuators are ready for execution.
+ ///
+ ///
+ bool m_ReadyForExecution;
+
+ ///
+ /// The sum of all of the discrete branches for all of the s in this manager.
+ ///
+ internal int SumOfDiscreteBranchSizes { get; private set; }
+
+ ///
+ /// The number of the discrete branches for all of the s in this manager.
+ ///
+ internal int NumDiscreteActions { get; private set; }
+
+ ///
+ /// The number of continuous actions for all of the s in this manager.
+ ///
+ internal int NumContinuousActions { get; private set; }
+
+ ///
+ /// Returns the total actions which is calculated by + .
+ ///
+ public int TotalNumberOfActions => NumContinuousActions + NumDiscreteActions;
+
+ ///
+ /// Gets the managed by this object.
+ ///
+ public ActuatorDiscreteActionMask DiscreteActionMask => m_DiscreteActionMask;
+
+ ///
+ /// Returns the previously stored actions for the actuators in this list.
+ ///
+ public float[] StoredContinuousActions { get; private set; }
+
+ ///
+ /// Returns the previously stored actions for the actuators in this list.
+ ///
+ public int[] StoredDiscreteActions { get; private set; }
+
+ ///
+ /// Create an ActuatorList with a preset capacity.
+ ///
+ /// The capacity of the list to create.
+ public ActuatorManager(int capacity = 0)
+ {
+ m_Actuators = new List(capacity);
+ }
+
+ ///
+ ///
+ ///
+ void ReadyActuatorsForExecution()
+ {
+ ReadyActuatorsForExecution(m_Actuators, NumContinuousActions, SumOfDiscreteBranchSizes,
+ NumDiscreteActions);
+ }
+
+ ///
+ /// This method validates that all s have unique names and equivalent action space types
+ /// if the `DEBUG` preprocessor macro is defined, and allocates the appropriate buffers to manage the actions for
+ /// all of the s that may live on a particular object.
+ ///
+ /// The list of actuators to validate and allocate buffers for.
+ /// The total number of continuous actions for all of the actuators.
+ /// The total sum of the discrete branches for all of the actuators in order
+ /// to be able to allocate an .
+ /// The number of discrete branches for all of the actuators.
+ internal void ReadyActuatorsForExecution(IList actuators, int numContinuousActions, int sumOfDiscreteBranches, int numDiscreteBranches)
+ {
+ if (m_ReadyForExecution)
+ {
+ return;
+ }
+#if DEBUG
+ // Make sure the names are actually unique
+ // Make sure all Actuators have the same SpaceType
+ ValidateActuators();
+#endif
+
+ // Sort the Actuators by name to ensure determinism
+ SortActuators();
+ StoredContinuousActions = numContinuousActions == 0 ? Array.Empty() : new float[numContinuousActions];
+ StoredDiscreteActions = numDiscreteBranches == 0 ? Array.Empty() : new int[numDiscreteBranches];
+ m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches);
+ m_ReadyForExecution = true;
+ }
+
+ ///
+ /// Updates the local action buffer with the action buffer passed in. If the buffer
+ /// passed in is null, the local action buffer will be cleared.
+ ///
+ /// The action buffer which contains all of the
+ /// continuous actions for the IActuators in this list.
+ /// The action buffer which contains all of the
+ /// discrete actions for the IActuators in this list.
+ public void UpdateActions(float[] continuousActionBuffer, int[] discreteActionBuffer)
+ {
+ ReadyActuatorsForExecution();
+ UpdateActionArray(continuousActionBuffer, StoredContinuousActions);
+ UpdateActionArray(discreteActionBuffer, StoredDiscreteActions);
+ }
+
+ static void UpdateActionArray(T[] sourceActionBuffer, T[] destination)
+ {
+ if (sourceActionBuffer == null || sourceActionBuffer.Length == 0)
+ {
+ Array.Clear(destination, 0, destination.Length);
+ }
+ else
+ {
+ Debug.Assert(sourceActionBuffer.Length == destination.Length,
+ $"sourceActionBuffer:{sourceActionBuffer.Length} is a different" +
+ $" size than destination: {destination.Length}.");
+
+ Array.Copy(sourceActionBuffer, destination, destination.Length);
+ }
+ }
+
+ ///
+ /// This method will trigger the writing to the by all of the actuators
+ /// managed by this object.
+ ///
+ public void WriteActionMask()
+ {
+ ReadyActuatorsForExecution();
+ m_DiscreteActionMask.ResetMask();
+ var offset = 0;
+ for (var i = 0; i < m_Actuators.Count; i++)
+ {
+ var actuator = m_Actuators[i];
+ m_DiscreteActionMask.CurrentBranchOffset = offset;
+ actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
+ offset += actuator.ActionSpec.NumDiscreteActions;
+ }
+ }
+
+ ///
+ /// Iterates through all of the IActuators in this list and calls their
+ /// method on them with the appropriate
+ /// s depending on their .
+ ///
+ public void ExecuteActions()
+ {
+ ReadyActuatorsForExecution();
+ var continuousStart = 0;
+ var discreteStart = 0;
+ for (var i = 0; i < m_Actuators.Count; i++)
+ {
+ var actuator = m_Actuators[i];
+ var numContinuousActions = actuator.ActionSpec.NumContinuousActions;
+ var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions;
+
+ var continuousActions = ActionSegment.Empty;
+ if (numContinuousActions > 0)
+ {
+ continuousActions = new ActionSegment(StoredContinuousActions,
+ continuousStart,
+ numContinuousActions);
+ }
+
+ var discreteActions = ActionSegment.Empty;
+ if (numDiscreteActions > 0)
+ {
+ discreteActions = new ActionSegment(StoredDiscreteActions,
+ discreteStart,
+ numDiscreteActions);
+ }
+
+ actuator.OnActionReceived(new ActionBuffers(continuousActions, discreteActions));
+ continuousStart += numContinuousActions;
+ discreteStart += numDiscreteActions;
+ }
+ }
+
+ ///
+ /// Resets the and buffers to be all
+ /// zeros and calls on each managed by this object.
+ ///
+ public void ResetData()
+ {
+ if (!m_ReadyForExecution)
+ {
+ return;
+ }
+ Array.Clear(StoredContinuousActions, 0, StoredContinuousActions.Length);
+ Array.Clear(StoredDiscreteActions, 0, StoredDiscreteActions.Length);
+ for (var i = 0; i < m_Actuators.Count; i++)
+ {
+ m_Actuators[i].ResetData();
+ }
+ }
+
+
+ ///
+ /// Sorts the s according to their value.
+ ///
+ void SortActuators()
+ {
+ ((List)m_Actuators).Sort((x,
+ y) => x.Name
+ .CompareTo(y.Name));
+ }
+
+ ///
+ /// Validates that the IActuators managed by this object have unique names and equivalent action space types.
+ /// Each Actuator needs to have a unique name in order for this object to ensure that the storage of action
+ /// buffers, and execution of Actuators remains deterministic across different sessions of running.
+ ///
+ void ValidateActuators()
+ {
+ for (var i = 0; i < m_Actuators.Count - 1; i++)
+ {
+ Debug.Assert(
+ !m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name),
+ "Actuator names must be unique.");
+ var first = m_Actuators[i].ActionSpec;
+ var second = m_Actuators[i + 1].ActionSpec;
+ Debug.Assert(first.NumContinuousActions > 0 == second.NumContinuousActions > 0,
+ "Actuators on the same Agent must have the same action SpaceType.");
+ }
+ }
+
+ ///
+ /// Helper method to update bookkeeping items around buffer management for actuators added to this object.
+ ///
+ /// The IActuator to keep bookkeeping for.
+ void AddToBufferSizes(IActuator actuatorItem)
+ {
+ if (actuatorItem == null)
+ {
+ return;
+ }
+
+ NumContinuousActions += actuatorItem.ActionSpec.NumContinuousActions;
+ NumDiscreteActions += actuatorItem.ActionSpec.NumDiscreteActions;
+ SumOfDiscreteBranchSizes += actuatorItem.ActionSpec.SumOfDiscreteBranchSizes;
+ }
+
+ ///
+ /// Helper method to update bookkeeping items around buffer management for actuators removed from this object.
+ ///
+ /// The IActuator to keep bookkeeping for.
+ void SubtractFromBufferSize(IActuator actuatorItem)
+ {
+ if (actuatorItem == null)
+ {
+ return;
+ }
+
+ NumContinuousActions -= actuatorItem.ActionSpec.NumContinuousActions;
+ NumDiscreteActions -= actuatorItem.ActionSpec.NumDiscreteActions;
+ SumOfDiscreteBranchSizes -= actuatorItem.ActionSpec.SumOfDiscreteBranchSizes;
+ }
+
+ ///
+ /// Sets all of the bookkeeping items back to 0.
+ ///
+ void ClearBufferSizes()
+ {
+ NumContinuousActions = NumDiscreteActions = SumOfDiscreteBranchSizes = 0;
+ }
+
+ /*********************************************************************************
+ * IList implementation that delegates to m_Actuators List. *
+ *********************************************************************************/
+
+ ///
+ ///
+ ///
+ public IEnumerator GetEnumerator()
+ {
+ return m_Actuators.GetEnumerator();
+ }
+
+ ///
+ ///
+ ///
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return ((IEnumerable)m_Actuators).GetEnumerator();
+ }
+
+ ///
+ ///
+ ///
+ ///
+ public void Add(IActuator item)
+ {
+ Debug.Assert(m_ReadyForExecution == false,
+ "Cannot add to the ActuatorManager after its buffers have been initialized");
+ m_Actuators.Add(item);
+ AddToBufferSizes(item);
+ }
+
+ ///
+ ///
+ ///
+ public void Clear()
+ {
+ Debug.Assert(m_ReadyForExecution == false,
+ "Cannot clear the ActuatorManager after its buffers have been initialized");
+ m_Actuators.Clear();
+ ClearBufferSizes();
+ }
+
+ ///
+ ///
+ ///
+ public bool Contains(IActuator item)
+ {
+ return m_Actuators.Contains(item);
+ }
+
+ ///
+ ///
+ ///
+ public void CopyTo(IActuator[] array, int arrayIndex)
+ {
+ m_Actuators.CopyTo(array, arrayIndex);
+ }
+
+ ///
+ ///
+ ///
+ public bool Remove(IActuator item)
+ {
+ Debug.Assert(m_ReadyForExecution == false,
+ "Cannot remove from the ActuatorManager after its buffers have been initialized");
+ if (m_Actuators.Remove(item))
+ {
+ SubtractFromBufferSize(item);
+ return true;
+ }
+ return false;
+ }
+
+ ///
+ ///
+ ///
+ public int Count => m_Actuators.Count;
+
+ ///
+ ///
+ ///
+ public bool IsReadOnly => m_Actuators.IsReadOnly;
+
+ ///
+ ///
+ ///
+ public int IndexOf(IActuator item)
+ {
+ return m_Actuators.IndexOf(item);
+ }
+
+ ///
+ ///
+ ///
+ public void Insert(int index, IActuator item)
+ {
+ Debug.Assert(m_ReadyForExecution == false,
+ "Cannot insert into the ActuatorManager after its buffers have been initialized");
+ m_Actuators.Insert(index, item);
+ AddToBufferSizes(item);
+ }
+
+ ///
+ ///
+ ///
+ public void RemoveAt(int index)
+ {
+ Debug.Assert(m_ReadyForExecution == false,
+ "Cannot remove from the ActuatorManager after its buffers have been initialized");
+ var actuator = m_Actuators[index];
+ SubtractFromBufferSize(actuator);
+ m_Actuators.RemoveAt(index);
+ }
+
+ ///
+ ///
+ ///
+ public IActuator this[int index]
+ {
+ get => m_Actuators[index];
+ set
+ {
+ Debug.Assert(m_ReadyForExecution == false,
+ "Cannot modify the ActuatorManager after its buffers have been initialized");
+ var old = m_Actuators[index];
+ SubtractFromBufferSize(old);
+ m_Actuators[index] = value;
+ AddToBufferSizes(value);
+ }
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs.meta b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs.meta
new file mode 100644
index 0000000000..aa56b5ca9f
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 7bb5b1e3779d4342a8e70f6e3c1d67cc
+timeCreated: 1593031463
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
new file mode 100644
index 0000000000..4e2a251f10
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
@@ -0,0 +1,101 @@
+using System;
+using System.Linq;
+
+namespace Unity.MLAgents.Actuators
+{
+ ///
+ /// A structure that wraps the s for a particular and is
+ /// used when is called.
+ ///
+ internal readonly struct ActionBuffers
+ {
+ ///
+ /// An empty action buffer.
+ ///
+ public static ActionBuffers Empty = new ActionBuffers(ActionSegment.Empty, ActionSegment.Empty);
+
+ ///
+ /// Holds the Continuous to be used by an .
+ ///
+ public ActionSegment ContinuousActions { get; }
+
+ ///
+ /// Holds the Discrete to be used by an .
+ ///
+ public ActionSegment DiscreteActions { get; }
+
+ ///
+ /// Construct an instance with the continuous and discrete actions that will
+ /// be used.
+ ///
+ /// The continuous actions to send to an .
+ /// The discrete actions to send to an .
+ public ActionBuffers(ActionSegment continuousActions, ActionSegment discreteActions)
+ {
+ ContinuousActions = continuousActions;
+ DiscreteActions = discreteActions;
+ }
+
+ ///
+ public override bool Equals(object obj)
+ {
+ if (!(obj is ActionBuffers))
+ {
+ return false;
+ }
+
+ var ab = (ActionBuffers)obj;
+ return ab.ContinuousActions.SequenceEqual(ContinuousActions) &&
+ ab.DiscreteActions.SequenceEqual(DiscreteActions);
+ }
+
+ ///
+ public override int GetHashCode()
+ {
+ unchecked
+ {
+ return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode();
+ }
+ }
+ }
+
+ ///
+ /// An interface that describes an object that can receive actions from a Reinforcement Learning network.
+ ///
+ internal interface IActionReceiver
+ {
+
+ ///
+ /// The specification of the Action space for this IActionReceiver.
+ ///
+ ///
+ ActionSpec ActionSpec { get; }
+
+ ///
+ /// Method called in order too allow object to execute actions based on the
+ /// contents. The structure of the contents in the
+ /// are defined by the .
+ ///
+ /// The data structure containing the action buffers for this object.
+ void OnActionReceived(ActionBuffers actionBuffers);
+
+ ///
+ /// Implement `WriteDiscreteActionMask()` to modify the masks for discrete
+ /// actions. When using discrete actions, the agent will not perform the masked
+ /// action.
+ ///
+ ///
+ /// The action mask for the agent.
+ ///
+ ///
+ /// When using Discrete Control, you can prevent the Agent from using a certain
+ /// action by masking it with .
+ ///
+ /// See [Agents - Actions] for more information on masking actions.
+ ///
+ /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md#actions
+ ///
+ ///
+ void WriteDiscreteActionMask(IDiscreteActionMask actionMask);
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs.meta b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs.meta
new file mode 100644
index 0000000000..b14a69d21c
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: b25a5b3027c9476ea1a310241be0f10f
+timeCreated: 1594756775
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Actuators/IActuator.cs b/com.unity.ml-agents/Runtime/Actuators/IActuator.cs
new file mode 100644
index 0000000000..eedb940a36
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/IActuator.cs
@@ -0,0 +1,21 @@
+using System;
+using UnityEngine;
+
+namespace Unity.MLAgents.Actuators
+{
+ ///
+ /// Abstraction that facilitates the execution of actions.
+ ///
+ internal interface IActuator : IActionReceiver
+ {
+ int TotalNumberOfActions { get; }
+
+ ///
+ /// Gets the name of this IActuator which will be used to sort it.
+ ///
+ ///
+ string Name { get; }
+
+ void ResetData();
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Actuators/IActuator.cs.meta b/com.unity.ml-agents/Runtime/Actuators/IActuator.cs.meta
new file mode 100644
index 0000000000..4fd0d172ca
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/IActuator.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 780d7f0a675f44bfa784b370025b51c3
+timeCreated: 1592848317
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
new file mode 100644
index 0000000000..7cb0e99f72
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
@@ -0,0 +1,38 @@
+using System.Collections.Generic;
+
+namespace Unity.MLAgents.Actuators
+{
+ ///
+ /// Interface for writing a mask to disable discrete actions for agents for the next decision.
+ ///
+ internal interface IDiscreteActionMask
+ {
+ ///
+ /// Modifies an action mask for discrete control agents.
+ ///
+ ///
+ /// When used, the agent will not be able to perform the actions passed as argument
+ /// at the next decision for the specified action branch. The actionIndices correspond
+ /// to the action options the agent will be unable to perform.
+ ///
+ /// See [Agents - Actions] for more information on masking actions.
+ ///
+ /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_2_docs/docs/Learning-Environment-Design-Agents.md#actions
+ ///
+ /// The branch for which the actions will be masked.
+ /// The indices of the masked actions.
+ void WriteMask(int branch, IEnumerable actionIndices);
+
+ ///
+ /// Get the current mask for an agent.
+ ///
+ /// A mask for the agent. A boolean array of length equal to the total number of
+ /// actions.
+ bool[] GetMask();
+
+ ///
+ /// Resets the current mask for an agent.
+ ///
+ void ResetMask();
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs.meta b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs.meta
new file mode 100644
index 0000000000..ebfa10158f
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 1bc4e4b71bf4470789488fab2ee65388
+timeCreated: 1595369065
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
new file mode 100644
index 0000000000..e2635c4164
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
@@ -0,0 +1,72 @@
+using System;
+
+using Unity.MLAgents.Policies;
+
+namespace Unity.MLAgents.Actuators
+{
+ internal class VectorActuator : IActuator
+ {
+ IActionReceiver m_ActionReceiver;
+
+ ActionBuffers m_ActionBuffers;
+ internal ActionBuffers ActionBuffers
+ {
+ get => m_ActionBuffers;
+ private set => m_ActionBuffers = value;
+ }
+
+ public VectorActuator(IActionReceiver actionReceiver,
+ int[] vectorActionSize,
+ SpaceType spaceType,
+ string name = "VectorActuator")
+ {
+ m_ActionReceiver = actionReceiver;
+ string suffix;
+ switch (spaceType)
+ {
+ case SpaceType.Continuous:
+ ActionSpec = ActionSpec.MakeContinuous(vectorActionSize[0]);
+ suffix = "-Continuous";
+ break;
+ case SpaceType.Discrete:
+ ActionSpec = ActionSpec.MakeDiscrete(vectorActionSize);
+ suffix = "-Discrete";
+ break;
+ default:
+ throw new ArgumentOutOfRangeException(nameof(spaceType),
+ spaceType,
+ "Unknown enum value.");
+ }
+ Name = name + suffix;
+ }
+
+ public void ResetData()
+ {
+ m_ActionBuffers = ActionBuffers.Empty;
+ }
+
+ public void OnActionReceived(ActionBuffers actionBuffers)
+ {
+ ActionBuffers = actionBuffers;
+ m_ActionReceiver.OnActionReceived(ActionBuffers);
+ }
+
+ public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
+ {
+ m_ActionReceiver.WriteDiscreteActionMask(actionMask);
+ }
+
+ ///
+ /// Returns the number of discrete branches + the number of continuous actions.
+ ///
+ public int TotalNumberOfActions => ActionSpec.NumContinuousActions +
+ ActionSpec.NumDiscreteActions;
+
+ ///
+ ///
+ ///
+ public ActionSpec ActionSpec { get; }
+
+ public string Name { get; }
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs.meta b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs.meta
new file mode 100644
index 0000000000..6e9f68b913
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: ff7a3292c0b24b23b3f1c0eeb690ec4c
+timeCreated: 1593023833
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs
index a12f1761d8..cad4575f0b 100644
--- a/com.unity.ml-agents/Runtime/Agent.cs
+++ b/com.unity.ml-agents/Runtime/Agent.cs
@@ -145,13 +145,13 @@ internal struct AgentAction
/// [OnDisable()]: https://docs.unity3d.com/ScriptReference/MonoBehaviour.OnDisable.html]
/// [OnBeforeSerialize()]: https://docs.unity3d.com/ScriptReference/MonoBehaviour.OnBeforeSerialize.html
/// [OnAfterSerialize()]: https://docs.unity3d.com/ScriptReference/MonoBehaviour.OnAfterSerialize.html
- /// [Agents]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md
- /// [Reinforcement Learning in Unity]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design.md
+ /// [Agents]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md
+ /// [Reinforcement Learning in Unity]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design.md
/// [Unity ML-Agents Toolkit]: https://github.com/Unity-Technologies/ml-agents
- /// [Unity ML-Agents Toolkit manual]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Readme.md
+ /// [Unity ML-Agents Toolkit manual]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Readme.md
///
///
- [HelpURL("https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/" +
+ [HelpURL("https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/" +
"docs/Learning-Environment-Design-Agents.md")]
[Serializable]
[RequireComponent(typeof(BehaviorParameters))]
@@ -603,8 +603,8 @@ public int CompletedEpisodes
/// for information about mixing reward signals from curiosity and Generative Adversarial
/// Imitation Learning (GAIL) with rewards supplied through this method.
///
- /// [Agents - Rewards]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md#rewards
- /// [Reward Signals]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/ML-Agents-Overview.md#a-quick-note-on-reward-signals
+ /// [Agents - Rewards]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#rewards
+ /// [Reward Signals]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/ML-Agents-Overview.md#a-quick-note-on-reward-signals
///
/// The new value of the reward.
public void SetReward(float reward)
@@ -633,8 +633,8 @@ public void SetReward(float reward)
/// for information about mixing reward signals from curiosity and Generative Adversarial
/// Imitation Learning (GAIL) with rewards supplied through this method.
///
- /// [Agents - Rewards]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md#rewards
- /// [Reward Signals]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/ML-Agents-Overview.md#a-quick-note-on-reward-signals
+ /// [Agents - Rewards]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#rewards
+ /// [Reward Signals]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/ML-Agents-Overview.md#a-quick-note-on-reward-signals
///
/// Incremental reward value.
public void AddReward(float increment)
@@ -790,8 +790,8 @@ public virtual void Initialize() {}
/// implementing a simple heuristic function can aid in debugging agent actions and interactions
/// with its environment.
///
- /// [Demonstration Recorder]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md#recording-demonstrations
- /// [Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md#actions
+ /// [Demonstration Recorder]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#recording-demonstrations
+ /// [Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#actions
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
///
///
@@ -996,7 +996,7 @@ void ResetSensors()
/// For more information about observations, see [Observations and Sensors].
///
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
- /// [Observations and Sensors]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md#observations-and-sensors
+ /// [Observations and Sensors]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#observations-and-sensors
///
public virtual void CollectObservations(VectorSensor sensor)
{
@@ -1027,7 +1027,7 @@ public ReadOnlyCollection GetObservations()
///
/// See [Agents - Actions] for more information on masking actions.
///
- /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md#actions
+ /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#actions
///
///
public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
@@ -1097,7 +1097,7 @@ public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker
///
/// For more information about implementing agent actions see [Agents - Actions].
///
- /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md#actions
+ /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#actions
///
///
/// An array containing the action vector. The length of the array is specified
diff --git a/com.unity.ml-agents/Runtime/AssemblyInfo.cs b/com.unity.ml-agents/Runtime/AssemblyInfo.cs
index 5a6e5ced39..4bc7a8bbb0 100644
--- a/com.unity.ml-agents/Runtime/AssemblyInfo.cs
+++ b/com.unity.ml-agents/Runtime/AssemblyInfo.cs
@@ -2,3 +2,4 @@
[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")]
[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor")]
+[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions")]
diff --git a/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs
index 5017149e1c..f10c1a29e8 100644
--- a/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs
+++ b/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs
@@ -19,7 +19,7 @@ namespace Unity.MLAgents.Demonstrations
/// See [Imitation Learning - Recording Demonstrations] for more information.
///
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
- /// [Imitation Learning - Recording Demonstrations]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs//Learning-Environment-Design-Agents.md#recording-demonstrations
+ /// [Imitation Learning - Recording Demonstrations]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs//Learning-Environment-Design-Agents.md#recording-demonstrations
///
[RequireComponent(typeof(Agent))]
[AddComponentMenu("ML Agents/Demonstration Recorder", (int)MenuGroup.Default)]
diff --git a/com.unity.ml-agents/Runtime/DiscreteActionMasker.cs b/com.unity.ml-agents/Runtime/DiscreteActionMasker.cs
index e8c8538640..1a9b322a98 100644
--- a/com.unity.ml-agents/Runtime/DiscreteActionMasker.cs
+++ b/com.unity.ml-agents/Runtime/DiscreteActionMasker.cs
@@ -40,7 +40,7 @@ internal DiscreteActionMasker(BrainParameters brainParameters)
///
/// See [Agents - Actions] for more information on masking actions.
///
- /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md#actions
+ /// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#actions
///
/// The branch for which the actions will be masked.
/// The indices of the masked actions.
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators.meta b/com.unity.ml-agents/Tests/Editor/Actuators.meta
new file mode 100644
index 0000000000..5c6399dc6c
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Actuators.meta
@@ -0,0 +1,8 @@
+fileFormatVersion: 2
+guid: c7e705f7d549e43c6be18ae809cd6f54
+folderAsset: yes
+DefaultImporter:
+ externalObjects: {}
+ userData:
+ assetBundleName:
+ assetBundleVariant:
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs
new file mode 100644
index 0000000000..874ef2e97a
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs
@@ -0,0 +1,55 @@
+using System;
+using NUnit.Framework;
+using Unity.MLAgents.Actuators;
+
+namespace Unity.MLAgents.Tests.Actuators
+{
+ [TestFixture]
+ public class ActionSegmentTests
+ {
+ [Test]
+ public void TestConstruction()
+ {
+ var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f };
+ Assert.Throws(
+ () => new ActionSegment(floatArray, 100, 1));
+
+ var segment = new ActionSegment(Array.Empty(), 0, 0);
+ Assert.AreEqual(segment, ActionSegment.Empty);
+ }
+ [Test]
+ public void TestIndexing()
+ {
+ var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f };
+ for (var i = 0; i < floatArray.Length; i++)
+ {
+ var start = 0 + i;
+ var length = floatArray.Length - i;
+ var actionSegment = new ActionSegment(floatArray, start, length);
+ for (var j = 0; j < actionSegment.Length; j++)
+ {
+ Assert.AreEqual(actionSegment[j], floatArray[start + j]);
+ }
+ }
+ }
+
+ [Test]
+ public void TestEnumerator()
+ {
+ var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f };
+ for (var i = 0; i < floatArray.Length; i++)
+ {
+ var start = 0 + i;
+ var length = floatArray.Length - i;
+ var actionSegment = new ActionSegment(floatArray, start, length);
+ var j = 0;
+ foreach (var item in actionSegment)
+ {
+ Assert.AreEqual(item, floatArray[start + j++]);
+ }
+ }
+ }
+
+ }
+
+}
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs.meta
new file mode 100644
index 0000000000..2332580c17
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 18cb6d052fba43a2b7437d87c0d9abad
+timeCreated: 1596486604
\ No newline at end of file
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs
new file mode 100644
index 0000000000..d3cead8ec0
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs
@@ -0,0 +1,114 @@
+using System;
+using System.Collections.Generic;
+using NUnit.Framework;
+using Unity.MLAgents.Actuators;
+
+namespace Unity.MLAgents.Tests.Actuators
+{
+ [TestFixture]
+ public class ActuatorDiscreteActionMaskTests
+ {
+ [Test]
+ public void Construction()
+ {
+ var masker = new ActuatorDiscreteActionMask(new List(), 0, 0);
+ Assert.IsNotNull(masker);
+ }
+
+ [Test]
+ public void NullMask()
+ {
+ var masker = new ActuatorDiscreteActionMask(new List(), 0, 0);
+ var mask = masker.GetMask();
+ Assert.IsNull(mask);
+ }
+
+ [Test]
+ public void FirstBranchMask()
+ {
+ var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
+ var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3);
+ var mask = masker.GetMask();
+ Assert.IsNull(mask);
+ masker.WriteMask(0, new[] {1, 2, 3});
+ mask = masker.GetMask();
+ Assert.IsFalse(mask[0]);
+ Assert.IsTrue(mask[1]);
+ Assert.IsTrue(mask[2]);
+ Assert.IsTrue(mask[3]);
+ Assert.IsFalse(mask[4]);
+ Assert.AreEqual(mask.Length, 15);
+ }
+
+ [Test]
+ public void SecondBranchMask()
+ {
+ var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
+ var masker = new ActuatorDiscreteActionMask(new[] {actuator1}, 15, 3);
+ masker.WriteMask(1, new[] {1, 2, 3});
+ var mask = masker.GetMask();
+ Assert.IsFalse(mask[0]);
+ Assert.IsFalse(mask[4]);
+ Assert.IsTrue(mask[5]);
+ Assert.IsTrue(mask[6]);
+ Assert.IsTrue(mask[7]);
+ Assert.IsFalse(mask[8]);
+ Assert.IsFalse(mask[9]);
+ }
+
+ [Test]
+ public void MaskReset()
+ {
+ var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
+ var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3);
+ masker.WriteMask(1, new[] {1, 2, 3});
+ masker.ResetMask();
+ var mask = masker.GetMask();
+ for (var i = 0; i < 15; i++)
+ {
+ Assert.IsFalse(mask[i]);
+ }
+ }
+
+ [Test]
+ public void ThrowsError()
+ {
+ var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
+ var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3);
+ Assert.Catch(
+ () => masker.WriteMask(0, new[] {5}));
+ Assert.Catch(
+ () => masker.WriteMask(1, new[] {5}));
+ masker.WriteMask(2, new[] {5});
+ Assert.Catch(
+ () => masker.WriteMask(3, new[] {1}));
+ masker.GetMask();
+ masker.ResetMask();
+ masker.WriteMask(0, new[] {0, 1, 2, 3});
+ Assert.Catch(
+ () => masker.GetMask());
+ }
+
+ [Test]
+ public void MultipleMaskEdit()
+ {
+ var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
+ var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3);
+ masker.WriteMask(0, new[] {0, 1});
+ masker.WriteMask(0, new[] {3});
+ masker.WriteMask(2, new[] {1});
+ var mask = masker.GetMask();
+ for (var i = 0; i < 15; i++)
+ {
+ if ((i == 0) || (i == 1) || (i == 3) || (i == 10))
+ {
+ Assert.IsTrue(mask[i]);
+ }
+ else
+ {
+ Assert.IsFalse(mask[i]);
+ }
+ }
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs.meta
new file mode 100644
index 0000000000..a5dd1f3ad9
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: b9f5f87049d04d8bba39d193a3ab2f5a
+timeCreated: 1596491682
\ No newline at end of file
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
new file mode 100644
index 0000000000..974c1fd35d
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
@@ -0,0 +1,310 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text.RegularExpressions;
+using NUnit.Framework;
+using Unity.MLAgents.Actuators;
+using Unity.MLAgents.Policies;
+using UnityEngine;
+using UnityEngine.TestTools;
+using Assert = UnityEngine.Assertions.Assert;
+
+namespace Unity.MLAgents.Tests.Actuators
+{
+ [TestFixture]
+ public class ActuatorManagerTests
+ {
+ [Test]
+ public void TestEnsureBufferSizeContinuous()
+ {
+ var manager = new ActuatorManager();
+ var actuator1 = new TestActuator(ActionSpec.MakeContinuous(10), "actuator1");
+ var actuator2 = new TestActuator(ActionSpec.MakeContinuous(2), "actuator2");
+ manager.Add(actuator1);
+ manager.Add(actuator2);
+ var actuator1ActionSpaceDef = actuator1.ActionSpec;
+ var actuator2ActionSpaceDef = actuator2.ActionSpec;
+ manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 },
+ actuator1ActionSpaceDef.NumContinuousActions + actuator2ActionSpaceDef.NumContinuousActions,
+ actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes,
+ actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions);
+
+ manager.UpdateActions(new[]
+ { 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f }, Array.Empty());
+
+ Assert.IsTrue(12 == manager.NumContinuousActions);
+ Assert.IsTrue(0 == manager.NumDiscreteActions);
+ Assert.IsTrue(0 == manager.SumOfDiscreteBranchSizes);
+ Assert.IsTrue(12 == manager.StoredContinuousActions.Length);
+ Assert.IsTrue(0 == manager.StoredDiscreteActions.Length);
+ }
+
+ [Test]
+ public void TestEnsureBufferDiscrete()
+ {
+ var manager = new ActuatorManager();
+ var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1 , 2, 3, 4}), "actuator1");
+ var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 1, 1}), "actuator2");
+ manager.Add(actuator1);
+ manager.Add(actuator2);
+ var actuator1ActionSpaceDef = actuator1.ActionSpec;
+ var actuator2ActionSpaceDef = actuator2.ActionSpec;
+ manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 },
+ actuator1ActionSpaceDef.NumContinuousActions + actuator2ActionSpaceDef.NumContinuousActions,
+ actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes,
+ actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions);
+
+ manager.UpdateActions(Array.Empty(),
+ new[] { 0, 1, 2, 3, 4, 5, 6});
+
+ Assert.IsTrue(0 == manager.NumContinuousActions);
+ Assert.IsTrue(7 == manager.NumDiscreteActions);
+ Assert.IsTrue(13 == manager.SumOfDiscreteBranchSizes);
+ Assert.IsTrue(0 == manager.StoredContinuousActions.Length);
+ Assert.IsTrue(7 == manager.StoredDiscreteActions.Length);
+ }
+
+ [Test]
+ public void TestFailOnMixedActionSpace()
+ {
+ var manager = new ActuatorManager();
+ var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1 , 2, 3, 4}), "actuator1");
+ var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2");
+ manager.Add(actuator1);
+ manager.Add(actuator2);
+ manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, 3, 10, 4);
+ LogAssert.Expect(LogType.Assert, "Actuators on the same Agent must have the same action SpaceType.");
+ }
+
+ [Test]
+ public void TestFailOnSameActuatorName()
+ {
+ var manager = new ActuatorManager();
+ var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator1");
+ var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator1");
+ manager.Add(actuator1);
+ manager.Add(actuator2);
+ manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, 3, 10, 4);
+ LogAssert.Expect(LogType.Assert, "Actuator names must be unique.");
+ }
+
+ [Test]
+ public void TestExecuteActionsDiscrete()
+ {
+ var manager = new ActuatorManager();
+ var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1 , 2, 3, 4}), "actuator1");
+ var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 1, 1}), "actuator2");
+ manager.Add(actuator1);
+ manager.Add(actuator2);
+
+ var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5, 6};
+ manager.UpdateActions(Array.Empty(),
+ discreteActionBuffer);
+
+ manager.ExecuteActions();
+ var actuator1Actions = actuator1.LastActionBuffer.DiscreteActions;
+ var actuator2Actions = actuator2.LastActionBuffer.DiscreteActions;
+ TestSegmentEquality(actuator1Actions, discreteActionBuffer); TestSegmentEquality(actuator2Actions, discreteActionBuffer);
+ }
+
+ [Test]
+ public void TestExecuteActionsContinuous()
+ {
+ var manager = new ActuatorManager();
+ var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3),
+ "actuator1");
+ var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2");
+ manager.Add(actuator1);
+ manager.Add(actuator2);
+
+ var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
+ manager.UpdateActions(continuousActionBuffer,
+ Array.Empty());
+
+ manager.ExecuteActions();
+ var actuator1Actions = actuator1.LastActionBuffer.ContinuousActions;
+ var actuator2Actions = actuator2.LastActionBuffer.ContinuousActions;
+ TestSegmentEquality(actuator1Actions, continuousActionBuffer);
+ TestSegmentEquality(actuator2Actions, continuousActionBuffer);
+ }
+
+ static void TestSegmentEquality(ActionSegment actionSegment, T[] actionBuffer)
+ where T : struct
+ {
+ Assert.IsFalse(actionSegment.Length == 0);
+ for (var i = 0; i < actionSegment.Length; i++)
+ {
+ var action = actionSegment[i];
+ Assert.AreEqual(action, actionBuffer[actionSegment.Offset + i]);
+ }
+ }
+
+ [Test]
+ public void TestUpdateActionsContinuous()
+ {
+ var manager = new ActuatorManager();
+ var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3),
+ "actuator1");
+ var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2");
+ manager.Add(actuator1);
+ manager.Add(actuator2);
+ var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
+ manager.UpdateActions(continuousActionBuffer,
+ Array.Empty());
+
+ Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer));
+ }
+
+ [Test]
+ public void TestUpdateActionsDiscrete()
+ {
+ var manager = new ActuatorManager();
+ var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }),
+ "actuator1");
+ var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
+ manager.Add(actuator1);
+ manager.Add(actuator2);
+ var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5};
+ manager.UpdateActions(Array.Empty(),
+ discreteActionBuffer);
+
+ Debug.Log(manager.StoredDiscreteActions);
+ Debug.Log(discreteActionBuffer);
+ Assert.IsTrue(manager.StoredDiscreteActions.SequenceEqual(discreteActionBuffer));
+ }
+
+ [Test]
+ public void TestRemove()
+ {
+ var manager = new ActuatorManager();
+ var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }),
+ "actuator1");
+ var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
+
+ manager.Add(actuator1);
+ manager.Add(actuator2);
+ Assert.IsTrue(manager.NumDiscreteActions == 6);
+ Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 12);
+
+ manager.Remove(actuator2);
+
+ Assert.IsTrue(manager.NumDiscreteActions == 3);
+ Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6);
+
+ manager.Remove(null);
+
+ Assert.IsTrue(manager.NumDiscreteActions == 3);
+ Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6);
+
+ manager.RemoveAt(0);
+ Assert.IsTrue(manager.NumDiscreteActions == 0);
+ Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 0);
+ }
+
+ [Test]
+ public void TestClear()
+ {
+ var manager = new ActuatorManager();
+ var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }),
+ "actuator1");
+ var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
+ manager.Add(actuator1);
+ manager.Add(actuator2);
+
+ Assert.IsTrue(manager.NumDiscreteActions == 6);
+ Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 12);
+
+ manager.Clear();
+
+ Assert.IsTrue(manager.NumDiscreteActions == 0);
+ Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 0);
+ }
+
+ [Test]
+ public void TestIndexSet()
+ {
+ var manager = new ActuatorManager();
+ var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4}),
+ "actuator1");
+ var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
+ manager.Add(actuator1);
+ Assert.IsTrue(manager.NumDiscreteActions == 4);
+ Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 10);
+ manager[0] = actuator2;
+ Assert.IsTrue(manager.NumDiscreteActions == 3);
+ Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6);
+ }
+
+ [Test]
+ public void TestInsert()
+ {
+ var manager = new ActuatorManager();
+ var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4}),
+ "actuator1");
+ var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
+ manager.Add(actuator1);
+ Assert.IsTrue(manager.NumDiscreteActions == 4);
+ Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 10);
+ manager.Insert(0, actuator2);
+ Assert.IsTrue(manager.NumDiscreteActions == 7);
+ Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 16);
+ Assert.IsTrue(manager.IndexOf(actuator2) == 0);
+ }
+
+ [Test]
+ public void TestResetData()
+ {
+ var manager = new ActuatorManager();
+ var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3),
+ "actuator1");
+ var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2");
+ manager.Add(actuator1);
+ manager.Add(actuator2);
+ var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
+ manager.UpdateActions(continuousActionBuffer,
+ Array.Empty());
+
+ Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer));
+ Assert.IsTrue(manager.NumContinuousActions == 6);
+ manager.ResetData();
+
+ Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(new[] { 0f, 0f, 0f, 0f, 0f, 0f}));
+ }
+
+ [Test]
+ public void TestWriteDiscreteActionMask()
+ {
+ var manager = new ActuatorManager(2);
+ var va1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "name");
+ var va2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {3, 2, 1}), "name1");
+ manager.Add(va1);
+ manager.Add(va2);
+
+ var groundTruthMask = new[]
+ {
+ false,
+ true, false,
+ false, true, true,
+ true, false, true,
+ false, true,
+ false
+ };
+
+ va1.Masks = new[]
+ {
+ Array.Empty(),
+ new[] { 0 },
+ new[] { 1, 2 }
+ };
+
+ va2.Masks = new[]
+ {
+ new[] {0, 2},
+ new[] {1},
+ Array.Empty()
+ };
+ manager.WriteActionMask();
+ Assert.IsTrue(groundTruthMask.SequenceEqual(manager.DiscreteActionMask.GetMask()));
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs.meta
new file mode 100644
index 0000000000..4946ff19fb
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: d48ba72f0ac64d7db0af22c9d82b11d8
+timeCreated: 1596494279
\ No newline at end of file
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs
new file mode 100644
index 0000000000..a8990b6100
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs
@@ -0,0 +1,38 @@
+using Unity.MLAgents.Actuators;
+namespace Unity.MLAgents.Tests.Actuators
+{
+ internal class TestActuator : IActuator
+ {
+ public ActionBuffers LastActionBuffer;
+ public int[][] Masks;
+ public TestActuator(ActionSpec actuatorSpace, string name)
+ {
+ ActionSpec = actuatorSpace;
+ TotalNumberOfActions = actuatorSpace.NumContinuousActions +
+ actuatorSpace.NumDiscreteActions;
+ Name = name;
+ }
+
+ public void OnActionReceived(ActionBuffers actionBuffers)
+ {
+ LastActionBuffer = actionBuffers;
+ }
+
+ public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
+ {
+ for (var i = 0; i < Masks.Length; i++)
+ {
+ actionMask.WriteMask(i, Masks[i]);
+ }
+ }
+
+ public int TotalNumberOfActions { get; }
+ public ActionSpec ActionSpec { get; }
+
+ public string Name { get; }
+
+ public void ResetData()
+ {
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs.meta b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs.meta
new file mode 100644
index 0000000000..57e13a0e26
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: fa950d7b175749bfa287fd8761dd831f
+timeCreated: 1596665978
\ No newline at end of file
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs
new file mode 100644
index 0000000000..e80daa4010
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs
@@ -0,0 +1,98 @@
+using System.Collections.Generic;
+using System.Linq;
+using NUnit.Framework;
+using Unity.MLAgents.Actuators;
+using Unity.MLAgents.Policies;
+using Assert = UnityEngine.Assertions.Assert;
+
+namespace Unity.MLAgents.Tests.Actuators
+{
+ [TestFixture]
+ public class VectorActuatorTests
+ {
+ class TestActionReceiver : IActionReceiver
+ {
+ public ActionBuffers LastActionBuffers;
+ public int Branch;
+ public IList Mask;
+ public ActionSpec ActionSpec { get; }
+
+ public void OnActionReceived(ActionBuffers actionBuffers)
+ {
+ LastActionBuffers = actionBuffers;
+ }
+
+ public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
+ {
+ actionMask.WriteMask(Branch, Mask);
+ }
+ }
+
+ [Test]
+ public void TestConstruct()
+ {
+ var ar = new TestActionReceiver();
+ var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name");
+
+ Assert.IsTrue(va.ActionSpec.NumDiscreteActions == 3);
+ Assert.IsTrue(va.ActionSpec.SumOfDiscreteBranchSizes == 6);
+ Assert.IsTrue(va.ActionSpec.NumContinuousActions == 0);
+
+ var va1 = new VectorActuator(ar, new[] {4}, SpaceType.Continuous, "name");
+
+ Assert.IsTrue(va1.ActionSpec.NumContinuousActions == 4);
+ Assert.IsTrue(va1.ActionSpec.SumOfDiscreteBranchSizes == 0);
+ Assert.AreEqual(va1.Name, "name-Continuous");
+ }
+
+ [Test]
+ public void TestOnActionReceived()
+ {
+ var ar = new TestActionReceiver();
+ var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name");
+
+ var discreteActions = new[] { 0, 1, 1 };
+ var ab = new ActionBuffers(ActionSegment.Empty,
+ new ActionSegment(discreteActions, 0, 3));
+
+ va.OnActionReceived(ab);
+
+ Assert.AreEqual(ar.LastActionBuffers, ab);
+ va.ResetData();
+ Assert.AreEqual(va.ActionBuffers.ContinuousActions, ActionSegment.Empty);
+ Assert.AreEqual(va.ActionBuffers.DiscreteActions, ActionSegment.Empty);
+ }
+
+ [Test]
+ public void TestResetData()
+ {
+ var ar = new TestActionReceiver();
+ var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name");
+
+ var discreteActions = new[] { 0, 1, 1 };
+ var ab = new ActionBuffers(ActionSegment.Empty,
+ new ActionSegment(discreteActions, 0, 3));
+
+ va.OnActionReceived(ab);
+ }
+
+ [Test]
+ public void TestWriteDiscreteActionMask()
+ {
+ var ar = new TestActionReceiver();
+ var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name");
+ var bdam = new ActuatorDiscreteActionMask(new[] {va}, 6, 3);
+
+ var groundTruthMask = new[] { false, true, false, false, true, true };
+
+ ar.Branch = 1;
+ ar.Mask = new[] { 0 };
+ va.WriteDiscreteActionMask(bdam);
+ ar.Branch = 2;
+ ar.Mask = new[] { 1, 2 };
+ va.WriteDiscreteActionMask(bdam);
+
+ Assert.IsTrue(groundTruthMask.SequenceEqual(bdam.GetMask()));
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs.meta b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs.meta
new file mode 100644
index 0000000000..2a5a86efd0
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: c2b191d2929f49adab0769705d49d86a
+timeCreated: 1596580289
\ No newline at end of file
diff --git a/com.unity.ml-agents/package.json b/com.unity.ml-agents/package.json
index 4fb78dab61..a6a17115e8 100755
--- a/com.unity.ml-agents/package.json
+++ b/com.unity.ml-agents/package.json
@@ -1,7 +1,7 @@
{
"name": "com.unity.ml-agents",
"displayName": "ML Agents",
- "version": "1.2.0-preview",
+ "version": "1.3.0-preview",
"unity": "2018.4",
"description": "Use state-of-the-art machine learning to create intelligent character behaviors in any Unity environment (games, robotics, film, etc.).",
"dependencies": {
diff --git a/docs/Installation-Anaconda-Windows.md b/docs/Installation-Anaconda-Windows.md
index c4c824c797..7ca0a1d63c 100644
--- a/docs/Installation-Anaconda-Windows.md
+++ b/docs/Installation-Anaconda-Windows.md
@@ -123,10 +123,10 @@ commands in an Anaconda Prompt _(if you open a new prompt, be sure to activate
the ml-agents Conda environment by typing `activate ml-agents`)_:
```sh
-git clone --branch release_4 https://github.com/Unity-Technologies/ml-agents.git
+git clone --branch release_5 https://github.com/Unity-Technologies/ml-agents.git
```
-The `--branch release_4` option will switch to the tag of the latest stable
+The `--branch release_5` option will switch to the tag of the latest stable
release. Omitting that will get the `master` branch which is potentially
unstable.
diff --git a/docs/Installation.md b/docs/Installation.md
index a3c1f41c38..917d8994e9 100644
--- a/docs/Installation.md
+++ b/docs/Installation.md
@@ -58,10 +58,10 @@ example environments and training configurations to experiment with them (some
of our tutorials / guides assume you have access to our example environments).
```sh
-git clone --branch release_4 https://github.com/Unity-Technologies/ml-agents.git
+git clone --branch release_5 https://github.com/Unity-Technologies/ml-agents.git
```
-The `--branch release_4` option will switch to the tag of the latest stable
+The `--branch release_5` option will switch to the tag of the latest stable
release. Omitting that will get the `master` branch which is potentially
unstable.
@@ -69,7 +69,7 @@ unstable.
You will need to clone the repository if you plan to modify or extend the
ML-Agents Toolkit for your purposes. If you plan to contribute those changes
-back, make sure to clone the `master` branch (by omitting `--branch release_4`
+back, make sure to clone the `master` branch (by omitting `--branch release_5`
from the command above). See our
[Contributions Guidelines](../com.unity.ml-agents/CONTRIBUTING.md) for more
information on contributing to the ML-Agents Toolkit.
diff --git a/docs/Training-Configuration-File.md b/docs/Training-Configuration-File.md
index a8f3259944..721923fb76 100644
--- a/docs/Training-Configuration-File.md
+++ b/docs/Training-Configuration-File.md
@@ -58,7 +58,7 @@ the `trainer` setting above).
| `hyperparameters -> beta` | (default = `5.0e-3`) Strength of the entropy regularization, which makes the policy "more random." This ensures that agents properly explore the action space during training. Increasing this will ensure more random actions are taken. This should be adjusted such that the entropy (measurable from TensorBoard) slowly decreases alongside increases in reward. If entropy drops too quickly, increase beta. If entropy drops too slowly, decrease `beta`.
Typical range: `1e-4` - `1e-2` |
| `hyperparameters -> epsilon` | (default = `0.2`) Influences how rapidly the policy can evolve during training. Corresponds to the acceptable threshold of divergence between the old and new policies during gradient descent updating. Setting this value small will result in more stable updates, but will also slow the training process.
Typical range: `0.1` - `0.3` |
| `hyperparameters -> lambd` | (default = `0.95`) Regularization parameter (lambda) used when calculating the Generalized Advantage Estimate ([GAE](https://arxiv.org/abs/1506.02438)). This can be thought of as how much the agent relies on its current value estimate when calculating an updated value estimate. Low values correspond to relying more on the current value estimate (which can be high bias), and high values correspond to relying more on the actual rewards received in the environment (which can be high variance). The parameter provides a trade-off between the two, and the right value can lead to a more stable training process.
Typical range: `0.9` - `0.95` |
-| `hyperparameters -> num_epoch` | Number of passes to make through the experience buffer when performing gradient descent optimization.The larger the batch_size, the larger it is acceptable to make this. Decreasing this will ensure more stable updates, at the cost of slower learning.
Typical range: `3` - `10` |
+| `hyperparameters -> num_epoch` | (default = `3`) Number of passes to make through the experience buffer when performing gradient descent optimization.The larger the batch_size, the larger it is acceptable to make this. Decreasing this will ensure more stable updates, at the cost of slower learning.
Typical range: `3` - `10` |
### SAC-specific Configurations
diff --git a/docs/Training-ML-Agents.md b/docs/Training-ML-Agents.md
index 8262d31da7..2a1be41746 100644
--- a/docs/Training-ML-Agents.md
+++ b/docs/Training-ML-Agents.md
@@ -433,7 +433,86 @@ if we wanted to train the 3D ball agent with parameter randomization, we would r
mlagents-learn config/ppo/3DBall_randomize.yaml --run-id=3D-Ball-randomize
```
-We can observe progress and metrics via Tensorboard.
+We can observe progress and metrics via TensorBoard.
+
+#### Curriculum
+
+To enable curriculum learning, you need to add a `curriculum` sub-section to your environment
+parameter. Here is one example with the environment parameter `my_environment_parameter` :
+
+```yml
+behaviors:
+ BehaviorY:
+ # < Same as above >
+
+# Add this section
+environment_parameters:
+ my_environment_parameter:
+ curriculum:
+ - name: MyFirstLesson # The '-' is important as this is a list
+ completion_criteria:
+ measure: progress
+ behavior: my_behavior
+ signal_smoothing: true
+ min_lesson_length: 100
+ threshold: 0.2
+ value: 0.0
+ - name: MySecondLesson # This is the start of the second lesson
+ completion_criteria:
+ measure: progress
+ behavior: my_behavior
+ signal_smoothing: true
+ min_lesson_length: 100
+ threshold: 0.6
+ require_reset: true
+ value:
+ sampler_type: uniform
+ sampler_parameters:
+ min_value: 4.0
+ max_value: 7.0
+ - name: MyLastLesson
+ value: 8.0
+```
+
+Note that this curriculum __only__ applies to `my_environment_parameter`. The `curriculum` section
+contains a list of `Lessons`. In the example, the lessons are named `MyFirstLesson`, `MySecondLesson`
+and `MyLastLesson`.
+Each `Lesson` has 3 fields :
+
+ - `name` which is a user defined name for the lesson (The name of the lesson will be displayed in
+ the console when the lesson changes)
+ - `completion_criteria` which determines what needs to happen in the simulation before the lesson
+ can be considered complete. When that condition is met, the curriculum moves on to the next
+ `Lesson`. Note that you do not need to specify a `completion_criteria` for the last `Lesson`
+ - `value` which is the value the environment parameter will take during the lesson. Note that this
+ can be a float or a sampler.
+
+ There are the different settings of the `completion_criteria` :
+
+
+| **Setting** | **Description** |
+| :------------------ | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `measure` | What to measure learning progress, and advancement in lessons by.
`reward` uses a measure received reward, while `progress` uses the ratio of steps/max_steps. |
+| `behavior` | Specifies which behavior is being tracked. There can be multiple behaviors with different names, each at different points of training. This setting allows the curriculum to track only one of them. |
+| `threshold` | Determines at what point in value of `measure` the lesson should be increased. |
+| `min_lesson_length` | The minimum number of episodes that should be completed before the lesson can change. If `measure` is set to `reward`, the average cumulative reward of the last `min_lesson_length` episodes will be used to determine if the lesson should change. Must be nonnegative.
**Important**: the average reward that is compared to the thresholds is different than the mean reward that is logged to the console. For example, if `min_lesson_length` is `100`, the lesson will increment after the average cumulative reward of the last `100` episodes exceeds the current threshold. The mean reward logged to the console is dictated by the `summary_freq` parameter defined above. |
+| `signal_smoothing` | Whether to weight the current progress measure by previous values. |
+| `require_reset` | Whether changing lesson requires the environment to reset (default: false) |
+##### Training with a Curriculum
+
+Once we have specified our metacurriculum and curricula, we can launch
+`mlagents-learn` to point to the config file containing
+our curricula and PPO will train using Curriculum Learning. For example, to
+train agents in the Wall Jump environment with curriculum learning, we can run:
+
+```sh
+mlagents-learn config/ppo/WallJump_curriculum.yaml --run-id=wall-jump-curriculum
+```
+
+We can then keep track of the current lessons and progresses via TensorBoard. If you've terminated
+the run, you can resume it using `--resume` and lesson progress will start off where it
+ended.
+
#### Curriculum
diff --git a/docs/Training-on-Amazon-Web-Service.md b/docs/Training-on-Amazon-Web-Service.md
index b41e0006f7..bc78e2a28b 100644
--- a/docs/Training-on-Amazon-Web-Service.md
+++ b/docs/Training-on-Amazon-Web-Service.md
@@ -69,7 +69,7 @@ After launching your EC2 instance using the ami and ssh into it:
2. Clone the ML-Agents repo and install the required Python packages
```sh
- git clone --branch release_4 https://github.com/Unity-Technologies/ml-agents.git
+ git clone --branch release_5 https://github.com/Unity-Technologies/ml-agents.git
cd ml-agents/ml-agents/
pip3 install -e .
```
diff --git a/docs/Unity-Inference-Engine.md b/docs/Unity-Inference-Engine.md
index 213136d1a5..a4adfee9a0 100644
--- a/docs/Unity-Inference-Engine.md
+++ b/docs/Unity-Inference-Engine.md
@@ -7,9 +7,6 @@ your Unity games. This support is possible thanks to the
[compute shaders](https://docs.unity3d.com/Manual/class-ComputeShader.html) to
run the neural network within Unity.
-**Note**: The ML-Agents Toolkit only supports the models created with our
-trainers.
-
## Supported devices
See the Unity Inference Engine documentation for a list of the
@@ -45,3 +42,22 @@ use for Inference.
**Note:** For most of the models generated with the ML-Agents Toolkit, CPU will
be faster than GPU. You should use the GPU only if you use the ResNet visual
encoder or have a large number of agents with visual observations.
+
+# Unsupported use cases
+## Externally trained models
+The ML-Agents Toolkit only supports the models created with our trainers. Model
+loading expects certain conventions for constants and tensor names. While it is
+possible to construct a model that follows these conventions, we don't provide
+any additional help for this. More details can be found in
+[TensorNames.cs](https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/com.unity.ml-agents/Runtime/Inference/TensorNames.cs)
+and
+[BarracudaModelParamLoader.cs](https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs).
+
+If you wish to run inference on an externally trained model, you should use
+Barracuda directly, instead of trying to run it through ML-Agents.
+
+## Model inference outside of Unity
+We do not provide support for inference anywhere outside of Unity. The
+`frozen_graph_def.pb` and `.onnx` files produced by training are open formats
+for TensorFlow and ONNX respectively; if you wish to convert these to another
+format or run inference with them, refer to their documentation.
diff --git a/docs/Using-Tensorboard.md b/docs/Using-Tensorboard.md
index 1ea2fcbaad..16984df76d 100644
--- a/docs/Using-Tensorboard.md
+++ b/docs/Using-Tensorboard.md
@@ -119,9 +119,15 @@ The ML-Agents training program saves the following statistics:
skill level between two players. In a proper training run, the ELO of the
agent should steadily increase.
+## Exporting Data from TensorBoard
+To export timeseries data in CSV or JSON format, check the "Show data download
+links" in the upper left. This will enable download links below each chart.
+
+![Example TensorBoard Run](images/TensorBoard-download.png)
+
## Custom Metrics from Unity
-To get custom metrics from a C# environment into Tensorboard, you can use the
+To get custom metrics from a C# environment into TensorBoard, you can use the
`StatsRecorder`:
```csharp
diff --git a/docs/images/TensorBoard-download.png b/docs/images/TensorBoard-download.png
new file mode 100644
index 0000000000..d5f38a17b2
Binary files /dev/null and b/docs/images/TensorBoard-download.png differ
diff --git a/gym-unity/gym_unity/__init__.py b/gym-unity/gym_unity/__init__.py
index 42899e704c..5ad197dd8d 100644
--- a/gym-unity/gym_unity/__init__.py
+++ b/gym-unity/gym_unity/__init__.py
@@ -1,5 +1,5 @@
# Version of the library that will be used to upload to pypi
-__version__ = "0.19.0.dev0"
+__version__ = "0.20.0.dev0"
# Git tag that will be checked to determine whether to trigger upload to pypi
__release_tag__ = None
diff --git a/ml-agents-envs/mlagents_envs/__init__.py b/ml-agents-envs/mlagents_envs/__init__.py
index 42899e704c..5ad197dd8d 100644
--- a/ml-agents-envs/mlagents_envs/__init__.py
+++ b/ml-agents-envs/mlagents_envs/__init__.py
@@ -1,5 +1,5 @@
# Version of the library that will be used to upload to pypi
-__version__ = "0.19.0.dev0"
+__version__ = "0.20.0.dev0"
# Git tag that will be checked to determine whether to trigger upload to pypi
__release_tag__ = None
diff --git a/ml-agents-envs/setup.py b/ml-agents-envs/setup.py
index be9cab13e6..9e4a454059 100644
--- a/ml-agents-envs/setup.py
+++ b/ml-agents-envs/setup.py
@@ -48,7 +48,7 @@ def run(self):
install_requires=[
"cloudpickle",
"grpcio>=1.11.0",
- "numpy>=1.14.1,<2.0",
+ "numpy>=1.14.1,<1.19.0",
"Pillow>=4.2.1",
"protobuf>=3.6",
"pyyaml>=3.1.0",
diff --git a/ml-agents/mlagents/model_serialization.py b/ml-agents/mlagents/model_serialization.py
index edc7a5f6ee..11714c3ec2 100644
--- a/ml-agents/mlagents/model_serialization.py
+++ b/ml-agents/mlagents/model_serialization.py
@@ -1,5 +1,6 @@
from distutils.util import strtobool
import os
+import shutil
from typing import Any, List, Set, NamedTuple
from distutils.version import LooseVersion
@@ -227,3 +228,20 @@ def _enforce_onnx_conversion() -> bool:
return strtobool(val)
except Exception:
return False
+
+
+def copy_model_files(source_nn_path: str, destination_nn_path: str) -> None:
+ """
+ Copy the .nn file at the given source to the destination.
+ Also copies the corresponding .onnx file if it exists.
+ """
+ shutil.copyfile(source_nn_path, destination_nn_path)
+ logger.info(f"Copied {source_nn_path} to {destination_nn_path}.")
+ # Copy the onnx file if it exists
+ source_onnx_path = os.path.splitext(source_nn_path)[0] + ".onnx"
+ destination_onnx_path = os.path.splitext(destination_nn_path)[0] + ".onnx"
+ try:
+ shutil.copyfile(source_onnx_path, destination_onnx_path)
+ logger.info(f"Copied {source_onnx_path} to {destination_onnx_path}.")
+ except OSError:
+ pass
diff --git a/ml-agents/mlagents/trainers/__init__.py b/ml-agents/mlagents/trainers/__init__.py
index 42899e704c..5ad197dd8d 100644
--- a/ml-agents/mlagents/trainers/__init__.py
+++ b/ml-agents/mlagents/trainers/__init__.py
@@ -1,5 +1,5 @@
# Version of the library that will be used to upload to pypi
-__version__ = "0.19.0.dev0"
+__version__ = "0.20.0.dev0"
# Git tag that will be checked to determine whether to trigger upload to pypi
__release_tag__ = None
diff --git a/ml-agents/mlagents/trainers/environment_parameter_manager.py b/ml-agents/mlagents/trainers/environment_parameter_manager.py
index 232dd0fb83..448bc2c28d 100644
--- a/ml-agents/mlagents/trainers/environment_parameter_manager.py
+++ b/ml-agents/mlagents/trainers/environment_parameter_manager.py
@@ -131,7 +131,7 @@ def update_lessons(
lesson = settings.curriculum[lesson_num]
if (
lesson.completion_criteria is not None
- and len(settings.curriculum) > lesson_num
+ and len(settings.curriculum) > lesson_num + 1
):
behavior_to_consider = lesson.completion_criteria.behavior
if behavior_to_consider in trainer_steps:
diff --git a/ml-agents/mlagents/trainers/exception.py b/ml-agents/mlagents/trainers/exception.py
index a2a77a60b2..3c0742bcec 100644
--- a/ml-agents/mlagents/trainers/exception.py
+++ b/ml-agents/mlagents/trainers/exception.py
@@ -19,6 +19,14 @@ class TrainerConfigError(Exception):
pass
+class TrainerConfigWarning(Warning):
+ """
+ Any warning related to the configuration of trainers in the ML-Agents Toolkit.
+ """
+
+ pass
+
+
class CurriculumError(TrainerError):
"""
Any error related to training with a curriculum.
diff --git a/ml-agents/mlagents/trainers/learn.py b/ml-agents/mlagents/trainers/learn.py
index d18b884f7a..16d1931d99 100644
--- a/ml-agents/mlagents/trainers/learn.py
+++ b/ml-agents/mlagents/trainers/learn.py
@@ -15,7 +15,6 @@
from mlagents.trainers.trainer_util import TrainerFactory, handle_existing_directories
from mlagents.trainers.stats import (
TensorboardWriter,
- CSVWriter,
StatsReporter,
GaugeWriter,
ConsoleWriter,
@@ -92,22 +91,13 @@ def run_training(run_seed: int, options: RunOptions) -> None:
os.path.join(run_logs_dir, "training_status.json")
)
- # Configure CSV, Tensorboard Writers and StatsReporter
- # We assume reward and episode length are needed in the CSV.
- csv_writer = CSVWriter(
- write_path,
- required_fields=[
- "Environment/Cumulative Reward",
- "Environment/Episode Length",
- ],
- )
+ # Configure Tensorboard Writers and StatsReporter
tb_writer = TensorboardWriter(
write_path, clear_past_data=not checkpoint_settings.resume
)
gauge_write = GaugeWriter()
console_writer = ConsoleWriter()
StatsReporter.add_writer(tb_writer)
- StatsReporter.add_writer(csv_writer)
StatsReporter.add_writer(gauge_write)
StatsReporter.add_writer(console_writer)
@@ -287,9 +277,11 @@ def run_cli(options: RunOptions) -> None:
add_timer_metadata("mlagents_envs_version", mlagents_envs.__version__)
add_timer_metadata("communication_protocol_version", UnityEnvironment.API_VERSION)
add_timer_metadata("tensorflow_version", tf_utils.tf.__version__)
+ add_timer_metadata("numpy_version", np.__version__)
if options.env_settings.seed == -1:
run_seed = np.random.randint(0, 10000)
+ logger.info(f"run_seed set to {run_seed}")
run_training(run_seed, options)
diff --git a/ml-agents/mlagents/trainers/settings.py b/ml-agents/mlagents/trainers/settings.py
index bb170349ec..1eb384243e 100644
--- a/ml-agents/mlagents/trainers/settings.py
+++ b/ml-agents/mlagents/trainers/settings.py
@@ -1,3 +1,5 @@
+import warnings
+
import attr
import cattr
from typing import Dict, Optional, List, Any, DefaultDict, Mapping, Tuple, Union
@@ -10,7 +12,7 @@
from mlagents.trainers.cli_utils import StoreConfigFile, DetectDefault, parser
from mlagents.trainers.cli_utils import load_config
-from mlagents.trainers.exception import TrainerConfigError
+from mlagents.trainers.exception import TrainerConfigError, TrainerConfigWarning
from mlagents_envs import logging_util
from mlagents_envs.side_channel.environment_parameters_channel import (
@@ -450,7 +452,7 @@ class EnvironmentParameterSettings:
def _check_lesson_chain(lessons, parameter_name):
"""
Ensures that when using curriculum, all non-terminal lessons have a valid
- CompletionCriteria
+ CompletionCriteria, and that the terminal lesson does not contain a CompletionCriteria.
"""
num_lessons = len(lessons)
for index, lesson in enumerate(lessons):
@@ -458,6 +460,12 @@ def _check_lesson_chain(lessons, parameter_name):
raise TrainerConfigError(
f"A non-terminal lesson does not have a completion_criteria for {parameter_name}."
)
+ if index == num_lessons - 1 and lesson.completion_criteria is not None:
+ warnings.warn(
+ f"Your final lesson definition contains completion_criteria for {parameter_name}."
+ f"It will be ignored.",
+ TrainerConfigWarning,
+ )
@staticmethod
def structure(d: Mapping, t: type) -> Dict[str, "EnvironmentParameterSettings"]:
diff --git a/ml-agents/mlagents/trainers/stats.py b/ml-agents/mlagents/trainers/stats.py
index 63fdb571a0..75655a17d6 100644
--- a/ml-agents/mlagents/trainers/stats.py
+++ b/ml-agents/mlagents/trainers/stats.py
@@ -3,7 +3,6 @@
from typing import List, Dict, NamedTuple, Any, Optional
import numpy as np
import abc
-import csv
import os
import time
from threading import RLock
@@ -49,8 +48,7 @@ def add_property(
"""
Add a generic property to the StatsWriter. This could be e.g. a Dict of hyperparameters,
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
- with all types of properties. For instance, a TB writer doesn't need a max step, nor should
- we write hyperparameters to the CSV.
+ with all types of properties. For instance, a TB writer doesn't need a max step.
:param category: The category that the property belongs to.
:param type: The type of property.
:param value: The property itself.
@@ -241,58 +239,6 @@ def _dict_to_tensorboard(
return None
-class CSVWriter(StatsWriter):
- def __init__(self, base_dir: str, required_fields: List[str] = None):
- """
- A StatsWriter that writes to a Tensorboard summary.
- :param base_dir: The directory within which to place the CSV file, which will be {base_dir}/{category}.csv.
- :param required_fields: If provided, the CSV writer won't write until these fields have statistics to write for
- them.
- """
- # We need to keep track of the fields in the CSV, as all rows need the same fields.
- self.csv_fields: Dict[str, List[str]] = {}
- self.required_fields = required_fields if required_fields else []
- self.base_dir: str = base_dir
-
- def write_stats(
- self, category: str, values: Dict[str, StatsSummary], step: int
- ) -> None:
- if self._maybe_create_csv_file(category, list(values.keys())):
- row = [str(step)]
- # Only record the stats that showed up in the first valid row
- for key in self.csv_fields[category]:
- _val = values.get(key, None)
- row.append(str(_val.mean) if _val else "None")
- with open(self._get_filepath(category), "a") as file:
- writer = csv.writer(file)
- writer.writerow(row)
-
- def _maybe_create_csv_file(self, category: str, keys: List[str]) -> bool:
- """
- If no CSV file exists and the keys have the required values,
- make the CSV file and write hte title row.
- Returns True if there is now (or already is) a valid CSV file.
- """
- if category not in self.csv_fields:
- summary_dir = self.base_dir
- os.makedirs(summary_dir, exist_ok=True)
- # Only store if the row contains the required fields
- if all(item in keys for item in self.required_fields):
- self.csv_fields[category] = keys
- with open(self._get_filepath(category), "w") as file:
- title_row = ["Steps"]
- title_row.extend(keys)
- writer = csv.writer(file)
- writer.writerow(title_row)
- return True
- return False
- return True
-
- def _get_filepath(self, category: str) -> str:
- file_dir = os.path.join(self.base_dir, category + ".csv")
- return file_dir
-
-
class StatsReporter:
writers: List[StatsWriter] = []
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list))
@@ -316,8 +262,7 @@ def add_property(self, property_type: StatsPropertyType, value: Any) -> None:
"""
Add a generic property to the StatsReporter. This could be e.g. a Dict of hyperparameters,
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
- with all types of properties. For instance, a TB writer doesn't need a max step, nor should
- we write hyperparameters to the CSV.
+ with all types of properties. For instance, a TB writer doesn't need a max step.
:param key: The type of property.
:param value: The property itself.
"""
diff --git a/ml-agents/mlagents/trainers/tests/test_env_param_manager.py b/ml-agents/mlagents/trainers/tests/test_env_param_manager.py
index b8fb92e15e..aea072617c 100644
--- a/ml-agents/mlagents/trainers/tests/test_env_param_manager.py
+++ b/ml-agents/mlagents/trainers/tests/test_env_param_manager.py
@@ -2,7 +2,7 @@
import yaml
-from mlagents.trainers.exception import TrainerConfigError
+from mlagents.trainers.exception import TrainerConfigError, TrainerConfigWarning
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.settings import (
RunOptions,
@@ -154,6 +154,41 @@ def test_curriculum_conversion():
"""
+test_bad_curriculum_all_competion_criteria_config_yaml = """
+environment_parameters:
+ param_1:
+ curriculum:
+ - name: Lesson1
+ completion_criteria:
+ measure: reward
+ behavior: fake_behavior
+ threshold: 30
+ min_lesson_length: 100
+ require_reset: true
+ value: 1
+ - name: Lesson2
+ completion_criteria:
+ measure: reward
+ behavior: fake_behavior
+ threshold: 30
+ min_lesson_length: 100
+ require_reset: true
+ value: 2
+ - name: Lesson3
+ completion_criteria:
+ measure: reward
+ behavior: fake_behavior
+ threshold: 30
+ min_lesson_length: 100
+ require_reset: true
+ value:
+ sampler_type: uniform
+ sampler_parameters:
+ min_value: 1
+ max_value: 3
+"""
+
+
def test_curriculum_raises_no_completion_criteria_conversion():
with pytest.raises(TrainerConfigError):
RunOptions.from_dict(
@@ -161,6 +196,33 @@ def test_curriculum_raises_no_completion_criteria_conversion():
)
+def test_curriculum_raises_all_completion_criteria_conversion():
+ with pytest.warns(TrainerConfigWarning):
+ run_options = RunOptions.from_dict(
+ yaml.safe_load(test_bad_curriculum_all_competion_criteria_config_yaml)
+ )
+
+ param_manager = EnvironmentParameterManager(
+ run_options.environment_parameters, 1337, False
+ )
+ assert param_manager.update_lessons(
+ trainer_steps={"fake_behavior": 500},
+ trainer_max_steps={"fake_behavior": 1000},
+ trainer_reward_buffer={"fake_behavior": [1000] * 101},
+ ) == (True, True)
+ assert param_manager.update_lessons(
+ trainer_steps={"fake_behavior": 500},
+ trainer_max_steps={"fake_behavior": 1000},
+ trainer_reward_buffer={"fake_behavior": [1000] * 101},
+ ) == (True, True)
+ assert param_manager.update_lessons(
+ trainer_steps={"fake_behavior": 500},
+ trainer_max_steps={"fake_behavior": 1000},
+ trainer_reward_buffer={"fake_behavior": [1000] * 101},
+ ) == (False, False)
+ assert param_manager.get_current_lesson_number() == {"param_1": 2}
+
+
test_everything_config_yaml = """
environment_parameters:
param_1:
diff --git a/ml-agents/mlagents/trainers/tests/test_stats.py b/ml-agents/mlagents/trainers/tests/test_stats.py
index a99c6aede4..0fed8210de 100644
--- a/ml-agents/mlagents/trainers/tests/test_stats.py
+++ b/ml-agents/mlagents/trainers/tests/test_stats.py
@@ -3,13 +3,11 @@
import pytest
import tempfile
import unittest
-import csv
import time
from mlagents.trainers.stats import (
StatsReporter,
TensorboardWriter,
- CSVWriter,
StatsSummary,
GaugeWriter,
ConsoleWriter,
@@ -123,46 +121,6 @@ def test_tensorboard_writer_clear(tmp_path):
assert len(os.listdir(os.path.join(tmp_path, "category1"))) == 1
-def test_csv_writer():
- # Test write_stats
- category = "category1"
- with tempfile.TemporaryDirectory(prefix="unittest-") as base_dir:
- csv_writer = CSVWriter(base_dir, required_fields=["key1", "key2"])
- statssummary1 = StatsSummary(mean=1.0, std=1.0, num=1)
- csv_writer.write_stats("category1", {"key1": statssummary1}, 10)
-
- # Test that the filewriter has been created and the directory has been created.
- filewriter_dir = "{basedir}/{category}.csv".format(
- basedir=base_dir, category=category
- )
- # The required keys weren't in the stats
- assert not os.path.exists(filewriter_dir)
-
- csv_writer.write_stats(
- "category1", {"key1": statssummary1, "key2": statssummary1}, 10
- )
- csv_writer.write_stats(
- "category1", {"key1": statssummary1, "key2": statssummary1}, 20
- )
-
- # The required keys were in the stats
- assert os.path.exists(filewriter_dir)
-
- with open(filewriter_dir) as csv_file:
- csv_reader = csv.reader(csv_file, delimiter=",")
- line_count = 0
- for row in csv_reader:
- if line_count == 0:
- assert "key1" in row
- assert "key2" in row
- assert "Steps" in row
- line_count += 1
- else:
- assert len(row) == 3
- line_count += 1
- assert line_count == 3
-
-
def test_gauge_stat_writer_sanitize():
assert GaugeWriter.sanitize_string("Policy/Learning Rate") == "Policy.LearningRate"
assert (
diff --git a/ml-agents/mlagents/trainers/trainer/rl_trainer.py b/ml-agents/mlagents/trainers/trainer/rl_trainer.py
index 51b4eb919c..5794fcce19 100644
--- a/ml-agents/mlagents/trainers/trainer/rl_trainer.py
+++ b/ml-agents/mlagents/trainers/trainer/rl_trainer.py
@@ -5,7 +5,7 @@
import abc
import time
import attr
-from mlagents.model_serialization import SerializationSettings
+from mlagents.model_serialization import SerializationSettings, copy_model_files
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpoint,
NNCheckpointManager,
@@ -187,12 +187,14 @@ def save_model(self) -> None:
logger.warning("Trainer has no policies, not saving anything.")
return
policy = list(self.policies.values())[0]
- settings = SerializationSettings(policy.model_path, self.brain_name)
model_checkpoint = self._checkpoint()
+
+ # Copy the checkpointed model files to the final output location
+ copy_model_files(model_checkpoint.file_path, f"{policy.model_path}.nn")
+
final_checkpoint = attr.evolve(
model_checkpoint, file_path=f"{policy.model_path}.nn"
)
- policy.save(policy.model_path, settings)
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
@abc.abstractmethod
diff --git a/ml-agents/tests/yamato/yamato_utils.py b/ml-agents/tests/yamato/yamato_utils.py
index 611dac7f17..7784687392 100644
--- a/ml-agents/tests/yamato/yamato_utils.py
+++ b/ml-agents/tests/yamato/yamato_utils.py
@@ -136,8 +136,9 @@ def init_venv(
if extra_packages:
pip_commands += extra_packages
for cmd in pip_commands:
+ pip_index_url = "--index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple"
subprocess.check_call(
- f"source {venv_path}/bin/activate; python -m pip install -q {cmd}",
+ f"source {venv_path}/bin/activate; python -m pip install -q {cmd} {pip_index_url}",
shell=True,
)
return venv_path
diff --git a/utils/make_readme_table.py b/utils/make_readme_table.py
index 757ac6a118..208ae61cad 100644
--- a/utils/make_readme_table.py
+++ b/utils/make_readme_table.py
@@ -70,6 +70,7 @@ def display_name(self) -> str:
ReleaseInfo("release_2", "1.0.2", "0.16.1", "May 20, 2020"),
ReleaseInfo("release_3", "1.1.0", "0.17.0", "June 10, 2020"),
ReleaseInfo("release_4", "1.2.0", "0.18.0", "July 15, 2020"),
+ ReleaseInfo("release_5", "1.2.1", "0.18.1", "July 31, 2020"),
]
MAX_DAYS = 150 # do not print releases older than this many days