Skip to content

chandar-lab/RL-Tuner-CP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 

Repository files navigation

RL-Tuner

This is the code for the modified version of the RL-Tuner. The goal of this project is to include marginal probabilities provided by the Mini-CPBP solver into the reward function to enforce constraints.

Training the RNN

RL-Tuner/preprocessing/bach_dataset_loader.py is used to created note sequences from the Bach dataset. The path_root represents the path to the folder containing the midi dataset. This path can be changed to process a different dataset. The quarter_note represents the number of events in a quarter-note. This depends on the bpm of the midi files. The resolution chosen is sent to load_all_midi_files_in_folder. Running this python file will create a pickle file containing the note sequences for the whole dataset.

RL-Tuner/preprocessing/bach_note_sequences.py is used to convert the pickled note sequences in sequence examples used to train the RNN. Three different modes can be specified:

  • melodic lines: removes the silences and any duration information (for example, the sequence [4 4 4 5 5] will become [4 5]). All the notes will have a value between 0 and 28.
  • note sequences: keeps the silences and the durations are expressed with the hold token (for example, the sequence [4 4 4 5 5] will become [4 1 1 5 1]). All the notes will have a value between 2 and 30. 0 represents silences and 1 is the hold token.
  • no hold: keeps the silences and the durations are expressed by repeating the note (for example, the sequence [4 4 4 5 5] will stay the same). All the notes will have a value between 1 and 29. 0 represents silences.

input_file_name and output_file_name are used to specify the path of the pickle file and the tfrecord file.

To train the RNN from the tfrecords, refer to https://github.com/magenta/magenta/tree/main/magenta/models/melody_rnn

Training the RL Agent

A few parameters can be specified to train the DQN agent in the file script.py:

  • Seed: The random seed to make results reproducible
  • Algorithm: 'q' for Q-Learning
  • Reward Scaler: Multiplies the reward from the CP model before adding the RNN reward
  • Reward Mode: Refer to the paper or to the function collect_reward to see available reward functions or create your own
  • Restrict Domain: If True, filters the domain based on the number of violations before picking an action
  • Output Every Nth: Number of iterations before evaluating the agent
  • Num steps: The number of training iterations
  • Num notes in composition: 32, by default
  • Prime with midi: If True, starts the composition with a midi primer
  • Checkpoint dir and checkpoint: Where the pretrained RNN checkpoint is