Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only send rewards if triggered #137

Merged
merged 21 commits into from
Jul 4, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
71582f6
Added dimension attribute to all reward producers.
Jul 1, 2016
2372f0c
Using Reward structure defined in schemas for storing rewards. Work i…
Jul 1, 2016
2e38e08
New class MultidimensionalReward.java. Work in progress.
timhutton Jul 1, 2016
896ef6f
Return of reward message as Reward XML now working. Work in progress.
timhutton Jul 1, 2016
b46a439
Made MissionEnded final reward optional, to match behavior with other…
timhutton Jul 2, 2016
a9c4c15
Fix: add(other) wasn't working. Separated clear().
timhutton Jul 3, 2016
4cefed9
Storing multidimensional rewards in TimestampedRewards, was Timestamp…
Jul 4, 2016
9e11f55
Tweaked TimestampedReward API.
Jul 4, 2016
3f77127
Added conversion functions for working with TimestampedReward instanc…
Jul 4, 2016
e5594ac
Handle parse errors in reward message by adding to world_state.error_…
Jul 4, 2016
e32be88
Merge branch 'master' into rewards
Jul 4, 2016
86ade07
Minor: updated changelog with rewards changes.
Jul 4, 2016
11e6703
Merge branch 'master' into rewards
Jul 4, 2016
99b5a28
Merge branch 'master' into rewards
Jul 4, 2016
f8209cb
Merge branch 'master' into rewards
Jul 4, 2016
4ce21b6
Merge branch 'master' into rewards
Jul 4, 2016
a59b5ed
Fix: need to iterate with Map, not HashMap - error on some platforms.
Jul 4, 2016
cebb40f
tabular_q_learning can now use the number of rewards, instead of thei…
Jul 4, 2016
9f845e3
Fix: restored mutex guard location to as before.
Jul 4, 2016
e273548
Minor: printing any errors received is good practice.
Jul 4, 2016
3f2a483
Merge branch 'master' into rewards
Jul 4, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Malmo/humanAction/MainWindow.xaml.cs
Original file line number Diff line number Diff line change
Expand Up @@ -911,13 +911,13 @@ private void runMission(MissionSpec mission)
{
foreach (var reward in worldState.rewards)
{
_score += reward.value;
if (reward.value < 0)
_score += reward.getValue();
if (reward.getValue() < 0)
{
failure = true;
}

_pendingMessages.Enqueue(string.Format("{0}> score {1}", reward.timestamp.ToString("hh:mm:ss.fff"), reward.value));
_pendingMessages.Enqueue(string.Format("{0}> score {1}", reward.timestamp.ToString("hh:mm:ss.fff"), reward.getValue()));
}

_score = Math.Max(_score, 0);
Expand Down
2 changes: 1 addition & 1 deletion Malmo/samples/CSharp_examples/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public static void Main()
worldState.number_of_video_frames_since_last_state,
worldState.number_of_observations_since_last_state,
worldState.number_of_rewards_since_last_state);
foreach (TimestampedFloat reward in worldState.rewards) Console.Error.WriteLine("Summed reward: {0}", reward.value);
foreach (TimestampedReward reward in worldState.rewards) Console.Error.WriteLine("Summed reward: {0}", reward.getValue());
foreach (TimestampedString error in worldState.errors) Console.Error.WriteLine("Error: {0}", error.text);
}
while (worldState.is_mission_running);
Expand Down
4 changes: 2 additions & 2 deletions Malmo/samples/Cpp_examples/run_mission.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ int main(int argc, const char **argv)
<< world_state.number_of_video_frames_since_last_state << ","
<< world_state.number_of_observations_since_last_state << ","
<< world_state.number_of_rewards_since_last_state << endl;
for( boost::shared_ptr<TimestampedFloat> reward : world_state.rewards )
cout << "Summed reward: " << reward->value << endl;
for( boost::shared_ptr<TimestampedReward> reward : world_state.rewards )
cout << "Summed reward: " << reward->getValue() << endl;
for( boost::shared_ptr<TimestampedString> error : world_state.errors )
cout << "Error: " << error->text << endl;
} while (world_state.is_mission_running);
Expand Down
2 changes: 1 addition & 1 deletion Malmo/samples/Java_examples/JavaExamples_run_mission.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public static void main(String argv[])
System.out.print( world_state.getNumberOfObservationsSinceLastState() + "," );
System.out.println( world_state.getNumberOfRewardsSinceLastState() );
for( int i = 0; i < world_state.getRewards().size(); i++ ) {
TimestampedFloat reward = world_state.getRewards().get(i);
TimestampedReward reward = world_state.getRewards().get(i);
System.out.println( "Summed reward: " + reward.getValue() );
}
for( int i = 0; i < world_state.getErrors().size(); i++ ) {
Expand Down
2 changes: 1 addition & 1 deletion Malmo/samples/Lua_examples/run_mission.lua
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ while world_state.is_mission_running do
world_state = agent_host:getWorldState()
print( "video,observations,rewards received: "..world_state.number_of_video_frames_since_last_state..","..world_state.number_of_observations_since_last_state..","..world_state.number_of_rewards_since_last_state )
for reward in world_state.rewards do
print( "Summed reward: "..reward.value )
print( "Summed reward: "..reward.getValue() )
print( "Timestamp of most recent reward: "..reward:timestamp() ) -- in milliseconds since Jan 1st, 1970.
end
for error in world_state.errors do
Expand Down
4 changes: 2 additions & 2 deletions Malmo/samples/Python_examples/ALE_HAC.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ def sendCommand():
# The ALE only updates in response to a command, so get the new world state now.
world_state = agent_host.getWorldState()
for reward in world_state.rewards:
if reward.value > 0:
print "Summed reward:",reward.value
if reward.getValue() > 0:
print "Summed reward:",reward.getValue()
for error in world_state.errors:
print "Error:",error.text
if world_state.number_of_video_frames_since_last_state > 0 and want_own_display:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ def GetMissionXML(num):
sys.stdout.write('O{0:.0f}'.format(distance))
if world_state.number_of_rewards_since_last_state > 0:
for rew in world_state.rewards:
if rew.value == 0:
if rew.getValue() == 0:
sys.stdout.write("r")
elif rew.value == 100:
elif rew.getValue() == 100:
sys.stdout.write("R")
elif rew.value == -1000:
elif rew.getValue() == -1000:
sys.stdout.write("*")
else:
sys.stdout.write("?")
total_rewards += rew.value
total_rewards += rew.getValue()
if world_state.is_mission_running:
sys.stdout.write("T")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def SetTurn(turn):
world_state = agent_host.getWorldState()
if world_state.number_of_rewards_since_last_state > 0:
# A reward signal has come in - see what it is:
delta = world_state.rewards[0].value
delta = world_state.rewards[0].getValue()
reward+=delta
if delta==10:
agent_host.sendCommand("chat " + random.choice(["Have a fish!", "Free trout!", "Fishy!", "Bleurgh, catch"]))
Expand Down
2 changes: 1 addition & 1 deletion Malmo/samples/Python_examples/reward_for_items_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def SetTurn(turn):
world_state = agent_host.getWorldState()
if world_state.number_of_rewards_since_last_state > 0:
# A reward signal has come in - see what it is:
delta = world_state.rewards[0].value
delta = world_state.rewards[0].getValue()
if delta != 0:
# The total reward has changed - use this to determine our turn.
reward += delta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def SetTurn(turn):
world_state = agent_host.getWorldState()
if world_state.number_of_rewards_since_last_state > 0:
# A reward signal has come in - see what it is:
delta = world_state.rewards[0].value
delta = world_state.rewards[0].getValue()
if delta != 0:
print "New reward: " + str(delta)
reward += delta
Expand Down
2 changes: 1 addition & 1 deletion Malmo/samples/Python_examples/run_mission.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
world_state = agent_host.getWorldState()
print "video,observations,rewards received:",world_state.number_of_video_frames_since_last_state,world_state.number_of_observations_since_last_state,world_state.number_of_rewards_since_last_state
for reward in world_state.rewards:
print "Summed reward:",reward.value
print "Summed reward:",reward.getValue()
for error in world_state.errors:
print "Error:",error.text
for frame in world_state.video_frames:
Expand Down
4 changes: 2 additions & 2 deletions Malmo/samples/Python_examples/sample_missions_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@
world_state = agent_host.getWorldState()
print "video,observations,rewards received:",world_state.number_of_video_frames_since_last_state,world_state.number_of_observations_since_last_state,world_state.number_of_rewards_since_last_state
for reward in world_state.rewards:
print "Summed reward:",reward.value
total_reward += reward.value
print "Summed reward:",reward.getValue()
total_reward += reward.getValue()
for error in world_state.errors:
print "Error:",error.text

Expand Down
12 changes: 8 additions & 4 deletions Malmo/samples/Python_examples/tabular_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def run(self, agent_host):
world_state = agent_host.peekWorldState()
while world_state.is_mission_running and all(e.text=='{}' for e in world_state.observations):
world_state = agent_host.peekWorldState()
world_state = agent_host.getWorldState()
world_state = agent_host.getWorldState()
for err in world_state.errors:
print err

if not world_state.is_mission_running:
return 0 # mission already ended
Expand All @@ -142,14 +144,14 @@ def run(self, agent_host):
# main loop:
while world_state.is_mission_running:

# wait for the position to have changed and a non-zero reward received
# wait for the position to have changed and a reward received
print 'Waiting for data...',
while True:
world_state = agent_host.peekWorldState()
if not world_state.is_mission_running:
print 'mission ended.'
break
if not sum(r.value for r in world_state.rewards) == 0 and not all(e.text=='{}' for e in world_state.observations):
if len(world_state.rewards) > 0 and not all(e.text=='{}' for e in world_state.observations):
obs = json.loads( world_state.observations[-1].text )
curr_x = int(obs[u'XPos'])
curr_z = int(obs[u'ZPos'])
Expand All @@ -158,7 +160,9 @@ def run(self, agent_host):
break

world_state = agent_host.getWorldState()
current_r = sum(r.value for r in world_state.rewards)
for err in world_state.errors:
print err
current_r = sum(r.getValue() for r in world_state.rewards)

if world_state.is_mission_running:
obs = json.loads( world_state.observations[-1].text )
Expand Down
6 changes: 3 additions & 3 deletions Malmo/samples/Python_examples/tutorial_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def run(self, agent_host):
for error in world_state.errors:
self.logger.error("Error: %s" % error.text)
for reward in world_state.rewards:
current_r += reward.value
current_r += reward.getValue()
if world_state.is_mission_running and len(world_state.observations)>0 and not world_state.observations[-1].text=="{}":
total_reward += self.act(world_state, agent_host, current_r)
break
Expand All @@ -163,15 +163,15 @@ def run(self, agent_host):
for error in world_state.errors:
self.logger.error("Error: %s" % error.text)
for reward in world_state.rewards:
current_r += reward.value
current_r += reward.getValue()
# allow time to stabilise after action
while True:
time.sleep(0.1)
world_state = agent_host.getWorldState()
for error in world_state.errors:
self.logger.error("Error: %s" % error.text)
for reward in world_state.rewards:
current_r += reward.value
current_r += reward.getValue()
if world_state.is_mission_running and len(world_state.observations)>0 and not world_state.observations[-1].text=="{}":
total_reward += self.act(world_state, agent_host, current_r)
break
Expand Down
2 changes: 1 addition & 1 deletion Malmo/samples/Torch_examples/run_mission.lua
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ while world_state.is_mission_running do
world_state = agent_host:getWorldState()
print( "video,observations,rewards received: "..world_state.number_of_video_frames_since_last_state..","..world_state.number_of_observations_since_last_state..","..world_state.number_of_rewards_since_last_state )
for reward in world_state.rewards do
print( "Summed reward: "..reward.value )
print( "Summed reward: "..reward.getValue() )
end
for error in world_state.errors do
print( "Error: "..error.text )
Expand Down
59 changes: 23 additions & 36 deletions Malmo/src/AgentHost.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,14 +398,12 @@ namespace malmo
}

if (this->world_state.is_mission_running) {
TimestampedFloat final_reward;
final_reward.timestamp = xml.timestamp;
final_reward.value = static_cast<float>(mission_ended->FinalReward());
this->processReceivedReward(final_reward);

std::stringstream json;
json << "{\"Reward\":" << final_reward.value << "}";
this->rewards_server->recordMessage(TimestampedString(xml.timestamp, json.str()));
schemas::MissionEnded::Reward_optional final_reward_optional = mission_ended->Reward();
if( final_reward_optional.present() ) {
TimestampedReward final_reward(xml.timestamp,final_reward_optional.get());
this->processReceivedReward(final_reward);
this->rewards_server->recordMessage(TimestampedString(xml.timestamp, final_reward.getAsXML(false)));
}
}
}
catch (const xml_schema::exception& e) {
Expand Down Expand Up @@ -486,54 +484,41 @@ namespace malmo
this->world_state.number_of_video_frames_since_last_state++;
}

void AgentHost::onReward(TimestampedString json)
void AgentHost::onReward(TimestampedString message)
{
boost::lock_guard<boost::mutex> scope_guard(this->world_state_mutex);

std::stringstream ss( json.text );
boost::property_tree::ptree pt;

try {
boost::property_tree::read_json( ss, pt);
}
catch( std::exception&e ) {
TimestampedString error_message( json );
error_message.text = std::string("Error parsing reward JSON: ") + e.what() + ":" + json.text.substr(0, 20) + "...";
this->world_state.errors.push_back( boost::make_shared<TimestampedString>( error_message ) );
return;
TimestampedReward reward(message.timestamp, message.text);
this->processReceivedReward(reward);
}

TimestampedFloat reward;
reward.timestamp = json.timestamp;
try {
reward.value = pt.get<float>( "Reward" );
} catch( std::exception& e ) {
TimestampedString error_message( json );
error_message.text = std::string("Error retrieving reward value from JSON: ") + e.what() + ":" + json.text.substr(0, 20) + "...";
this->world_state.errors.push_back( boost::make_shared<TimestampedString>( error_message ) );
return;
catch( const xml_schema::exception& e ) {
std::ostringstream oss;
oss << "Error parsing Reward message XML: " << e.what() << " : " << e << ":" << message.text.substr(0, 20) << "...";
TimestampedString error_message(message);
error_message.text = oss.str();
this->world_state.errors.push_back(boost::make_shared<TimestampedString>(error_message));
}

this->processReceivedReward( reward );
}

void AgentHost::processReceivedReward( TimestampedFloat reward )
void AgentHost::processReceivedReward( TimestampedReward reward )
{
switch( this->rewards_policy )
{
case RewardsPolicy::LATEST_REWARD_ONLY:
this->world_state.rewards.clear();
this->world_state.rewards.push_back( boost::make_shared<TimestampedFloat>( reward ) );
this->world_state.rewards.push_back( boost::make_shared<TimestampedReward>( reward ) );
break;
case RewardsPolicy::SUM_REWARDS:
if( !this->world_state.rewards.empty() ) {
reward.value += this->world_state.rewards.front()->value;
reward.add(*this->world_state.rewards.front());
this->world_state.rewards.clear();
}
this->world_state.rewards.push_back( boost::make_shared<TimestampedFloat>( reward ) );
this->world_state.rewards.push_back( boost::make_shared<TimestampedReward>( reward ) );
// (timestamp is that of latest reward, even if zero)
break;
case RewardsPolicy::KEEP_ALL_REWARDS:
this->world_state.rewards.push_back( boost::make_shared<TimestampedFloat>( reward ) );
this->world_state.rewards.push_back( boost::make_shared<TimestampedReward>( reward ) );
break;
}

Expand All @@ -560,6 +545,8 @@ namespace malmo

void AgentHost::sendCommand(std::string command)
{
boost::lock_guard<boost::mutex> scope_guard(this->world_state_mutex);

if( !this->commands_connection ) {
TimestampedString error_message(
boost::posix_time::microsec_clock::universal_time(),
Expand Down
6 changes: 3 additions & 3 deletions Malmo/src/AgentHost.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,16 @@ namespace malmo
void listenForRewards( int port );
void listenForObservations( int port );

void onMissionControlMessage(TimestampedString xml);
void onMissionControlMessage(TimestampedString message);
void onVideo(TimestampedVideoFrame message);
void onReward(TimestampedString json);
void onReward(TimestampedString message);
void onObservation(TimestampedString message);

void openCommandsConnection();

void close();

void processReceivedReward( TimestampedFloat reward );
void processReceivedReward( TimestampedReward reward );

boost::asio::io_service io_service;
boost::shared_ptr<StringServer> mission_control_server;
Expand Down
15 changes: 8 additions & 7 deletions Malmo/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ set( SOURCES
TCPClient.cpp
TCPConnection.cpp
TCPServer.cpp
TimestampedFloat.cpp
TimestampedReward.cpp
TimestampedString.cpp
TimestampedVideoFrame.cpp
VideoFrameWriter.cpp
Expand Down Expand Up @@ -64,7 +64,7 @@ set( HEADERS
TCPConnection.h
TCPServer.h
TimestampedUnsignedCharVector.h
TimestampedFloat.h
TimestampedReward.h
TimestampedString.h
TimestampedVideoFrame.h
VideoFrameWriter.h
Expand Down Expand Up @@ -103,12 +103,13 @@ add_definitions( -DMALMO_VERSION=${MALMO_VERSION} )

add_custom_command( # Run CodeSynthesis on the .xsd files to generate C++ sources we can use
OUTPUT Mission.h Mission.cpp MissionHandlers.h MissionHandlers.cpp MissionEnded.h MissionEnded.cpp MissionInit.h MissionInit.cpp Types.h Types.cpp
COMMAND ${XSD_EXECUTABLE} cxx-tree --generate-polymorphic --namespace-map http://ProjectMalmo.microsoft.com=malmo::schemas --root-element Mission --root-element MissionInit --root-element MissionEnded
COMMAND ${XSD_EXECUTABLE} cxx-tree --generate-polymorphic --namespace-map http://ProjectMalmo.microsoft.com=malmo::schemas
--root-element Mission --root-element MissionInit --root-element MissionEnded --root-element Reward
--generate-serialization --hxx-suffix .h --cxx-suffix .cpp #--std c++11
${CMAKE_SOURCE_DIR}/Schemas/Mission.xsd ${CMAKE_SOURCE_DIR}/Schemas/MissionEnded.xsd ${CMAKE_SOURCE_DIR}/Schemas/MissionHandlers.xsd ${CMAKE_SOURCE_DIR}/Schemas/MissionInit.xsd
${CMAKE_SOURCE_DIR}/Schemas/Types.xsd
DEPENDS ${CMAKE_SOURCE_DIR}/Schemas/Mission.xsd ${CMAKE_SOURCE_DIR}/Schemas/MissionEnded.xsd ${CMAKE_SOURCE_DIR}/Schemas/MissionHandlers.xsd ${CMAKE_SOURCE_DIR}/Schemas/MissionInit.xsd
${CMAKE_SOURCE_DIR}/Schemas/Types.xsd
${CMAKE_SOURCE_DIR}/Schemas/Mission.xsd ${CMAKE_SOURCE_DIR}/Schemas/MissionEnded.xsd ${CMAKE_SOURCE_DIR}/Schemas/MissionHandlers.xsd
${CMAKE_SOURCE_DIR}/Schemas/MissionInit.xsd ${CMAKE_SOURCE_DIR}/Schemas/Types.xsd
DEPENDS ${CMAKE_SOURCE_DIR}/Schemas/Mission.xsd ${CMAKE_SOURCE_DIR}/Schemas/MissionEnded.xsd ${CMAKE_SOURCE_DIR}/Schemas/MissionHandlers.xsd
${CMAKE_SOURCE_DIR}/Schemas/MissionInit.xsd ${CMAKE_SOURCE_DIR}/Schemas/Types.xsd
COMMENT "Generating C++ from XSD files..."
)

Expand Down
Loading