{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Introduction"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This tutorial is a brief introduction to music generation using **Generative Adversarial Networks** (**GAN**s). \n",
"\n",
"The goal of this tutorial is to train a machine learning model using a dataset of Bach compositions so that the model learns to add accompaniments to a single track input melody. In other words, if the user provides a single piano track of a song such as \"twinkle twinkle little star\", the GAN model would add three other piano tracks to make the music sound more Bach-inspired.\n",
"\n",
"The proposed algorithm consists of two competing networks: a generator and a critic (discriminator). A generator is a deep neural network that learns to create new synthetic data that resembles the distribution of the dataset on which it was trained. A critic is another deep neural network that is trained to differentiate between real and synthetic data. The generator and the critic are trained in alternating cycles such that the generator learns to produce more and more realistic data (Bach-like music in this use case) while the critic iteratively gets better at learning to differentiate real data (Bach music) from the synthetic ones.\n",
"\n",
"As a result, the quality of music produced by the generator gets more and more realistic with time."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dependencies\n",
"First, let's import all of the python packages we will use throughout the tutorial.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.\n",
"\n",
"# Permission is hereby granted, free of charge, to any person obtaining a copy of\n",
"# this software and associated documentation files (the \"Software\"), to deal in\n",
"# the Software without restriction, including without limitation the rights to\n",
"# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of\n",
"# the Software, and to permit persons to whom the Software is furnished to do so.\n",
"\n",
"# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
"# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS\n",
"# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR\n",
"# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER\n",
"# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN\n",
"# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n",
"\n",
"\n",
"# Create the environment\n",
"import subprocess\n",
"print(\"Please wait, while the required packages are being installed...\")\n",
"subprocess.call(['./requirements.sh'], shell=True)\n",
"print(\"All the required packages are installed successfully...\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# IMPORTS\n",
"import os \n",
"import numpy as np\n",
"from PIL import Image\n",
"import logging\n",
"import pypianoroll\n",
"import scipy.stats\n",
"import pickle\n",
"import music21\n",
"from IPython import display\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Configure Tensorflow\n",
"import tensorflow as tf\n",
"print(tf.__version__)\n",
"tf.logging.set_verbosity(tf.logging.ERROR)\n",
"tf.enable_eager_execution()\n",
"\n",
"# Use this command to make a subset of GPUS visible to the jupyter notebook.\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
"\n",
"# Utils library for plotting, loading and saving midi among other functions\n",
"from utils import display_utils, metrics_utils, path_utils, inference_utils, midi_utils\n",
"\n",
"LOGGER = logging.getLogger(\"gan.train\")\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configuration"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we configure paths to retrieve our dataset and save our experiments."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"root_dir = './Experiments'\n",
"\n",
"# Directory to save checkpoints\n",
"model_dir = os.path.join(root_dir,'2Bar') # JSP: 229, Bach: 19199\n",
"\n",
"# Directory to save pianorolls during training\n",
"train_dir = os.path.join(model_dir, 'train')\n",
"\n",
"# Directory to save checkpoint generated during training\n",
"check_dir = os.path.join(model_dir, 'preload')\n",
"\n",
"# Directory to save midi during training\n",
"sample_dir = os.path.join(model_dir, 'sample')\n",
"\n",
"# Directory to save samples generated during inference\n",
"eval_dir = os.path.join(model_dir, 'eval')\n",
"\n",
"os.makedirs(train_dir, exist_ok=True)\n",
"os.makedirs(eval_dir, exist_ok=True)\n",
"os.makedirs(sample_dir, exist_ok=True)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Preparation\n",
"\n",
"### Dataset summary\n",
"\n",
"In this tutorial, we use the [`JSB-Chorales-dataset`](http://www-etud.iro.umontreal.ca/~boulanni/icml2012), comprising 229 chorale snippets. A chorale is a hymn that is usually sung with a single voice playing a simple melody and three lower voices providing harmony. In this dataset, these voices are represented by four piano tracks.\n",
"\n",
"Let's listen to a song from this dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"display_utils.playmidi('./original_midi/MIDI-0.mid')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data format - piano roll"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For the purpose of this tutorial, we represent music from the JSB-Chorales dataset in the piano roll format.\n",
"\n",
"**Piano roll** is a discrete representation of music which is intelligible by many machine learning algorithms. Piano rolls can be viewed as a two-dimensional grid with \"Time\" on the horizontal axis and \"Pitch\" on the vertical axis. A one or zero in any particular cell in this grid indicates if a note was played or not at that time for that pitch.\n",
"\n",
"Let us look at a few piano rolls in our dataset. In this example, a single piano roll track has 32 discrete time steps and 128 pitches. We see four piano rolls here, each one representing a separate piano track in the song."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
\n",
"\n",
"You might notice this representation looks similar to an image. While the sequence of notes is often the natural way that people view music, many modern machine learning models instead treat music as images and leverage existing techniques within the computer vision domain. You will see such techniques used in our architecture later in this tutorial."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Why 32 time steps?**\n",
"\n",
"For the purpose of this tutorial, we sample two non-empty bars (https://en.wikipedia.org/wiki/Bar_(music)) from each song in the JSB-Chorales dataset. A **bar** (or **measure**) is a unit of composition and contains four beats for songs in our particular dataset (our songs are all in 4/4 time) :\n",
"\n",
"We’ve found that using a resolution of four time steps per beat captures enough of the musical detail in this dataset.\n",
"\n",
"This yields...\n",
"\n",
"$$ \\frac{4\\;timesteps}{1\\;beat} * \\frac{4\\;beats}{1\\;bar} * \\frac{2\\;bars}{1} = 32\\;timesteps $$\n",
"\n",
"Let us now load our dataset as a numpy array. Our dataset comprises 229 samples of 4 tracks (all tracks are piano). Each sample is a 32 time-step snippet of a song, so our dataset has a shape of...\n",
"(num_samples, time_steps, pitch_range, tracks) = (229, 32, 128, 4)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"training_data = np.load('./dataset/train.npy')\n",
"print(training_data.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see a sample of the data we'll feed into our model. The four graphs represent the four tracks."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"display_utils.show_pianoroll(training_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load data "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now create a Tensorflow dataset object from our numpy array to feed into our model. The dataset object helps us feed batches of data into our model. A batch is a subset of the data that is passed through the deep learning network before the weights are updated. Batching data is necessary in most training scenarios as our training environment might not be able to load the entire dataset into memory at once."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Number of input data samples in a batch\n",
"BATCH_SIZE = 64\n",
"\n",
"#Shuffle buffer size for shuffling data\n",
"SHUFFLE_BUFFER_SIZE = 1000\n",
"\n",
"#Preloads PREFETCH_SIZE batches so that there is no idle time between batches\n",
"PREFETCH_SIZE = 4"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def prepare_dataset(filename):\n",
" \n",
" \"\"\"Load the samples used for training.\"\"\"\n",
" \n",
" data = np.load(filename)\n",
" data = np.asarray(data, dtype=np.float32) # {-1, 1}\n",
"\n",
" print('data shape = {}'.format(data.shape))\n",
"\n",
" dataset = tf.data.Dataset.from_tensor_slices(data)\n",
" dataset = dataset.shuffle(SHUFFLE_BUFFER_SIZE).repeat()\n",
" dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)\n",
" dataset = dataset.prefetch(PREFETCH_SIZE)\n",
"\n",
" return dataset \n",
"\n",
"dataset = prepare_dataset('./dataset/train.npy')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model architecture\n",
"In this section, we will walk through the architecture of the proposed GAN.\n",
"\n",
"The model consists of two networks, a generator and a critic. These two networks work in a tight loop as following:\n",
"\n",
"* Generator:\n",
" 1. The generator takes in a batch of single-track piano rolls (melody) as the input and generates a batch of multi-track piano rolls as the output by adding accompaniments to each of the input music tracks. \n",
" 2. The critic then takes these generated music tracks and predicts how far it deviates from the real data present in your training dataset.\n",
" 3. This feedback from the critic is used by the generator to update its weights.\n",
"* Critic: As the generator gets better at creating better music accompaniments using the feedback from the critic, the critic needs to be retrained as well.\n",
" 1. Train the critic with the music tracks just generated by the generator as fake inputs and an equivalent number of songs from the original dataset as the real input. \n",
"* Alternate between training these two networks until the model converges and produces realistic music, beginning with the critic on the first iteration.\n",
"\n",
"We use a special type of GAN called the **Wasserstein GAN with Gradient Penalty** (or **WGAN-GP**) to generate music. While the underlying architecture of a WGAN-GP is very similar to vanilla variants of GAN, WGAN-GPs help overcome some of the commonly seen defects in GANs such as the vanishing gradient problem and mode collapse (see appendix for more details).\n",
"\n",
"Note our \"critic\" network is more generally called a \"discriminator\" network in the more general context of vanilla GANs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The generator is adapted from the U-Net architecture (a popular CNN that is used extensively in the computer vision domain), consisting of an “encoder” that maps the single track music data (represented as piano roll images) to a relatively lower dimensional “latent space“ and a ”decoder“ that maps the latent space back to multi-track music data.\n",
"\n",
"Here are the inputs provided to the generator:\n",
"\n",
"**Single-track piano roll input**: A single melody track of size (32, 128, 1) => (TimeStep, NumPitches, NumTracks) is provided as the input to the generator. \n",
"\n",
"**Latent noise vector**: A latent noise vector z of dimension (2, 8, 512) is also passed in as input and this is responsible for ensuring that there is a distinctive flavor to each output generated by the generator, even when the same input is provided.\n",
"\n",
"Notice from the figure below that the encoding layers of the generator on the left side and decoder layer on on the right side are connected to create a U-shape, thereby giving the name U-Net to this architecture."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this implementation, we build the generator following a simple four-level Unet architecture by combining `_conv2d`s and `_deconv2d`, where `_conv2d` compose the contracting path and `_deconv2d` forms the expansive path. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _conv2d(layer_input, filters, f_size=4, bn=True):\n",
" \"\"\"Generator Basic Downsampling Block\"\"\"\n",
" d = tf.keras.layers.Conv2D(filters, kernel_size=f_size, strides=2,\n",
" padding='same')(layer_input)\n",
" d = tf.keras.layers.LeakyReLU(alpha=0.2)(d)\n",
" if bn:\n",
" d = tf.keras.layers.BatchNormalization(momentum=0.8)(d)\n",
" return d\n",
"\n",
"\n",
"def _deconv2d(layer_input, pre_input, filters, f_size=4, dropout_rate=0):\n",
" \"\"\"Generator Basic Upsampling Block\"\"\"\n",
" u = tf.keras.layers.UpSampling2D(size=2)(layer_input)\n",
" u = tf.keras.layers.Conv2D(filters, kernel_size=f_size, strides=1,\n",
" padding='same')(u)\n",
" u = tf.keras.layers.BatchNormalization(momentum=0.8)(u)\n",
" u = tf.keras.layers.ReLU()(u)\n",
"\n",
" if dropout_rate:\n",
" u = tf.keras.layers.Dropout(dropout_rate)(u)\n",
" \n",
" u = tf.keras.layers.Concatenate()([u, pre_input])\n",
" return u\n",
"\n",
" \n",
"def build_generator(condition_input_shape=(32, 128, 1), filters=64,\n",
" instruments=4, latent_shape=(2, 8, 512)):\n",
" \"\"\"Buld Generator\"\"\"\n",
" c_input = tf.keras.layers.Input(shape=condition_input_shape)\n",
" z_input = tf.keras.layers.Input(shape=latent_shape)\n",
"\n",
" d1 = _conv2d(c_input, filters, bn=False)\n",
" d2 = _conv2d(d1, filters * 2)\n",
" d3 = _conv2d(d2, filters * 4)\n",
" d4 = _conv2d(d3, filters * 8)\n",
"\n",
" d4 = tf.keras.layers.Concatenate(axis=-1)([d4, z_input])\n",
"\n",
" u4 = _deconv2d(d4, d3, filters * 4)\n",
" u5 = _deconv2d(u4, d2, filters * 2)\n",
" u6 = _deconv2d(u5, d1, filters)\n",
"\n",
" u7 = tf.keras.layers.UpSampling2D(size=2)(u6)\n",
" output = tf.keras.layers.Conv2D(instruments, kernel_size=4, strides=1,\n",
" padding='same', activation='tanh')(u7) # 32, 128, 4\n",
"\n",
" generator = tf.keras.models.Model([c_input, z_input], output, name='Generator')\n",
"\n",
" return generator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us now dive into each layer of the generator to see the inputs/outputs at each layer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Models\n",
"generator = build_generator()\n",
"generator.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Critic (Discriminator)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The goal of the critic is to provide feedback to the generator about how realistic the generated piano rolls are, so that the generator can learn to produce more realistic data. The critic provides this feedback by outputting a scalar that represents how “real” or “fake” a piano roll is.\n",
"\n",
"Since the critic tries to classify data as “real” or “fake”, it is not very different from commonly used binary classifiers. We use a simple architecture for the critic, composed of four convolutional layers and a dense layer at the end."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def _build_critic_layer(layer_input, filters, f_size=4):\n",
" \"\"\"\n",
" This layer decreases the spatial resolution by 2:\n",
"\n",
" input: [batch_size, in_channels, H, W]\n",
" output: [batch_size, out_channels, H/2, W/2]\n",
" \"\"\"\n",
" d = tf.keras.layers.Conv2D(filters, kernel_size=f_size, strides=2,\n",
" padding='same')(layer_input)\n",
" # Critic does not use batch-norm\n",
" d = tf.keras.layers.LeakyReLU(alpha=0.2)(d) \n",
" return d\n",
"\n",
"\n",
"def build_critic(pianoroll_shape=(32, 128, 4), filters=64):\n",
" \"\"\"WGAN critic.\"\"\"\n",
" \n",
" condition_input_shape = (32,128,1)\n",
" groundtruth_pianoroll = tf.keras.layers.Input(shape=pianoroll_shape)\n",
" condition_input = tf.keras.layers.Input(shape=condition_input_shape)\n",
" combined_imgs = tf.keras.layers.Concatenate(axis=-1)([groundtruth_pianoroll, condition_input])\n",
"\n",
"\n",
" \n",
" d1 = _build_critic_layer(combined_imgs, filters)\n",
" d2 = _build_critic_layer(d1, filters * 2)\n",
" d3 = _build_critic_layer(d2, filters * 4)\n",
" d4 = _build_critic_layer(d3, filters * 8)\n",
"\n",
" x = tf.keras.layers.Flatten()(d4)\n",
" logit = tf.keras.layers.Dense(1)(x)\n",
"\n",
" critic = tf.keras.models.Model([groundtruth_pianoroll,condition_input], logit,\n",
" name='Critic')\n",
" \n",
"\n",
" return critic"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create the Discriminator\n",
"\n",
"critic = build_critic()\n",
"critic.summary() # View discriminator architecture."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training\n",
"\n",
"We train our models by searching for model parameters which optimize an objective function. For our WGAN-GP, we have special loss functions that we minimize as we alternate between training our generator and critic networks:\n",
"\n",
"*Generator Loss:*\n",
"* We use the Wasserstein (Generator) loss function which is negative of the Critic Loss function. The generator is trained to bring the generated pianoroll as close to the real pianoroll as possible.\n",
" * $\\frac{1}{m} \\sum_{i=1}^{m} -D_w(G(z^{i}|c^{i})|c^{i})$\n",
"\n",
"*Critic Loss:*\n",
"\n",
"* We begin with the Wasserstein (Critic) loss function designed to maximize the distance between the real piano roll distribution and generated (fake) piano roll distribution.\n",
" * $\\frac{1}{m} \\sum_{i=1}^{m} [D_w(G(z^{i}|c^{i})|c^{i}) - D_w(x^{i}|c^{i})]$\n",
"\n",
"* We add a gradient penalty loss function term designed to control how the gradient of the critic with respect to its input behaves. This makes optimization of the generator easier. \n",
" * $\\frac{1}{m} \\sum_{i=1}^{m}(\\lVert \\nabla_{\\hat{x}^i}D_w(\\hat{x}^i|c^{i}) \\rVert_2 - 1)^2 $"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define the different loss functions\n",
"\n",
"def generator_loss(critic_fake_output):\n",
" \"\"\" Wasserstein GAN loss\n",
" (Generator) -D(G(z|c))\n",
" \"\"\"\n",
" return -tf.reduce_mean(critic_fake_output)\n",
"\n",
"\n",
"def wasserstein_loss(critic_real_output, critic_fake_output):\n",
" \"\"\" Wasserstein GAN loss\n",
" (Critic) D(G(z|c)) - D(x|c)\n",
" \"\"\"\n",
" return tf.reduce_mean(critic_fake_output) - tf.reduce_mean(\n",
" critic_real_output)\n",
"\n",
"\n",
"def compute_gradient_penalty(critic, x, fake_x):\n",
" \n",
" c = tf.expand_dims(x[..., 0], -1)\n",
" batch_size = x.get_shape().as_list()[0]\n",
" eps_x = tf.random.uniform(\n",
" [batch_size] + [1] * (len(x.get_shape()) - 1)) # B, 1, 1, 1, 1\n",
" inter = eps_x * x + (1.0 - eps_x) * fake_x\n",
"\n",
" with tf.GradientTape() as g:\n",
" g.watch(inter)\n",
" disc_inter_output = critic((inter,c), training=True)\n",
" grads = g.gradient(disc_inter_output, inter)\n",
" slopes = tf.sqrt(1e-8 + tf.reduce_sum(\n",
" tf.square(grads),\n",
" reduction_indices=tf.range(1, grads.get_shape().ndims)))\n",
" gradient_penalty = tf.reduce_mean(tf.square(slopes - 1.0))\n",
" \n",
" return gradient_penalty\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With our loss functions defined, we associate them with Tensorflow optimizers to define how our model will search for a good set of model parameters. We use the *Adam* algorithm, a commonly used general-purpose optimizer. We also set up checkpoints to save our progress as we train."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Setup Adam optimizers for both G and D\n",
"generator_optimizer = tf.keras.optimizers.Adam(1e-3, beta_1=0.5, beta_2=0.9)\n",
"critic_optimizer = tf.keras.optimizers.Adam(1e-3, beta_1=0.5, beta_2=0.9)\n",
"\n",
"# We define our checkpoint directory and where to save trained checkpoints\n",
"ckpt = tf.train.Checkpoint(generator=generator,\n",
" generator_optimizer=generator_optimizer,\n",
" critic=critic,\n",
" critic_optimizer=critic_optimizer)\n",
"ckpt_manager = tf.train.CheckpointManager(ckpt, check_dir, max_to_keep=5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define the `generator_train_step` and `critic_train_step` functions, each of which performs a single forward pass on a batch and returns the corresponding loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@tf.function\n",
"def generator_train_step(x, condition_track_idx=0):\n",
"\n",
" ############################################\n",
" #(1) Update G network: maximize D(G(z|c))\n",
" ############################################\n",
"\n",
" # Extract condition track to make real batches pianoroll\n",
" c = tf.expand_dims(x[..., condition_track_idx], -1)\n",
"\n",
" # Generate batch of latent vectors\n",
" z = tf.random.truncated_normal([BATCH_SIZE, 2, 8, 512])\n",
"\n",
" with tf.GradientTape() as tape:\n",
" fake_x = generator((c, z), training=True)\n",
" fake_output = critic((fake_x,c), training=False)\n",
"\n",
" # Calculate Generator's loss based on this generated output\n",
" gen_loss = generator_loss(fake_output)\n",
"\n",
" # Calculate gradients for Generator\n",
" gradients_of_generator = tape.gradient(gen_loss,\n",
" generator.trainable_variables)\n",
" # Update Generator\n",
" generator_optimizer.apply_gradients(\n",
" zip(gradients_of_generator, generator.trainable_variables))\n",
"\n",
" return gen_loss\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@tf.function\n",
"def critic_train_step(x, condition_track_idx=0):\n",
"\n",
" ############################################################################\n",
" #(2) Update D network: maximize (D(x|c)) + (1 - D(G(z|c))|c) + GradientPenality() \n",
" ############################################################################\n",
"\n",
" # Extract condition track to make real batches pianoroll\n",
" c = tf.expand_dims(x[..., condition_track_idx], -1)\n",
"\n",
" # Generate batch of latent vectors\n",
" z = tf.random.truncated_normal([BATCH_SIZE, 2, 8, 512])\n",
"\n",
" # Generated fake pianoroll\n",
" fake_x = generator((c, z), training=False)\n",
"\n",
"\n",
" # Update critic parameters\n",
" with tf.GradientTape() as tape:\n",
" real_output = critic((x,c), training=True)\n",
" fake_output = critic((fake_x,c), training=True)\n",
" critic_loss = wasserstein_loss(real_output, fake_output)\n",
"\n",
" # Caculate the gradients from the real and fake batches\n",
" grads_of_critic = tape.gradient(critic_loss,\n",
" critic.trainable_variables)\n",
"\n",
" with tf.GradientTape() as tape:\n",
" gp_loss = compute_gradient_penalty(critic, x, fake_x)\n",
" gp_loss *= 10.0\n",
"\n",
" # Calculate the gradients penalty from the real and fake batches\n",
" grads_gp = tape.gradient(gp_loss, critic.trainable_variables)\n",
" gradients_of_critic = [g + ggp for g, ggp in\n",
" zip(grads_of_critic, grads_gp)\n",
" if ggp is not None]\n",
"\n",
" # Update Critic\n",
" critic_optimizer.apply_gradients(\n",
" zip(gradients_of_critic, critic.trainable_variables))\n",
"\n",
" return critic_loss + gp_loss\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before we begin training, let's define some training configuration parameters and prepare to monitor important quantities. Here we log the losses and metrics which we can use to determine when to stop training. Consider coming back here to tweak these parameters and explore how your model responds. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# We use load_melody_samples() to load 10 input data samples from our dataset into sample_x \n",
"# and 10 random noise latent vectors into sample_z\n",
"sample_x, sample_z = inference_utils.load_melody_samples(n_sample=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Number of iterations to train for\n",
"iterations = 1000\n",
"\n",
"# Update critic n times per generator update \n",
"n_dis_updates_per_gen_update = 5\n",
"\n",
"# Determine input track in sample_x that we condition on\n",
"condition_track_idx = 0 \n",
"sample_c = tf.expand_dims(sample_x[..., condition_track_idx], -1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us now train our model!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Clear out any old metrics we've collected\n",
"metrics_utils.metrics_manager.initialize()\n",
"\n",
"# Keep a running list of various quantities:\n",
"c_losses = []\n",
"g_losses = []\n",
"\n",
"# Data iterator to iterate over our dataset\n",
"it = iter(dataset)\n",
"\n",
"for iteration in range(iterations):\n",
"\n",
" # Train critic\n",
" for _ in range(n_dis_updates_per_gen_update):\n",
" c_loss = critic_train_step(next(it))\n",
"\n",
" # Train generator\n",
" g_loss = generator_train_step(next(it))\n",
"\n",
" # Save Losses for plotting later\n",
" c_losses.append(c_loss)\n",
" g_losses.append(g_loss)\n",
"\n",
" display.clear_output(wait=True)\n",
" fig = plt.figure(figsize=(15, 5))\n",
" line1, = plt.plot(range(iteration+1), c_losses, 'r')\n",
" line2, = plt.plot(range(iteration+1), g_losses, 'k')\n",
" plt.xlabel('Iterations')\n",
" plt.ylabel('Losses')\n",
" plt.legend((line1, line2), ('C-loss', 'G-loss'))\n",
" display.display(fig)\n",
" plt.close(fig)\n",
" \n",
" # Output training stats\n",
" print('Iteration {}, c_loss={:.2f}, g_loss={:.2f}'.format(iteration, c_loss, g_loss))\n",
" \n",
" # Save checkpoints, music metrics, generated output\n",
" if iteration < 100 or iteration % 50 == 0 :\n",
" # Check how the generator is doing by saving G's samples on fixed_noise\n",
" fake_sample_x = generator((sample_c, sample_z), training=False)\n",
" metrics_utils.metrics_manager.append_metrics_for_iteration(fake_sample_x.numpy(), iteration)\n",
"\n",
" if iteration % 50 == 0:\n",
" # Save the checkpoint to disk.\n",
" ckpt_manager.save(checkpoint_number=iteration) \n",
" \n",
" fake_sample_x = fake_sample_x.numpy()\n",
" \n",
" # plot the pianoroll\n",
" display_utils.plot_pianoroll(iteration, sample_x[:4], fake_sample_x[:4], save_dir=train_dir)\n",
"\n",
" # generate the midi\n",
" destination_path = path_utils.generated_midi_path_for_iteration(iteration, saveto_dir=sample_dir)\n",
" midi_utils.save_pianoroll_as_midi(fake_sample_x[:4], destination_path=destination_path)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### We have started training!\n",
"\n",
"When using the Wasserstein loss function, we should train the critic to converge to ensure that the gradients for the generator update are accurate. This is in contrast to a standard GAN, where it is important not to let the critic get too strong, to avoid vanishing gradients.\n",
"\n",
"Therefore, using the Wasserstein loss removes one of the key difficulties of training GANs—how to balance the training of the discriminator and generator. With WGANs, we can simply train the critic several times between generator updates, to ensure it is close to convergence. A typical ratio used is five critic updates to one generator update.\n",
"\n",
"### \"Babysitting\" the learning process\n",
"\n",
"Given that training these models can be an investment in time and resources, we must to continuously monitor training in order to catch and address anomalies if/when they occur. Here are some things to look out for:\n",
"\n",
"**What should the losses look like?**\n",
"\n",
"The adversarial learning process is highly dynamic and high-frequency oscillations are quite common. However if either loss (critic or generator) skyrockets to huge values, plunges to 0, or get stuck on a single value, there is likely an issue somewhere.\n",
"\n",
"**Is my model learning?**\n",
"- Monitor the critic loss and other music quality metrics (if applicable). Are they following the expected trajectories?\n",
"- Monitor the generated samples (piano rolls). Are they improving over time? Do you see evidence of mode collapse? Have you tried listening to your samples?\n",
"\n",
"**How do I know when to stop?**\n",
"- If the samples meet your expectations\n",
"- Critic loss no longer improving\n",
"- The expected value of the musical quality metrics converge to the corresponding expected value of the same metric on the training data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### How to measure sample quality during training \n",
"\n",
"Typically, when training any sort of neural networks, it is standard practice to monitor the value of the loss function throughout the duration of the training. The critic loss in WGANs has been found to correlate well with sample quality.\n",
"\n",
"While standard mechanisms exist for evaluating the accuracy of more traditional models like classifiers or regressors, evaluating generative models is an active area of research. Within the domain of music generation, this hard problem is even less well-understood.\n",
"\n",
"To address this, we take high-level measurements of our data and show how well our model produces music that aligns with those measurements. If our model produces music which is close to the mean value of these measurements for our training dataset, our music should match on general “shape”.\n",
"\n",
"We’ll look at three such measurements:\n",
"- **Empty bar rate:** The ratio of empty bars to total number of bars.\n",
"- **Pitch histogram distance:** A metric that captures the distribution and position of pitches.\n",
"- **In Scale Ratio:** Ratio of the number of notes that are in C major key, which is a common key found in music, to the total number of notes. \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate results\n",
"\n",
"Now that we have finished training, let's find out how we did. We will analyze our model in several ways:\n",
"1. Examine how the generator and critic losses changed while training\n",
"2. Understand how certain musical metrics changed while training\n",
"3. Visualize generated piano roll output for a fixed input at every iteration and create a video\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us first restore our last saved checkpoint. If you did not complete training but still want to continue with a pre-trained version, set `TRAIN = False`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ckpt = tf.train.Checkpoint(generator=generator)\n",
"ckpt_manager = tf.train.CheckpointManager(ckpt, check_dir, max_to_keep=5)\n",
"\n",
"ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()\n",
"print('Latest checkpoint {} restored.'.format(ckpt_manager.latest_checkpoint))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Plot losses"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"display_utils.plot_loss_logs(g_losses, c_losses, figsize=(15, 5), smoothing=0.01)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Observe how the critic loss (C_loss in the graph) decays to zero as we train. In WGAN-GPs, the critic loss decreases (almost) monotonically as you train."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Plot metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"metrics_utils.metrics_manager.set_reference_metrics(training_data)\n",
"metrics_utils.metrics_manager.plot_metrics()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each row here corresponds to a different music quality metric and each column denotes an instrument track. \n",
"\n",
"Observe how the expected value of the different metrics (blue scatter) approach the corresponding training set expected values (red) as the number of iterations increase. You might expect to see diminishing returns as the model converges.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generated samples during training\n",
"\n",
"The function below helps you probe intermediate samples generated in the training process. Remember that the conditioned input here is sampled from our training data. Let's start by listening to and observing a sample at iteration 0 and then iteration 100. Notice the difference!\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Enter an iteration number (can be divided by 50) and listen to the midi at that iteration\n",
"iteration = 50\n",
"midi_file = os.path.join(sample_dir, 'iteration-{}.mid'.format(iteration))\n",
"display_utils.playmidi(midi_file) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Enter an iteration number (can be divided by 50) and look at the generated pianorolls at that iteration\n",
"iteration = 50\n",
"pianoroll_png = os.path.join(train_dir, 'sample_iteration_%05d.png' % iteration)\n",
"display.Image(filename=pianoroll_png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see how the generated piano rolls change with the number of iterations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Video\n",
"\n",
"\n",
"display_utils.make_training_video(train_dir)\n",
"video_path = \"movie.mp4\"\n",
"Video(video_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inference "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generating accompaniment for custom input\n",
"\n",
"Congratulations! You have trained your very own WGAN-GP to generate music. Let us see how our generator performs on a custom input.\n",
"\n",
"The function below generates a new song based on \"Twinkle Twinkle Little Star\"."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"latest_midi = inference_utils.generate_midi(generator, eval_dir, input_midi_file='./input_twinkle_twinkle.mid')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"display_utils.playmidi(latest_midi)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also take a look at the generated piano rolls for a certain sample, to see how diverse they are!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"inference_utils.show_generated_pianorolls(generator, eval_dir, input_midi_file='./input_twinkle_twinkle.mid')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# What's next?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using your own data (Optional)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To create your own dataset you can extract the piano roll from MIDI data. An example of creating a piano roll from a MIDI file is given below"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from pypianoroll import Multitrack\n",
"\n",
"midi_data = Multitrack('./input_twinkle_twinkle.mid')\n",
"tracks = [track.pianoroll for track in midi_data.tracks]\n",
"sample = np.stack(tracks, axis=-1)\n",
"\n",
"print(sample.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Appendix"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Open source implementations\n",
"For more open-source implementations of generative models for music, check out:\n",
"\n",
"- [MuseGAN](https://github.com/salu133445/musegan): Official TensorFlow Implementation that uses GANs to generate multi track polyphonic music\n",
"- [GANSynth](https://github.com/tensorflow/magenta/tree/master/magenta/models/gansynth): GANSynth uses a Progressive GAN architecture to incrementally upsample with convolution from a single vector to the full audio spectrogram\n",
"- [Music Transformer](https://github.com/tensorflow/magenta/tree/master/magenta/models/score2perf): Uses transformers to generate music!\n",
"\n",
"GANs have also achieved state of the generative modeling in several other domains including cross domain image tranfer, celebrity face generation, super resolution text to image and image inpainting.\n",
"\n",
"- [Keras-GAN](https://github.com/eriklindernoren/Keras-GAN): Library of reference implementations in Keras for image generation(good for educational purposes).\n",
"\n",
"There's an ocean of literatures out there that use GANs for modeling distributions across fields! If you are interested, [Gan Zoo](https://github.com/hindupuravinash/the-gan-zoo) is a good place to start."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### References\n",
"\n",
"1. [Dong, H.W., Hsiao, W.Y., Yang, L.C. and Yang, Y.H., 2018, April. MuseGAN: Multi-track sequential generative adversarial networks for symbolic music generation and accompaniment. In Thirty-Second AAAI Conference on Artificial Intelligence.](https://arxiv.org/abs/1709.06298)\n",
"2. [Ishaan, G., Faruk, A., Martin, A., Vincent, D. and Aaron, C., 2017. Improved training of wasserstein gans. In Advances in Neural Information Processing Systems.](https://arxiv.org/abs/1704.00028)\n",
"3. [Arjovsky, M., Chintala, S. and Bottou, L., 2017. Wasserstein gan. arXiv preprint arXiv:1701.07875.](https://arxiv.org/abs/1701.07875)\n",
"4. [Foster, D., 2019. Generative Deep Learning: Teaching Machines to Paint, Write, Compose, and Play. O'Reilly Media.](https://www.amazon.com/Generative-Deep-Learning-Teaching-Machines/dp/1492041947)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### More on Wassertein GAN with Gradient Penalty (optional)\n",
"\n",
"While GANs are a major breakthrough for generative modeling, plain GANs are also notoriously difficult to train. Some common problems encountered are:\n",
"\n",
"* **Oscillating loss:** The loss of the discriminator and generator can start to oscillate without exhibiting any long term stability.\n",
"* **Mode collapse:** The generator may get stuck on a small set of samples that always fool the discriminator. This reduces the capability of the network to produce novel samples.\n",
"* **Uninformative loss:** The lack of correlation between the generator loss and quality of generated output makes plain GAN training difficult to interpret.\n",
"\n",
"\n",
"The [Wasserstein GAN](#references) was a major advancement in GANs and helped mitigate to some of these issues. Some of its features are:\n",
"\n",
"1. It significantly improves the interpretability of loss functions and provides clearer stopping criteria\n",
"2. WGANs generally produce results of higher quality (demonstrated within the image generation domain)\n",
"\n",
"**Mathematics of Wasserstein GAN with Gradient Penalty**\n",
"\n",
"The [Wasserstein distance](https://en.wikipedia.org/wiki/Wasserstein_metric) between the true distribution $P_r$ and generated piano roll distribution $P_g$ is defined as follows:\n",
"\n",
"$$\\mathbb{W}(P_{r},P_{g})=\\sup_{\\lVert{f} \\rVert_{L} \\le 1} \\mathbb{E}_{x \\sim \\mathbb{P}_r}(f(x)) - \\mathbb{E}_{x \\sim \\mathbb{P}_g}(f(x)) $$\n",
"\n",
"In this equation we are trying to minimize the distance between the expectation of the real distribution and the expectation of the generation distribution. $f$ is subject to a technical constraint in that it must be [1-Lipschitz](https://en.wikipedia.org/wiki/Lipschitz_continuity).\n",
"\n",
"To enforce the 1-Lipschitz condition that basically constraints the gradients from varying too rapidly we use the gradient penalty.\n",
"\n",
"**Gradient penalty**: We want to penalize the gradients of the critic. We implicitly define $P_{\\hat{x}}$ by sampling uniformly along straight lines between pairs of points sampled from the data distribution $P_r$ and the generator distribution $P_g$. This was originally motivated by the fact that the optimal critic contains straight lines with gradient norm 1 connecting coupled points from $P_r$ and $P_g$. We use a penalty coefficient $\\lambda$= 10 as was recommended in the original paper. \n",
"\n",
"The loss with gradient penalty is:\n",
"\n",
"$$\\mathbb{L}(P_{r},P_{g},P_{\\hat{x}} )= \\mathbb{W}(P_{r},P_{g}) + \\lambda \\mathbb{E}_{\\hat{x} \\sim \\mathbb{P}_\\hat{x}}[(\\lVert \\nabla_{\\hat{x}}D(\\hat{x}) \\rVert_2 - 1)^2]$$\n",
"|\n",
"This loss can be parametrized in terms of $w$ and $\\theta$. We then use neural networks to learn the functions $f_w$ (discriminator) and $g_\\theta$ (generator).\n",
"$$\\mathbb{W}(P_{r},P_{\\theta})=\\max_{w \\in \\mathbb{W}} \\mathbb{E}_{x \\sim \\mathbb{P}_r}(D_w(x)) - \\mathbb{E}_{z \\sim p(z)}(D_w(G_{\\theta}(z)) $$\n",
"$$\\mathbb{L}(P_{r},P_{\\theta},P_{\\hat{x}})=\\max_{w \\in \\mathbb{W}} \\mathbb{E}_{x \\sim \\mathbb{P}_r}(D_w(x)) - \\mathbb{E}_{z \\sim p(z)}(D_w(G_{\\theta}(z)) + \\lambda \\mathbb{E}_{\\hat{x} \\sim \\mathbb{P}_\\hat{x}}[(\\lVert \\nabla_{\\hat{x}}D_w(\\hat{x}) \\rVert_2 - 1)^2]$$\n",
"\n",
"where $$ \\hat{x} = \\epsilon x + (1- \\epsilon) G(z) $$ and $$\\epsilon \\sim Unif(0,1)$$\n",
"\n",
"The basic procedure to train is as following: \n",
"1. We draw real_x from the real distribution $P_r$ and fake_x from the generated distribution $G_{\\theta}(z)$ where $z \\sim p(z)$\n",
"2. The latent vectors are sampled from z and then tranformed using the generator $G_{\\theta}$ to get the fake samples fake_x. They are evaluated using the critic function $D_w$\n",
"3. We are trying to minimize the Wasserstein distance between the two distributions\n",
"\n",
"Both the generator and critic are conditioned on the input pianoroll melody."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}