Skip to content

Commit

Permalink
Add support for audio queries (#579)
Browse files Browse the repository at this point in the history
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
  • Loading branch information
gkennickell and pseudo-rnd-thoughts authored Nov 28, 2024
1 parent e2ff2a4 commit 2d8ae89
Show file tree
Hide file tree
Showing 19 changed files with 594 additions and 13 deletions.
12 changes: 12 additions & 0 deletions docs/cpp-interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ ale::Action a = legal_actions[rand() % legal_actions.size()];
float reward = ale.act(a);
```

An optional sound observation is provided. To enable, set the associated environment parameter:

```cpp
ale.setBool("sound_obs", True);
```

Once enabled, the sound observation may be obtained by calling:

```cpp
ale.getAudio()
```

Finally, one can check whether the episode has terminated using the function `ale.game_over()`. With these functions one can already implement a very simple agent that plays randomly for one episode:

```cpp
Expand Down
5 changes: 5 additions & 0 deletions src/ale/ale_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,11 @@ void ALEInterface::getScreenRGB(std::vector<unsigned char>& output_rgb_buffer) c
ale_screen_data, screen_size);
}

// Returns the current audio data
const std::vector<uint8_t>& ALEInterface::getAudio() const {
return environment->getAudio();
}

// Returns the current RAM content
const ALERAM& ALEInterface::getRAM() const { return environment->getRAM(); }

Expand Down
3 changes: 3 additions & 0 deletions src/ale/ale_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ class ALEInterface {
//followed by the green colours and then the blue colours
void getScreenRGB(std::vector<unsigned char>& output_rgb_buffer) const;

// Returns the current audio data
const std::vector<uint8_t> &getAudio() const;

// Returns the current RAM content
const ALERAM& getRAM() const;

Expand Down
1 change: 1 addition & 0 deletions src/ale/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ target_sources(ale
ScreenExporter.cpp
SoundExporter.cpp
SoundNull.cxx
SoundRaw.cxx
SoundSDL.cxx
SDL2.cpp
DynamicLoad.cpp
Expand Down
5 changes: 5 additions & 0 deletions src/ale/common/SoundNull.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class SoundNull : public stella::Sound
*/
virtual void recordNextFrame() { }

/**
* Processes audio for raw sample generation (applies all reg updates, fills buffer)
*/
virtual void process(uint8_t* buffer, uint32_t samples) { }

public:
/**
Loads the current state of this device from the given Deserializer.
Expand Down
228 changes: 228 additions & 0 deletions src/ale/common/SoundRaw.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/* *****************************************************************************
* A.L.E (Arcade Learning Environment)
* Copyright (c) 2009-2013 by Yavar Naddaf, Joel Veness, Marc G. Bellemare,
* Matthew Hausknecht and the Reinforcement Learning and Artificial Intelligence
* Laboratory
* Released under the GNU General Public License; see License.txt for details.
*
* Based on: Stella -- "An Atari 2600 VCS Emulator"
* Copyright (c) 1995-2007 by Bradford W. Mott and the Stella team
*
* *****************************************************************************
* SoundRaw.cxx
*
* A class for generating raw Atari 2600 sound samples.
*
**************************************************************************** */

#include "ale/emucore/Serializer.hxx"
#include "ale/emucore/Deserializer.hxx"

#include "ale/emucore/Settings.hxx"
#include "ale/common/SoundRaw.hxx"

#include "ale/common/Log.hpp"

namespace ale {
using namespace stella; // Settings, Serializer, Deserializer

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
SoundRaw::SoundRaw(Settings* settings)
: Sound(settings),
myIsEnabled(settings->getBool("sound_obs")),
myIsInitializedFlag(false),
myLastRegisterSetCycle(0)
{
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
SoundRaw::~SoundRaw()
{
close();
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
void SoundRaw::setEnabled(bool state)
{
myIsEnabled = state;
mySettings->setBool("sound_obs", state);
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
void SoundRaw::initialize()
{
// Check whether to start the sound subsystem
if(!myIsEnabled)
{
close();
return;
}

// Make sure the sound queue is clear
myRegWriteQueue.clear();
myTIASound.reset();

myLastRegisterSetCycle = 0;
myIsInitializedFlag = true;

// Now initialize the TIASound object which will actually generate sound
int frequency = mySettings->getInt("freq");
myTIASound.outputFrequency(frequency);

int tiafreq = mySettings->getInt("tiafreq");
myTIASound.tiaFrequency(tiafreq);

// currently only support mono
myTIASound.channels(1);
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
void SoundRaw::close()
{
myIsInitializedFlag = false;
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
bool SoundRaw::isSuccessfullyInitialized() const
{
return myIsInitializedFlag;
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
void SoundRaw::reset()
{
if(myIsInitializedFlag)
{
myLastRegisterSetCycle = 0;
myTIASound.reset();
myRegWriteQueue.clear();
}
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
void SoundRaw::adjustCycleCounter(int amount)
{
myLastRegisterSetCycle += amount;
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
void SoundRaw::set(uint16_t addr, uint8_t value, int cycle)
{
TIARegister info;
info.addr = addr;
info.value = value;
myRegWriteQueue.push_back(info);

// Update last cycle counter to the current cycle
myLastRegisterSetCycle = cycle;
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
void SoundRaw::process(uint8_t* buffer, uint32_t samples)
{
// Process all the audio register updates up to this frame
// Set audio registers
uint32_t regSize = myRegWriteQueue.size();
for(uint32_t i = 0; i < regSize; ++i) {
TIARegister& info = myRegWriteQueue.front();
myTIASound.set(info.addr, info.value);
myRegWriteQueue.pop_front();
}

// Process audio registers
myTIASound.process(buffer, samples);
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
bool SoundRaw::load(Deserializer& in)
{
std::string device = "TIASound";

try
{
if(in.getString() != device)
return false;

uint8_t reg1 = 0, reg2 = 0, reg3 = 0, reg4 = 0, reg5 = 0, reg6 = 0;
reg1 = (uint8_t) in.getInt();
reg2 = (uint8_t) in.getInt();
reg3 = (uint8_t) in.getInt();
reg4 = (uint8_t) in.getInt();
reg5 = (uint8_t) in.getInt();
reg6 = (uint8_t) in.getInt();

myLastRegisterSetCycle = (int) in.getInt();

// Only update the TIA sound registers if sound is enabled
// Make sure to empty the queue of previous sound fragments
if(myIsInitializedFlag)
{
myRegWriteQueue.clear();
myTIASound.set(0x15, reg1);
myTIASound.set(0x16, reg2);
myTIASound.set(0x17, reg3);
myTIASound.set(0x18, reg4);
myTIASound.set(0x19, reg5);
myTIASound.set(0x1a, reg6);
}
}
catch(char *msg)
{
ale::Logger::Error << msg << std::endl;
return false;
}
catch(...)
{
ale::Logger::Error << "Unknown error in load state for " << device << std::endl;
return false;
}

return true;
}

// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
bool SoundRaw::save(Serializer& out)
{
std::string device = "TIASound";

try
{
out.putString(device);

uint8_t reg1 = 0, reg2 = 0, reg3 = 0, reg4 = 0, reg5 = 0, reg6 = 0;

// Only get the TIA sound registers if sound is enabled
if(myIsInitializedFlag)
{
reg1 = myTIASound.get(0x15);
reg2 = myTIASound.get(0x16);
reg3 = myTIASound.get(0x17);
reg4 = myTIASound.get(0x18);
reg5 = myTIASound.get(0x19);
reg6 = myTIASound.get(0x1a);
}

out.putInt(reg1);
out.putInt(reg2);
out.putInt(reg3);
out.putInt(reg4);
out.putInt(reg5);
out.putInt(reg6);

out.putInt(myLastRegisterSetCycle);
}
catch(char *msg)
{
ale::Logger::Error << msg << std::endl;
return false;
}
catch(...)
{
ale::Logger::Error << "Unknown error in save state for " << device << std::endl;
return false;
}

return true;
}

} // namespace ale
Loading

0 comments on commit 2d8ae89

Please sign in to comment.