In [None]:
import time
import boto3
import sagemaker
import json
import sys
from time import gmtime, strftime
from sagemaker import get_execution_role
from sagemaker.workflow.steps import TrainingStep
from sagemaker.estimator import Estimator
from sagemaker.workflow.steps import ProcessingStep
from sagemaker.processing import Processor, ProcessingOutput
from sagemaker.workflow.conditions import ConditionEquals
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.inputs import FileSystemInput
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.pipeline_context import PipelineSession
from sagemaker.workflow.parameters import (
    ParameterInteger,
    ParameterString,
    ParameterFloat,
    ParameterBoolean
)

In [None]:
client = boto3.client("sts")
account=client.get_caller_identity()["Account"]

sess=sagemaker.Session()
region = boto3.session.Session().region_name
role = get_execution_role()
default_bucket= sess.default_bucket()

print(sagemaker.__version__)
pipeline_session = PipelineSession()

In [None]:
#User Input - Static Environment Varaibles
prefix='xxxxxx'

file_system_id = "xxxxxx" 
fsx_mount_id = "xxxxxx" 

vpc_subnet_ids = ['xxxxxx'] 
security_group_ids = ['xxxxxx'] 

#Below variables do not need to be changed if using default settings
file_system_access_mode = 'ro'
file_system_type = 'FSxLustre'
file_system_directory_path = f'/{fsx_mount_id}/{prefix}/alphafold-genetic-db'

alphafold_image_uri = f'{account}.dkr.ecr.{region}.amazonaws.com/sagemaker-studio-alphafold:v2.3.0-estimator'
openfold_image_uri=f'{account}.dkr.ecr.{region}.amazonaws.com/sagemaker-studio-openfold:v1.0.1'

In [None]:
#Dynamic variables which are inputs to each pipeline exection
fasta_file = ParameterString(name="FastaFileName")
fasta_input = ParameterString(name="FastaInputS3URI")

pipeline_db_preset = ParameterString(name="db_preset",
                                     default_value='full_dbs',
                                     enum_values=['full_dbs', 'reduced_dbs'])
max_template_date = ParameterString(name="MaxTemplateDate")
model_preset = ParameterString(name="ModelPreset")
num_multimer_predictions_per_model = ParameterString(name="NumMultimerPredictionsPerModel")

msa_instance_type = ParameterString(name="MSAInstanceType", default_value='ml.m5.4xlarge')
instance_type = ParameterString(name="InferenceInstanceType", default_value='ml.g5.4xlarge')

In [None]:
genetic_db = FileSystemInput(
     file_system_id=file_system_id,
     file_system_type=file_system_type,
     directory_path=file_system_directory_path,
     file_system_access_mode=file_system_access_mode
)

pipeline_fasta_msa = sagemaker.inputs.TrainingInput(fasta_input, 
                                       distribution="FullyReplicated", 
                                       s3_data_type="S3Prefix",
                                       input_mode='File'
                                      )
pipeline_data_channels_msa = {"genetic": genetic_db, 'fasta': pipeline_fasta_msa }

parameters={
            'DB_PRESET': pipeline_db_preset, 
            'FASTA_SUFFIX': fasta_file,
            'MAX_TEMPLATE_DATE': max_template_date,
            'MODEL_PRESET': model_preset,
            'NUM_MULTIMER_PREDICTIONS_PER_MODEL': num_multimer_predictions_per_model,
           }

output_path='s3://%s/%s/job-output/'%(default_bucket, prefix)

pipeline_msa = Estimator( 
                      source_dir='./source_dir',
                      entry_point='run_create_alignment.sh',
                      role=role,
                      image_uri=alphafold_image_uri,
                      instance_count=1,
                      instance_type=msa_instance_type,
                      volume_size=3000,
                      sagemaker_session=pipeline_session,
                      subnets=vpc_subnet_ids,
                      security_group_ids=security_group_ids,
                      debugger_hook_config=False,
                      base_job_name='msa-default-run',
                      hyperparameters=parameters,
                      enable_sagemaker_metrics=True,
                      output_path=output_path)

pipeline_msa_args = pipeline_msa.fit(inputs=pipeline_data_channels_msa)
                                                           
step_msa = TrainingStep(
    name="RunMSA",
    step_args=pipeline_msa_args,
)

In [None]:
genetic_db = FileSystemInput(
     file_system_id=file_system_id,
     file_system_type=file_system_type,
     directory_path=file_system_directory_path,
     file_system_access_mode=file_system_access_mode
)

pipeline_fasta_alphafold = sagemaker.inputs.TrainingInput(fasta_input,
                                       distribution="FullyReplicated", 
                                       s3_data_type="S3Prefix",
                                       input_mode='File'
                                      )

model_data=step_msa.properties.ModelArtifacts.S3ModelArtifacts

msa = sagemaker.inputs.TrainingInput(model_data,
                                     distribution="FullyReplicated", 
                                     s3_data_type="S3Prefix",
                                     input_mode='File'
                                    )

pipeline_data_channels_alphafold = {"genetic": genetic_db, 'fasta': pipeline_fasta_alphafold, 'msa': msa}

parameters={
            'DB_PRESET': pipeline_db_preset, 
            'FASTA_SUFFIX': fasta_file,
            'MAX_TEMPLATE_DATE': max_template_date,
            'MODEL_PRESET': model_preset,
            'NUM_MULTIMER_PREDICTIONS_PER_MODEL': num_multimer_predictions_per_model,
           }

output_path='s3://%s/%s/job-output/'%(default_bucket, prefix)

pipeline_alphafold_default = Estimator( 
                      source_dir='./source_dir',
                      entry_point='run_alphafold.sh',
                      role=role,
                      image_uri=alphafold_image_uri,
                      instance_count=1,
                      instance_type=instance_type,
                      sagemaker_session=pipeline_session,
                      subnets=vpc_subnet_ids,
                      security_group_ids=security_group_ids,
                      debugger_hook_config=False,
                      base_job_name='alphafold-default-run',
                      hyperparameters=parameters,
                      enable_sagemaker_metrics=True,
                      output_path=output_path)

pipeline_alphafold_default_args = pipeline_alphafold_default.fit(inputs=pipeline_data_channels_alphafold)
                                                           
step_alphafold = TrainingStep(
    name="RunAlphaFold",
    step_args=pipeline_alphafold_default_args,
)
step_alphafold.add_depends_on([step_msa])

In [None]:
genetic_db = FileSystemInput(
     file_system_id=file_system_id,
     file_system_type=file_system_type,
     directory_path=file_system_directory_path,
     file_system_access_mode=file_system_access_mode
)

pipeline_fasta_openfold = sagemaker.inputs.TrainingInput(fasta_input,
                                       distribution="FullyReplicated", 
                                       s3_data_type="S3Prefix",
                                       input_mode='File'
                                      )

s3_param_openfold = sess.upload_data(path='./source_dir/finetuning_ptm_2.pt',
                          key_prefix=f'{prefix}/openfold_params')

param_openfold = sagemaker.inputs.TrainingInput(s3_param_openfold,
                                       distribution="FullyReplicated", 
                                       s3_data_type="S3Prefix",
                                       input_mode='File'
                                      )

model_data = step_msa.properties.ModelArtifacts.S3ModelArtifacts

pipeline_data_channels_openfold = {"genetic": genetic_db, 'fasta': pipeline_fasta_openfold, 'param': param_openfold, 'msa':model_data }

parameters={
            'DB_PRESET': pipeline_db_preset, 
           }

output_path='s3://%s/%s/job-output/'%(default_bucket, "protein-folding/openfold")

pipeline_openfold = Estimator( 
                      source_dir='./source_dir',
                      entry_point='run_openfold.sh',
                      role=role,
                      image_uri=openfold_image_uri,
                      instance_count=1,
                      instance_type=instance_type,
                      sagemaker_session=pipeline_session,
                      subnets=vpc_subnet_ids,
                      security_group_ids=security_group_ids,
                      debugger_hook_config=False,
                      base_job_name='openfold-default-run',
                      hyperparameters=parameters,
                      enable_sagemaker_metrics=True,
                      output_path=output_path,
                      code_location=output_path)

pipeline_openfold_args = pipeline_openfold.fit(inputs=pipeline_data_channels_openfold)
                                                           
step_openfold = TrainingStep(
    name="RunOpenFold",
    step_args=pipeline_openfold_args,
)
step_openfold.add_depends_on([step_msa])

In [None]:
pipeline_name = f"ProteinFoldWorkflow"
pipeline = Pipeline(
    name=pipeline_name,
    parameters=[
        fasta_file,
        fasta_input,
        max_template_date,
        model_preset,
        num_multimer_predictions_per_model,
        instance_type,
        msa_instance_type,
        pipeline_db_preset                           
    ],
    steps=[step_msa, step_alphafold, step_openfold],
)

pipeline.upsert(role_arn=role, 
               description='Protein_Workflow_MSA_Alphafold_Openfold') 

In [None]:
#User inputs for pipeline run 
fasta_file = 'T1030.fasta' #Default pipeline execution name will drop .fasta
!mkdir ./sequence_input/
!curl 'https://www.predictioncenter.org/casp14/target.cgi?target=T1030&view=sequence' > ./sequence_input/T1030.fasta 

In [None]:
pathName = f'./sequence_input/{fasta_file}'
s3_fasta=sess.upload_data(path=pathName,
                          key_prefix='alphafoldv2/sequence_input')

PipelineParameters={
            'FastaInputS3URI':s3_fasta,
            'db_preset': 'full_dbs', 
            'FastaFileName': fasta_file,
            'MaxTemplateDate': '2020-05-14',
            'ModelPreset': 'monomer',
            'NumMultimerPredictionsPerModel': '5',
            'InferenceInstanceType':'ml.g5.4xlarge',
            'MSAInstanceType':'ml.m5.4xlarge'
        }

experiment_name = fasta_file[:fasta_file.find(".")] 

execution = pipeline.start(execution_display_name=experiment_name, 
                           execution_description=f'This pipeline was executed via SageMaker SDK and is running an inference for {experiment_name}',
                           parameters=PipelineParameters
                          )