In [None]:
!pip install bio
!pip install captum
!pip install umap-learn
!pip install pytorch-lightning

In [1]:
import sys
sys.path.append('../')

## Load data and model

In [2]:
from torch_geometric.data import DataLoader
from lmgvp import data_loaders

dataset = data_loaders.get_dataset(
    'mf', 'seq_struct', split="test"
)

loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=False,
    num_workers=32,
)

Loading BertTokenizer...


  cpuset_checked))


Get mapping from protein names to data index:

In [3]:
protein_name_indices = {d[0].name:i for i, d in enumerate(dataset)}

Build model and load pretained weights from checkpoint:

In [None]:
from lmgvp.modules import BertMQAModel
import torch

checkpoint_path = "../../data/epoch=0-step=1868.ckpt"

model = BertMQAModel.load_from_checkpoint(
    checkpoint_path,
    weights=dataset.pos_weights,
)

device = torch.device("cuda")
model = model.to(device)
model.eval()

## Get latent activation and prediction results

We use PyTorch forward hook here to extract the latent activations in the pen-ultimate layer for cluster analysis. The prediction results are also collected along the way.

In [5]:
import tqdm
import torch

def collect_activation(self, input, output):
    activations.append(input[0].clone().detach().cpu())
    
activations = []
y_preds = []
y_true = []
names = []
handle = None

try: 
    handle = model.dense[3].register_forward_hook(collect_activation)
    with torch.no_grad():
        for batch in tqdm.tqdm(loader, total=len(loader)):
            names+=batch[0].name            
            y_true.append(batch[-1])
            batch = [b.to(device) for b in batch]
            y_pred = model(batch)
            y_preds.append(y_pred.cpu())
    y_preds = torch.vstack(y_preds)
    y_true = torch.vstack(y_true)
finally:
    if handle is not None:
        handle.remove()
        
activations = torch.cat(activations, dim=0)

  cpuset_checked))
100%|██████████| 107/107 [05:23<00:00,  3.02s/it]


## Load groundtruth binding sites obtained from BioLiP 

Link to source file(we use processed pickle file here)https://zhanggroup.org/BioLiP/download.html, please use the `Process_binding_site_data.ipynb` to preprocess the downloaded data first.

In [13]:
import pickle

binding_data = None
with open('../../data/BioLiP_binding_sites.pickle', 'rb') as f:
    binding_data = pickle.load(f)

## Load mapping from molecular functions to indices in model output

In [7]:
from lmgvp import deepfrier_utils
import os

prot2annot, goterms, gonames, counts = deepfrier_utils.load_GO_annot(
    os.path.join(
        "../../data",
        "nrPDB-GO_2019.06.18_annot.tsv",
    )
)

name_indices = {n:i for i, n in enumerate(gonames['mf'])}

## Cluster analysis in latent space

Use Umap projection and DBSCAN to obtain clusters of proteins based on the latent activation.

In [8]:
import umap.umap_ as umap
from sklearn.cluster import DBSCAN

def get_umap_projection_and_cluster(mf_term):
    
    label_index = name_indices[mf_term]

    filtered_activations = activations[y_true[:, label_index] > 0, :]
    filtered_names = []
    filtered_predictions = y_preds[y_true[:, label_index] > 0, label_index]
    for i, v in enumerate(y_true[:, label_index]):
        if v > 0:
            filtered_names.append(names[i])            
    reducer = umap.UMAP()
    embedding = reducer.fit_transform(filtered_activations.numpy())
    clustering = DBSCAN(eps=0.4, min_samples=2).fit(embedding)
    cluster_labels = clustering.labels_
    
    results = []
    for i, name in enumerate(filtered_names):
        results.append({
            'umap_x': embedding[i, 0],
            'umap_y': embedding[i, 1],
            'name': name,
            'cluster_id': cluster_labels[i],
            'pred': str(float(filtered_predictions[i]) > 0),
            'binding_data': (mf_term in binding_data and name in binding_data[mf_term])
        })
        
    return results

Run cluster analysis on proteins with `ATP binding` function and visualize the results

In [9]:
mf_term = "ATP binding"
results = get_umap_projection_and_cluster(mf_term)

import pandas as pd
df = pd.DataFrame.from_dict(results)

In [10]:
import altair as alt

points = alt.Chart(df).mark_point(
    filled=True,
    size=36,    
).encode(
    alt.X('umap_x:Q',
        scale=alt.Scale(
            domain=(0, 10.5),
            clamp=True
        )
    ),
    alt.Y('umap_y:Q',
        scale=alt.Scale(
            domain=(5, 14),
            clamp=True
        )
    ),
    shape = alt.Shape(
       "pred:N",
        scale = alt.Scale(range=["triangle", "circle"],zero=True)),
    color='cluster_id:N'
).properties(
    width=600,
    height=600
)

text = points.mark_text(
    align='left',
    baseline='middle',
    dx=7,
    color='black'
).encode(
    text='name'
)

points.interactive()

## Integrated Gradients (from Sequence Embeddings)

Create baseline reference sequences using [SEP] tokens. The reference sequence should have the same length as the baseline sequence.

In [14]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained(
            "Rostlab/prot_bert", do_lower_case=False)

def construct_input_ref_pair(input_ids, attention_mask):
    ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
    sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
    cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence    

    ref_input_ids = input_ids.clone()
    ref_input_ids[attention_mask>0] = ref_token_id
    ref_input_ids[0] = cls_token_id
    ref_input_ids[ref_input_ids[attention_mask>0].shape[0]-1] = sep_token_id
    
    return input_ids.clone().unsqueeze(0), ref_input_ids.unsqueeze(0)

Wrap the original model to get the model output for a particular molecular function. The molecular function is selected using `label_idx`.

In [15]:
def get_forward_func_wrapper(label_idx):
    def wrapper(input_ids, additional_forward_args=None):
        batch = additional_forward_args
        model_out = model(batch, input_ids=input_ids)
        return model_out[:,label_idx]
    return wrapper

The `LayerIntegratedGradientsRevisited` class is created to resolve out of memory issues caused by large bert models:

In [16]:
from captum.attr import GradientAttribution, LayerAttribution

from captum._utils.gradient import _forward_layer_eval, _run_forward
from captum._utils.common import (
    _extract_device
)
from torch.nn.parallel.scatter_gather import scatter

class LayerIntegratedGradientsRevisited(LayerAttribution, GradientAttribution):
    def __init__(
        self,
        forward_func,
        layer,
        device_ids = None,
        multiply_by_inputs = True,
    ):

        r"""
        Args:
            forward_func (callable):  The forward function of the model or any
                    modification of it
            multiply_by_inputs (bool, optional): Indicates whether to factor
                    model inputs' multiplier in the final attribution scores.
                    More detailed can be found here:
                    https://arxiv.org/abs/1711.06104
                    In case of integrated gradients, if `multiply_by_inputs`
                    is set to True, final sensitivity scores are being multiplied by
                    (inputs - baselines).
        """
        LayerAttribution.__init__(self, forward_func, layer, device_ids=device_ids)
        GradientAttribution.__init__(self, forward_func)
        self.multiply_by_inputs = multiply_by_inputs

            
    def attribute(
        self,
        inputs,
        baselines = None,
        target = None,
        additional_forward_args = None,
        n_steps = 50,
        internal_batch_size = None
    ):

        if self.device_ids is None:
            self.device_ids = getattr(self.forward_func, "device_ids", None)

        inputs_layer = _forward_layer_eval(
            self.forward_func,
            inputs,
            self.layer,
            device_ids=self.device_ids,
            additional_forward_args=additional_forward_args
        )
        
        inputs_layer = inputs_layer[0]
        
        baselines_layer = _forward_layer_eval(
            self.forward_func,
            baselines,
            self.layer,
            device_ids=self.device_ids,
            additional_forward_args=additional_forward_args
        )
        
        baselines_layer = baselines_layer[0]
                
#         print(torch.abs((inputs_layer-baselines_layer)).sum())

        # inputs -> these inputs are scaled
        def gradient_func(
            forward_fn,
            inputs,
            target = None,
            additional_forward_args = None,
        ):
            if self.device_ids is None or len(self.device_ids) == 0:
                scattered_inputs = (inputs,)
            else:
                # scatter method does not have a precise enough return type in its
                # stub, so suppress the type warning.
                scattered_inputs = scatter(  # type:ignore
                    inputs, target_gpus=self.device_ids
                )

            scattered_inputs_dict = {
                scattered_input[0].device: scattered_input
                for scattered_input in scattered_inputs
            }

            with torch.autograd.set_grad_enabled(True):

                def layer_forward_hook(
                    module, hook_inputs, hook_outputs=None
                ):
                    device = _extract_device(module, hook_inputs, hook_outputs)
                    return scattered_inputs_dict[device]

                hook = None
                try:
                    layer = self.layer
                    hook = layer.register_forward_hook(layer_forward_hook)
                    output = _run_forward(
                        self.forward_func, baselines, target, 
                            additional_forward_args=additional_forward_args)
                finally:
                    if hook is not None:
                        hook.remove()

                assert output[0].numel() == 1, (
                    "Target not provided when necessary, cannot"
                    " take gradient with respect to multiple outputs."
                )
                # torch.unbind(forward_out) is a list of scalar tensor tuples and
                # contains batch_size * #steps elements
                grads = torch.autograd.grad(torch.unbind(output), inputs)
            return grads
        
        
        alphas = [i*1.0/n_steps for i in range(n_steps+1)]
        grads = []
        final_grad = None
        for i, alpha in enumerate(alphas):
            
#             print(inputs_layer.shape, baselines_layer.shape, baselines.size())
            _inputs = (baselines_layer + alpha * (inputs_layer - baselines_layer)).requires_grad_()
            # grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)
            grad = gradient_func(
                forward_fn=self.forward_func,
                inputs=_inputs,
                target=target,
                additional_forward_args=additional_forward_args,
            )
            grads.append(grad[0].detach())
            
            if i == n_steps:
                final_grad = grad[0].detach()
            
        grads = torch.stack(grads, dim=-1)
        
        #calculate integration using trapzoid rule
        integral = ((grads[:,:,:-1] + grads[:,:,1:]) / 2 ).sum(dim=-1)
        saliency = integral
        
        if self.multiply_by_inputs:
            saliency = saliency * (inputs_layer - baselines_layer)
            
        saliency = saliency.sum(dim=-1).squeeze()
        
        # calculate gradient norm
        gradient_norm = final_grad
        
        return saliency, gradient_norm

    def has_convergence_delta(self):
        return True

    def multiplies_by_inputs(self):
        return self.multiplies_by_inputs

In [17]:
from torch_geometric.data import DataLoader
from torch.utils import data

def get_ig_attribution(mf_term, data_indice):
    label_index = name_indices[mf_term]
    forward_func = get_forward_func_wrapper(label_index)
    lig2 = LayerIntegratedGradientsRevisited(forward_func, model.identity, multiply_by_inputs=True, device_ids=[0])    
    subset = data.Subset(dataset, [data_indice])
    batch_loader = DataLoader(subset, batch_size=len(subset), shuffle=False)
    batch = None
    for _batch in batch_loader:
        batch = _batch
    batch = [b.to(device) for b in batch]
    output = model(batch)[:, label_index]
    input_ids, ref_input_ids = construct_input_ref_pair(subset[0][0].input_ids, subset[0][0].attention_mask)
    input_ids = input_ids.to(device)
    ref_input_ids = ref_input_ids.to(device)
    sequence = tokenizer.convert_ids_to_tokens(input_ids[0])
    attr_node_embeddings, grad_norm = lig2.attribute(inputs=input_ids, baselines=ref_input_ids, additional_forward_args=batch, n_steps=50)
    grad_norm = grad_norm.norm(dim=1)
    return attr_node_embeddings.cpu().numpy(), grad_norm.cpu().numpy(), output.detach().cpu().numpy(), sequence

In [18]:
from sklearn import metrics

def get_compiled_file(mf_term, protein_name, pred, binding_sites=None, folder='saliency_weights'):
    protein_index = protein_name_indices[protein_name]
    attrs, grad_norm, result, sequence = get_ig_attribution(mf_term, protein_index)
    auroc_attr_ig, binding_sites_vector = None, None
    if binding_sites is not None:
        binding_sites_vector = np.zeros(len(attrs))
        binding_sites_vector[binding_sites] = 1
        auroc_attr_ig = metrics.roc_auc_score(binding_sites_vector, attrs)
    sequence = ''.join(sequence[1:len(attrs) + 1])
    result = {'name': protein_name, 
            'mf-term': mf_term, 
            'sequence': sequence,
            'binding_sites': binding_sites_vector, 
            'attribution_integrated_gradient': attrs,
            'pred': pred,
            'auroc': float(auroc_attr_ig) if auroc_attr_ig is not None else None
           }
#     if binding_sites is not None:
    import pickle
    with open(os.path.join(folder, protein_name+'.pkl'), 'wb') as f:
        pickle.dump(result, f)
    return result


In [19]:
import numpy as np
import pandas as pd

mf_term = "ATP binding"
results = get_umap_projection_and_cluster(mf_term)
df = pd.DataFrame.from_dict(results)

df.to_csv(os.path.join("./ATP_binding_Umap_and_Cluster.csv"))

In [20]:
import altair as alt

points = alt.Chart(df).mark_point(
    filled=True,
    size=36,    
).encode(
    alt.X('umap_x:Q',
        scale=alt.Scale(
            domain=(0, 10.5),
            clamp=True
        )
    ),
    alt.Y('umap_y:Q',
        scale=alt.Scale(
            domain=(5, 14),
            clamp=True
        )
    ),
    shape = alt.Shape(
       "pred:N",
        scale = alt.Scale(range=["triangle", "circle"],zero=True)),
    color='cluster_id:N'
).properties(
    width=600,
    height=600
)

In [21]:
points.interactive()

## Run feature attribution via IG

In [22]:
sequences = []
    
for i, r in df.iterrows():
    if r['cluster_id'] == 8:
        binding_sites = None
        if mf_term in binding_data and r['name'] in binding_data[mf_term]:
            binding_sites = binding_data[mf_term][r['name']]["sites"] 
        d = get_compiled_file(mf_term, r['name'], r['pred'], binding_sites=binding_sites, folder='.')
        sequences.append(d)
        print(r['name'], r['pred'], "cluster_id:" , r['cluster_id'], '\t', d['auroc'])


1E2Q-A True cluster_id: 8 	 0.9073604060913706
1QPG-A True cluster_id: 8 	 None
2ORV-A True cluster_id: 8 	 None
4AKE-A True cluster_id: 8 	 None
2BBW-A True cluster_id: 8 	 None
3ZLB-A True cluster_id: 8 	 None
5NP8-A True cluster_id: 8 	 None
2AKY-A True cluster_id: 8 	 None
4Q1A-A True cluster_id: 8 	 None
2C9Y-A True cluster_id: 8 	 None
1ZD8-A True cluster_id: 8 	 None
5JZV-A True cluster_id: 8 	 None
3CH4-B True cluster_id: 8 	 None
2FEM-A True cluster_id: 8 	 None
1UKY-A True cluster_id: 8 	 None
2A30-A True cluster_id: 8 	 None
1FW8-A True cluster_id: 8 	 None
1Z83-A True cluster_id: 8 	 None
2TMK-A True cluster_id: 8 	 None
2PAA-A True cluster_id: 8 	 0.7678117048346056
1P4S-A True cluster_id: 8 	 None
2IYT-A True cluster_id: 8 	 None
1TEV-A True cluster_id: 8 	 None
4TMK-A True cluster_id: 8 	 None
3IIK-A True cluster_id: 8 	 None


In [26]:
sequences = msa_alignment(sequences, 'ATP_binding_cluster_8')

25it [00:00, 147271.91it/s]


Using 8 threads
Read 25 sequences (type: Protein) from ./inputATP_binding_cluster_8.fasta
not more sequences (25) than cluster-size (100), turn off mBed
Calculating pairwise ktuple-distances...
Ktuple-distance calculation progress: 0 % (0 out of 325)
Ktuple-distance calculation progress: 1 % (5 out of 325)
Ktuple-distance calculation progress: 2 % (7 out of 325)
Ktuple-distance calculation progress: 3 % (10 out of 325)
Ktuple-distance calculation progress: 19 % (63 out of 325)
Ktuple-distance calculation progress: 30 % (100 out of 325)
Ktuple-distance calculation progress: 34 % (111 out of 325)
Ktuple-distance calculation progress: 36 % (118 out of 325)
Ktuple-distance calculation progress: 42 % (138 out of 325)
Ktuple-distance calculation progress: 47 % (154 out of 325)
Ktuple-distance calculation progress: 56 % (183 out of 325)
Ktuple-distance calculation progress: 57 % (187 out of 325)
Ktuple-distance calculation progress: 60 % (197 out of 325)
Ktuple-distance calculation progress: 

25it [00:00, 132229.00it/s]


In [28]:
all_results = []

for i, d in tqdm(enumerate(sequences)):
    attribution = d["attribution_integrated_gradient"]
    sequence = d["sequence"]
    name = d["name"]
    alignment = d["alignment_result"]
    for j in range(len(sequence)):
        all_results.append({
            'aa': sequence[j],
            'attr': float(attribution[j]),
            'j': j,
            'j_aligned': alignment[j],
            'i': i,
            'name': name
        })
        
import pandas as pd
df = pd.DataFrame.from_dict(all_results)

df.to_csv(os.path.join('./', "ATP_binding_MSA_Cluster_8.csv"))
        

25it [00:00, 2569.03it/s]


## MSA Alignment

In [25]:
# input sequences in the format of [{sequence: , name: }], file name marks mf function and cluster number
# output sequences in the format of [{sequence: , name: , aligned_sequence}]
from tqdm import tqdm

def msa_alignment(sequences, filename_postfix):
    
    import os
    from Bio.Align.Applications import ClustalOmegaCommandline

    temp_dir = './'

    # Generate FASTA file
    infile = os.path.join(temp_dir, "input" + filename_postfix + ".fasta")
    with open(infile, "w+") as f:
        for i, d in tqdm(enumerate(sequences)):
            sequence = d["sequence"]
            name = d["name"]
            line=f">{name}\n{sequence}\n"
            f.write(line)

    clustal_path = "/home/ec2-user/SageMaker/efs/install/clustalo-1.2.4-Ubuntu-x86_64"

    outfile = os.path.join(temp_dir, "output"+ filename_postfix +".fasta")
    distmat = os.path.join(temp_dir, "distmat")
    clusters = os.path.join(temp_dir, "clusters")
    guidetree = os.path.join(temp_dir, "guidetree")
    clustalo_cline = ClustalOmegaCommandline(clustal_path, 
                                             infile=infile, 
                                             outfile=outfile, 
                                             verbose=True,  
                                             force=True,
                                             distmat_full=True,
                                             distmat_out=distmat,
                                             # clusteringout=clusters,
                                             # guidetree_out=guidetree,
                                             percentid=True
                                            )

    stdout, stderr = clustalo_cline()
    print(stdout)
    print(stderr)
    
    from Bio import SeqIO

    ## get seq alignment index
    alignment_results = {}

    for record in SeqIO.parse(os.path.join(temp_dir, "output" + filename_postfix + ".fasta"), "fasta"):
        name, sequence = record.id, record.seq
        alignment_results[name] = {}
        idx = 0
        for i, c in enumerate(sequence):
            if c != '-':
                alignment_results[name][idx] = i
                idx += 1
                
    for i, d in tqdm(enumerate(sequences)):
        name = d["name"]
        d["alignment_result"] = alignment_results[name]
        
    return sequences    