# Using ESMFold to Predict Protein Structures on SageMaker

NOTE: The authors recommend running this notebook in Amazon SageMaker Studio with the following environment settings: 
* **PyTorch 1.13 Python 3.9 CPU-optimized** image 
* **Python 3** kernel 
* **ml.r5.xlarge** instance type 

For improved performance, you may also use the **PyTorch 1.13 Python 3.9 GPU-optimized** image and a **ml.g4dn.2xlarge** instance type.

---

Understanding the structure of proteins like antibodies is important for understanding their function. However, it can be difficult and expensive to do this in a laboratory. Recently AI-driven protein folding algorithms have enabled biologists to predict these structures from their aminio acid sequences instead.

In this notebook, we will use the [ESMFold](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1) protein folding algorithm to predict the structure of Herceptin (Trastuzumab), an important breast cancer therapy. Herceptin is a [monoclonal antibody](https://www.cancer.org/treatment/treatments-and-side-effects/treatment-types/immunotherapy/monoclonal-antibodies.html) (mAb) that binds to the HER2 receptor, inhibiting cancer cell growth. The following diagram shows several of the common elements of monoclonal antibodies.

![A diagram of the major structural elements of an antibody](img/antibody.png)

In this notebook, we'll focus on predicting the structure of the heavy chain region.

In [None]:
%pip install -U -q -r esmfold-requirements.txt --disable-pip-version-check

## 1. Download and Visualize the Experimentally-Determined Herceptin Protein Structure

In [None]:
from Bio.PDB import PDBList, MMCIFParser
import os
import py3Dmol
from prothelpers.structure import atoms_to_pdb
import warnings

target_id = "1N8Z"

if not os.path.isdir("data"):
 os.mkdir("data")

pdbl = PDBList()
filename = pdbl.retrieve_pdb_file(target_id, pdir="data", file_format="mmCif")
parser = MMCIFParser()
with warnings.catch_warnings():
 warnings.simplefilter("ignore")
 structure_1N8Z = parser.get_structure(target_id, filename)
pdb_string = atoms_to_pdb(structure_1N8Z[0])

view = py3Dmol.view(width=600, height=400)
view.addModel(pdb_string)
view.setStyle({"chain": "A"}, {"cartoon": {"color": "orange", "opacity": 0.5}})
view.setStyle({"chain": "B"}, {"cartoon": {"color": "blue"}})
view.setStyle({"chain": "C"}, {"cartoon": {"color": "green", "opacity": 0.5}})
view.zoomTo()
view.show()

In the image above, the light chain (A) is in orange, the heavy chain (B) is in blue, and the HER2 antigen (chain C) is in green. In this notebook, we will use ESMFold to predict the structure of chain B from it's amino acid sequence. Then, we will compare the prediction to the experimentally-determined structure shown above.

Extract the structure and sequence of chain B for later use.

In [None]:
from prothelpers.structure import get_aa_seq

experimental_structure = atoms_to_pdb(structure_1N8Z[0]["B"])
with open("data/experimental.pdb", "w") as f:
 f.write(experimental_structure)

experimental_sequence = get_aa_seq(structure_1N8Z[0]["B"])

## 2. Make an In-Notebook ESMFold Prediction
We'll use the ESMFold model to predict the structure of the heavy chain and compare it to the experimental result. First, we load the pre-trained ESMFold model and tokenizer from HuggingFace Hub. This will take about one minute.

In [None]:
from transformers import AutoTokenizer, EsmForProteinFolding

tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained(
 "facebook/esmfold_v1", low_cpu_mem_usage=True
)

If using a GPU, we can switch to half-precision and use other optimizations to improve performance. We can also use a CPU for lower cost, but slower performance.

In [None]:
import torch

if torch.cuda.is_available():
 device = torch.device("cuda")
 model.esm = model.esm.half()
 torch.backends.cuda.matmul.allow_tf32 = True
else:
 print("CUDA not detected. Using CPU-only parameters")
 device = torch.device("cpu")
 model.esm = model.esm.float()
 torch.backends.cuda.matmul.allow_tf32 = False

model = model.to(device)
model.trunk.set_chunk_size(64)

Next, we tokenize the sequence to convert it into a numerical format that ESMFold can use for prediction

In [None]:
tokenized_input = tokenizer(
 [experimental_sequence], return_tensors="pt", add_special_tokens=False
)["input_ids"]
tokenized_input = tokenized_input.to(device)

print(f"The human-readable sequence is {experimental_sequence}")
print(f"The tokenized representation of the sequences is {tokenized_input}")

Finally, we submit the tokenized sequence to the ESMFold model to predict the 3D structure. This will take about three minutes on a non-accelerated instance type and several seconds on an accelerated instance.

In [None]:
%%time
print(f"Predicting the structure of protein sequence {experimental_sequence}")
with torch.no_grad():
 notebook_prediction = model.infer_pdb(experimental_sequence)

with open("data/prediction.pdb", "w") as f:
 f.write(notebook_prediction)

if torch.cuda.is_available():
 torch.cuda.empty_cache()

print(f"Prediction complete")

We can determine the accuracy of our prediction by comparing it to the experimental structure and calculating a [TM-score](https://zhanggroup.org/TM-score/) between 0 and 1.

In [None]:
from prothelpers.usalign import tmscore

tmscore("data/prediction.pdb", "data/experimental.pdb", pymol="data/superimposed")

A TM-Score of 0.8 isn't perfect, but good enough for analytical tasks like ligand-binding analysis. Let's visualize how the experimentally-determined and predicted structure align.

In [None]:
view = py3Dmol.view(width=600, height=400)

with open("data/experimental.pdb") as ifile:
 superimposed = "".join([x for x in ifile])
view.addModel(experimental_structure)

with open("data/superimposed.pdb") as ifile:
 superimposed = "".join([x for x in ifile])
view.addModel(superimposed)

view.setStyle({"model": 0}, {"cartoon": {"color": "blue"}})
view.setStyle({"model": 1}, {"cartoon": {"color": "red"}})
view.zoomTo()
view.show()

The two structures show a high, but not perfect, degree of overlap. Protein structure predictions is a rapidly-evolving field and many research teams are developing ever-more accurate algorithms!

## 3. Deploy ESMFold as a SageMaker Inference API

Running model inference in a notebook is fine for experimentation, but what if you need to integrate your model with an application? Or an MLOps pipeline? In this case, a better option is to deploy your model as an API endpoint. In this example, we'll deploy ESMFold as a real-time inference endpoint on an accelerated instance.

In [None]:
import sagemaker
import boto3
import os

boto_session = boto3.Session()
sm_sess = sagemaker.Session()
sm_client = boto_session.client("sagemaker")
region = sm_sess.boto_region_name
bucket = sm_sess.default_bucket()
prefix = "sagemaker/ESMfold"
role = sagemaker.get_execution_role()

if not os.path.isdir("code"):
 os.mkdir("code")

The pre-built PyTorch container makes it easy to deploy deep learning models for common tasks. To predict protein structures we need to define a custom `inference.py` script to load the model, run the prediction, and format the output.

In [None]:
%%writefile code/inference.py

import json
import numpy as np
import os
import torch
import traceback
import transformers
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from typing import Any, Dict, List

MODEL_NAME = "facebook/esmfold_v1"

def model_fn(model_dir: str) -> Dict[str, Any]:
 """ Load the model artifact """

 try:
 model_path = os.path.join(model_dir, "esmfold_v1")
 tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
 model = transformers.EsmForProteinFolding.from_pretrained(model_path, low_cpu_mem_usage=True)

 if torch.cuda.is_available():
 model.to("cuda")
 model.esm = model.esm.half()
 torch.backends.cuda.matmul.allow_tf32 = True
 model.trunk.set_chunk_size(64)
 else:
 model.to("cpu")
 model.esm = model.esm.float()
 model.trunk.set_chunk_size(64)

 return tokenizer, model
 except Exception as e:
 traceback.print_exc()
 raise e
 
def input_fn(request_body: str, request_content_type: str = "text/csv") -> List[str]:
 """ Process the request """

 print(request_content_type)
 
 if request_content_type == "text/csv":
 sequence = request_body
 print("Input protein sequence: ", sequence)
 return sequence
 elif request_content_type == "application/json":
 sequence = json.loads(request_body)
 print("Input protein sequence: ", sequence)
 return sequence
 else:
 raise ValueError("Unsupported content type: {}".format(request_content_type)) 

def predict_fn(input_data: List, tokenizer_model: tuple) -> np.ndarray:
 """ Run the prediction """
 
 try:
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 esm_tokenizer, esm_model = tokenizer_model
 tokenized_input = esm_tokenizer(
 input_data, return_tensors="pt", add_special_tokens=False
 )["input_ids"].to(device)

 with torch.no_grad():
 output = esm_model(tokenized_input)
 return output
 except Exception as e:
 traceback.print_exc()
 raise e
 
def output_fn(outputs: str, response_content_type: str = "text/csv"):
 """Transform the prediction into a pdb-formatted string"""
 
 final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
 outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
 final_atom_positions = final_atom_positions.cpu().numpy()
 final_atom_mask = outputs["atom37_atom_exists"]
 pdbs = []
 for i in range(outputs["aatype"].shape[0]):
 aa = outputs["aatype"][i]
 pred_pos = final_atom_positions[i]
 mask = final_atom_mask[i]
 resid = outputs["residue_index"][i] + 1
 pred = OFProtein(
 aatype=aa,
 atom_positions=pred_pos,
 atom_mask=mask,
 residue_index=resid,
 b_factors=outputs["plddt"][i],
 chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
 )
 pdbs.append(to_pdb(pred))
 
 if response_content_type == "text/csv":
 return pdbs
 elif response_content_type == "application/json":
 return json.dumps(pdbs)
 else:
 raise ValueError("Unsupported content type: {}".format(response_content_type))

We can also define additional packages needed by our inference code in a `requirements.txt` file.

In [None]:
%%writefile code/requirements.txt

transformers==4.24.0
accelerate==0.17.0

For this example, we use a model artifact previously downloaded from HuggingFace hub and uploaded to S3.

In [None]:
from sagemaker.huggingface import HuggingFaceModel
from datetime import datetime

public_model_url = "s3://aws-hcls-ml/workshop/esmfold/esmfold_v1.tar.gz"

huggingface_model = HuggingFaceModel(
 model_data=public_model_url,
 name=f"emsfold-v1-model-" + datetime.now().strftime("%Y%m%d%s"),
 transformers_version="4.17",
 pytorch_version="1.10",
 py_version="py38",
 role=role,
 source_dir="code",
 entry_point="inference.py",
)

Next we deploy the model to a real-time inference endpoint hosted on an accelerated instance type. This may take 10 minutes or more depending on the availability of ml.g4dn.2xlarge instances in your region.

In [None]:
endpoint_name = f"emsfold-v1-rt-endpoint-" + datetime.now().strftime("%Y%m%d%s")
%store endpoint_name

In [None]:
%%time

predictor = huggingface_model.deploy(
 initial_instance_count=1,
 instance_type="ml.g4dn.2xlarge",
 endpoint_name=endpoint_name
)

Submit the chain B sequence to the endpoint and print the first few rows of the predicted structure file.

In [None]:
endpoint_prediction = predictor.predict(experimental_sequence)[0]
print(endpoint_prediction[:900])

Finally, save the predicted structure to a file and compare it to the notebook prediction. The two predictions should be identical.

In [None]:
from prothelpers import usalign

with open("data/endpoint_prediction.pdb", "w") as f:
 f.write(endpoint_prediction)

usalign.tmscore(
 "data/endpoint_prediction.pdb", "data/prediction.pdb", pymol="data/esm_superimposed"
)

In [None]:
view2 = py3Dmol.view(width=600, height=400)

with open("data/prediction.pdb") as ifile:
 prediction = "".join([x for x in ifile])
view2.addModel(prediction)

with open("data/esm_superimposed.pdb") as ifile:
 esm_superimposed = "".join([x for x in ifile])
view2.addModel(esm_superimposed)

view2.setStyle({"model": 0}, {"cartoon": {"color": "red"}})
view2.setStyle({"model": 1}, {"cartoon": {"color": "blue"}})
view2.zoomTo()
view2.show()

If you don't plan to run the next module, you can uncomment this final cell to remove the endpoint and data and halt further charges.

In [None]:
# try:
# predictor.delete_endpoint()
# except:
# pass