# Zero-Shot Mutation Analysis of LMs

## Installs

In [None]:
!pip install transformers
!pip install biopython
!pip install pytorch-lightning

## Imports

In [None]:
import os
import shutil
import re
import pickle as pkl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm
tqdm.pandas()

sns.set_theme(style='ticks')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

## Prep DeepSequence Data

Reference: https://github.com/debbiemarkslab/DeepSequence

In [None]:
from glob import glob

mutations_path = "/home/ec2-user/SageMaker/efs/brandry/DeepSequence/examples/mutations/"
alignments_path = "/home/ec2-user/SageMaker/efs/brandry/DeepSequence/examples/alignments/"

mutations_files = glob(os.path.join(mutations_path, "*"))
alignments_files = glob(os.path.join(alignments_path, "*"))

In [None]:
mutations_base = [os.path.basename(m) for m in mutations_files]
alignments_base = [os.path.basename(a) for a in alignments_files]

In [None]:
genes = ["_".join(m.split("_")[0:2]) for m in mutations_base]
genes

In [None]:
alignments = [x for x in alignments_base if x.startswith(tuple(genes))]
alignments

In [None]:
ALIGN = 'BLAT_ECOLX_1_b0.5.a2m'
MUT = 'BLAT_ECOLX_Ranganathan2015.csv'

In [None]:
from Bio import SeqIO

FASTA = os.path.join(alignments_path, ALIGN)

with open(FASTA, "r") as handle:
 records = list(SeqIO.parse(handle, "fasta"))

In [None]:
offset = int(str(records[0].id).split("/")[1].split("-")[0])
offset

In [None]:
seqs = [str(record.seq) for record in records]

In [None]:
wt_seq = seqs[0].upper()
wt_seq

In [None]:
df_mt = pd.read_csv(os.path.join(mutations_path, MUT))
df_mt.head()

In [None]:
def generate_mt_seq(mutant):
 """Generate full mutated sequences from annotated AAs."""
 pos = int(mutant[1:-1])
 old_aa = mutant[0]
 new_aa = mutant[-1]
 assert old_aa == wt_seq[pos-offset]
 return wt_seq[:(pos-offset)] + new_aa + wt_seq[(pos-offset+1):], (old_aa, pos-offset, new_aa)

In [None]:
df_mt["aligned_primary"], df_mt["mutations"] = zip(*df_mt.mutant.progress_apply(generate_mt_seq))
df_mt.mutations = [[x] for x in df_mt.mutations]
df_mt.head()

## Prep Fluorescence Data

Reference: https://github.com/songlab-cal/tape

In [None]:
df = pd.read_csv("/home/ec2-user/SageMaker/efs/paper_data/Flurescence/fluorescence_full.csv", index_col=0)
df.head()

In [None]:
df.protein_length.value_counts()

In [None]:
# Fetch Wildtype AA sequence
df_wt = df[df.num_mutations == 0]
wt_seq = df_wt.primary.tolist()[0]
wt_seq

### Align short sequences to Wildtype

In [None]:
from Bio import pairwise2

def align(mt_seq):
 """Align variable-length mutant sequences to wildtype without introducing wildtype gaps."""
 if len(mt_seq) == len(wt_seq):
 return mt_seq
 else:
 alignments = pairwise2.align.globalxd(wt_seq, mt_seq, -10, -1, -1, -.1, gap_char="X")
 result = alignments[0]
 assert result.seqA == wt_seq, "Bad alignment."
 assert len(result.seqB) == len(wt_seq), "Bad alignment."
 return result.seqB

In [None]:
df["aligned_primary"] = df.primary.apply(align)

### Catalog mutations

In [None]:
# Extract mutants to separate dataframe
df_mt = df[df.primary != wt_seq]
df_mt.head()

In [None]:
def find_mutations(mt_seq):
 """Extract mutation annotations from mutant sequence."""
 mts = []
 for i in range(len(mt_seq)):
 if wt_seq[i] != mt_seq[i]:
 mts.append((wt_seq[i], i, mt_seq[i]))
 return mts

In [None]:
df_mt["mutations"] = df_mt.aligned_primary.progress_apply(find_mutations)
df_mt.head()

## Zero-Shot inference

Use forward passes through BERT encoder to compute masked marginal probability of mutated sequence relative to wildtype

In [None]:
import torch
import torch.nn as nn
from torch.nn.functional import softmax, log_softmax

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('Device:', device)

In [None]:
from collections import OrderedDict

def rename_state_dict_keys(state_dict, key_transformation):
 """Utility function for remapping keys from PyTorch lightning."""
 new_state_dict = OrderedDict()

 for key, value in state_dict.items():
 new_key = key_transformation(key)
 new_state_dict[new_key] = value
 
 return new_state_dict

### Download trained model objects (if necessary)

Seq-only:

* GO-BP: s3://janssen-mlsl-dev-data/zichen-dev-data/GO_bp_BERT/lightning_logs/version_0/checkpoints/epoch=76-step=143912.ckpt
* GO-MF: s3://janssen-mlsl-dev-data/zichen-dev-data/GO_mf_BERT/lightning_logs/version_0/checkpoints/epoch=99-step=186899.ckpt
* GO-CC: s3://janssen-mlsl-dev-data/zichen-dev-data/GO_cc_BERT/lightning_logs/version_0/checkpoints/epoch=46-step=87842.ckpt

Seq+Structure (combo):

* GO-BP model: s3://janssen-mlsl-dev-data/zichen-dev-data/GO_bp_BERT_GVP_tf/lightning_logs/version_2/checkpoints/epoch=3-step=7475.ckpt
* GO-MF model: s3://janssen-mlsl-dev-data/zichen-dev-data/GO_mf_BERT_GVP_tf/lightning_logs/version_3/checkpoints/epoch=0-step=1868.ckpt 
* GO-CC model: s3://janssen-mlsl-dev-data/zichen-dev-data/GO_cc_BERT_GVP_tf/lightning_logs/version_2/checkpoints/epoch=0-step=1868.ckpt

In [None]:
!mkdir -p final-models
!aws s3 cp s3://janssen-mlsl-dev-data/zichen-dev-data/GO_bp_BERT/lightning_logs/version_0/checkpoints/epoch=76-step=143912.ckpt final-models/go-bp-seq.pkl
!aws s3 cp s3://janssen-mlsl-dev-data/zichen-dev-data/GO_mf_BERT/lightning_logs/version_0/checkpoints/epoch=99-step=186899.ckpt final-models/go-mf-seq.pkl
!aws s3 cp s3://janssen-mlsl-dev-data/zichen-dev-data/GO_cc_BERT/lightning_logs/version_0/checkpoints/epoch=46-step=87842.ckpt final-models/go-cc-seq.pkl
!aws s3 cp s3://janssen-mlsl-dev-data/zichen-dev-data/GO_bp_BERT_GVP_tf/lightning_logs/version_2/checkpoints/epoch=3-step=7475.ckpt final-models/go-bp-combo.pkl
!aws s3 cp s3://janssen-mlsl-dev-data/zichen-dev-data/GO_mf_BERT_GVP_tf/lightning_logs/version_3/checkpoints/epoch=0-step=1868.ckpt final-models/go-mf-combo.pkl
!aws s3 cp s3://janssen-mlsl-dev-data/zichen-dev-data/GO_cc_BERT_GVP_tf/lightning_logs/version_2/checkpoints/epoch=0-step=1868.ckpt final-models/go-cc-combo.pkl

### Load model onto device for analysis

In [None]:
# Specify model to use by name
which_model = "go-cc-combo"
valid_models = ["go-cc-seq", "go-cc-combo", "go-mf-seq", "go-mf-combo", "go-bp-seq", "go-bp-combo"]
assert which_model in valid_models, "Invalid value of 'which_model'."

In [None]:
from transformers import BertForMaskedLM, BertTokenizer, pipeline

tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert")
vocab = tokenizer.get_vocab()

# Use pretrained weights from PyTorch lightning tuning
model_path = f"final-models/{which_model}.pkl"
state_dict = torch.load(model_path, map_location='cpu')
renamed_state_dict = rename_state_dict_keys(state_dict['state_dict'], lambda key: key.replace("bert_model.", ""))
model.bert.load_state_dict(renamed_state_dict, strict=False)

model.to(device)
model.eval()

### Compute masked marginal probability score of mutated sequences

In [None]:
def compute_masked_marginal_score(mutations, wt_seq):
 """Compute the masked marginal probability score of a set of mutations relative to the wildtype sequence."""
 seq_spaced = " ".join(wt_seq)
 seq_input = re.sub(r"[UZOB]", "X", seq_spaced)
 aa = seq_input.split()
 for mutation in mutations:
 aa[mutation[1]] = "[MASK]"
 seq_input = " ".join(aa)
 inputs = tokenizer(seq_input, return_tensors='pt')
 outputs = model(inputs['input_ids'].to(device))
 logits = outputs.logits[:, 1:-1, :]
 log_probs = log_softmax(logits, dim=2)

 total_log_p = 0.
 for mutation in mutations:
 log_p = log_probs[0, mutation[1], :].detach().cpu().numpy()
 log_p_wt = log_p[vocab[mutation[0]]]
 log_p_mt = log_p[vocab[mutation[2]]]
 total_log_p += (log_p_mt - log_p_wt)
 return total_log_p

In [None]:
from functools import partial

func = partial(compute_masked_marginal_score, wt_seq=wt_seq)
df_mt["masked_marginal_score"] = df_mt.mutations.progress_apply(func)

In [None]:
from scipy.stats import spearmanr

# target_cols = ['log'] # PABP
# target_cols = ['CRIPT', 'Tm2F'] # DLG4
target_cols = ['km', 'vmax'] # BLAT
# target_cols = ['log_fluorescence'] # GFP
for c in target_cols:
 y_true = df_mt[c].values
 y_pred = df_mt.masked_marginal_score.values

 rho = spearmanr(y_true, y_pred).correlation
 print(f"Spearman's Rho ({c}) =", rho)

In [None]:
# Write result to disk
output_path = "/home/ec2-user/SageMaker/efs/brandry/ZeroShot"

base = MUT.split(".")[0]
fname = base + f"_masked_marginal_{which_model}.csv"

df_mt.to_csv(os.path.join(output_path, fname), index=False)