{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training the Transformer-XL to generate music\n",
"\n",
"The [Transformer](https://arxiv.org/abs/1706.03762), introduced in 2017, has become the top choice to tackle problems in natural language processing (NLP). It has been used in both language generation, and [General Language Understanding Evaluation (GLUE)](https://gluebenchmark.com/) tasks. The Transformer model has even shown [promising results in computer vision related tasks](https://arxiv.org/pdf/2010.11929.pdf).\n",
"\n",
"The Transformer has also helped push the state-of-the-art in [symbolic music generation](https://musicinformationretrieval.com/symbolic_representations.html), enabling minutes-long music performance generation with compelling musical structure. In this Jupyter notebook, you will look at how symbolic music can be generated using a recent variant of the Transformer called the Transformer-XL. \n",
"\n",
"The code provided in this notebook will allow you to access the dataset used for training, and examine how [MIDI](https://en.wikipedia.org/wiki/MIDI) files are transformed into the note sequences used for training. Next, you can examine the Transformer-XL architecture and the data loader. Then, you can examine the end-to-end training process used to train a new model. Finally, you can use the trained model to perform inference and extend an input track. You can also examine how sampling was performed during inference to extend the input track.\n",
"\n",
"## Prerequisites\n",
"This notebook assumes an understanding of Convolutional Neural Networks (CNNs), Recurrent Neural Networks (RNNs) and a high level understanding of the original Transformer. The following links are useful resources to meet these prerequisites.\n",
"1. [Learn the basics of generative AI](https://d32g4xocucupjo.cloudfront.net/#welcome)\n",
"2. [Convolutional Neural Networks](https://d2l.ai/chapter_convolutional-neural-networks/index.html)\n",
"3. [Recurrent Neural Networks](https://d2l.ai/chapter_recurrent-neural-networks/index.html)\n",
"4. [Introduction to Transformers](http://jalammar.github.io/illustrated-transformer/)\n",
"5. [Attention mechanism in Transformers](https://d2l.ai/chapter_attention-mechanisms/attention.html)\n",
"\n",
"\n",
"## Using generative AI to create music \n",
"\n",
"There are two primary ways to represent music in a format that can be used for training a machine learning model. *Image-based* methods use a [pianoroll](https://salu133445.github.io/lakh-pianoroll-dataset/representation.html) image while *text-based* methods use a sequence of tokens to represent the musical events occurring in the piece of music.\n",
"\n",
"1. The pianoroll is a 2D matrix that shows what pitches are played at each timestep. [MuseGAN](https://arxiv.org/abs/1709.06298) and [Coconet](https://arxiv.org/pdf/1903.07227.pdf) are examples of models that use piano roll images during training to later create music.\n",
"\n",
"\n",
"2. There are several ways to convert MIDI files into a sequence of tokens. MIDI carries event messages, data that specify the instructions for music including a note's notation, pitch, velocity (which is heard typically as loudness or softness of volume). The tokens will accordingly represent these events found in the music/MIDI file such as the onset of notes, the duration or offset of notes, and the passage of time. [Music Transformer](https://arxiv.org/abs/1809.04281) and [Musenet](https://openai.com/blog/musenet/) are examples of models that treat music as a sequence. You can examine a specific sequential format, called the note sequence, later on in this notebook.\n",
"\n",
"In this notebook, music generation is treated as a sequence generation problem. You will train a strong sequential model - the Transformer-XL - to learn the distribution of these sequences."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## An introduction to Transformers and the Transformer-XL\n",
"\n",
"In this section, you can examine the structure of the Transformer architecture, and see how the Transformer-XL improves on the original design.\n",
"\n",
"### The original Transformer\n",
"Briefly, here are the advantages and disadvantages of using either CNNs or RNNs to solve sequence generation based problems.\n",
"- CNNs are easy to parallelize, but can only capture fixed length sequential dependencies.\n",
"- RNNs can learn long-range, variable length sequential dependencies but cannot be parallelized within a sequence.\n",
"\n",
"To combine the advantages of CNNs and RNNs, a novel architecture was proposed in the paper [Attention is all you need](https://arxiv.org/abs/1706.03762), which solely uses the [attention mechanism](https://d2l.ai/chapter_attention-mechanisms/attention.html) and therefore helps the model focus on relevant parts of the input or output sentence. This architecture, called the Transformer, first encodes each word's position in the input sequence and then acts on it with the attention mechanism. The attention mechanism enables the Transformer to be trained in parallel, achieving significantly shorter training time (often by a factor of 10 or more) when compared to an RNN.\n",
"\n",
"
\n",
"\n",
"The Transformer consists of an encoder, the decoder, and the encoder-decoder attention connections between them.\n",
"The encoder (on the left) is a stack of encoder layers and the decoder (on the right) is a stack of decoder layers of the same number.\n",
"\n",
"The encoder layers are all identical in structure. Its inputs first go through a multi-head attention layer, and then through a feed-forward neural network. The multi-head attention layer helps the encoder look at other words in the input sentence as it encodes a specific word. \n",
"The outputs of the multi-head attention layer are then fed into a feed-forward neural network. The exact same feed-forward network is independently applied to each position in the sequence. More details on the exact mechanisms used in the multi-head attention layer and self-attention can be found in the sections ahead.\n",
"\n",
"The structure of the decoder is similar to the encoder, but has an additional encoder-decoder attention layer that helps the decoder focus on relevant parts of the input sequence. The decoder uses a [causal mask](https://medium.com/analytics-vidhya/masking-in-transformers-self-attention-mechanism-bad3c9ec235c) in the attention heads to force predictions to only attend to the tokens at previous positions, so that the model can be used autoregressively at inference time.\n",
"\n",
"The decoder of the Transformer alone without the encoder is widely used for the __Language Modeling (LM)__ task: to predict the next token given a sequence of tokens. It still remains a challenge however to properly equip neural networks with *long-term dependency*.\n",
" \n",
"#### What is long-term dependency and why is it needed for music generation?\n",
"Generating long pieces of music is challenging, as music is structured to use multiple timescales, from millisecond notes and rests to repeating motifs, phrases and sections. Therefore, to generate music that maintains long-term coherence the model will need to refer to a musical note that occurs a few thousand notes or tokens back. This long-range dependence— a common phenomena in sequential data — must be understood in order to generate realistic music. While humans can do this naturally, [modeling long-term dependencies with neural networks is a challenge](https://www.aclweb.org/anthology/D18-1458.pdf). Transformers provide one mechanism, attention, to help with this challenge. Attention allows for direct connections between hidden states which can be used to capture long-term dependency.\n",
"\n",
"#### Limitations of the original Transformer\n",
"However, the original Transformer decoder uses a fixed-length context, meaning that a long sequence is truncated into fixed-length segments comprised of a few hundred tokens, and then each segment is processed separately. The original Transformer using a fixed-length context during training is depicted below.\n",
"
\n",
"Using a fixed-length context introduces two __critical limitations__: \n",
"\n",
"1. The Transformer is not able to model dependencies that are longer than a fixed length. \n",
"2. In music generation, a single musical note is frequently composed of multiple tokens. Fixed-length segments often do not respect musical note boundaries, resulting in context fragmentation which in turn causes inefficient optimization. This is problematic even for short sequences, where long range dependency isn't an issue.\n",
"\n",
"### The Transformer-XL \n",
"\n",
"The [Transformer-XL](https://arxiv.org/abs/1901.02860) is a novel architecture based on the Transformer decoder. It improves on the original Transformer decoder and enables language modeling beyond a fixed-length context. It accomplishes this with two techniques: a segment-level recurrence mechanism and a relative positional encoding scheme.\n",
"\n",
"*__Segment-level Recurrence__* During Transformer-XL training, the representations computed for the previous segment are fixed and cached to be reused as an extended context when the model processes the next new segment. This additional connection increases the largest possible dependency length by N times, where N is the depth of the network, because contextual information is now able to flow across segment boundaries. This recurrence mechanism also resolves the context fragmentation issue, providing necessary context for tokens in the front of a new segment.\n",
" \n",
"
\n",
" \n",
"*__Relative Positional Encodings__* Since the Transformer-XL caches representations between segments, a new way to represent the contexual positions of representations from different segments is needed. For example, if an old segment uses contextual positions [0, 1, 2, 3], then when a new segment is processed,the positions are [0, 1, 2, 3, 0, 1, 2, 3] for the two segments combined. The semantics of each position id is incoherent throughout the sequence and the Transformer-XL therefore proposes a novel relative positional encoding scheme to make the recurrence mechanism possible. \n",
" \n",
"When both of these approaches are combined, Transformer-XL has a much longer effective context than the original Transformer model at evaluation time."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Installing dependencies\n",
"First, let's install and import all of the Python packages that you will use in this tutorial."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [],
"source": [
"# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n",
"# SPDX-License-Identifier: Apache-2.0\n",
"\n",
"# Create the environment and install required packages\n",
"!pip install -r requirements.txt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"pycharm": {
"is_executing": false,
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n",
" warn(\"Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\", RuntimeWarning)\n"
]
}
],
"source": [
"# Imports\n",
"import glob\n",
"import json\n",
"import math\n",
"import multiprocessing\n",
"import os\n",
"import pickle\n",
"import random\n",
"import time\n",
"from pprint import pprint\n",
"from typing import *\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from IPython import display\n",
"from autocfg import dataclass, field\n",
"from utils.performance_event_repo import BaseVocab\n",
"from utils.midi_utils import play_midi, print_sample_array\n",
"from utils.music_encoder import MusicEncoder\n",
"from utils.utils import plot_losses, save_checkpoint\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Importing the data \n",
"\n",
"The input dataset is in the [MIDI](https://en.wikipedia.org/wiki/MIDI) format. In this tutorial, you will use the [`JSB-Chorales-dataset`](http://www-etud.iro.umontreal.ca/~boulanni/icml2012). The link contains pickled files that you will convert to MIDI in the cells below. \n",
"\n",
"A chorale is a type of musical structure or form. Chorales usually consist of one voice singing a simple melody and three lower voices providing harmony. In this dataset, the voices are represented by four individual piano tracks.\n",
"\n",
"You will now download and use this dataset locally.\n",
"\n",
"### Downloading the `JSB-Chorales-dataset`\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2021-01-21 17:50:43-- http://www-etud.iro.umontreal.ca/~boulanni/JSB%20Chorales.zip\n",
"Resolving www-etud.iro.umontreal.ca (www-etud.iro.umontreal.ca)... 132.204.26.158\n",
"Connecting to www-etud.iro.umontreal.ca (www-etud.iro.umontreal.ca)|132.204.26.158|:80... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 215242 (210K) [application/zip]\n",
"Saving to: ‘data/JSB Chorales.zip.1’\n",
"\n",
"JSB Chorales.zip.1 100%[===================>] 210.20K --.-KB/s in 0.09s \n",
"\n",
"2021-01-21 17:50:44 (2.31 MB/s) - ‘data/JSB Chorales.zip.1’ saved [215242/215242]\n",
"\n"
]
}
],
"source": [
"# Download a .zip file containing the .mid files from the dataset.\n",
"!wget http://www-etud.iro.umontreal.ca/~boulanni/JSB%20Chorales.zip -P data/\n",
"\n",
"# Unzip the contents of that directory\n",
"!unzip -q \"data/JSB Chorales.zip\" -d data/\n",
"\n",
"# Rename downloaded file\n",
"!mv \"data/JSB Chorales\" data/jsb_chorales/\n",
"\n",
"# Change the string in the the `data_dir` variable to the correct file path\n",
"data_dir = \"data/jsb_chorales/**/*.mid\"\n",
"\n",
"# Load midi files\n",
"midi_files = glob.glob(data_dir)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"