Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop remove past action communication #2913

Merged
merged 19 commits into from
Nov 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ public void TestAgentWrite()
reader.Seek(DemonstrationStore.MetaDataBytes + 1, 0);
BrainParametersProto.Parser.ParseDelimitedFrom(reader);

var agentInfoProto = AgentInfoProto.Parser.ParseDelimitedFrom(reader);
var agentInfoProto = AgentInfoActionPairProto.Parser.ParseDelimitedFrom(reader).AgentInfo;
var obs = agentInfoProto.Observations[2]; // skip dummy sensors
{
var vecObs = obs.FloatData.Data;
Assert.AreEqual(bpA.brainParameters.vectorObservationSize, vecObs.Count);
for (var i = 0; i < vecObs.Count; i++)
{
Assert.AreEqual((float) i+1, vecObs[i]);
Assert.AreEqual((float)i + 1, vecObs[i]);
}
}

Expand Down
4 changes: 2 additions & 2 deletions UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void CreateDemonstrationFile(string demonstrationName)
}

m_Writer = m_FileSystem.File.Create(m_FilePath);
m_MetaData = new DemonstrationMetaData {demonstrationName = demonstrationName};
m_MetaData = new DemonstrationMetaData { demonstrationName = demonstrationName };
var metaProto = m_MetaData.ToProto();
metaProto.WriteDelimitedTo(m_Writer);
}
Expand Down Expand Up @@ -102,7 +102,7 @@ public void Record(AgentInfo info)
}

// Write AgentInfo to file.
var agentProto = info.ToProto();
var agentProto = info.ToInfoActionPairProto();
agentProto.WriteDelimitedTo(m_Writer);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,17 @@ static AgentInfoReflection() {
string.Concat(
"CjNtbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
"Zm8ucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjRtbGFnZW50cy9lbnZz",
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvIuoBCg5B",
"Z2VudEluZm9Qcm90bxIdChVzdG9yZWRfdmVjdG9yX2FjdGlvbnMYBCADKAIS",
"DgoGcmV3YXJkGAcgASgCEgwKBGRvbmUYCCABKAgSGAoQbWF4X3N0ZXBfcmVh",
"Y2hlZBgJIAEoCBIKCgJpZBgKIAEoBRITCgthY3Rpb25fbWFzaxgLIAMoCBI8",
"CgxvYnNlcnZhdGlvbnMYDSADKAsyJi5jb21tdW5pY2F0b3Jfb2JqZWN0cy5P",
"YnNlcnZhdGlvblByb3RvSgQIARACSgQIAhADSgQIAxAESgQIBRAGSgQIBhAH",
"SgQIDBANQh+qAhxNTEFnZW50cy5Db21tdW5pY2F0b3JPYmplY3RzYgZwcm90",
"bzM="));
"L2NvbW11bmljYXRvcl9vYmplY3RzL29ic2VydmF0aW9uLnByb3RvItEBCg5B",
"Z2VudEluZm9Qcm90bxIOCgZyZXdhcmQYByABKAISDAoEZG9uZRgIIAEoCBIY",
"ChBtYXhfc3RlcF9yZWFjaGVkGAkgASgIEgoKAmlkGAogASgFEhMKC2FjdGlv",
"bl9tYXNrGAsgAygIEjwKDG9ic2VydmF0aW9ucxgNIAMoCzImLmNvbW11bmlj",
"YXRvcl9vYmplY3RzLk9ic2VydmF0aW9uUHJvdG9KBAgBEAJKBAgCEANKBAgD",
"EARKBAgEEAVKBAgFEAZKBAgGEAdKBAgMEA1CH6oCHE1MQWdlbnRzLkNvbW11",
"bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.ObservationReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "StoredVectorActions", "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoProto), global::MLAgents.CommunicatorObjects.AgentInfoProto.Parser, new[]{ "Reward", "Done", "MaxStepReached", "Id", "ActionMask", "Observations" }, null, null, null)
}));
}
#endregion
Expand Down Expand Up @@ -69,7 +68,6 @@ public AgentInfoProto() {

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public AgentInfoProto(AgentInfoProto other) : this() {
storedVectorActions_ = other.storedVectorActions_.Clone();
reward_ = other.reward_;
done_ = other.done_;
maxStepReached_ = other.maxStepReached_;
Expand All @@ -84,16 +82,6 @@ public AgentInfoProto Clone() {
return new AgentInfoProto(this);
}

/// <summary>Field number for the "stored_vector_actions" field.</summary>
public const int StoredVectorActionsFieldNumber = 4;
private static readonly pb::FieldCodec<float> _repeated_storedVectorActions_codec
= pb::FieldCodec.ForFloat(34);
private readonly pbc::RepeatedField<float> storedVectorActions_ = new pbc::RepeatedField<float>();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public pbc::RepeatedField<float> StoredVectorActions {
get { return storedVectorActions_; }
}

/// <summary>Field number for the "reward" field.</summary>
public const int RewardFieldNumber = 7;
private float reward_;
Expand Down Expand Up @@ -171,7 +159,6 @@ public bool Equals(AgentInfoProto other) {
if (ReferenceEquals(other, this)) {
return true;
}
if(!storedVectorActions_.Equals(other.storedVectorActions_)) return false;
if (!pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.Equals(Reward, other.Reward)) return false;
if (Done != other.Done) return false;
if (MaxStepReached != other.MaxStepReached) return false;
Expand All @@ -184,7 +171,6 @@ public bool Equals(AgentInfoProto other) {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
hash ^= storedVectorActions_.GetHashCode();
if (Reward != 0F) hash ^= pbc::ProtobufEqualityComparers.BitwiseSingleEqualityComparer.GetHashCode(Reward);
if (Done != false) hash ^= Done.GetHashCode();
if (MaxStepReached != false) hash ^= MaxStepReached.GetHashCode();
Expand All @@ -204,7 +190,6 @@ public override string ToString() {

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
storedVectorActions_.WriteTo(output, _repeated_storedVectorActions_codec);
if (Reward != 0F) {
output.WriteRawTag(61);
output.WriteFloat(Reward);
Expand All @@ -231,7 +216,6 @@ public void WriteTo(pb::CodedOutputStream output) {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
size += storedVectorActions_.CalculateSize(_repeated_storedVectorActions_codec);
if (Reward != 0F) {
size += 1 + 4;
}
Expand All @@ -257,7 +241,6 @@ public void MergeFrom(AgentInfoProto other) {
if (other == null) {
return;
}
storedVectorActions_.Add(other.storedVectorActions_);
if (other.Reward != 0F) {
Reward = other.Reward;
}
Expand All @@ -283,11 +266,6 @@ public void MergeFrom(pb::CodedInputStream input) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 34:
case 37: {
storedVectorActions_.AddEntriesFrom(input, _repeated_storedVectorActions_codec);
break;
}
case 61: {
Reward = input.ReadFloat();
break;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
// <auto-generated>
// Generated by the protocol buffer compiler. DO NOT EDIT!
// source: mlagents/envs/communicator_objects/agent_info_action_pair.proto
// </auto-generated>
#pragma warning disable 1591, 0612, 3021
#region Designer generated code

using pb = global::Google.Protobuf;
using pbc = global::Google.Protobuf.Collections;
using pbr = global::Google.Protobuf.Reflection;
using scg = global::System.Collections.Generic;
namespace MLAgents.CommunicatorObjects {

/// <summary>Holder for reflection information generated from mlagents/envs/communicator_objects/agent_info_action_pair.proto</summary>
public static partial class AgentInfoActionPairReflection {

#region Descriptor
/// <summary>File descriptor for mlagents/envs/communicator_objects/agent_info_action_pair.proto</summary>
public static pbr::FileDescriptor Descriptor {
get { return descriptor; }
}
private static pbr::FileDescriptor descriptor;

static AgentInfoActionPairReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"Cj9tbGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2lu",
"Zm9fYWN0aW9uX3BhaXIucHJvdG8SFGNvbW11bmljYXRvcl9vYmplY3RzGjNt",
"bGFnZW50cy9lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2FnZW50X2luZm8u",
"cHJvdG8aNW1sYWdlbnRzL2VudnMvY29tbXVuaWNhdG9yX29iamVjdHMvYWdl",
"bnRfYWN0aW9uLnByb3RvIpEBChhBZ2VudEluZm9BY3Rpb25QYWlyUHJvdG8S",
"OAoKYWdlbnRfaW5mbxgBIAEoCzIkLmNvbW11bmljYXRvcl9vYmplY3RzLkFn",
"ZW50SW5mb1Byb3RvEjsKC2FjdGlvbl9pbmZvGAIgASgLMiYuY29tbXVuaWNh",
"dG9yX29iamVjdHMuQWdlbnRBY3Rpb25Qcm90b0IfqgIcTUxBZ2VudHMuQ29t",
"bXVuaWNhdG9yT2JqZWN0c2IGcHJvdG8z"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { global::MLAgents.CommunicatorObjects.AgentInfoReflection.Descriptor, global::MLAgents.CommunicatorObjects.AgentActionReflection.Descriptor, },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::MLAgents.CommunicatorObjects.AgentInfoActionPairProto), global::MLAgents.CommunicatorObjects.AgentInfoActionPairProto.Parser, new[]{ "AgentInfo", "ActionInfo" }, null, null, null)
}));
}
#endregion

}
#region Messages
public sealed partial class AgentInfoActionPairProto : pb::IMessage<AgentInfoActionPairProto> {
private static readonly pb::MessageParser<AgentInfoActionPairProto> _parser = new pb::MessageParser<AgentInfoActionPairProto>(() => new AgentInfoActionPairProto());
private pb::UnknownFieldSet _unknownFields;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<AgentInfoActionPairProto> Parser { get { return _parser; } }

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::MLAgents.CommunicatorObjects.AgentInfoActionPairReflection.Descriptor.MessageTypes[0]; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public AgentInfoActionPairProto() {
OnConstruction();
}

partial void OnConstruction();

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public AgentInfoActionPairProto(AgentInfoActionPairProto other) : this() {
AgentInfo = other.agentInfo_ != null ? other.AgentInfo.Clone() : null;
ActionInfo = other.actionInfo_ != null ? other.ActionInfo.Clone() : null;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public AgentInfoActionPairProto Clone() {
return new AgentInfoActionPairProto(this);
}

/// <summary>Field number for the "agent_info" field.</summary>
public const int AgentInfoFieldNumber = 1;
private global::MLAgents.CommunicatorObjects.AgentInfoProto agentInfo_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::MLAgents.CommunicatorObjects.AgentInfoProto AgentInfo {
get { return agentInfo_; }
set {
agentInfo_ = value;
}
}

/// <summary>Field number for the "action_info" field.</summary>
public const int ActionInfoFieldNumber = 2;
private global::MLAgents.CommunicatorObjects.AgentActionProto actionInfo_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::MLAgents.CommunicatorObjects.AgentActionProto ActionInfo {
get { return actionInfo_; }
set {
actionInfo_ = value;
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as AgentInfoActionPairProto);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(AgentInfoActionPairProto other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (!object.Equals(AgentInfo, other.AgentInfo)) return false;
if (!object.Equals(ActionInfo, other.ActionInfo)) return false;
return Equals(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (agentInfo_ != null) hash ^= AgentInfo.GetHashCode();
if (actionInfo_ != null) hash ^= ActionInfo.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
return hash;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (agentInfo_ != null) {
output.WriteRawTag(10);
output.WriteMessage(AgentInfo);
}
if (actionInfo_ != null) {
output.WriteRawTag(18);
output.WriteMessage(ActionInfo);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (agentInfo_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(AgentInfo);
}
if (actionInfo_ != null) {
size += 1 + pb::CodedOutputStream.ComputeMessageSize(ActionInfo);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
return size;
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(AgentInfoActionPairProto other) {
if (other == null) {
return;
}
if (other.agentInfo_ != null) {
if (agentInfo_ == null) {
agentInfo_ = new global::MLAgents.CommunicatorObjects.AgentInfoProto();
}
AgentInfo.MergeFrom(other.AgentInfo);
}
if (other.actionInfo_ != null) {
if (actionInfo_ == null) {
actionInfo_ = new global::MLAgents.CommunicatorObjects.AgentActionProto();
}
ActionInfo.MergeFrom(other.ActionInfo);
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
_unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input);
break;
case 10: {
if (agentInfo_ == null) {
agentInfo_ = new global::MLAgents.CommunicatorObjects.AgentInfoProto();
}
input.ReadMessage(agentInfo_);
break;
}
case 18: {
if (actionInfo_ == null) {
actionInfo_ = new global::MLAgents.CommunicatorObjects.AgentActionProto();
}
input.ReadMessage(actionInfo_);
break;
}
}
}
}

}

#endregion

}

#endregion Designer generated code

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading