# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 """ Helper functions for the AWS-Alphafold notebook. """ from datetime import datetime import boto3 import uuid import sagemaker from Bio import SeqIO from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord from Bio import AlignIO from Bio.Align import MultipleSeqAlignment import os import pandas as pd import matplotlib.pyplot as plt from matplotlib import colors import numpy as np import string from string import ascii_uppercase, ascii_lowercase import py3Dmol import json import re boto_session = boto3.session.Session() sm_session = sagemaker.session.Session(boto_session) region = boto_session.region_name s3 = boto_session.client("s3", region_name=region) batch = boto_session.client("batch", region_name=region) cfn = boto_session.client("cloudformation", region_name=region) logs_client = boto_session.client("logs") def create_job_name(suffix=None): """ Define a simple job identifier """ if suffix == None: return datetime.now().strftime("%Y%m%dT%H%M%S") else: ## Ensure that the suffix conforms to the Batch requirements, (only letters, ## numbers, hyphens, and underscores are allowed). suffix = re.sub("\W", "_", suffix) return datetime.now().strftime("%Y%m%dT%H%M%S") + "_" + suffix def upload_fasta_to_s3( sequences, ids, bucket=sm_session.default_bucket(), job_name=uuid.uuid4(), region="us-east-1", ): """ Create a fasta file and upload it to S3. """ file_out = "_tmp.fasta" with open(file_out, "a") as f_out: for i, seq in enumerate(sequences): seq_record = SeqRecord(Seq(seq), id=ids[i]) SeqIO.write(seq_record, f_out, "fasta") object_key = f"{job_name}/{job_name}.fasta" response = s3.upload_file(file_out, bucket, object_key) os.remove(file_out) s3_uri = f"s3://{bucket}/{object_key}" print(f"Sequence file uploaded to {s3_uri}") return object_key def list_alphafold_stacks(): af_stacks = [] for stack in cfn.list_stacks( StackStatusFilter=["CREATE_COMPLETE", "UPDATE_COMPLETE"] )["StackSummaries"]: if "alphafold-cfn-batch.yaml" in stack.get("TemplateDescription", []): af_stacks.append(stack) return af_stacks def get_batch_resources(stack_name): """ Get the resource names of the Batch resources for running Alphafold jobs. """ # stack_name = af_stacks[0]["StackName"] stack_resources = cfn.list_stack_resources(StackName=stack_name) cpu_job_queue_spot = None for resource in stack_resources["StackResourceSummaries"]: if resource["LogicalResourceId"] == "GPUFoldingJobDefinition": gpu_job_definition = resource["PhysicalResourceId"] if resource["LogicalResourceId"] == "PrivateGPUJobQueue": gpu_job_queue = resource["PhysicalResourceId"] if resource["LogicalResourceId"] == "CPUFoldingJobDefinition": cpu_job_definition = resource["PhysicalResourceId"] if resource["LogicalResourceId"] == "PrivateCPUJobQueueOnDemand": cpu_job_queue_od = download_job_queue = resource["PhysicalResourceId"] if resource["LogicalResourceId"] == "PrivateCPUJobQueueSpot": cpu_job_queue_spot = resource["PhysicalResourceId"] if resource["LogicalResourceId"] == "CPUDownloadJobDefinition": download_job_definition = resource["PhysicalResourceId"] return { "gpu_job_definition": gpu_job_definition, "gpu_job_queue": gpu_job_queue, "cpu_job_definition": cpu_job_definition, "cpu_job_queue_od": cpu_job_queue_od, "cpu_job_queue_spot": cpu_job_queue_spot, "download_job_definition": download_job_definition, "download_job_queue": download_job_queue, } def get_batch_job_info(jobId): """ Retrieve and format information about a batch job. """ job_description = batch.describe_jobs(jobs=[jobId]) output = { "jobArn": job_description["jobs"][0]["jobArn"], "jobName": job_description["jobs"][0]["jobName"], "jobId": job_description["jobs"][0]["jobId"], "status": job_description["jobs"][0]["status"], "createdAt": datetime.utcfromtimestamp( job_description["jobs"][0]["createdAt"] / 1000 ).strftime("%Y-%m-%dT%H:%M:%SZ"), "dependsOn": job_description["jobs"][0]["dependsOn"], "tags": job_description["jobs"][0]["tags"], } if output["status"] in ["STARTING", "RUNNING", "SUCCEEDED", "FAILED"]: output["logStreamName"] = job_description["jobs"][0]["container"][ "logStreamName" ] return output def get_batch_logs(logStreamName): """ Retrieve and format logs for batch job. """ try: response = logs_client.get_log_events( logGroupName="/aws/batch/job", logStreamName=logStreamName ) except logs_client.meta.client.exceptions.ResourceNotFoundException: return f"Log stream {logStreamName} does not exist. Please try again in a few minutes" logs = pd.DataFrame.from_dict(response["events"]) logs.timestamp = logs.timestamp.transform( lambda x: datetime.fromtimestamp(x / 1000) ) logs.drop("ingestionTime", axis=1, inplace=True) return logs def download_dir(client, bucket, local="data", prefix=""): """Recursively download files from S3.""" paginator = client.get_paginator("list_objects_v2") file_count = 0 for result in paginator.paginate(Bucket=bucket, Delimiter="/", Prefix=prefix): if result.get("CommonPrefixes") is not None: for subdir in result.get("CommonPrefixes"): download_dir(client, bucket, local, subdir.get("Prefix")) for file in result.get("Contents", []): dest_pathname = os.path.join(local, file.get("Key")) if not os.path.exists(os.path.dirname(dest_pathname)): os.makedirs(os.path.dirname(dest_pathname)) client.download_file(bucket, file.get("Key"), dest_pathname) file_count += 1 print(f"{file_count} files downloaded from s3.") return local def download_results(bucket, job_name, local="data"): """Download MSA information from S3""" return download_dir(s3, bucket, local, job_name) def reduce_stockholm_file(sto_file): """Read in a .sto file and parse format it into a numpy array of the same length as the first (target) sequence """ msa = AlignIO.read(sto_file, "stockholm") msa_arr = np.array([list(rec) for rec in msa]) return msa_arr[:, msa_arr[0, :] != "-"] def plot_msa_array(msa_arr, id=None): total_msa_size = len(msa_arr) if total_msa_size > 1: aa_map = {res: i for i, res in enumerate("ABCDEFGHIJKLMNOPQRSTUVWXYZ-")} msa_arr = np.array([[aa_map[aa] for aa in seq] for seq in msa_arr]) plt.figure(figsize=(12, 3)) plt.title( f"Per-Residue Count of Non-Gap Amino Acids in the MSA for Sequence {id}" ) plt.plot(np.sum(msa_arr != aa_map["-"], axis=0), color="black") plt.ylabel("Non-Gap Count") plt.yticks(range(0, total_msa_size + 1, max(1, int(total_msa_size / 3)))) return plt else: print("Unable to display MSA of length 1") return None def plot_msa_folder(msa_folder, id=None): combined_msa = None with os.scandir(msa_folder) as it: for obj in it: obj_path = os.path.splitext(obj.path) if "pdb_hits" not in obj_path[0] and obj_path[1] == ".sto": msa_arr = reduce_stockholm_file(obj.path) if combined_msa is None: combined_msa = msa_arr else: combined_msa = np.concatenate((combined_msa, msa_arr), axis=0) if combined_msa is not None: print(f"Total number of aligned sequences is {len(combined_msa)}") plot_msa_array(combined_msa, id).show() return None else: return None def plot_msa_output_folder(path, id=None): """Plot MSAs in a folder that may have multiple chain folders""" plots = [] monomer = True with os.scandir(path) as it: for obj in it: if obj.is_dir(): monomer = False plot_msa_folder(obj.path, id + " " + obj.name) if monomer: plot_msa_folder(path, id) return None def display_structure( pdb_path, color="lDDT", show_sidechains=False, show_mainchains=False, chains=1, vmin=50, vmax=90, ): """ Display the predicted structure in a Jupyter notebook cell """ if color not in ["chain", "lDDT", "rainbow"]: raise ValueError("Color must be 'LDDT' (default), 'chain', or 'rainbow'") plot_pdb( pdb_path, show_sidechains=show_sidechains, show_mainchains=show_mainchains, color=color, chains=chains, vmin=vmin, vmax=vmax, ).show() if color == "lDDT": plot_plddt_legend().show() def submit_batch_alphafold_job( job_name, fasta_paths, s3_bucket, data_dir="/mnt/data_dir/fsx", output_dir="alphafold", bfd_database_path="/mnt/bfd_database_path/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt", mgnify_database_path="/mnt/mgnify_database_path/mgy_clusters_2018_12.fa", pdb70_database_path="/mnt/pdb70_database_path/pdb70", obsolete_pdbs_path="/mnt/obsolete_pdbs_path/obsolete.dat", template_mmcif_dir="/mnt/template_mmcif_dir/mmcif_files", pdb_seqres_database_path="/mnt/pdb_seqres_database_path/pdb_seqres.txt", small_bfd_database_path="/mnt/small_bfd_database_path/bfd-first_non_consensus_sequences.fasta", uniclust30_database_path="/mnt/uniclust30_database_path/uniclust30_2018_08/uniclust30_2018_08", uniprot_database_path="/mnt/uniprot_database_path/uniprot.fasta", uniref90_database_path="/mnt/uniref90_database_path/uniref90.fasta", max_template_date=datetime.now().strftime("%Y-%m-%d"), db_preset="reduced_dbs", model_preset="monomer", benchmark=False, use_precomputed_msas=False, features_paths=None, run_features_only=False, logtostderr=True, cpu=4, memory=16, gpu=1, depends_on=None, stack_name=None, use_spot_instances=False, run_relax=True, num_multimer_predictions_per_model=1 ): if stack_name is None: stack_name = list_alphafold_stacks()[0]["StackName"] batch_resources = get_batch_resources(stack_name) container_overrides = { "command": [ f"--fasta_paths={fasta_paths}", f"--uniref90_database_path={uniref90_database_path}", f"--mgnify_database_path={mgnify_database_path}", f"--data_dir={data_dir}", f"--template_mmcif_dir={template_mmcif_dir}", f"--obsolete_pdbs_path={obsolete_pdbs_path}", f"--output_dir={output_dir}", f"--max_template_date={max_template_date}", f"--db_preset={db_preset}", f"--model_preset={model_preset}", f"--s3_bucket={s3_bucket}", f"--run_relax={run_relax}", ], "resourceRequirements": [ {"value": str(cpu), "type": "VCPU"}, {"value": str(memory * 1000), "type": "MEMORY"}, ], } if model_preset == "multimer": container_overrides["command"].append( f"--uniprot_database_path={uniprot_database_path}" ) container_overrides["command"].append( f"--pdb_seqres_database_path={pdb_seqres_database_path}" ) container_overrides["command"].append( f"--num_multimer_predictions_per_model={num_multimer_predictions_per_model}" ) print("If multimer prediction failes due to Amber relaxation, re-run with run_relax=False") else: container_overrides["command"].append( f"--pdb70_database_path={pdb70_database_path}" ) if db_preset == "reduced_dbs": container_overrides["command"].append( f"--small_bfd_database_path={small_bfd_database_path}" ) else: container_overrides["command"].append( f"--uniclust30_database_path={uniclust30_database_path}" ) container_overrides["command"].append( f"--bfd_database_path={bfd_database_path}" ) if benchmark: container_overrides["command"].append("--benchmark") if use_precomputed_msas: container_overrides["command"].append("--use_precomputed_msas") if features_paths is not None: container_overrides["command"].append(f"--features_paths={features_paths}") if run_features_only: container_overrides["command"].append("--run_features_only") if logtostderr: container_overrides["command"].append("--logtostderr") if gpu > 0: if use_spot_instances: print("Spot instance queue not available for GPU jobs. Using on-demand queue instead.") job_definition = batch_resources["gpu_job_definition"] job_queue = batch_resources["gpu_job_queue"] container_overrides["resourceRequirements"].append( {"value": str(gpu), "type": "GPU"} ) else: job_definition = batch_resources["cpu_job_definition"] if use_spot_instances and batch_resources["cpu_job_queue_spot"] is not None: job_queue = batch_resources["cpu_job_queue_spot"] elif use_spot_instances and batch_resources["cpu_job_queue_spot"] is None: print("Spot instance queue not available. Using on-demand queue instead.") job_queue = batch_resources["cpu_job_queue_od"] else: job_queue = batch_resources["cpu_job_queue_od"] print(container_overrides) if depends_on is None: response = batch.submit_job( jobDefinition=job_definition, jobName=job_name, jobQueue=job_queue, containerOverrides=container_overrides, ) else: response = batch.submit_job( jobDefinition=job_definition, jobName=job_name, jobQueue=job_queue, containerOverrides=container_overrides, dependsOn=[{"jobId": depends_on, "type": "SEQUENTIAL"}], ) return response def get_run_metrics(bucket, job_name): timings_uri = sagemaker.s3.s3_path_join(bucket, job_name, "timings.json") ranking_uri = sagemaker.s3.s3_path_join(bucket, job_name, "ranking_debug.json") downloader = sagemaker.s3.S3Downloader() timing_dict = json.loads(downloader.read_file(f"s3://{timings_uri}")) ranking_dict = json.loads(downloader.read_file(f"s3://{ranking_uri}")) timing_df = pd.DataFrame.from_dict( timing_dict, orient="index", columns=["duration_sec"] ) ranking_plddts_df = pd.DataFrame.from_dict( ranking_dict["plddts"], orient="index", columns=["plddts"] ) order_df = pd.DataFrame.from_dict(ranking_dict["order"]) return (timing_df, ranking_plddts_df, order_df) def validate_input(input_sequences): output = [] for sequence in input_sequences: sequence = sequence.upper().strip() if re.search("[^ARNDCQEGHILKMFPSTWYV]", sequence): raise ValueError( f"Input sequence contains invalid amino acid symbols." f"{sequence}" ) output.append(sequence) if len(output) == 1: print("Using the monomer models.") model_preset = "monomer" return output, model_preset elif len(output) > 1: print("Using the multimer models.") model_preset = "multimer" return output, model_preset else: raise ValueError("Please provide at least one input sequence.") ### --------------------------------------------- # Original Copyright 2021 Sergey Ovchinnikov https://github.com/sokrypton/ColabFold # Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. pymol_color_list = [ "#33ff33", "#00ffff", "#ff33cc", "#ffff00", "#ff9999", "#e5e5e5", "#7f7fff", "#ff7f00", "#7fff7f", "#199999", "#ff007f", "#ffdd5e", "#8c3f99", "#b2b2b2", "#007fff", "#c4b200", "#8cb266", "#00bfbf", "#b27f7f", "#fcd1a5", "#ff7f7f", "#ffbfdd", "#7fffff", "#ffff7f", "#00ff7f", "#337fcc", "#d8337f", "#bfff3f", "#ff7fff", "#d8d8ff", "#3fffbf", "#b78c4c", "#339933", "#66b2b2", "#ba8c84", "#84bf00", "#b24c66", "#7f7f7f", "#3f3fa5", "#a5512b", ] alphabet_list = list(ascii_uppercase + ascii_lowercase) def plot_pdb( pred_output_path, show_sidechains=False, show_mainchains=False, color="lDDT", chains=None, Ls=None, vmin=50, vmax=90, color_HP=False, size=(800, 480), ): """ Create a 3D view of a pdb structure Copied from https://github.com/sokrypton/ColabFold/blob/main/beta/colabfold.py """ if chains is None: chains = 1 if Ls is None else len(Ls) view = py3Dmol.view( js="https://3dmol.org/build/3Dmol.js", width=size[0], height=size[1] ) view.addModel(open(pred_output_path,'r').read(),'pdb') if color == "lDDT": view.setStyle( { "cartoon": { "colorscheme": { "prop": "b", "gradient": "roygb", "min": vmin, "max": vmax, } } } ) elif color == "rainbow": view.setStyle({"cartoon": {"color": "spectrum"}}) elif color == "chain": for n, chain, color in zip(range(chains), alphabet_list, pymol_color_list): view.setStyle({"chain": chain}, {"cartoon": {"color": color}}) if show_sidechains: BB = ["C", "O", "N"] HP = [ "ALA", "GLY", "VAL", "ILE", "LEU", "PHE", "MET", "PRO", "TRP", "CYS", "TYR", ] if color_HP: view.addStyle( {"and": [{"resn": HP}, {"atom": BB, "invert": True}]}, {"stick": {"colorscheme": "yellowCarbon", "radius": 0.3}}, ) view.addStyle( {"and": [{"resn": HP, "invert": True}, {"atom": BB, "invert": True}]}, {"stick": {"colorscheme": "whiteCarbon", "radius": 0.3}}, ) view.addStyle( {"and": [{"resn": "GLY"}, {"atom": "CA"}]}, {"sphere": {"colorscheme": "yellowCarbon", "radius": 0.3}}, ) view.addStyle( {"and": [{"resn": "PRO"}, {"atom": ["C", "O"], "invert": True}]}, {"stick": {"colorscheme": "yellowCarbon", "radius": 0.3}}, ) else: view.addStyle( { "and": [ {"resn": ["GLY", "PRO"], "invert": True}, {"atom": BB, "invert": True}, ] }, {"stick": {"colorscheme": f"WhiteCarbon", "radius": 0.3}}, ) view.addStyle( {"and": [{"resn": "GLY"}, {"atom": "CA"}]}, {"sphere": {"colorscheme": f"WhiteCarbon", "radius": 0.3}}, ) view.addStyle( {"and": [{"resn": "PRO"}, {"atom": ["C", "O"], "invert": True}]}, {"stick": {"colorscheme": f"WhiteCarbon", "radius": 0.3}}, ) if show_mainchains: BB = ["C", "O", "N", "CA"] view.addStyle( {"atom": BB}, {"stick": {"colorscheme": f"WhiteCarbon", "radius": 0.3}} ) view.zoomTo() return view def plot_plddt_legend(dpi=100): """ Create 3D Plot legend Copied from https://github.com/sokrypton/ColabFold/blob/main/beta/colabfold.py """ thresh = [ "plDDT:", "Very low (<50)", "Low (60)", "OK (70)", "Confident (80)", "Very high (>90)", ] plt.figure(figsize=(1, 0.1), dpi=dpi) ######################################## for c in ["#FFFFFF", "#FF0000", "#FFFF00", "#00FF00", "#00FFFF", "#0000FF"]: plt.bar(0, 0, color=c) plt.legend( thresh, frameon=False, loc="center", ncol=6, handletextpad=1, columnspacing=1, markerscale=0.5, ) plt.axis(False) return plt ### ---------------------------------------------