{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "4839cf79", "metadata": {}, "outputs": [], "source": [ "!pip install bio\n", "!pip install captum\n", "!pip install umap-learn\n", "!pip install pytorch-lightning" ] }, { "cell_type": "code", "execution_count": 1, "id": "cf4a48f3", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../')" ] }, { "cell_type": "markdown", "id": "1fd01bd9", "metadata": {}, "source": [ "## Load data and model" ] }, { "cell_type": "code", "execution_count": 2, "id": "0144808c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading BertTokenizer...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/ec2-user/anaconda3/envs/pytorch_latest_p36_clone/lib/python3.6/site-packages/torch/utils/data/dataloader.py:477: UserWarning: This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", " cpuset_checked))\n" ] } ], "source": [ "from torch_geometric.data import DataLoader\n", "from lmgvp import data_loaders\n", "\n", "dataset = data_loaders.get_dataset(\n", " 'mf', 'seq_struct', split=\"test\"\n", ")\n", "\n", "loader = DataLoader(\n", " dataset,\n", " batch_size=32,\n", " shuffle=False,\n", " num_workers=32,\n", ")" ] }, { "cell_type": "markdown", "id": "535440e6", "metadata": {}, "source": [ "Get mapping from protein names to data index:" ] }, { "cell_type": "code", "execution_count": 3, "id": "a52f9370", "metadata": {}, "outputs": [], "source": [ "protein_name_indices = {d[0].name:i for i, d in enumerate(dataset)}" ] }, { "cell_type": "markdown", "id": "7f48d5b5", "metadata": {}, "source": [ "Build model and load pretained weights from checkpoint:" ] }, { "cell_type": "code", "execution_count": null, "id": "809d1278", "metadata": {}, "outputs": [], "source": [ "from lmgvp.modules import BertMQAModel\n", "import torch\n", "\n", "checkpoint_path = \"../../data/epoch=0-step=1868.ckpt\"\n", "\n", "model = BertMQAModel.load_from_checkpoint(\n", " checkpoint_path,\n", " weights=dataset.pos_weights,\n", ")\n", "\n", "device = torch.device(\"cuda\")\n", "model = model.to(device)\n", "model.eval()" ] }, { "cell_type": "markdown", "id": "5098a99a", "metadata": {}, "source": [ "## Get latent activation and prediction results\n", "\n", "We use PyTorch forward hook here to extract the latent activations in the pen-ultimate layer for cluster analysis. The prediction results are also collected along the way." ] }, { "cell_type": "code", "execution_count": 5, "id": "fb59bf8b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/107 [00:00 0, :]\n", " filtered_names = []\n", " filtered_predictions = y_preds[y_true[:, label_index] > 0, label_index]\n", " for i, v in enumerate(y_true[:, label_index]):\n", " if v > 0:\n", " filtered_names.append(names[i]) \n", " reducer = umap.UMAP()\n", " embedding = reducer.fit_transform(filtered_activations.numpy())\n", " clustering = DBSCAN(eps=0.4, min_samples=2).fit(embedding)\n", " cluster_labels = clustering.labels_\n", " \n", " results = []\n", " for i, name in enumerate(filtered_names):\n", " results.append({\n", " 'umap_x': embedding[i, 0],\n", " 'umap_y': embedding[i, 1],\n", " 'name': name,\n", " 'cluster_id': cluster_labels[i],\n", " 'pred': str(float(filtered_predictions[i]) > 0),\n", " 'binding_data': (mf_term in binding_data and name in binding_data[mf_term])\n", " })\n", " \n", " return results" ] }, { "cell_type": "markdown", "id": "885d0116", "metadata": {}, "source": [ "Run cluster analysis on proteins with `ATP binding` function and visualize the results" ] }, { "cell_type": "code", "execution_count": 9, "id": "2588ef54", "metadata": {}, "outputs": [], "source": [ "mf_term = \"ATP binding\"\n", "results = get_umap_projection_and_cluster(mf_term)\n", "\n", "import pandas as pd\n", "df = pd.DataFrame.from_dict(results)" ] }, { "cell_type": "code", "execution_count": 10, "id": "9377cc89", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import altair as alt\n", "\n", "points = alt.Chart(df).mark_point(\n", " filled=True,\n", " size=36, \n", ").encode(\n", " alt.X('umap_x:Q',\n", " scale=alt.Scale(\n", " domain=(0, 10.5),\n", " clamp=True\n", " )\n", " ),\n", " alt.Y('umap_y:Q',\n", " scale=alt.Scale(\n", " domain=(5, 14),\n", " clamp=True\n", " )\n", " ),\n", " shape = alt.Shape(\n", " \"pred:N\",\n", " scale = alt.Scale(range=[\"triangle\", \"circle\"],zero=True)),\n", " color='cluster_id:N'\n", ").properties(\n", " width=600,\n", " height=600\n", ")\n", "\n", "text = points.mark_text(\n", " align='left',\n", " baseline='middle',\n", " dx=7,\n", " color='black'\n", ").encode(\n", " text='name'\n", ")\n", "\n", "points.interactive()" ] }, { "cell_type": "markdown", "id": "217c0995", "metadata": {}, "source": [ "## Integrated Gradients (from Sequence Embeddings)" ] }, { "cell_type": "markdown", "id": "502e1207", "metadata": {}, "source": [ "Create baseline reference sequences using [SEP] tokens. The reference sequence should have the same length as the baseline sequence." ] }, { "cell_type": "code", "execution_count": 14, "id": "7c492e35", "metadata": {}, "outputs": [], "source": [ "from transformers import BertTokenizer\n", "\n", "tokenizer = BertTokenizer.from_pretrained(\n", " \"Rostlab/prot_bert\", do_lower_case=False)\n", "\n", "def construct_input_ref_pair(input_ids, attention_mask):\n", " ref_token_id = tokenizer.pad_token_id # A token used for generating token reference\n", " sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.\n", " cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence \n", "\n", " ref_input_ids = input_ids.clone()\n", " ref_input_ids[attention_mask>0] = ref_token_id\n", " ref_input_ids[0] = cls_token_id\n", " ref_input_ids[ref_input_ids[attention_mask>0].shape[0]-1] = sep_token_id\n", " \n", " return input_ids.clone().unsqueeze(0), ref_input_ids.unsqueeze(0)" ] }, { "cell_type": "markdown", "id": "ab8c738a", "metadata": {}, "source": [ "Wrap the original model to get the model output for a particular molecular function. The molecular function is selected using `label_idx`." ] }, { "cell_type": "code", "execution_count": 15, "id": "78f82a51", "metadata": {}, "outputs": [], "source": [ "def get_forward_func_wrapper(label_idx):\n", " def wrapper(input_ids, additional_forward_args=None):\n", " batch = additional_forward_args\n", " model_out = model(batch, input_ids=input_ids)\n", " return model_out[:,label_idx]\n", " return wrapper" ] }, { "cell_type": "markdown", "id": "38a2718d", "metadata": {}, "source": [ "The `LayerIntegratedGradientsRevisited` class is created to resolve out of memory issues caused by large bert models:" ] }, { "cell_type": "code", "execution_count": 16, "id": "5640053d", "metadata": {}, "outputs": [], "source": [ "from captum.attr import GradientAttribution, LayerAttribution\n", "\n", "from captum._utils.gradient import _forward_layer_eval, _run_forward\n", "from captum._utils.common import (\n", " _extract_device\n", ")\n", "from torch.nn.parallel.scatter_gather import scatter\n", "\n", "class LayerIntegratedGradientsRevisited(LayerAttribution, GradientAttribution):\n", " def __init__(\n", " self,\n", " forward_func,\n", " layer,\n", " device_ids = None,\n", " multiply_by_inputs = True,\n", " ):\n", "\n", " r\"\"\"\n", " Args:\n", " forward_func (callable): The forward function of the model or any\n", " modification of it\n", " multiply_by_inputs (bool, optional): Indicates whether to factor\n", " model inputs' multiplier in the final attribution scores.\n", " More detailed can be found here:\n", " https://arxiv.org/abs/1711.06104\n", " In case of integrated gradients, if `multiply_by_inputs`\n", " is set to True, final sensitivity scores are being multiplied by\n", " (inputs - baselines).\n", " \"\"\"\n", " LayerAttribution.__init__(self, forward_func, layer, device_ids=device_ids)\n", " GradientAttribution.__init__(self, forward_func)\n", " self.multiply_by_inputs = multiply_by_inputs\n", "\n", " \n", " def attribute(\n", " self,\n", " inputs,\n", " baselines = None,\n", " target = None,\n", " additional_forward_args = None,\n", " n_steps = 50,\n", " internal_batch_size = None\n", " ):\n", "\n", " if self.device_ids is None:\n", " self.device_ids = getattr(self.forward_func, \"device_ids\", None)\n", "\n", " inputs_layer = _forward_layer_eval(\n", " self.forward_func,\n", " inputs,\n", " self.layer,\n", " device_ids=self.device_ids,\n", " additional_forward_args=additional_forward_args\n", " )\n", " \n", " inputs_layer = inputs_layer[0]\n", " \n", " baselines_layer = _forward_layer_eval(\n", " self.forward_func,\n", " baselines,\n", " self.layer,\n", " device_ids=self.device_ids,\n", " additional_forward_args=additional_forward_args\n", " )\n", " \n", " baselines_layer = baselines_layer[0]\n", " \n", "# print(torch.abs((inputs_layer-baselines_layer)).sum())\n", "\n", " # inputs -> these inputs are scaled\n", " def gradient_func(\n", " forward_fn,\n", " inputs,\n", " target = None,\n", " additional_forward_args = None,\n", " ):\n", " if self.device_ids is None or len(self.device_ids) == 0:\n", " scattered_inputs = (inputs,)\n", " else:\n", " # scatter method does not have a precise enough return type in its\n", " # stub, so suppress the type warning.\n", " scattered_inputs = scatter( # type:ignore\n", " inputs, target_gpus=self.device_ids\n", " )\n", "\n", " scattered_inputs_dict = {\n", " scattered_input[0].device: scattered_input\n", " for scattered_input in scattered_inputs\n", " }\n", "\n", " with torch.autograd.set_grad_enabled(True):\n", "\n", " def layer_forward_hook(\n", " module, hook_inputs, hook_outputs=None\n", " ):\n", " device = _extract_device(module, hook_inputs, hook_outputs)\n", " return scattered_inputs_dict[device]\n", "\n", " hook = None\n", " try:\n", " layer = self.layer\n", " hook = layer.register_forward_hook(layer_forward_hook)\n", " output = _run_forward(\n", " self.forward_func, baselines, target, \n", " additional_forward_args=additional_forward_args)\n", " finally:\n", " if hook is not None:\n", " hook.remove()\n", "\n", " assert output[0].numel() == 1, (\n", " \"Target not provided when necessary, cannot\"\n", " \" take gradient with respect to multiple outputs.\"\n", " )\n", " # torch.unbind(forward_out) is a list of scalar tensor tuples and\n", " # contains batch_size * #steps elements\n", " grads = torch.autograd.grad(torch.unbind(output), inputs)\n", " return grads\n", " \n", " \n", " alphas = [i*1.0/n_steps for i in range(n_steps+1)]\n", " grads = []\n", " final_grad = None\n", " for i, alpha in enumerate(alphas):\n", " \n", "# print(inputs_layer.shape, baselines_layer.shape, baselines.size())\n", " _inputs = (baselines_layer + alpha * (inputs_layer - baselines_layer)).requires_grad_()\n", " # grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)\n", " grad = gradient_func(\n", " forward_fn=self.forward_func,\n", " inputs=_inputs,\n", " target=target,\n", " additional_forward_args=additional_forward_args,\n", " )\n", " grads.append(grad[0].detach())\n", " \n", " if i == n_steps:\n", " final_grad = grad[0].detach()\n", " \n", " grads = torch.stack(grads, dim=-1)\n", " \n", " #calculate integration using trapzoid rule\n", " integral = ((grads[:,:,:-1] + grads[:,:,1:]) / 2 ).sum(dim=-1)\n", " saliency = integral\n", " \n", " if self.multiply_by_inputs:\n", " saliency = saliency * (inputs_layer - baselines_layer)\n", " \n", " saliency = saliency.sum(dim=-1).squeeze()\n", " \n", " # calculate gradient norm\n", " gradient_norm = final_grad\n", " \n", " return saliency, gradient_norm\n", "\n", " def has_convergence_delta(self):\n", " return True\n", "\n", " def multiplies_by_inputs(self):\n", " return self.multiplies_by_inputs" ] }, { "cell_type": "code", "execution_count": 17, "id": "3324dab3", "metadata": {}, "outputs": [], "source": [ "from torch_geometric.data import DataLoader\n", "from torch.utils import data\n", "\n", "def get_ig_attribution(mf_term, data_indice):\n", " label_index = name_indices[mf_term]\n", " forward_func = get_forward_func_wrapper(label_index)\n", " lig2 = LayerIntegratedGradientsRevisited(forward_func, model.identity, multiply_by_inputs=True, device_ids=[0]) \n", " subset = data.Subset(dataset, [data_indice])\n", " batch_loader = DataLoader(subset, batch_size=len(subset), shuffle=False)\n", " batch = None\n", " for _batch in batch_loader:\n", " batch = _batch\n", " batch = [b.to(device) for b in batch]\n", " output = model(batch)[:, label_index]\n", " input_ids, ref_input_ids = construct_input_ref_pair(subset[0][0].input_ids, subset[0][0].attention_mask)\n", " input_ids = input_ids.to(device)\n", " ref_input_ids = ref_input_ids.to(device)\n", " sequence = tokenizer.convert_ids_to_tokens(input_ids[0])\n", " attr_node_embeddings, grad_norm = lig2.attribute(inputs=input_ids, baselines=ref_input_ids, additional_forward_args=batch, n_steps=50)\n", " grad_norm = grad_norm.norm(dim=1)\n", " return attr_node_embeddings.cpu().numpy(), grad_norm.cpu().numpy(), output.detach().cpu().numpy(), sequence" ] }, { "cell_type": "code", "execution_count": 18, "id": "92ad5848", "metadata": {}, "outputs": [], "source": [ "from sklearn import metrics\n", "\n", "def get_compiled_file(mf_term, protein_name, pred, binding_sites=None, folder='saliency_weights'):\n", " protein_index = protein_name_indices[protein_name]\n", " attrs, grad_norm, result, sequence = get_ig_attribution(mf_term, protein_index)\n", " auroc_attr_ig, binding_sites_vector = None, None\n", " if binding_sites is not None:\n", " binding_sites_vector = np.zeros(len(attrs))\n", " binding_sites_vector[binding_sites] = 1\n", " auroc_attr_ig = metrics.roc_auc_score(binding_sites_vector, attrs)\n", " sequence = ''.join(sequence[1:len(attrs) + 1])\n", " result = {'name': protein_name, \n", " 'mf-term': mf_term, \n", " 'sequence': sequence,\n", " 'binding_sites': binding_sites_vector, \n", " 'attribution_integrated_gradient': attrs,\n", " 'pred': pred,\n", " 'auroc': float(auroc_attr_ig) if auroc_attr_ig is not None else None\n", " }\n", "# if binding_sites is not None:\n", " import pickle\n", " with open(os.path.join(folder, protein_name+'.pkl'), 'wb') as f:\n", " pickle.dump(result, f)\n", " return result\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "f3905a7e", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "mf_term = \"ATP binding\"\n", "results = get_umap_projection_and_cluster(mf_term)\n", "df = pd.DataFrame.from_dict(results)\n", "\n", "df.to_csv(os.path.join(\"./ATP_binding_Umap_and_Cluster.csv\"))" ] }, { "cell_type": "code", "execution_count": 20, "id": "e5d74f8d", "metadata": {}, "outputs": [], "source": [ "import altair as alt\n", "\n", "points = alt.Chart(df).mark_point(\n", " filled=True,\n", " size=36, \n", ").encode(\n", " alt.X('umap_x:Q',\n", " scale=alt.Scale(\n", " domain=(0, 10.5),\n", " clamp=True\n", " )\n", " ),\n", " alt.Y('umap_y:Q',\n", " scale=alt.Scale(\n", " domain=(5, 14),\n", " clamp=True\n", " )\n", " ),\n", " shape = alt.Shape(\n", " \"pred:N\",\n", " scale = alt.Scale(range=[\"triangle\", \"circle\"],zero=True)),\n", " color='cluster_id:N'\n", ").properties(\n", " width=600,\n", " height=600\n", ")" ] }, { "cell_type": "code", "execution_count": 21, "id": "c847773a", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.Chart(...)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "points.interactive()" ] }, { "cell_type": "markdown", "id": "f3d65e6a", "metadata": {}, "source": [ "## Run feature attribution via IG" ] }, { "cell_type": "code", "execution_count": 22, "id": "2fcedc2b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1E2Q-A True cluster_id: 8 \t 0.9073604060913706\n", "1QPG-A True cluster_id: 8 \t None\n", "2ORV-A True cluster_id: 8 \t None\n", "4AKE-A True cluster_id: 8 \t None\n", "2BBW-A True cluster_id: 8 \t None\n", "3ZLB-A True cluster_id: 8 \t None\n", "5NP8-A True cluster_id: 8 \t None\n", "2AKY-A True cluster_id: 8 \t None\n", "4Q1A-A True cluster_id: 8 \t None\n", "2C9Y-A True cluster_id: 8 \t None\n", "1ZD8-A True cluster_id: 8 \t None\n", "5JZV-A True cluster_id: 8 \t None\n", "3CH4-B True cluster_id: 8 \t None\n", "2FEM-A True cluster_id: 8 \t None\n", "1UKY-A True cluster_id: 8 \t None\n", "2A30-A True cluster_id: 8 \t None\n", "1FW8-A True cluster_id: 8 \t None\n", "1Z83-A True cluster_id: 8 \t None\n", "2TMK-A True cluster_id: 8 \t None\n", "2PAA-A True cluster_id: 8 \t 0.7678117048346056\n", "1P4S-A True cluster_id: 8 \t None\n", "2IYT-A True cluster_id: 8 \t None\n", "1TEV-A True cluster_id: 8 \t None\n", "4TMK-A True cluster_id: 8 \t None\n", "3IIK-A True cluster_id: 8 \t None\n" ] } ], "source": [ "sequences = []\n", " \n", "for i, r in df.iterrows():\n", " if r['cluster_id'] == 8:\n", " binding_sites = None\n", " if mf_term in binding_data and r['name'] in binding_data[mf_term]:\n", " binding_sites = binding_data[mf_term][r['name']][\"sites\"] \n", " d = get_compiled_file(mf_term, r['name'], r['pred'], binding_sites=binding_sites, folder='.')\n", " sequences.append(d)\n", " print(r['name'], r['pred'], \"cluster_id:\" , r['cluster_id'], '\\t', d['auroc'])\n" ] }, { "cell_type": "code", "execution_count": 26, "id": "a0e4ff26", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25it [00:00, 147271.91it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using 8 threads\n", "Read 25 sequences (type: Protein) from ./inputATP_binding_cluster_8.fasta\n", "not more sequences (25) than cluster-size (100), turn off mBed\n", "Calculating pairwise ktuple-distances...\n", "Ktuple-distance calculation progress: 0 % (0 out of 325)\n", "Ktuple-distance calculation progress: 1 % (5 out of 325)\n", "Ktuple-distance calculation progress: 2 % (7 out of 325)\n", "Ktuple-distance calculation progress: 3 % (10 out of 325)\n", "Ktuple-distance calculation progress: 19 % (63 out of 325)\n", "Ktuple-distance calculation progress: 30 % (100 out of 325)\n", "Ktuple-distance calculation progress: 34 % (111 out of 325)\n", "Ktuple-distance calculation progress: 36 % (118 out of 325)\n", "Ktuple-distance calculation progress: 42 % (138 out of 325)\n", "Ktuple-distance calculation progress: 47 % (154 out of 325)\n", "Ktuple-distance calculation progress: 56 % (183 out of 325)\n", "Ktuple-distance calculation progress: 57 % (187 out of 325)\n", "Ktuple-distance calculation progress: 60 % (197 out of 325)\n", "Ktuple-distance calculation progress: 71 % (232 out of 325)\n", "Ktuple-distance calculation progress: 74 % (242 out of 325)\n", "Ktuple-distance calculation progress: 75 % (245 out of 325)\n", "Ktuple-distance calculation progress: 77 % (251 out of 325)\n", "Ktuple-distance calculation progress: 78 % (254 out of 325)\n", "Ktuple-distance calculation progress: 79 % (258 out of 325)\n", "Ktuple-distance calculation progress: 80 % (260 out of 325)\n", "Ktuple-distance calculation progress: 84 % (274 out of 325)\n", "Pairwise distance matrix written to ./distmat\n", "Ktuple-distance calculation progress done. CPU time: 0.08u 0.00s 00:00:00.08 Elapsed: 00:00:00\n", "Guide-tree computation done.\n", "Progressive alignment progress: 4 % (1 out of 24)\n", "Progressive alignment progress: 8 % (2 out of 24)\n", "Progressive alignment progress: 12 % (3 out of 24)\n", "Progressive alignment progress: 16 % (4 out of 24)\n", "Progressive alignment progress: 20 % (5 out of 24)\n", "Progressive alignment progress: 25 % (6 out of 24)\n", "Progressive alignment progress: 29 % (7 out of 24)\n", "Progressive alignment progress: 33 % (8 out of 24)\n", "Progressive alignment progress: 37 % (9 out of 24)\n", "Progressive alignment progress: 41 % (10 out of 24)\n", "Progressive alignment progress: 45 % (11 out of 24)\n", "Progressive alignment progress: 50 % (12 out of 24)\n", "Progressive alignment progress: 54 % (13 out of 24)\n", "Progressive alignment progress: 58 % (14 out of 24)\n", "Progressive alignment progress: 62 % (15 out of 24)\n", "Progressive alignment progress: 66 % (16 out of 24)\n", "Progressive alignment progress: 70 % (17 out of 24)\n", "Progressive alignment progress: 75 % (18 out of 24)\n", "Progressive alignment progress: 79 % (19 out of 24)\n", "Progressive alignment progress: 83 % (20 out of 24)\n", "Progressive alignment progress: 87 % (21 out of 24)\n", "Progressive alignment progress: 91 % (22 out of 24)\n", "Progressive alignment progress: 95 % (23 out of 24)\n", "Progressive alignment progress: 100 % (24 out of 24)\n", "Progressive alignment progress done. CPU time: 2.04u 0.02s 00:00:02.06 Elapsed: 00:00:00\n", "Alignment written to ./outputATP_binding_cluster_8.fasta\n", "\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "25it [00:00, 132229.00it/s]\n" ] } ], "source": [ "sequences = msa_alignment(sequences, 'ATP_binding_cluster_8')" ] }, { "cell_type": "code", "execution_count": 28, "id": "77603e9b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25it [00:00, 2569.03it/s]\n" ] } ], "source": [ "all_results = []\n", "\n", "for i, d in tqdm(enumerate(sequences)):\n", " attribution = d[\"attribution_integrated_gradient\"]\n", " sequence = d[\"sequence\"]\n", " name = d[\"name\"]\n", " alignment = d[\"alignment_result\"]\n", " for j in range(len(sequence)):\n", " all_results.append({\n", " 'aa': sequence[j],\n", " 'attr': float(attribution[j]),\n", " 'j': j,\n", " 'j_aligned': alignment[j],\n", " 'i': i,\n", " 'name': name\n", " })\n", " \n", "import pandas as pd\n", "df = pd.DataFrame.from_dict(all_results)\n", "\n", "df.to_csv(os.path.join('./', \"ATP_binding_MSA_Cluster_8.csv\"))\n", " " ] }, { "cell_type": "markdown", "id": "2a0c19f1", "metadata": {}, "source": [ "## MSA Alignment" ] }, { "cell_type": "code", "execution_count": 25, "id": "d505c1c9", "metadata": {}, "outputs": [], "source": [ "# input sequences in the format of [{sequence: , name: }], file name marks mf function and cluster number\n", "# output sequences in the format of [{sequence: , name: , aligned_sequence}]\n", "from tqdm import tqdm\n", "\n", "def msa_alignment(sequences, filename_postfix):\n", " \n", " import os\n", " from Bio.Align.Applications import ClustalOmegaCommandline\n", "\n", " temp_dir = './'\n", "\n", " # Generate FASTA file\n", " infile = os.path.join(temp_dir, \"input\" + filename_postfix + \".fasta\")\n", " with open(infile, \"w+\") as f:\n", " for i, d in tqdm(enumerate(sequences)):\n", " sequence = d[\"sequence\"]\n", " name = d[\"name\"]\n", " line=f\">{name}\\n{sequence}\\n\"\n", " f.write(line)\n", "\n", " clustal_path = \"/home/ec2-user/SageMaker/efs/install/clustalo-1.2.4-Ubuntu-x86_64\"\n", "\n", " outfile = os.path.join(temp_dir, \"output\"+ filename_postfix +\".fasta\")\n", " distmat = os.path.join(temp_dir, \"distmat\")\n", " clusters = os.path.join(temp_dir, \"clusters\")\n", " guidetree = os.path.join(temp_dir, \"guidetree\")\n", " clustalo_cline = ClustalOmegaCommandline(clustal_path, \n", " infile=infile, \n", " outfile=outfile, \n", " verbose=True, \n", " force=True,\n", " distmat_full=True,\n", " distmat_out=distmat,\n", " # clusteringout=clusters,\n", " # guidetree_out=guidetree,\n", " percentid=True\n", " )\n", "\n", " stdout, stderr = clustalo_cline()\n", " print(stdout)\n", " print(stderr)\n", " \n", " from Bio import SeqIO\n", "\n", " ## get seq alignment index\n", " alignment_results = {}\n", "\n", " for record in SeqIO.parse(os.path.join(temp_dir, \"output\" + filename_postfix + \".fasta\"), \"fasta\"):\n", " name, sequence = record.id, record.seq\n", " alignment_results[name] = {}\n", " idx = 0\n", " for i, c in enumerate(sequence):\n", " if c != '-':\n", " alignment_results[name][idx] = i\n", " idx += 1\n", " \n", " for i, d in tqdm(enumerate(sequences)):\n", " name = d[\"name\"]\n", " d[\"alignment_result\"] = alignment_results[name]\n", " \n", " return sequences " ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_latest_p36_clone", "language": "python", "name": "conda_pytorch_latest_p36_clone" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 5 }