Skip to content

Latest commit

 

History

History
198 lines (78 loc) · 3.55 KB

README.md

File metadata and controls

198 lines (78 loc) · 3.55 KB

Transformer XL

Table of contents

  1. Description

  2. Architecture

  3. Dataset

  4. Dependencies

  5. Usage for music

  6. Usage for text

Description

The goal of this project is to generate long and coherent sequences of data using Transformer architectures based on the following papers:

The neural networks are tested on two separate tasks : music generation and text generation. All the models are implemented from scratch in Tensorflow 2.

Architecture

Music Model Text Model
image image

The structure of the GTrXL (Gated Transformer XL) block is illustrated in detail below:

image

The architecture used for text generation is the one proposed in the paper Stabilizing Transformers for Reinforcement Learning. Music generation requires a modified model where the input features are split into MIDI events (note_on, note_off and control_change) and MIDI deltas (time periods between consecutive MIDI events).

Dataset

For the task of music generation the union of the following datasets is used:

  1. The MAESTRO Dataset
  2. SMD MIDI-Audio Piano Music
  3. Stanford University Piano Roll Archive
  4. Classical Music ML Format

All of the above contain classical piano music in MIDI format. The MIDI files are preprocessed with the mido library.


As for the text generation, the CLAIR collection of "Nigerian" fraud emails is used.


Generated data for both datasets can be found here.

Dependencies

  • NumPy

  • Tensorflow

  • argparse

  • pathlib

  • tqdm

  • pickle

  • re

  • joblib

  • mido

  • glob

  • bs4

  • dload

Usage for music

Data Preprocessing


python preprocess_music.py -d

Training


python train_music.py

Music generation


python generate_music.py <n_songs> <checkpoint path>

Usage for text

Data Preprocessing


python preprocess_text.py <corpus path>

Training


python train_text.py

Text generation


python generate_text.py <n_samples> <checkpoint path>