In [None]:
# Import necessary libraries 

import kfp
from kfp import components
from kubeflow.training.utils import utils
from kfp import dsl
from kfp import compiler

import os
import yaml
import json
from kubeflow.training import PyTorchJobClient
import time
import boto3
import kfp.components as comp

import boto3
import random, string
from datetime import datetime

In [None]:
# Read PyTorch Operator master and worker from the YAML file
# Pipeline specs are created in notebook '1_submit_pytorchdist_k8s.ipynb'. Please ensure you have run it. Alternatively, you can create the specs manually. 

with open("pipeline_yaml_specifications/pipeline_master_spec.yml", 'r') as master_stream:
 master_spec_loaded = yaml.safe_load(master_stream)
 
with open("pipeline_yaml_specifications/pipeline_worker_spec.yml", 'r') as worker_stream:
 worker_spec_loaded = yaml.safe_load(worker_stream)

In [None]:
#Change the ecr image url below eg <'458473390725.dkr.ecr.us-west-2.amazonaws.com\/kserve_layout'>
!sed -i "s/{image_url}/'458473390725.dkr.ecr.us-west-2.amazonaws.com\/kserve_layout:latest'/g" pipeline_components/kserve_layout_component.yaml

In [None]:
# Initialize global variables 
user_namespace = utils.get_default_target_namespace()

# Loads PyTorch Training Operator, KServe Operator and KServe Model Layout component from the File
pytorch_job_op = components.load_component_from_file('pipeline_components/pytorch_component.yaml')
unix_model_layout_op = components.load_component_from_file('pipeline_components/kserve_layout_component.yaml')
kserve_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/master/components/kserve/component.yaml')

In [None]:
# Same bucket name that we used in notebook 1_submit_pytorchdist_k8s.ipynb
s3_bucket_name='kserve-model-20230527042622'

# Set KServe InferenceEndpoint
timsestamp = datetime.now().strftime('%Y%m%d%H%M%S')
kserve_model_endpoint_name='image-classify-' + timsestamp

kserve_model_endpoint_name

In [None]:
# Create job name for tracking kuberenets PyTorchJob custom resource or SageMaker training job
pytorch_distributed_jobname=f'pytorch-cnn-dist-job-{time.strftime("%Y-%m-%d-%H-%M-%S-%j", time.gmtime())}'


# Create Hybrid Pipeline using Kubeflow PyTorch Training Operators and Amazon SageMaker Service
@dsl.pipeline(name="PyTorch Training pipeline", description="Sample training job test")
def pytorch_cnn_pipeline(action='apply',
 model_name=kserve_model_endpoint_name,
 model_uri=f's3://{kserve_s3_bucket_name}',
 framework='pytorch',
 region='us-west-2',
 training_job_name=pytorch_distributed_jobname,
 namespace=user_namespace,
 ):
 
 train_task = pytorch_job_op(
 name=training_job_name, 
 namespace=user_namespace, 
 master_spec=json.dumps(master_spec_loaded), # Please refer file at pipeline_yaml_specifications/pipeline_master_spec.yml
 worker_spec=json.dumps(worker_spec_loaded), # Please refer file at pipeline_yaml_specifications/pipeline_worker_spec.yml
 delete_after_done=False
 )
 
 unix_model_layout=unix_model_layout_op(
 bucket=kserve_s3_bucket_name, 
 model_input='model_kserve.pth', 
 model_archive_name='cifar').after(train_task)

 kserve_deploy=kserve_op(action=action,
 model_name=model_name,
 model_uri=model_uri,
 namespace=namespace,
 framework=framework).after(unix_model_layout)
 
 
 #Disable pipeline cache 
 train_task.execution_options.caching_strategy.max_cache_staleness = "P0D"
 unix_model_layout.execution_options.caching_strategy.max_cache_staleness = "P0D"
 kserve_deploy.execution_options.caching_strategy.max_cache_staleness = "P0D"

In [None]:
# DSL Compiler that compiles pipeline functions into workflow yaml.

kfp.compiler.Compiler().compile(pytorch_cnn_pipeline, "pytorch_cnn_pipeline.yaml")

In [None]:
# Connect to Kubeflow Pipelines using the Kubeflow Pipelines SDK client
client = kfp.Client()

experiment = client.create_experiment(name="kubeflow")

# Run a specified pipeline 
my_run = client.run_pipeline(experiment.id, "pytorch_cnn_pipeline", "pytorch_cnn_pipeline.yaml")

# Please click “Run details” link generated below this cell to view your pipeline. You can click every pipeline step to see logs. 