{ "cells": [ { "cell_type": "markdown", "id": "29267299", "metadata": {}, "source": [ "# Zero-Shot Mutation Analysis of LMs" ] }, { "cell_type": "markdown", "id": "452f132e", "metadata": {}, "source": [ "## Installs" ] }, { "cell_type": "code", "execution_count": null, "id": "1ac62f30", "metadata": {}, "outputs": [], "source": [ "!pip install transformers\n", "!pip install biopython\n", "!pip install pytorch-lightning" ] }, { "cell_type": "markdown", "id": "00e1119e", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "ee774a7c", "metadata": {}, "outputs": [], "source": [ "import os\n", "import shutil\n", "import re\n", "import pickle as pkl\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from tqdm import tqdm\n", "tqdm.pandas()\n", "\n", "sns.set_theme(style='ticks')\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'retina'" ] }, { "cell_type": "markdown", "id": "6603ff25", "metadata": {}, "source": [ "## Prep DeepSequence Data" ] }, { "cell_type": "markdown", "id": "cb297700", "metadata": {}, "source": [ "Reference: https://github.com/debbiemarkslab/DeepSequence" ] }, { "cell_type": "code", "execution_count": null, "id": "a153c382", "metadata": {}, "outputs": [], "source": [ "from glob import glob\n", "\n", "mutations_path = \"/home/ec2-user/SageMaker/efs/brandry/DeepSequence/examples/mutations/\"\n", "alignments_path = \"/home/ec2-user/SageMaker/efs/brandry/DeepSequence/examples/alignments/\"\n", "\n", "mutations_files = glob(os.path.join(mutations_path, \"*\"))\n", "alignments_files = glob(os.path.join(alignments_path, \"*\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "77e95b77", "metadata": {}, "outputs": [], "source": [ "mutations_base = [os.path.basename(m) for m in mutations_files]\n", "alignments_base = [os.path.basename(a) for a in alignments_files]" ] }, { "cell_type": "code", "execution_count": null, "id": "09382bfa", "metadata": {}, "outputs": [], "source": [ "genes = [\"_\".join(m.split(\"_\")[0:2]) for m in mutations_base]\n", "genes" ] }, { "cell_type": "code", "execution_count": null, "id": "70f449b2", "metadata": {}, "outputs": [], "source": [ "alignments = [x for x in alignments_base if x.startswith(tuple(genes))]\n", "alignments" ] }, { "cell_type": "code", "execution_count": null, "id": "fd48af76", "metadata": {}, "outputs": [], "source": [ "ALIGN = 'BLAT_ECOLX_1_b0.5.a2m'\n", "MUT = 'BLAT_ECOLX_Ranganathan2015.csv'" ] }, { "cell_type": "code", "execution_count": null, "id": "a8160662", "metadata": {}, "outputs": [], "source": [ "from Bio import SeqIO\n", "\n", "FASTA = os.path.join(alignments_path, ALIGN)\n", "\n", "with open(FASTA, \"r\") as handle:\n", " records = list(SeqIO.parse(handle, \"fasta\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "b893b473", "metadata": {}, "outputs": [], "source": [ "offset = int(str(records[0].id).split(\"/\")[1].split(\"-\")[0])\n", "offset" ] }, { "cell_type": "code", "execution_count": null, "id": "6eab328b", "metadata": {}, "outputs": [], "source": [ "seqs = [str(record.seq) for record in records]" ] }, { "cell_type": "code", "execution_count": null, "id": "b06b92ab", "metadata": {}, "outputs": [], "source": [ "wt_seq = seqs[0].upper()\n", "wt_seq" ] }, { "cell_type": "code", "execution_count": null, "id": "be03ae0d", "metadata": {}, "outputs": [], "source": [ "df_mt = pd.read_csv(os.path.join(mutations_path, MUT))\n", "df_mt.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "e28556ab", "metadata": {}, "outputs": [], "source": [ "def generate_mt_seq(mutant):\n", " \"\"\"Generate full mutated sequences from annotated AAs.\"\"\"\n", " pos = int(mutant[1:-1])\n", " old_aa = mutant[0]\n", " new_aa = mutant[-1]\n", " assert old_aa == wt_seq[pos-offset]\n", " return wt_seq[:(pos-offset)] + new_aa + wt_seq[(pos-offset+1):], (old_aa, pos-offset, new_aa)" ] }, { "cell_type": "code", "execution_count": null, "id": "8e79c618", "metadata": {}, "outputs": [], "source": [ "df_mt[\"aligned_primary\"], df_mt[\"mutations\"] = zip(*df_mt.mutant.progress_apply(generate_mt_seq))\n", "df_mt.mutations = [[x] for x in df_mt.mutations]\n", "df_mt.head()" ] }, { "cell_type": "markdown", "id": "bcce1834", "metadata": {}, "source": [ "## Prep Fluorescence Data" ] }, { "cell_type": "markdown", "id": "3847d7df", "metadata": {}, "source": [ "Reference: https://github.com/songlab-cal/tape" ] }, { "cell_type": "code", "execution_count": null, "id": "cb3cb564", "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv(\"/home/ec2-user/SageMaker/efs/paper_data/Flurescence/fluorescence_full.csv\", index_col=0)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "9c9b33dc", "metadata": {}, "outputs": [], "source": [ "df.protein_length.value_counts()" ] }, { "cell_type": "code", "execution_count": null, "id": "361b4780", "metadata": {}, "outputs": [], "source": [ "# Fetch Wildtype AA sequence\n", "df_wt = df[df.num_mutations == 0]\n", "wt_seq = df_wt.primary.tolist()[0]\n", "wt_seq" ] }, { "cell_type": "markdown", "id": "483a1706", "metadata": {}, "source": [ "### Align short sequences to Wildtype" ] }, { "cell_type": "code", "execution_count": null, "id": "6a560d61", "metadata": {}, "outputs": [], "source": [ "from Bio import pairwise2\n", "\n", "def align(mt_seq):\n", " \"\"\"Align variable-length mutant sequences to wildtype without introducing wildtype gaps.\"\"\"\n", " if len(mt_seq) == len(wt_seq):\n", " return mt_seq\n", " else:\n", " alignments = pairwise2.align.globalxd(wt_seq, mt_seq, -10, -1, -1, -.1, gap_char=\"X\")\n", " result = alignments[0]\n", " assert result.seqA == wt_seq, \"Bad alignment.\"\n", " assert len(result.seqB) == len(wt_seq), \"Bad alignment.\"\n", " return result.seqB" ] }, { "cell_type": "code", "execution_count": null, "id": "31798592", "metadata": {}, "outputs": [], "source": [ "df[\"aligned_primary\"] = df.primary.apply(align)" ] }, { "cell_type": "markdown", "id": "f40f016c", "metadata": {}, "source": [ "### Catalog mutations" ] }, { "cell_type": "code", "execution_count": null, "id": "45ef46a5", "metadata": {}, "outputs": [], "source": [ "# Extract mutants to separate dataframe\n", "df_mt = df[df.primary != wt_seq]\n", "df_mt.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "e831e6c2", "metadata": {}, "outputs": [], "source": [ "def find_mutations(mt_seq):\n", " \"\"\"Extract mutation annotations from mutant sequence.\"\"\"\n", " mts = []\n", " for i in range(len(mt_seq)):\n", " if wt_seq[i] != mt_seq[i]:\n", " mts.append((wt_seq[i], i, mt_seq[i]))\n", " return mts" ] }, { "cell_type": "code", "execution_count": null, "id": "4ad02040", "metadata": {}, "outputs": [], "source": [ "df_mt[\"mutations\"] = df_mt.aligned_primary.progress_apply(find_mutations)\n", "df_mt.head()" ] }, { "cell_type": "markdown", "id": "79d0e72d", "metadata": {}, "source": [ "## Zero-Shot inference" ] }, { "cell_type": "markdown", "id": "86ff6e50", "metadata": {}, "source": [ "Use forward passes through BERT encoder to compute masked marginal probability of mutated sequence relative to wildtype" ] }, { "cell_type": "code", "execution_count": null, "id": "3ff2103c", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.nn.functional import softmax, log_softmax\n", "\n", "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", "print('Device:', device)" ] }, { "cell_type": "code", "execution_count": null, "id": "69b974ac", "metadata": {}, "outputs": [], "source": [ "from collections import OrderedDict\n", "\n", "def rename_state_dict_keys(state_dict, key_transformation):\n", " \"\"\"Utility function for remapping keys from PyTorch lightning.\"\"\"\n", " new_state_dict = OrderedDict()\n", "\n", " for key, value in state_dict.items():\n", " new_key = key_transformation(key)\n", " new_state_dict[new_key] = value\n", " \n", " return new_state_dict" ] }, { "cell_type": "markdown", "id": "ea410261", "metadata": {}, "source": [ "### Download trained model objects (if necessary)" ] }, { "cell_type": "markdown", "id": "c3170794", "metadata": {}, "source": [ "Seq-only:\n", "\n", "* GO-BP: s3://janssen-mlsl-dev-data/zichen-dev-data/GO_bp_BERT/lightning_logs/version_0/checkpoints/epoch=76-step=143912.ckpt\n", "* GO-MF: s3://janssen-mlsl-dev-data/zichen-dev-data/GO_mf_BERT/lightning_logs/version_0/checkpoints/epoch=99-step=186899.ckpt\n", "* GO-CC: s3://janssen-mlsl-dev-data/zichen-dev-data/GO_cc_BERT/lightning_logs/version_0/checkpoints/epoch=46-step=87842.ckpt\n", "\n", "Seq+Structure (combo):\n", "\n", "* GO-BP model: s3://janssen-mlsl-dev-data/zichen-dev-data/GO_bp_BERT_GVP_tf/lightning_logs/version_2/checkpoints/epoch=3-step=7475.ckpt\n", "* GO-MF model: s3://janssen-mlsl-dev-data/zichen-dev-data/GO_mf_BERT_GVP_tf/lightning_logs/version_3/checkpoints/epoch=0-step=1868.ckpt \n", "* GO-CC model: s3://janssen-mlsl-dev-data/zichen-dev-data/GO_cc_BERT_GVP_tf/lightning_logs/version_2/checkpoints/epoch=0-step=1868.ckpt" ] }, { "cell_type": "code", "execution_count": null, "id": "daad3b87", "metadata": {}, "outputs": [], "source": [ "!mkdir -p final-models\n", "!aws s3 cp s3://janssen-mlsl-dev-data/zichen-dev-data/GO_bp_BERT/lightning_logs/version_0/checkpoints/epoch=76-step=143912.ckpt final-models/go-bp-seq.pkl\n", "!aws s3 cp s3://janssen-mlsl-dev-data/zichen-dev-data/GO_mf_BERT/lightning_logs/version_0/checkpoints/epoch=99-step=186899.ckpt final-models/go-mf-seq.pkl\n", "!aws s3 cp s3://janssen-mlsl-dev-data/zichen-dev-data/GO_cc_BERT/lightning_logs/version_0/checkpoints/epoch=46-step=87842.ckpt final-models/go-cc-seq.pkl\n", "!aws s3 cp s3://janssen-mlsl-dev-data/zichen-dev-data/GO_bp_BERT_GVP_tf/lightning_logs/version_2/checkpoints/epoch=3-step=7475.ckpt final-models/go-bp-combo.pkl\n", "!aws s3 cp s3://janssen-mlsl-dev-data/zichen-dev-data/GO_mf_BERT_GVP_tf/lightning_logs/version_3/checkpoints/epoch=0-step=1868.ckpt final-models/go-mf-combo.pkl\n", "!aws s3 cp s3://janssen-mlsl-dev-data/zichen-dev-data/GO_cc_BERT_GVP_tf/lightning_logs/version_2/checkpoints/epoch=0-step=1868.ckpt final-models/go-cc-combo.pkl" ] }, { "cell_type": "markdown", "id": "de52fad3", "metadata": {}, "source": [ "### Load model onto device for analysis" ] }, { "cell_type": "code", "execution_count": null, "id": "47fc61b1", "metadata": {}, "outputs": [], "source": [ "# Specify model to use by name\n", "which_model = \"go-cc-combo\"\n", "valid_models = [\"go-cc-seq\", \"go-cc-combo\", \"go-mf-seq\", \"go-mf-combo\", \"go-bp-seq\", \"go-bp-combo\"]\n", "assert which_model in valid_models, \"Invalid value of 'which_model'.\"" ] }, { "cell_type": "code", "execution_count": null, "id": "2d447a8d", "metadata": { "scrolled": true }, "outputs": [], "source": [ "from transformers import BertForMaskedLM, BertTokenizer, pipeline\n", "\n", "tokenizer = BertTokenizer.from_pretrained(\"Rostlab/prot_bert\", do_lower_case=False)\n", "model = BertForMaskedLM.from_pretrained(\"Rostlab/prot_bert\")\n", "vocab = tokenizer.get_vocab()\n", "\n", "# Use pretrained weights from PyTorch lightning tuning\n", "model_path = f\"final-models/{which_model}.pkl\"\n", "state_dict = torch.load(model_path, map_location='cpu')\n", "renamed_state_dict = rename_state_dict_keys(state_dict['state_dict'], lambda key: key.replace(\"bert_model.\", \"\"))\n", "model.bert.load_state_dict(renamed_state_dict, strict=False)\n", "\n", "model.to(device)\n", "model.eval()" ] }, { "cell_type": "markdown", "id": "4beffb97", "metadata": {}, "source": [ "### Compute masked marginal probability score of mutated sequences" ] }, { "cell_type": "code", "execution_count": null, "id": "93e89597", "metadata": {}, "outputs": [], "source": [ "def compute_masked_marginal_score(mutations, wt_seq):\n", " \"\"\"Compute the masked marginal probability score of a set of mutations relative to the wildtype sequence.\"\"\"\n", " seq_spaced = \" \".join(wt_seq)\n", " seq_input = re.sub(r\"[UZOB]\", \"X\", seq_spaced)\n", " aa = seq_input.split()\n", " for mutation in mutations:\n", " aa[mutation[1]] = \"[MASK]\"\n", " seq_input = \" \".join(aa)\n", " inputs = tokenizer(seq_input, return_tensors='pt')\n", " outputs = model(inputs['input_ids'].to(device))\n", " logits = outputs.logits[:, 1:-1, :]\n", " log_probs = log_softmax(logits, dim=2)\n", "\n", " total_log_p = 0.\n", " for mutation in mutations:\n", " log_p = log_probs[0, mutation[1], :].detach().cpu().numpy()\n", " log_p_wt = log_p[vocab[mutation[0]]]\n", " log_p_mt = log_p[vocab[mutation[2]]]\n", " total_log_p += (log_p_mt - log_p_wt)\n", " return total_log_p" ] }, { "cell_type": "code", "execution_count": null, "id": "92fe2e1b", "metadata": {}, "outputs": [], "source": [ "from functools import partial\n", "\n", "func = partial(compute_masked_marginal_score, wt_seq=wt_seq)\n", "df_mt[\"masked_marginal_score\"] = df_mt.mutations.progress_apply(func)" ] }, { "cell_type": "code", "execution_count": null, "id": "71f241b4", "metadata": {}, "outputs": [], "source": [ "from scipy.stats import spearmanr\n", "\n", "# target_cols = ['log'] # PABP\n", "# target_cols = ['CRIPT', 'Tm2F'] # DLG4\n", "target_cols = ['km', 'vmax'] # BLAT\n", "# target_cols = ['log_fluorescence'] # GFP\n", "for c in target_cols:\n", " y_true = df_mt[c].values\n", " y_pred = df_mt.masked_marginal_score.values\n", "\n", " rho = spearmanr(y_true, y_pred).correlation\n", " print(f\"Spearman's Rho ({c}) =\", rho)" ] }, { "cell_type": "code", "execution_count": null, "id": "844f1668", "metadata": {}, "outputs": [], "source": [ "# Write result to disk\n", "output_path = \"/home/ec2-user/SageMaker/efs/brandry/ZeroShot\"\n", "\n", "base = MUT.split(\".\")[0]\n", "fname = base + f\"_masked_marginal_{which_model}.csv\"\n", "\n", "df_mt.to_csv(os.path.join(output_path, fname), index=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "19587596", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_latest_p36", "language": "python", "name": "conda_pytorch_latest_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 5 }