# Cerebras SageMaker Finetuning

This is a sample code to finetune and deploy Cerebras with LoRA on SageMaker.

In [None]:
!pip install -U "sagemaker>=2.143.0"

In [None]:
import sagemaker, boto3, json
from sagemaker import get_execution_role
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.huggingface import HuggingFace

role = get_execution_role()
region = boto3.Session().region_name
sess = sagemaker.Session()
bucket = sess.default_bucket()

sagemaker.__version__

## Upload Data

We will use Databricks-dolly-15k as sample dataset to finetune the model. (License: [Creative Commons Attribution-ShareAlike 3.0 Unported License](https://creativecommons.org/licenses/by-sa/3.0/legalcode))

You may also choose to use custom dataset.

In [None]:
!curl -L https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl --create-dirs -o data/databricks-dolly-15k.jsonl

In [None]:
!head -n 2 data/databricks-dolly-15k.jsonl

In [None]:
# Convet .jsonl to .json
import pandas as pd
df = pd.read_json('data/databricks-dolly-15k.jsonl', orient='records', lines=True)
df = df.rename(columns={"context": "input", "response": "output"})
df.to_json("data/databricks-dolly-15k.json", orient='records')

In [None]:
input_train = sess.upload_data(
    path="./data/databricks-dolly-15k.json",
    key_prefix="Dolly"
)
input_train

## Fine-tuning

Fine-tuning took approximately 4 hours for 1 epoch on p3.2xlarge.

In [None]:
hyperparameters={
    'base_model':'cerebras/Cerebras-GPT-6.7B',
    'data_path': '/opt/ml/input/data/train/databricks-dolly-15k.json',
    'num_epochs': 3, # default 3
    'cutoff_len': 512,
    'group_by_length': True,
    'output_dir': '/opt/ml/model',
    'lora_target_modules': '[c_attn]',
    'lora_r': 8,
    'micro_batch_size': 8,
    "prompt_template_name": "alpaca_short",
}

In [None]:
huggingface_estimator = HuggingFace(
    base_job_name="Cerebras",
    role=role,
    entry_point='finetune.py',
    source_dir='./scripts/code',
    instance_type='ml.p3.2xlarge',
    instance_count=1,
    volume_size=200,
    transformers_version='4.26',
    pytorch_version='1.13',
    py_version='py39',
    use_spot_instances=True,
    max_wait=86400,
    hyperparameters=hyperparameters,
)
huggingface_estimator.fit({'train': input_train})

## Download and Extract Model

In [None]:
import boto3
import sagemaker

def get_latest_training_job_artifact(base_job_name):
    sagemaker_client = boto3.client('sagemaker')
    response = sagemaker_client.list_training_jobs(NameContains=base_job_name, SortBy='CreationTime', SortOrder='Descending')
    training_job_arn = response['TrainingJobSummaries'][0]['TrainingJobArn']
    training_job_description = sagemaker_client.describe_training_job(TrainingJobName=training_job_arn.split('/')[-1])
    return training_job_description['ModelArtifacts']['S3ModelArtifacts']

try:
    model_data = huggingface_estimator.model_data
except:
    # Retrieve artifact url when kernel is restarted
    model_data = get_latest_training_job_artifact('Cerebras')
    
!aws s3 cp {model_data} cerebras.tar.gz

In [None]:
!rm -rf scripts/model && mkdir scripts/model
!tar -xf cerebras.tar.gz -C scripts/model --no-same-owner

## Package and Upload Model

In [None]:
%cd scripts
!tar -czvf ../package.tar.gz *
%cd -

In [None]:
model_path = sess.upload_data('package.tar.gz', bucket=bucket, key_prefix=f"Cerebras")
model_path

## Deploy Model

In [None]:
from sagemaker.async_inference import AsyncInferenceConfig
from sagemaker.serializers import JSONSerializer

endpoint_name = "Cerebras"

huggingface_model = PyTorchModel(
    model_data=model_path,
    framework_version="1.13",
    py_version='py39',
    role=role,
    name=endpoint_name,
    env={
        "model_params": json.dumps({
            "base_model": "cerebras/Cerebras-GPT-6.7B",
            "lora_weights": "model", # path relative to model package
            "load_8bit": "True",
            "prompt_template": "alpaca_short",
        })
    }
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type='ml.g5.2xlarge',
    endpoint_name=endpoint_name,
    serializer=JSONSerializer(),
    async_inference_config=AsyncInferenceConfig()
)

## Run Inference

In [None]:
from sagemaker.predictor import Predictor
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import NumpyDeserializer

endpoint_name = "Cerebras"

predictor_client = AsyncPredictor(
    predictor=Predictor(
        endpoint_name=endpoint_name,
        sagemaker_session=sess,
        serializer=JSONSerializer(),
        deserializer=NumpyDeserializer()
    ),
    name=endpoint_name
)
data = {
    "instruction": "When did Virgin Australia start operating?",
    "input": "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3] It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]",
    "max_new_tokens": 512,
}
response = predictor_client.predict(
    data=data
)
print(response)

## Delete Endpoint

In [None]:
predictor.delete_model()
predictor.delete_endpoint()