# Training the Transformer-XL to generate music

The [Transformer](https://arxiv.org/abs/1706.03762), introduced in 2017, has become the top choice to tackle problems in natural language processing (NLP). It has been used in both language generation, and [General Language Understanding Evaluation (GLUE)](https://gluebenchmark.com/) tasks. The Transformer model has even shown [promising results in computer vision related tasks](https://arxiv.org/pdf/2010.11929.pdf).

The Transformer has also helped push the state-of-the-art in [symbolic music generation](https://musicinformationretrieval.com/symbolic_representations.html), enabling minutes-long music performance generation with compelling musical structure. In this Jupyter notebook, you will look at how symbolic music can be generated using a recent variant of the Transformer called the Transformer-XL. 

The code provided in this notebook will allow you to access the dataset used for training, and examine how [MIDI](https://en.wikipedia.org/wiki/MIDI) files are transformed into the note sequences used for training. Next, you can examine the Transformer-XL architecture and the data loader. Then, you can examine the end-to-end training process used to train a new model. Finally, you can use the trained model to perform inference and extend an input track. You can also examine how sampling was performed during inference to extend the input track.

## Prerequisites
This notebook assumes an understanding of Convolutional Neural Networks (CNNs), Recurrent Neural Networks (RNNs) and a high level understanding of the original Transformer. The following links are useful resources to meet these prerequisites.
1. [Learn the basics of generative AI](https://d32g4xocucupjo.cloudfront.net/#welcome)
2. [Convolutional Neural Networks](https://d2l.ai/chapter_convolutional-neural-networks/index.html)
3. [Recurrent Neural Networks](https://d2l.ai/chapter_recurrent-neural-networks/index.html)
4. [Introduction to Transformers](http://jalammar.github.io/illustrated-transformer/)
5. [Attention mechanism in Transformers](https://d2l.ai/chapter_attention-mechanisms/attention.html)


## Using generative AI to create music 

There are two primary ways to represent music in a format that can be used for training a machine learning model. *Image-based* methods use a [pianoroll](https://salu133445.github.io/lakh-pianoroll-dataset/representation.html) image while *text-based* methods use a sequence of tokens to represent the musical events occurring in the piece of music.

1. The pianoroll is a 2D matrix that shows what pitches are played at each timestep. [MuseGAN](https://arxiv.org/abs/1709.06298) and [Coconet](https://arxiv.org/pdf/1903.07227.pdf) are examples of models that use piano roll images during training to later create music.
<img src="images/pianoroll.png" width="200">

2. There are several ways to convert MIDI files into a sequence of tokens. MIDI carries event messages, data that specify the instructions for music including a note's notation, pitch, velocity (which is heard typically as loudness or softness of volume). The tokens will accordingly represent these events found in the music/MIDI file such as the onset of notes, the duration or offset of notes, and the passage of time. [Music Transformer](https://arxiv.org/abs/1809.04281) and [Musenet](https://openai.com/blog/musenet/) are examples of models that treat music as a sequence. You can examine a specific sequential format, called the note sequence, later on in this notebook.

In this notebook, music generation is treated as a sequence generation problem. You will train a strong sequential model - the Transformer-XL - to learn the distribution of these sequences.

## An introduction to Transformers and the Transformer-XL

In this section, you can examine the structure of the Transformer architecture, and see how the Transformer-XL improves on the original design.

### The original Transformer
Briefly, here are the advantages and disadvantages of using either CNNs or RNNs to solve sequence generation based problems.
- CNNs are easy to parallelize, but can only capture fixed length sequential dependencies.
- RNNs can learn long-range, variable length sequential dependencies but cannot be parallelized within a sequence.

To combine the advantages of CNNs and RNNs, a novel architecture was proposed in the paper [Attention is all you need](https://arxiv.org/abs/1706.03762), which solely uses the [attention mechanism](https://d2l.ai/chapter_attention-mechanisms/attention.html) and therefore helps the model focus on relevant parts of the input or output sentence. This architecture, called the Transformer, first encodes each word's position in the input sequence and then acts on it with the attention mechanism. The attention mechanism enables the Transformer to be trained in parallel, achieving significantly shorter training time (often by a factor of 10 or more) when compared to an RNN.

<img src="images/transformer.png" width=400 align="center" />

The Transformer consists of an encoder, the decoder, and the encoder-decoder attention connections between them.
The encoder (on the left) is a stack of encoder layers and the decoder (on the right) is a stack of decoder layers of the same number.

The encoder layers are all identical in structure. Its inputs first go through a multi-head attention layer, and then through a feed-forward neural network. The multi-head attention layer helps the encoder look at other words in the input sentence as it encodes a specific word. 
The outputs of the multi-head attention layer are then fed into a feed-forward neural network. The exact same feed-forward network is independently applied to each position in the sequence. More details on the exact mechanisms used in the multi-head attention layer and self-attention can be found in the sections ahead.

The structure of the decoder is similar to the encoder, but has an additional encoder-decoder attention layer that helps the decoder focus on relevant parts of the input sequence. The decoder uses a [causal mask](https://medium.com/analytics-vidhya/masking-in-transformers-self-attention-mechanism-bad3c9ec235c) in the attention heads to force predictions to only attend to the tokens at previous positions, so that the model can be used autoregressively at inference time.

The decoder of the Transformer alone without the encoder is widely used for the __Language Modeling (LM)__ task: to predict the next token given a sequence of tokens. It still remains a challenge however to properly equip neural networks with *long-term dependency*.
 
####  What is long-term dependency and why is it needed for music generation?
Generating long pieces of music is challenging, as music is structured to use multiple timescales, from millisecond notes and rests to repeating motifs, phrases and sections. Therefore, to generate music that maintains long-term coherence the model will need to refer to a musical note that occurs a few thousand notes or tokens back. This long-range dependence— a common phenomena in sequential data — must be understood in order to generate realistic music. While humans can do this naturally, [modeling long-term dependencies with neural networks is a challenge](https://www.aclweb.org/anthology/D18-1458.pdf). Transformers provide one mechanism, attention, to help with this challenge. Attention allows for direct connections between hidden states which can be used to capture long-term dependency.

#### Limitations of the original Transformer
However, the original Transformer decoder uses a fixed-length context, meaning that a long sequence is truncated into fixed-length segments comprised of a few hundred tokens, and then each segment is processed separately. The original Transformer using a fixed-length context during training is depicted below.
 <img src="images/vanilla.gif" alt=" Vanilla Transformer" width="800" >
Using a fixed-length context introduces two __critical limitations__: 

1. The Transformer is not able to model dependencies that are longer than a fixed length. 
2. In music generation, a single musical note is frequently composed of multiple tokens. Fixed-length segments often do not respect musical note boundaries, resulting in context fragmentation which in turn causes inefficient optimization. This is problematic even for short sequences, where long range dependency isn't an issue.

### The Transformer-XL 

The [Transformer-XL](https://arxiv.org/abs/1901.02860) is a novel architecture based on the Transformer decoder. It improves on the original Transformer decoder and enables language modeling beyond a fixed-length context. It accomplishes this with two techniques: a segment-level recurrence mechanism and a relative positional encoding scheme.

*__Segment-level Recurrence__* During Transformer-XL training, the representations computed for the previous segment are fixed and cached to be reused as an extended context when the model processes the next new segment. This additional connection increases the largest possible dependency length by N times, where N is the depth of the network, because contextual information is now able to flow across segment boundaries. This recurrence mechanism also resolves the context fragmentation issue, providing necessary context for tokens in the front of a new segment.
 
 <img src="images/xl.gif" alt=" Vanilla Transformer" width="800">
 
*__Relative Positional Encodings__* Since the Transformer-XL caches representations between segments, a new way to represent the contexual positions of representations from different segments is needed. For example, if an old segment uses contextual positions [0, 1, 2, 3], then when a new segment is processed,the positions are [0, 1, 2, 3, 0, 1, 2, 3] for the two segments combined. The semantics of each position id is incoherent throughout the sequence and the Transformer-XL therefore proposes a novel relative positional encoding scheme to make the recurrence mechanism possible. 
 
When both of these approaches are combined, Transformer-XL has a much longer effective context than the original Transformer model at evaluation time.

## Installing dependencies
First, let's install and import all of the Python packages that you will use in this tutorial.

In [None]:
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

# Create the environment and install required packages
!pip install -r requirements.txt

In [3]:
# Imports
import glob
import json
import math
import multiprocessing
import os
import pickle
import random
import time
from pprint import pprint
from typing import *

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from IPython import display
from autocfg import dataclass, field
from utils.performance_event_repo import BaseVocab
from utils.midi_utils import play_midi, print_sample_array
from utils.music_encoder import MusicEncoder
from utils.utils import plot_losses, save_checkpoint

%matplotlib inline



## Importing the data 

The input dataset is in the [MIDI](https://en.wikipedia.org/wiki/MIDI) format. In this tutorial, you will use the [`JSB-Chorales-dataset`](http://www-etud.iro.umontreal.ca/~boulanni/icml2012). The link contains pickled files that you will convert to MIDI in the cells below. 

A chorale is a type of musical structure or form. Chorales usually consist of one voice singing a simple melody and three lower voices providing harmony. In this dataset, the voices are represented by four individual piano tracks.

You will now download and use this dataset locally.

### Downloading the `JSB-Chorales-dataset`


In [5]:
# Download a .zip file containing the .mid files from the dataset.
!wget http://www-etud.iro.umontreal.ca/~boulanni/JSB%20Chorales.zip -P data/

# Unzip the contents of that directory
!unzip -q "data/JSB Chorales.zip" -d data/

# Rename downloaded file
!mv "data/JSB Chorales" data/jsb_chorales/

# Change the string in the the `data_dir` variable to the correct file path
data_dir = "data/jsb_chorales/**/*.mid"

# Load midi files
midi_files = glob.glob(data_dir)

--2021-01-21 17:50:43--  http://www-etud.iro.umontreal.ca/~boulanni/JSB%20Chorales.zip
Resolving www-etud.iro.umontreal.ca (www-etud.iro.umontreal.ca)... 132.204.26.158
Connecting to www-etud.iro.umontreal.ca (www-etud.iro.umontreal.ca)|132.204.26.158|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 215242 (210K) [application/zip]
Saving to: ‘data/JSB Chorales.zip.1’


2021-01-21 17:50:44 (2.31 MB/s) - ‘data/JSB Chorales.zip.1’ saved [215242/215242]



In [6]:
#Use this cell to play a random sample chorale snippet from the dataset `midi_files` 
#You can run the cell over-and-over again to hear different samples

random_midi = random.randrange(len(midi_files))
play_midi(midi_files[random_midi])

If your dataset has been successfully downloaded, you should be able to play a track after you have run the previous code cell.

## Preprocessing the data into the note sequence format

In this section, you can see how the MIDI files from the dataset are preprocessed into the note sequence format needed to train the Transformer-XL model.

<!-- You will train on the Piano-e-Competition dataset that consists of polyphonic piano performance MIDI with expressive timing and dynamics. . -->
 
The MIDI files in the `JSB-Chorales-dataset` consist of polyphonic piano performances with expressive timing and dynamics. To train a language model on these MIDI files you will first need to serialize the [polyphonic](https://ccnmtl.columbia.edu/projects/sonicg/terms/polyphony.html) performance into a single sequence of encoded words representing the different musical (MIDI) events. To do this, the different piano tracks or voices are interwoven into a sequence format (called the *Performance note sequence format*) as originally suggested in [This Time with Feeling: Learning Expressive Musical Performance](https://arxiv.org/pdf/1808.03715.pdf).

The musical events are converted into a sequence of words from a [predefined vocabulary](utils/magenta_vocab.txt). Of the 128 available pitches, 88 pitches were used. The structure of the vocabulary is described below.
The vocabulary includes:
- 88 NOTE_ON events for starting a note using one of the 88 MIDI pitches
- 88 NOTE_OFF events for ending or releasing a note using one of the 88 MIDI pitches
- 100 TIME_SHIFT events where each one moves the time step forward incrementally. These time shift events can range anywhere from 10 ms to 1 second.
- 32 SET_VELOCITY events that change the velocity applied to all subsequent notes (until the next velocity event). The 128 possible MIDI velocities are quantized into 32 bins. 

### Comparing a piano roll image to a sequence
This section highlights an example encoding where the pianoroll (pitches vs. time) on top is converted into a sequence of tokens below. A C Major chord is [arpeggiated](https://en.wikipedia.org/wiki/Arpeggio) along with an active sustain pedal. The pedal is released at the 2-second mark, ending all of the notes. At the 3-second mark, an F is played for a half second. The C chord is played at velocity (volume) 80 and the F is played at velocity (volume) 100.
 
<img src="images/event_rep.jpg" align="center" />

### Data augmentation
[Data augmentation](https://d2l.ai/chapter_computer-vision/image-augmentation.html) provides a simple way of encoding domain specific prior knowledge in any machine learning algorithm. You will augment the dataset as originally suggested in the paper, [Music Transformer](https://arxiv.org/abs/1809.04281). Two kinds of data augmentation are applied to every note to reflect how music can retain a melody when either the pitch is transposed (moving up or down an octave) or the tempo is changed. Below are the two variables used to accomplish this: 
- Pitch transpositions are controlled using integers `pitch_transpose_lower` and `pitch_transpose_upper`. The transpositions applied to a note are sampled uniformly from the integer set of half-steps {pitch_transpose_lower,...,pitch_transpose_upper}. We use `pitch_transpose_lower`=-3 and `pitch_transpose_upper`=3.
- Time stretch factors are uniformly sampled from the set {0.95,0.975,1.0,1.025,1.05} that we denote using the variable `stretch_factors`. The duration of a note is multiplied with this sampled stretch factor.

In [7]:
# Creates a class that will be used to transform the dataset from 
# MIDI to a numpy array based on the Performance vocabulary
music_encoder = MusicEncoder()
 
# Specify event representation algorithm ('performance') and dataset augmentation parameters
pitch_transpose_lower, pitch_transpose_upper = -3, 3
music_encoder.build_encoder(algorithm='performance', stretch_factors=[0.95,0.975,1.0,1.025,1.05],
                            pitch_transpose = (pitch_transpose_lower, pitch_transpose_upper))

# Convert midi dataset to numpy array
music_encoder.convert(input_folder= 'data/jsb_chorales', 
                      output_folder='data/jsb_chorales_numpy',
                      mode='midi_to_npy')


Converting midi files from data/jsb_chorales to npy...
Loaded dataset from data/jsb_chorales. Train/Val/Test=229/76/77
Split train converted! Spent 12.860990047454834s to convert 229 samples.
Split valid converted! Spent 0.4431154727935791s to convert 76 samples.
Split test converted! Spent 0.42433857917785645s to convert 77 samples.


You will print a sample MIDI file from the training set. This is now a numpy array of indices for each event type.

In [None]:
# Prints a randomly sampled numpy array from the parent_dir
print_sample_array(split="train", parent_dir="data/jsb_chorales_numpy")

## Model architecture

In the sections below, you can see how the Transformer-XL decoder is implemented one piece at a time. You will begin by looking at the attention matrix and the terms within. You will then implement the Transformer-XL decoder using a sequence of multi-head attention layers along with feed forward layers, residual connections and `LayerNorm` layers.



### Self attention

The attention layer is the fundamental part of the Transformer-XL model architecture. In this section you can see how the attention score matrix for a single head is computed. Later on you will see how the multi-head attention layer is composed of several attention heads.

Say you have a music sequence of length (`tgt_len`) of 10 events with a batch size (`batch_size`) of 8 that you want to use self-attention on. `tgt_len` and `batch_size` are examples of hyperparameters that are tuned while training the Transformer-XL. 

In [None]:
tgt_len = 10  # Length of input music sequence
batch_size = 8  # Batch size
vocab_size = 310  # Effective Vocabulary size for the Performance event representation
 
sequence = torch.randint(vocab_size, (tgt_len, batch_size))

You first transform the sequence indices into word embeddings of dimension 32 which is another hyperparameter.

In [None]:
dim_embed = 32  # Embedding dimension

word_emb_layer = nn.Embedding(vocab_size, dim_embed)
word_emb = word_emb_layer(sequence)

In [None]:
word_emb.shape  # tgt_len x batch_size x dim_embed

The Transformer-XL introduces a segment level recurrence mechanism, where cached states for the previous segment (sequence) are also provided as additional input to the model.
The cached states are referred to as the memory. 


In [None]:
mem_len = 10  # Previous sequence len (same as input sequence length here)

mems = torch.rand(mem_len, batch_size, dim_embed)  # memory of cached states

While the memory length or number of cached states stored is the same as the sequence length (equal to 10) in this example, this need not be the case.


As in the original Transformer, the attention layer operates on keys, queries and values. The keys and values are comprised of both memory and current sequence while the queries are the current sequence alone.

You first linearly project the queries, keys and values to a dimension `dim_head` for efficient computation of attention.

In [None]:
dim_head = 5  # Inner dimension of attention head

# Define linear projection layers
q_net = nn.Linear(dim_embed, dim_head)
k_net = nn.Linear(dim_embed, dim_head)
v_net = nn.Linear(dim_embed, dim_head)

# Keys and values are comprised of memory and current sequence
word_emb_concat = torch.cat([mems, word_emb], dim=0)  # Concatenated along seq dim
K = k_net(word_emb_concat)
V = v_net(word_emb_concat)

# Queries are comprised of current sequence
Q = q_net(word_emb)

You now compute the scaled dot product attention between the query __Q__, key __K__ and value __V__ matrices. As in the original paper, you divide the attention score matrix with the square root of the embedding dimension. For large values of embedding dimension the dot products grow large in magnitude, pushing the softmax function into regions where it has
extremely small gradients. Scaling the dot products by $\sqrt{1/d_k}$ helps address this.

$$ \textrm{Attention}(Q, K, V) = \textrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$

![image](images/attention_fig.png)

To compute the dot product $QK^T$ in Pytorch you will use [einsum](https://pytorch.org/docs/stable/generated/torch.einsum.html) and then scale the result. This term is called `attn_a` and you shall soon understand why.


In [None]:
attn_a = torch.einsum("ibd,jbd->bij", Q, K) / (dim_embed ** 0.5)

The Transformer-XL attention matrix introduces additional components from relative positional embeddings as you will see next. Once you compute these additional components you add them to obtain the net attention score matrix before applying the softmax activation.

### Relative Positional Encoding

Observe that unlike the original Transformer, you did not add positional embeddings to `word_emb` before you computed the attention score. 

Music has multiple dimensions along which relative differences arguably matter more than their absolute values; the two most prominent are timing and pitch. To model such pairwise relations between representations, the paper, [Self-Attention with Relative Position Representations](https://arxiv.org/abs/1803.02155) introduced a relation-aware version of self-attention.

The Transformer-XL uses an efficient implementation of this idea in the form of relative positional embeddings. Unlike the Transformer that uses absolute positional embeddings that are added to every token embedding, the Transformer-XL uses an embedding that represents the relative distance i-j between query $q_i$ and key $k_j$. 


Recap by taking a look at the $ij^{th}$ entry in the attention matrix of the original Transformer. The original Transformer uses absolute positional embeddings denoted by $U$

\begin{align}
A^{abs}_{i,j} = 
    \underbrace{E_{x_i}^T W_q^T W_{k} E_{x_j}}_{(a)}
    + \underbrace{E_{x_i}^T W_q^T W_{k} U_{j}}_{(b)}
    \\ 
    + \underbrace{ U_{i}^T W_{q}^T W_{k} E_{x_j}}_{(c)} 
    + \underbrace{ U_{i}^T W_{q}^T W_{k} U_{j}}_{(d)}
\end{align}

Now look at the $ij^{th}$ entry in the relative attention matrix of a the Transformer-XL.

\begin{align}
A^{rel}_{i,j} = 
    \underbrace{E_{x_i}^TW_q^TW_{k,E}E_{x_j}}_{(a)}
    + \underbrace{E_{x_i}^TW_q^TW_{k,R} \color{blue}R_\color{blue}{i-j} }_{(b)}
    \\ 
    + \underbrace{ \color{red}u^\color{red}T W_{k,E}E_{x_j}}_{(c)} 
    + \underbrace{ \color{red}v^\color{red}T W_{k,R} \color{blue}R_\color{blue}{i-j}}_{(d)}
\end{align}

$E_x$ is the sequence embedding representing content<br> 
$W_q$,$W_{k,E}$ and $W_{k,R}$ are the linear transformation matrices <br>
$u$ and $v$ are learnable bias terms that represent content bias and position bias respectively <br>
T denotes matrix transpose

What are the differences?

- The first change is to replace all appearances of the absolute positional embedding $U_j$ for computing key vectors in term (b) and (d) with its relative counterpart $R_{i−j}$. This essentially reflects the prior belief that only the relative distance matters for where to attend. R is a [sinusoid encoding matrix](https://kazemnejad.com/blog/transformer_architecture_positional_encoding/) without learnable parameters that you shall soon visualize in the cells below.

- Second, a trainable parameter $u \in \mathbb{R}^d$ is introduced to replace the content query $U_{i}^T W_{q}^T$ in term (c). In this case, since the query vector is the same for all query positions, it suggests that the attentive bias towards different words should remain the same regardless of the query position. With a similar reasoning, a trainable parameter $v \in \mathbb{R}^d$ is added to substitute $U_{i}^T W_{q}^T$ in term (d).

- Finally, the two weight matrices $W_{k,E}$ and $W_{k,R}$ are deliberately separated for producing the content-based key vectors and location-based key vectors respectively.


Above you saw the code to compute the content based attention term (a), denoted by `attn_a`. You can see below how remaining terms (b),(c) and (d) are computed. (b) and (d) involve the relative positional embeddings. 

Term (c) is computed first. 

In [None]:
u = torch.rand(5).expand_as(Q)  # Learnable bias

attn_c = torch.einsum("ibd,jbd->bij", u, K) / (dim_embed ** 0.5)

In [None]:
attn_content = attn_a + attn_c  # Content attention is composed of terms (a) and (c)

You will now understand the relative positional embeddings terms (b) and (d). The Transformer-XL uses [sinusoidal embeddings](https://console.aws.amazon.com/deepcomposer/home?region=us-east-1#learningCapsules/transformerTechnique) like the original Transformer.

In [None]:
# Define positional indices
pos_seq = torch.arange(tgt_len + mem_len - 1, -1, -1, dtype=torch.float)
pos_seq

In [None]:
# Compute sinusoidal positional embeddings
inv_freq = 1 / (10000 ** (torch.arange(0.0, dim_embed, 2.0) / dim_embed))
sinusoid_inp = torch.ger(pos_seq, inv_freq) # Outer product of vectors
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)[:, None, :]

Let us visualize a few dimensions in the positional embeddings matrix for better understanding. The sines and cosines act as basis functions to [represent the position](https://kazemnejad.com/blog/transformer_architecture_positional_encoding/).

In [None]:
# Visualize a few sinusoidal embeddings in pos_emb
plt.plot(pos_seq, pos_emb[:, 0, 0:4])
plt.legend(["dim %d" % p for p in [0, 1, 2, 3]])

You will now define all the core components of the Transformer-XL in a class. The relative positional embeddings are defined inside a class `PositionalEmbedding`.

In [None]:
class PositionalEmbedding(nn.Module):
    """ 
    Transformer-XL positional embedding definition 
    """

    def __init__(self, dim_embed):
        super(PositionalEmbedding, self).__init__()

        self.dim_embed = dim_embed

        inv_freq = 1 / (10000 ** (torch.arange(0.0, dim_embed, 2.0) / dim_embed))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, pos_seq, batch_size=None):

        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)  # Outer product

        # Define relative positional embeddings
        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)

        if batch_size is not None:
            return pos_emb[:, None, :].expand(-1, batch_size, -1)
        else:
            return pos_emb[:, None, :]

As seen earlier for the queries, keys and values, you need to apply linear transformations ($W_{k,R}$) to the positional embeddings as well for efficiency when computing across several heads.

In [None]:
# Define linear transformation layer that acts on positional embeddings
r_net = nn.Linear(dim_embed, dim_head)

# Project positional embeddings
R = r_net(pos_emb) 

You will now compute terms (b) and (d) together

In [None]:
v = torch.rand(5)  # Learnable bias

attn_bd = torch.einsum("ibd,jd->bij", Q + v, R[:, 0, :]) / (dim_embed ** 0.5)

In [None]:
attn_bd.shape

Observe that `R` currently scales linearly with sequence length (say n) since `pos_seq` scales linearly with n. The terms (b) and (d) however depend on both i and j. Naively computing `R` and its linear projection would involve computing and storing vectors that are proportional to $O(n^2)$. The authors proposed a trick (padding+shifting) to reduce this to $O(n)$ time and memory instead by computing the attention for one query then shifting the embeddings for different query positions. This trick makes the original [Relative Attention](https://arxiv.org/abs/1803.02155) idea feasible for modeling long sequences.

In [None]:
padding = torch.zeros((batch_size, tgt_len, 1), dtype=torch.float)

# padding + shifting is the trick for efficiently computing pos_attn
attn_pos = (
    torch.cat([padding, attn_bd], dim=-1)
    .view(batch_size, tgt_len + mem_len + 1, tgt_len)[:, 1:]
    .view_as(attn_bd)
)

The net attention matrix is the sum of terms (a),(b),(c) and (d)

In [None]:
attn = attn_content + attn_pos

Since you are interested in music generation (formulated as language modeling), you need to use a causal mask to prevent the model from looking into the future as its predicting.

In [None]:
mask = torch.triu(
    torch.ones((tgt_len, tgt_len + mem_len)),
    diagonal=1 + mem_len,
).bool()[None, ...]

attn = attn.masked_fill(mask, -float("inf"))

You now compute the outputs of the attention layer by taking a weighted sum of the value vectors using the attention probabilities

In [None]:
attn_probs = torch.softmax(attn, dim=2)
attn_vec = torch.einsum("bij,jbd->ibd", attn_probs, V)


Finally you transform `attn_vec` using a linear layer to bring it back to its original dimension. You also apply the residual connection and layer normalization as in the original Transformer.

In [None]:
o_net = nn.Linear(dim_head, dim_embed)
layer_norm = nn.LayerNorm(dim_embed)
outputs = layer_norm(word_emb + o_net(attn_vec))

### Multi-head attention

You now define the Transformer-XL `MultiHeadAttn` module by applying self attention across several heads. 

In [None]:
class MultiHeadAttn(nn.Module):
    """ 
    Defines the Multihead Attention module in the Transformer-XL 
    """
    
    def __init__(
            self,
            n_head,
            dim_model,
            dim_head,
            dropout,
            dropatt=0,
    ):
        super(MultiHeadAttn, self).__init__()

        # number of heads
        self.n_head = n_head
        
        # input dimension
        self.dim_model = dim_model
        
        # inner dimension
        self.dim_head = dim_head
        
        self.dropout = dropout
        
        # you apply the linear transformation to queries, keys and values 
        # at once for efficiency
        self.qkv_net = nn.Linear(dim_model, 3 * n_head * dim_head, bias=False)
        
        # linear transformation for output and position embeddings
        self.o_net = nn.Linear(n_head * dim_head, dim_model, bias=False)
        self.r_net = nn.Linear(self.dim_model, self.n_head * self.dim_head, bias=False)
        
        # Parameters controlling dropout or drop attention
        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        
        self.layer_norm = nn.LayerNorm(dim_model)
        
        # Dot product attention scaling factor 
        self.scale = 1 / (dim_head ** 0.5)
         
    def rel_shift(self, x):
        """ 
        Function to help compute positional attention component efficiently 
        """
        
        padding = torch.zeros(
            (x.size(0), x.size(1), x.size(2), 1), device=x.device, dtype=x.dtype
        )
        x_padded = torch.cat([padding, x], dim=3)

        x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2))

        x = x_padded[:, :, 1:].view_as(x)

        return x

    def forward(self, w: torch.FloatTensor, # (q_len, batch_size, dim_model)
                r: torch.FloatTensor, # (k_len, dim_model)
                u: torch.FloatTensor, # (batch_size, dim_model)
                v: torch.FloatTensor,  # (batch_size, dim_model)
                attn_mask: Optional[torch.FloatTensor]=None, 
                mems: Optional[torch.FloatTensor]=None): #(prev_seq_len, batch_size, dim_model)
        
        # qlen is length of current segment
        # rlen is length of current segment + length of previous segment
        qlen, rlen, batch_size = w.size(0), r.size(0), w.size(1)

        if mems is not None:
            # concatenate memory across sequence dimension
            cat = torch.cat([mems, w], 0)
            
            w_heads = self.qkv_net(cat)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
            w_head_q = w_head_q[-qlen:]
        else:
            w_heads = self.qkv_net(w)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

        klen = w_head_k.size(0)

        w_head_q = w_head_q.view(
            qlen, batch_size, self.n_head, self.dim_head
        )  # [qlen x batch_size x n_head x dim_head]
        w_head_k = w_head_k.view(
            klen, batch_size, self.n_head, self.dim_head
        )  # [klen x batch_size x n_head x dim_head]
        w_head_v = w_head_v.view(
            klen, batch_size, self.n_head, self.dim_head
        )  # [klen x batch_size x n_head x dim_head]

        r_head_k = r_head_k.view(
            rlen, self.n_head, self.dim_head
        )  # [klen x n_head x dim_head]

        # Compute attention score
        
        # Terms (a) and (c)
        uw_head_q = w_head_q + u  # qlen x batch_size x n_head x dim_head
        
        AC = torch.einsum(
            "ibnd,jbnd->bnij", (uw_head_q, w_head_k)
        )  # [batch_size x n_head x qlen x klen]

        # Terms (b) and (d)
        vw_head_q = w_head_q + v
        BD = torch.einsum(
            "ibnd,jnd->bnij", (vw_head_q, r_head_k)
        )  # [batch_size x n_head x qlen x klen]
        
        # Compute positional attention component efficiently
        BD = self.rel_shift(BD) # [batch_size x n_head x qlen x klen]
        
        attn_score = AC + BD
        attn_score.mul_(self.scale)

        # Compute attention probability
        
        # Use a causal mask if provided
        if attn_mask is not None:
            attn_score.masked_fill_(attn_mask[None, None, :, :], -float("inf"))

        
        attn_prob = F.softmax(attn_score, dim=3) # [batch_size x n_head x qlen x klen]
        attn_prob = self.dropatt(attn_prob) # [batch_size x n_head x qlen x klen]

        attn_vec = torch.einsum("bnij,jbnd->ibnd", (attn_prob, w_head_v))
        # [qlen x batch_size x n_head x dim_head]
        
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.dim_head
        )

        # linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        # residual connection + layer normalization
        output = self.layer_norm(w + attn_out)

        return output


To test if this layer works as expected, we expect the input and output shapes to be identical.

In [None]:
mha_layer = MultiHeadAttn(n_head=2, dim_model=16, dim_head=5, dropout=0.1, dropatt=0)

seq = torch.rand(10, 8, 16)  # [q_len x batch_size x dim_model]
pos_emb = torch.rand(16, 16)  # [k_len x dim_model]
mems = torch.rand(6, 8, 16)  # [mem_len x batch_size x dim_model]
u, v = torch.rand(2, 5), torch.rand(2, 5)  # [batch_size, dim_head]

outputs = mha_layer(w=seq, r=pos_emb, u=u, v=v, attn_mask=None, mems=mems)

In [None]:
print("Input sequence shape = ", seq.shape)
print("Output sequence shape = ", outputs.shape)

assert seq.shape == outputs.shape

### The decoder layer

You will build the decoder block by using MultiHeadAttn along with a Positionwise Feed Forward layer identical to the original Transformer.
![image](images/decoder.png)

In [None]:
class PositionwiseFF(nn.Module):
    """ 
    Defines the Position wise FeedForward layer in the Transformer-XL 
    """
    
    def __init__(self, dim_model, dim_inner, dropout):
        super(PositionwiseFF, self).__init__()

        # Input dimension
        self.dim_model = dim_model

        # Inner dimension within Positionwise FF
        self.dim_inner = dim_inner
        
        self.dropout = dropout

        self.CoreNet = nn.Sequential(
            nn.Linear(dim_model, dim_inner),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(dim_inner, dim_model),
            nn.Dropout(dropout),
        )

        self.layer_norm = nn.LayerNorm(dim_model)

    def forward(
        self, inp: torch.FloatTensor  # (q_len, batch_size, dim_model)
    ) -> torch.FloatTensor:  # (q_len, batch_size, dim_model)

        # positionwise feed-forward
        core_out = self.CoreNet(inp)

        # residual connection + layer normalization
        output = self.layer_norm(inp + core_out)

        return output

You will use this `PositionwiseFF` layer in the `DecoderLayer`

In [None]:
class DecoderLayer(nn.Module):
    """ 
    Transformer-XL decoder layer comprised of the Multihead attention 
    and Positionwise Feed Forward layers 
    """
    
    def __init__(self, n_head, dim_model, dim_head, dim_inner, dropout, dropatt=0):

        super(DecoderLayer, self).__init__()

        self.dec_attn = MultiHeadAttn(n_head, dim_model, dim_head, dropout, dropatt)

        self.pos_ff = PositionwiseFF(dim_model, dim_inner, dropout)

    def forward(self, dec_inp, r, u, v, dec_attn_mask=None, mems=None):

        output = self.dec_attn(
            dec_inp, r, u, v, attn_mask=dec_attn_mask, mems=mems
        )

        output = self.pos_ff(output)

        return output

### The Transformer-XL decoder

Using the `DecoderLayer` defined above you can now build the Transformer-XL.

You will also need to define the input `Embedding` layer that maps from input indices to a sequence of dimension `dim_model` and the Output layer that maps from `dim_model` to the length of vocabulary.

The Transformer-XL ties the weights in the Embedding Layer and Output layer so that the total parameter count is reduced. 

In [None]:
class Embedding(nn.Module):
    """
    Embedding layer in the Transformer-XL
    """
    def __init__(self, n_token, dim_embed):
        """
        Args:
            n_token: number of tokens in vocab
            dim_embed: dimension of embedding
        """

        super(Embedding, self).__init__()

        self.n_token = n_token
        self.dim_embed = dim_embed

        self.emb_scale = dim_embed ** 0.5

        self.emb_layers = nn.ModuleList()

        self.emb_layers.append(nn.Embedding(n_token, dim_embed, sparse=False))

    def forward(
        self,
        inp: torch.LongTensor,  # (qlen, batch_size)
    ) -> torch.FloatTensor:  # (qlen, batch_size, dim_embed)
        embed = self.emb_layers[0](inp)

        # Embeddings are scaled while the Output Layer of Transformer-XL is not
        embed.mul_(self.emb_scale)

        return embed

You will now define the Transformer-XL module using everything you have learnt so far.

In [None]:
class TransformerXL(nn.Module):
    """ 
    The Transformer-XL module comprised of the Embedding layer,
    multiple Decoder layers and the output layer
    """
    def __init__(
        self,
        n_layer,
        n_head,
        dim_model,
        dim_inner,
        dropout,
        dropatt,
        tie_weight,
        tgt_len,
        mem_len,
        n_token,
    ):

        super(TransformerXL, self).__init__()

        # Embedding layer
        self.word_emb = Embedding(
            n_token,
            dim_model,
        )

        dim_head = dim_model // n_head # Dimensionality of the model’s heads

        # Positional embedding
        self.pos_emb = PositionalEmbedding(dim_model)
        self.u = nn.Parameter(torch.Tensor(n_head, dim_head))
        self.v = nn.Parameter(torch.Tensor(n_head, dim_head))

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len

        # Define the decoder layers that comprise the Transformer-XL
        self.layers = nn.ModuleList()

        for i in range(n_layer):
            self.layers.append(
                DecoderLayer(
                    n_head,
                    dim_model,
                    dim_head,
                    dim_inner,
                    dropout,
                    dropatt=dropatt,
                )
            )

        # Define output layer
        self.out_layers = nn.ModuleList()
        self.out_layers.append(nn.Linear(dim_model, n_token))

        # Tie weights of output layer with embedding layer
        if tie_weight:
            for i in range(len(self.out_layers)):
                self.out_layers[i].weight = self.word_emb.emb_layers[i].weight

    def reset_length(self, tgt_len, mem_len):
        """
        Resets tgt_len and mem_len to specified values
        
        Used when tgt_len and mem_len may be different between training,
        evaluation and generation
        """
        self.tgt_len = tgt_len
        self.mem_len = mem_len

    def init_mems(self, n_layers):
        """
        Initialize mems tensor if mems is None
        """
        param = next(self.parameters())
        mems = torch.empty(n_layers + 1, 0, dtype=param.dtype, device=param.device)
        return mems

    def update_mems(self, hids, mems, qlen, mlen):
        """
        This function is called at the end of a forward.
        Updates mems with hidden states of current segment
        """
        
        if mems is None:
            return None

        with torch.no_grad():

            # Update mems with the most recent `self.mem_len`
            # states that includes the previous memory

            stacked = torch.stack(hids)
            end_idx = mlen + max(0, qlen)
            start_idx = max(0, end_idx - self.mem_len)
            
            # Dimension of cat is (num_layers, self.mem_len+qlen, batch_size, dim_model)
            cat = torch.cat([mems, stacked], dim=1) if mems.numel() else stacked
            
            # Dimension of new_mems is (num_layers, self.mem_len, batch_size, dim_model)
            new_mems = cat[:, start_idx:end_idx].detach()
            
        return new_mems

    def _forward(self, dec_inp, mems=None):
        """
        Helper function used by forward()

        """
        qlen, batch_size = dec_inp.size()[0], dec_inp.size()[1]
        word_emb = self.word_emb(dec_inp)

        mlen = mems[0].size(0) if mems is not None else 0
        klen = mlen + qlen

        # Construct attention mask
        dec_attn_mask = torch.triu(
            word_emb.new_ones(qlen, klen), diagonal=1 + mlen
        ).bool()[:, :]

        # Construct positional embeddings
        pos_seq = torch.arange(
            klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype
        )

        pos_emb = self.pos_emb(pos_seq)
        pos_emb = self.drop(pos_emb)

        # Successively run through Decoder Layers
        hids = []
        core_out = self.drop(word_emb)
        hids.append(core_out)

        for i, layer in enumerate(self.layers):
            mems_i = None if mems is None else mems[i]
            core_out = layer(
                core_out,
                pos_emb,
                self.u,
                self.v,
                dec_attn_mask=dec_attn_mask,
                mems=mems_i,
            )
            hids.append(core_out)
        core_out = self.drop(core_out)

        # Update memory
        new_mems = self.update_mems(hids, mems, mlen, qlen)

        return core_out, new_mems

    def forward(self, data, target, mems=None):

        if mems is None and self.mem_len > 0:
            mems = self.init_mems(self.n_layer)

        tgt_len = target.size(0)
        hidden, new_mems = self._forward(data, mems=mems)

        pred_hid = hidden[-tgt_len:]

        logit = self.out_layers[0](pred_hid.view(-1, pred_hid.size(-1)))

        loss = (
            -F.log_softmax(logit, dim=-1)
            .gather(1, target.view(-1).unsqueeze(1))
            .squeeze(1)
        )

        loss = loss.view(tgt_len, -1)

        return (loss, new_mems)

    def forward_generate(self, data, mems):
        """
        This function is called during inference (decoding)
        when one generates tokens incrementally.
        It is identical to forward() but does not compute the loss
        and returns the logits instead

        """
        if mems is None and self.mem_len > 0:
            mems = self.init_mems(self.n_layer)

        tgt_len = data.size(0)
        batch_size = data.size(1)

        hidden, new_mems = self._forward(data, mems=mems)

        pred_hid = hidden[-tgt_len:]

        logits = self.out_layers[0](pred_hid.view(-1, pred_hid.size(-1)))
        logits = logits.view(tgt_len, batch_size, -1)

        return (logits, new_mems)

You will test if the model is working with some dummy inputs

In [None]:
transformerxl = TransformerXL(
    n_layer=5,
    n_head=4,
    dim_model=10,
    dim_inner=5,
    dropout=0.1,
    dropatt=0,
    tie_weight=True,
    tgt_len=20,
    mem_len=5,
    n_token=310,
)

In [None]:
inputs = torch.randint(310, (20, 2)) # input indices of shape (tgt_len, batch_size)
tgts = torch.randint(310, (20, 2)) # target indices of shape (tgt_len, batch_size)
outputs = transformerxl(inputs, tgts)

print("Output is a tuple of shape ", len(outputs))
assert len(outputs)==2

print("Loss is a tensor of shape ", outputs[0].shape)  # (tgt_len, batch_size)
assert outputs[0].shape == tgts.shape

print("Memory is a tensor of shape ", outputs[1].shape)  # (n_layer+1, mem_len, batch_size, dim_model)
assert outputs[1].shape == torch.Size([5+1, 5, tgts.shape[1], 10])

## Training the model

In the following sections, you can see how to train the Transformers-XL model.

You will first define the hyperparameters that you will use in the data loader, training, and evaluation loops.

### Training hyperparameters

Hyperparameters are broadly categorized into those that control training and evaluation, and those that define the model architecture.

In [None]:
@dataclass
class TrainConfig:
    """
    Defines configuration parameters used during model training
    """
    
    # Dataset
    data_dir = "data/jsb_chorales_numpy"

    # Checkpoint save path
    save_path = "checkpoints/"

    # Training and evaluation hyperparameters
    batch_size = 64 # Training batch size
    seed = 101 # Seed to reproduce losses
    tgt_len = 128 # Target length or bptt (use as large as fits in GPU memory)
    mem_len = 512 # Memory length (use as large as fits in GPU memory)
    clip = 1.0 # Grad norm clip constant
    scheduler = "inv_sqrt" # Learning rate scheduler
    warmup_step = 4000 # Learning rate warmup
    lr = 0.004 / 8 # Learning rate
    lr_min = 0.0001 / 4 # Min learning rate
    optim = "adam" # Optimizer
    weight_decay = 0.0  # Weight decay for adam
    max_step = 20000 # Max steps
    
    eval_batch_size = 2 # Evaluation batch size
    eval_tgt_len = 128 # Evaluation target length
    eval_mem_len = 512 # Evaluation memory length

    log_interval = 100 # Print logs every log_interval training iterations
    eval_interval = 500 # Evaluate after eval_interval training iterations

    # Plotting and saving params
    save_all_test_losses = True 
    plot_losses_while_training = True
    plot_interval = 100 # Plot losses every plot_interval training iterations

    # Weight initialization
    base_init = ["normal", 0.01] # Initialization parameters for weights
    embed_init = ["normal", 0.01] # Intialization parameters for embeddings

    # Model hyperparameters
    dropout = 0.1 # The dropout probability for all fully connected layers in the embeddings, encoder, and pooler
    dropatt = 0.1  # The dropout ratio for the attention probabilities
    dim_inner = 1000 # Inner dimension within Positionwise FF 
    num_heads = 10 # Number of heads in Multihead attention
    num_layers = 4 # Number of layers in Transformer-XL
    tie_embedding = True # Share weights between input embedding and output layer
    dim_model = 500 # Dimensionality of the model’s hidden states

In [None]:
train_cfg = TrainConfig()

### Building the data loader

While the Transformer-XL data loader resembles the standard data loading pipeline in language models like the RNN, it is different in that the Transformer-XL memory persists across batches.

You need to ensure that batch elements across batches in the Transformer-XL correspond to the same MIDI file. This ensures that the Transformer-XL memory contains the cached segments preceding the current segment when you train using minibatches.

![image](images/dataloader.jpg)

In the figure above, we depict two successive batches (each with 2 elements) that the model is trained on. Each batch element is a sequence of length `tgt_len`. The batches are constructed so that each batch element (Element 1 in Batch 1 and Element 1 in Batch 2) uses segments from the same MIDI file. This ensures that the Transformer-XL memory for a batch element caches segments from same MIDI file that the current segment is from. 

You will also use a `BaseVocab` class that will help you work with the Magenta vocabulary. This class abstracts several functions that map from token names (NOTE_ON_70 etc.) to indices.

In [None]:
class MusicDataset:
    def __init__(self, data_dir):
        """Load the music corpus
        Args:
            data_dir: The base folder of the preprocessed music dataset
        """
        self.vocab_path = os.path.join(data_dir, "vocab.txt")
        self.train_folder = os.path.join(data_dir, "train")
        self.valid_folder = os.path.join(data_dir, "valid")
        self.test_folder = os.path.join(data_dir, "test")
        all_tokens = []
        with open(self.vocab_path, "r") as f:
            all_tokens = [token.strip() for token in f]
        
        # BaseVocab class that provides useful functions to interact with vocabulary
        self.vocab = BaseVocab(all_tokens)

        self.train_data = self.load_cache_data(self.train_folder)
        self.valid_data = self.load_cache_data(self.valid_folder)
        self.test_data = self.load_cache_data(self.test_folder)

        # Insert start tokens
        
        self.train_data = [
            torch.from_numpy(np.insert(arr, 0, self.vocab.bos_id))
            for arr in self.train_data
        ]
        self.valid_data = [
            torch.from_numpy(np.insert(arr, 0, self.vocab.bos_id))
            for arr in self.valid_data
        ]
        self.test_data = [
            torch.from_numpy(np.insert(arr, 0, self.vocab.bos_id))
            for arr in self.test_data
        ]

        # Extract sequence lengths for the different splits
        self.train_seq_length = np.array(
            [ele.shape[0] for ele in self.train_data], dtype=np.int32
        )
        self.valid_seq_length = np.array(
            [ele.shape[0] for ele in self.valid_data], dtype=np.int32
        )
        self.test_seq_length = np.array(
            [ele.shape[0] for ele in self.test_data], dtype=np.int32
        )
        print(
            "Loaded Data, #Samples Train/Val/Test:{}/{}/{}".format(
                len(self.train_data), len(self.valid_data), len(self.test_data)
            )
        )
        print(
            "#Avg Length:{}/{}/{}".format(
                np.mean([len(ele) for ele in self.train_data]),
                np.mean([len(ele) for ele in self.valid_data]),
                np.mean([len(ele) for ele in self.test_data]),
            )
        )
        print(
            "#Total Number of Valid/Test Tokens: {}/{}".format(
                (self.valid_seq_length - 1).sum(), (self.test_seq_length - 1).sum()
            )
        )

    def load_cache_data(self, dir_name):
        """
        Returns the loaded numpy dataset from dir_name
        """
        all_fnames = sorted(glob.glob(os.path.join(dir_name, "*.npy")))
        print("Loading #{} files from {}".format(len(all_fnames), dir_name))
        # Create a large array
        with multiprocessing.Pool(8) as pool:
            dat = pool.map(np.load, all_fnames)
        return np.array(dat)

    def get_iterator(
            self, batch_size, bptt, device, split="train", do_shuffle=True, seed=None
    ):
        """
        Function that returns an iterator over the dataset specified by 
        batch_size, bptt, device and split
        """
        if split == "train":
            split_data = self.train_data
            split_seq_lengths = self.train_seq_length
        elif split == "valid":
            split_data = self.valid_data
            split_seq_lengths = self.valid_seq_length
        elif split == "test":
            split_data = self.test_data
            split_seq_lengths = self.test_seq_length
        else:
            raise NotImplementedError
        total_sample_num = len(split_data)

        def iterator():
            perm = np.arange(total_sample_num)
            if do_shuffle:
                rng = np.random.RandomState(seed)
                rng.shuffle(perm)
            assert batch_size < total_sample_num
            tracker_list = [(i, 0) for i in range(batch_size)]
            next_idx = batch_size
            data = torch.LongTensor(bptt, batch_size)
            target = torch.LongTensor(bptt, batch_size)

            while True:

                # Fill with pad_id
                data[:] = self.vocab.pad_id
                target[:] = self.vocab.pad_id

                batch_token_num = 0
                for i in range(batch_size):
                    idx, pos = tracker_list[i]
                    while idx < total_sample_num:
                        seq_id = perm[idx]
                        seq_length = split_seq_lengths[seq_id]
                        if pos + 1 >= seq_length:
                            idx, pos = next_idx, 0
                            tracker_list[i] = (idx, pos)
                            next_idx += 1
                            continue
                        else:
                            n_new = min(seq_length - 1 - pos, bptt)
                            data[:n_new, i] = split_data[seq_id][pos: pos + n_new]
                            target[:n_new, i] = split_data[seq_id][
                                                (pos + 1): (pos + 1 + n_new)]
                            batch_token_num += n_new
                            tracker_list[i] = (idx, pos + n_new)

                            break
                            
                if batch_token_num == 0:
                    # Haven't found anything to fill. This indicates we have reached the end
                    if do_shuffle:
                        rng.shuffle(perm)
                    else:
                        return  # One pass dataloader when do_shuffle is False
                    tracker_list = [(i, 0) for i in range(batch_size)]
                    next_idx = batch_size
                    continue

                yield data.to(device), target.to(device), batch_token_num

        return iterator


### Loading data

You will use the Dataloader you just defined to load the dataset.


In [None]:
dataset = MusicDataset(train_cfg.data_dir) # Dataset path
vocab = dataset.vocab  # Vocabulary class
seed = train_cfg.seed  # seed to ensure constant behavior across runs
device = (
    torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
)

batch_size = train_cfg.batch_size

# Train split iterator
train_iter = dataset.get_iterator(
    batch_size, train_cfg.tgt_len, device, "train", do_shuffle=True, seed=seed
)

# Validation split iterator
val_iter = dataset.get_iterator(
    train_cfg.eval_batch_size,
    train_cfg.eval_tgt_len,
    device,
    "valid",
    do_shuffle=False,
    seed=seed,
)

# Test split iterator
test_iter = dataset.get_iterator(
    train_cfg.eval_batch_size,
    train_cfg.eval_tgt_len,
    device,
    "test",
    do_shuffle=False,
    seed=seed,
)

### Evaluation loop

You will now define the evaluation loop used while training the model.

In [None]:
def evaluate(model, eval_iter):
    """
    Function to compute validation negative log-likelihood (nll) of a model
    on a dataset specified with the eval_iter iterator
    """
    
    # Turn on evaluation mode def disables dropout.
    model.eval()

    model.reset_length(tgt_len=train_cfg.eval_tgt_len, 
                       mem_len=train_cfg.eval_mem_len)
    
    # Evaluation
    total_token_num = 0
    total_nll = 0.0

    with torch.no_grad():
        mems = None

        for i, (data, target, batch_token_num) in enumerate(eval_iter()):

            loss, mems = model(data, target, mems)
            loss = loss[target != dataset.vocab.pad_id]
            loss = loss.mean()
            total_nll += batch_token_num * loss.float().item()
            total_token_num += batch_token_num
    
    model.reset_length(train_cfg.tgt_len, train_cfg.mem_len)
    model.train()
    
    return total_token_num, total_nll


### Training loop

You will now write the training loop. First a few helper functions that help with training.

In [None]:
# Dictionaries to record train, val and test losses
train_losses = dict()
val_losses = dict()
test_losses = dict()

In [None]:
def plot_losses_while_training():
    """
    Helper function to plot losses while training
    """
    display.clear_output(wait=True)
    fig = plt.figure(figsize=(15, 5))

    def plot_lines(loss_dic, color):
        iters = list(loss_dic.keys())
        vals = [loss_dic[i] for i in iters]
        return plt.plot(iters, vals, color)

    (line1,) = plot_lines(train_losses, "r")
    (line2,) = plot_lines(val_losses, "k")
    (line3,) = plot_lines(test_losses, "b")

    plt.xlabel("Iterations")
    plt.ylabel("Losses")
    plt.legend((line1, line2, line3), ("train-loss", "val-loss", "test-loss"))
    display.display(fig)
    plt.close()

In [None]:
def evaluate_and_log(model, train_step, mode="eval"):
    """
    Helper function to evaluate model in "eval" or "test" mode and log losses
    """
    start_time = time.time()
    
    eval_iter = val_iter if mode=="eval" else test_iter
    token_num, total_nll = evaluate(model=model, eval_iter=eval_iter)

    nll = total_nll / token_num

    pprint(
        f"{mode} step {train_step}, time={(time.time() - start_time)}s, {mode} nll={nll}," 
        f"{mode} ppl={math.exp(nll)}, #evaluated tokens={token_num}"
        )
    
    return nll

You are now ready to write the entire training loop. We minimize the [negative log-likelihood (NLL)](https://d2l.ai/chapter_appendix-mathematics-for-deep-learning/maximum-likelihood.html) loss that is standard practice in language modeling.

In [None]:
def train(model):
    """
    Main training function that iterates over epochs, computes the loss,
    computes gradients via the backward pass and updates weights using 
    the optimizer.
    
    Also includes functionality for plotting losses realtime and saving the
    best model checkpoint
    
    """
    train_step = 0
    best_val_nll = np.inf

    log_train_loss = torch.tensor(0.0).float().to(device)
    log_grad_norm = torch.tensor(0.0).float().to(device)
    log_token_num = torch.tensor(0).to(device)

    log_start_time = time.time()

    mems = None

    # Define optimizer
    if train_cfg.optim.lower() == "adam":
        optimizer = optim.Adam(
            model.parameters(), lr=train_cfg.lr, weight_decay=train_cfg.weight_decay
        )
    else:
        raise NotImplementedError

    # Define scheduler
    if train_cfg.scheduler == "inv_sqrt":
        # originally used for Transformer (in Attention is all you need)
        def lr_lambda(step):
            # return a multiplier instead of a learning rate
            if step == 0 and train_cfg.warmup_step == 0:
                return 1.0
            else:
                return (
                    max(
                        (train_cfg.warmup_step ** 0.5) / (step ** 0.5),
                        train_cfg.lr_min / train_cfg.lr,
                    )
                    if step > train_cfg.warmup_step
                    else step / train_cfg.warmup_step
                )

        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
    else:
        raise NotImplementedError

    train_real_iter = train_iter()

    # Iterate over epochs
    for batch, (data, target, batch_token_num) in enumerate(train_real_iter):

        model.zero_grad()

        loss, mems = model(data, target, mems)

        loss = loss[target != dataset.vocab.pad_id]
        loss = loss.float().mean()

        # Record total loss over all non pad tokens
        log_train_loss += loss.item() * (target != dataset.vocab.pad_id).sum()

        loss.backward()

        log_token_num += int(batch_token_num)

        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), train_cfg.clip)

        log_grad_norm += grad_norm
        optimizer.step()
        optimizer.zero_grad()

        # step-wise learning rate annealing
        train_step += 1
        scheduler.step()

        # Log losses
        if train_step % train_cfg.log_interval == 0:
            
            log_train_loss /= log_token_num
            log_grad_norm /= train_cfg.log_interval

            elapsed = time.time() - log_start_time
            pprint(
                "train Step {}/{}, lr={:f}, tokens/s={:.1f},"
                " train nll={:.4f}, train ppl={:.2f}, grad norm={}".format(
                    train_step,
                    train_cfg.max_step,
                    optimizer.param_groups[0]["lr"],
                    log_token_num.item() / elapsed,
                    log_train_loss.item(),
                    math.exp(log_train_loss.item()),
                    log_grad_norm.item(),
                )
            )

            # Save train loss
            train_losses[train_step] = log_train_loss.item()

            log_train_loss[()] = 0
            log_grad_norm[()] = 0
            log_token_num[()] = 0

            log_start_time = time.time()
        
        # Evaluate
        if train_step % train_cfg.eval_interval == 0:
            val_nll = evaluate_and_log(model, train_step, mode="val")
            
            # Save val loss
            val_losses[train_step] = val_nll.item()

            # Save best model
            if val_nll < best_val_nll or train_cfg.save_all_test_losses:
                
                if val_nll < best_val_nll:
                    best_val_nll = val_nll

                    save_checkpoint(
                        model,
                        train_step,
                        best_val_nll,
                        train_cfg.save_path,
                        "checkpoint_best.pt",
                    )

                # Get test nll
                test_nll = evaluate_and_log(model, train_step, mode="test")
                
                # Save test loss
                test_losses[train_step] = test_nll.item()
        
        # Plot losses while training
        if train_cfg.plot_losses_while_training and train_step % train_cfg.plot_interval == 0:
            plot_losses_while_training()

        if train_step == train_cfg.max_step:
            pprint("-" * 100)
            pprint("Max steps reached. End of training")
            break

### Weight initialization

Before you begin training, you will define functions to initialize the weights in the model. 


In [None]:
def init_weight(weight):
    """
    Function to help initialize all layer weights
    """
    if train_cfg.base_init[0] == "normal":
        init_std = train_cfg.base_init[1]
        nn.init.normal_(weight, 0.0, init_std)
    else:
        raise NotImplementedError
        
def init_embed(weight):
    """
    Function to help initialize embedding weights
    """
    if train_cfg.embed_init[0] == "normal":
        init_std = train_cfg.embed_init[1]
        nn.init.normal_(weight, 0.0, init_std)
    else:
        raise NotImplementedError

def init_bias(bias):
    """
    Function to help initialize layer bias
    """
    nn.init.constant_(bias, 0.0)

def weights_init(m):
    """
    Function that initializes layer weights and biases in the Transformer-XL
    based on name 
    """
    classname = m.__class__.__name__
    if classname.find("Linear") != -1:
        if hasattr(m, "weight") and m.weight is not None:
            init_weight(m.weight)
        if hasattr(m, "bias") and m.bias is not None:
            init_bias(m.bias)
    elif classname.find("Embedding") != -1:
        if hasattr(m, "weight"):
            init_weight(m.weight)
    elif classname.find("LayerNorm") != -1:
        if hasattr(m, "weight"):
            nn.init.normal_(m.weight, 1.0, train_cfg.base_init[1])
        if hasattr(m, "bias") and m.bias is not None:
            init_bias(m.bias)
    elif classname.find("TransformerXL") != -1:
        if hasattr(m, "u"):
            init_weight(m.u)
        if hasattr(m, "v"):
            init_weight(m.v)



### Let's train the model
You shall now proceed to define the model, initialize weights and then begin training.


In [None]:
# Let us save our config file along with the saved checkpoints
train_cfg.save(os.path.join(train_cfg.save_path, "exp.yaml"))

In [None]:
# Create the model
transformerxl = TransformerXL(n_layer=train_cfg.num_layers, n_head=train_cfg.num_heads,
                            dim_model=train_cfg.dim_model, dim_inner=train_cfg.dim_inner,       
                            dropout=train_cfg.dropout, dropatt=train_cfg.dropatt,
                            tie_weight=train_cfg.tie_embedding, tgt_len=train_cfg.tgt_len,
                            mem_len=train_cfg.mem_len, n_token=len(vocab),)

# Apply weight initialization to model
transformerxl.apply(weights_init)
transformerxl.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing

# Send model to device
transformerxl = transformerxl.to(device)


You will now call `train` that also plots the train and validation losses while training. With the default parameters training for 5000 iterations is sufficient.

In [None]:
train(transformerxl)

### Load a pretrained checkpoint (Optional)

If your model did not finish training, uncomment the cell below so that you can load a pretrained checkpoint.

In [None]:
# pretrained_path = 'pretrained_checkpoints'

# # Load pretrained config file
# train_cfg = TrainConfig.load(os.path.join(pretrained_path,'exp.yaml'))

# # Create the model
# transformerxl = TransformerXL(n_layer=train_cfg.num_layers, n_head=train_cfg.num_heads,
#                             dim_model=train_cfg.dim_model, dim_inner=train_cfg.dim_inner,       
#                             dropout=train_cfg.dropout, dropatt=train_cfg.dropatt,
#                             tie_weight=train_cfg.tie_embedding, tgt_len=train_cfg.tgt_len,
#                             mem_len=train_cfg.mem_len, n_token=len(vocab),)

# # Load pretrained checkpoint
# model_fp = os.path.join(pretrained_path,'checkpoint_best.pt')
# checkpoint = torch.load(model_fp)
# transformerxl.load_state_dict(checkpoint["model"])

# # Send model to device
# transformerxl = transformerxl.to(device)

# # Load saved losses
# with open(os.path.join(pretrained_path,'losses.pickle'),'rb') as handle:
#     losses = pickle.load(handle)
# train_losses = losses['train_losses']
# val_losses = losses['val_losses']
# test_losses = losses['test_losses']

### Plot losses

Below, the training, validation and test losses are plotted after training is completed. The training loss is non-increasing while the validation and test losses increase after a certain number of iterations due to overfitting. Our best checkpoint corresponds to the model with lowest validation loss.


In [None]:
# Plot train, validation and test losses
plot_losses(train_losses, val_losses, test_losses)

## Generating samples 


Congratulations! You have now trained your very own Transformer-XL model, and you can now use this trained model to extend an input melody file.

The model generates a melody by sampling a sequence from the model's distribution.

Sampling in an autoregressive model like the Transformer is an iterative process. To sample a sequence from the model  sample the next token from the model's output probabability distribution given the history of tokens. 

You will look at two different sampling techniques in this notebook: TopK and Nucleus. A `threshold` value called the sampling threshold is associated with each. 

__TopK sampling__

![image](images/top-k.jpg) 

In TopK sampling, the model samples from the K-tokens that have the highest probability of occurring. Here k is set using the `threshold` parameter.

If your sampling threshold is set high, then the number of available tokens (K) is large. This means the model can choose from a wider variety of musical tokens. In your extended melody, this means the generated notes are likely to be more diverse, but it comes at the cost of potentially creating less coherent music.

On the other hand, if you choose a threshold value that is too low, the model is limited to choosing from a smaller set of tokens which the model believes has a higher probability of being correct. In your extended melody, you might notice less musical diversity and more repetitive results. 

__Nucleus sampling__

![image](images/nucleus.jpg) 

Instead of sampling only from the most likely K tokens, nucleus sampling chooses from the smallest possible set of tokens whose cumulative probability exceeds the probability p. p is set using the `threshold` parameter. The probability mass is then redistributed among this set of tokens. This way, the size of the set of tokens can dynamically increase and decrease according to the next word's probability distribution.

At a high level, Nucleus sampling is very similar to TopK. Setting a higher sampling threshold allows for more diversity at the cost of coherence or consistency. 

__Other inference parameters__

- Number of Conditional Tokens: This parameter tells the model what portion of the input melody to condition on during inference. 

- Temperature: To create the output probability distribution, the final layer uses a softmax activation. You can [change the temperature](https://console.aws.amazon.com/deepcomposer/home?region=us-east-1#musicStudio) for the softmax to produce different levels of creativity in the outputs generated by the model.

- Generation length: The number of tokens to generate using the Transformer-XL. 

You can change the inference parameters in `InferenceConfig` class to observe differences in the quality of the music generated.

### Inference hyperparameters

In [None]:
@dataclass
class InferenceConfig:
    """
    Defines configuration parameters used during inference (melody extension)
    """
    
    # Model parameters
    memory_length = 4096

    # Sampling parameters
    technique = 'nucleus' # topk or nucleus
    threshold = 0.95 # theshold acts as both k [0-309] for topk sampling or p [0-1] for nucleus sampling
    temperature = 0.95

    # Input parameters  
    num_conditional_tokens = 100 # Number of tokens [>= 1] from the input melody that is used
    
    # Generation parameters
    generation_length = 1500 # Number of tokens to extend the melody


In [None]:
inference_cfg = InferenceConfig()

### Input midi file to extend


Below, you can update the code cell to select an input MIDI melody to extend.

The default provided is a MIDI from the test set.

In [None]:
input_melody_path = 'data/jsb_chorales_numpy/test/9.npy'

#### To choose a  custom input melody (Optional)

1. Open the `sample_inputs` directory
2. Upload the file that you want to use into this folder. For example, `new_world.midi`
3. Uncomment and run the following cell replacing midi_file with the custom file path

In [None]:
# # Input melody name
# midi_file = "sample_inputs/new_world.midi"

# # Convert midi to numpy 
# out_dir = 'sample_inputs'
# music_encoder.run_to_npy(midi_file, out_dir)

# filename = os.path.splitext(os.path.basename(midi_file))[0]
# input_melody_path = os.path.join(out_dir, filename + '.npy')

### Let's run inference!

You will first define a few helper functions for top-k and nucleus sampling.

In [None]:
def get_topk(probs):
    """
    Apply Top-k sampling mask to probabilities vector probs
    """
    # Apply topk mask
    topk = int(inference_cfg.threshold)
    _, top_idx = torch.topk(probs, topk)
    mask = torch.zeros_like(probs)
    mask[top_idx] = 1.0
    probs *= mask
    probs /= probs.sum()
    
def get_topp(probs): 
    """
    Apply nucleus sampling mask to probabilities vector probs
    """
    p = inference_cfg.threshold
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)

    cumulative_probs = torch.cumsum(sorted_probs, dim=0)

    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probs >= p

    # Shift the indices to the right to keep also the first token above the threshold
    sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
    sorted_indices_to_remove[0] = 0

    # scatter sorted tensors to original indexing
    indices_to_remove = sorted_indices_to_remove.scatter(
        dim=0, index=sorted_indices, src=sorted_indices_to_remove
    )
    probs[indices_to_remove] = 0
    probs /= probs.sum()

#### Prepare model for inference

In [None]:
transformerxl.eval()

# Reset tgt_length to 1, so that 1 token is generated incrementally
transformerxl.reset_length(1, inference_cfg.memory_length)

#### Define function to generate tokens incrementally

In [None]:
def extend_melody():
    """
    Loads the input melody specified by input_melody_path and returns
    the extended melody based on parameters specified in inference_cfg
    
    """
    # Load input melody
    
    conditional_data = np.load(input_melody_path).tolist()
    print('Loaded conditional file {}'.format(input_melody_path))
    
    num_conditional_tokens = inference_cfg.num_conditional_tokens

    seq = [0]
    mems = None
    
    with torch.no_grad():   
        
        # Pass prefix through Transformer-XL
        context = np.array(seq + conditional_data[:num_conditional_tokens-1], dtype=np.int32)[:, np.newaxis]
        context = torch.from_numpy(context).to(device).type(torch.long)
        ret = transformerxl.forward_generate(context, mems)
        _, mems = ret
        seq = seq + conditional_data[:num_conditional_tokens]

        # Load generation length
        generation_length = inference_cfg.generation_length
        
        for _ in range(generation_length):
          
            # Create input array from last token
            inp = np.array([seq[-1]], dtype=np.int32)[:, np.newaxis]
            inp = torch.from_numpy(inp).to(device).type(torch.long)
            
            # Generate next token incrementally
            ret = transformerxl.forward_generate(inp, mems)
            all_logits, mems = ret

            # Select last tinmestep from the single batch item
            logits = all_logits[-1, 0]

            # Do not predict start token
            logits = logits[1:]

            # Handle temp 0 (argmax) case
            if inference_cfg.temperature == 0:
                probs = torch.zeros_like(logits)
                probs[logits.argmax()] = 1.0
            else:
                # Apply temperature normalization
                logits /= inference_cfg.temperature

                # Compute softmax
                probs = F.softmax(logits, dim=-1)

            probs = F.pad(probs, [1, 0])

            # Apply sampling masks
            if inference_cfg.technique == "topk":
                get_topk(probs)
            elif inference_cfg.technique == "nucleus":
                get_topp(probs)
                
            # Sample from probabilities
            token = torch.multinomial(probs, 1)
            token = int(token.item())
            
            # Add to output list
            seq.append(token)

        # Convert output list to numpy, ignore start token and return
        return np.asarray(seq[1:])

#### Run inference and save outputs

In [None]:
# Run inference
outputs = extend_melody()

# Save numpy outputs
output_dir = "sample_outputs"
output_path = os.path.join(output_dir, "sample_melody.npy")
np.save(output_path, outputs)

### Listen to your output

 You can listen to the extended melody.


In [None]:
# Convert numpy to midi
music_encoder.run_npy_to_midi(output_path, output_dir)

In [None]:
# Play midi
filename, _ = os.path.splitext(os.path.basename(output_path))
midi_name = os.path.join(output_dir, filename + ".mid")
play_midi(midi_name)

## Cleaning up 

After completing this notebook, make sure that you stop your Amazon SageMaker notebook instance so that you don't incur unexpected costs. 

#### To stop an Amazon SageMaker notebook instance 

1. Open the [Amazon SageMaker console](https://console.aws.amazon.com/sagemaker/home?region=us-east-1#/dashboard).

2. In the navigation pane, choose **Notebook instances**.

3. Choose the notebook instance that you want to stop. 

4. From the **Actions** menu, choose **Stop**.

>**NOTE**: When your notebook instance stops, its status changes from **In service** to **Stopped**. 

# More info

For more open-source implementations of generative models for music, see the following:

- [Transformer-GAN](https://www.amazon.science/publications/symbolic-music-generation-with-transformer-gans): Trains the Transformer-XL in a GAN framework to generate music

- [LakhNES](https://arxiv.org/abs/1907.04868): Transformer-XL to generate multi-instrumental scores from the NES-MDB dataset

- [Pop music transformer](https://arxiv.org/abs/2002.00212): Transformer-XL to generate pop music by imposing a metrical structure

- [Jukebox](https://openai.com/blog/jukebox/): Uses various neural nets to generate music, including rudimentary singing, as raw audio in a variety of genres and artist styles
- [Music Transformer](https://github.com/tensorflow/magenta/tree/master/magenta/models/score2perf): Uses transformers to generate music
- [MuseNet](https://openai.com/blog/musenet/): Uses GPT2, a large-scale Transformer model, to generate multi instrumental music

