# RWKV SageMaker Finetuning

This is a sample code to finetune and deploy RWKV on SageMaker.

In [None]:
!pip install transformers

In [None]:
import sagemaker, boto3, json, os
import pandas as pd
import numpy as np
from sagemaker import get_execution_role
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.huggingface import HuggingFace

role = get_execution_role()
region = boto3.Session().region_name
sess = sagemaker.Session()
bucket = sess.default_bucket()

sagemaker.__version__

## Download and Tokenize Dataset

We will use Databricks-dolly-15k as sample dataset to finetune the model. (License: [Creative Commons Attribution-ShareAlike 3.0 Unported License](https://creativecommons.org/licenses/by-sa/3.0/legalcode))

You may also choose to use custom dataset.

In [None]:
# Download Databricks Dolly 15k dataset
!curl -L https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl --create-dirs -o data/databricks-dolly-15k.jsonl

In [None]:
!head -n 2 data/databricks-dolly-15k.jsonl

In [None]:
# Download Tokenizer
!curl https://raw.githubusercontent.com/BlinkDL/RWKV-LM/main/RWKV-v4/20B_tokenizer.json --create-dirs -o data/20B_tokenizer.json

In [None]:
# Tokenize Input
from transformers import PreTrainedTokenizerFast

def generate_prompt(instruction, input, output):
 if input != "":
 return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
# Instruction:
{instruction}
# Input:
{input}
# Response:
{output}
<|endoftext|>
"""
 else:
 return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
# Instruction:
{instruction}
# Response:
{output}
<|endoftext|>
"""

def tokenize_and_save(input_file):
 dataset = pd.read_json(input_file, orient='records', lines=True)
 dataset = dataset.rename(columns={"context": "input", "response": "output"})
 dataset = dataset.to_dict(orient='records')

 output = ""
 for item in dataset:
 if len(item.keys()) > 0:
 output += generate_prompt(
 item["instruction"].strip(),
 item["input"].strip(),
 item["output"].strip()
 )

 np.set_printoptions(precision=4, suppress=True, linewidth=200)

 tokenizer = PreTrainedTokenizerFast(tokenizer_file='data/20B_tokenizer.json')
 output_file = 'data/train.npy'
 data_raw = output
 data_code = tokenizer.encode(data_raw)
 print(f'Tokenized length = {len(data_code)}')

 out = np.array(data_code, dtype='uint16')
 np.save(output_file, out, allow_pickle=False)

tokenize_and_save('data/databricks-dolly-15k.jsonl')

In [None]:
input_train = sess.upload_data(
 path="./data/train.npy",
 key_prefix="Dolly"
)
input_train

## Prepare Model

In [None]:
base_model_name = "RWKV-4-Pile-169M"


if base_model_name == "RWKV-4-Pile-169M":
 base_model_file = "RWKV-4-Pile-169M-20220807-8023.pth"
 n_layer = 12
 n_embd = 768
elif base_model_name == "RWKV-4-Pile-430M":
 base_model_file = "RWKV-4-Pile-430M-20220808-8066.pth"
 n_layer = 24
 n_embd = 1024
elif base_model_name == "RWKV-4-Pile-1B5":
 base_model_file = "RWKV-4-Pile-1B5-20220903-8040.pth"
 n_layer = 24
 n_embd = 2048
elif base_model_name == "RWKV-4-Pile-3B":
 base_model_file = "RWKV-4-Pile-3B-20221008-8023.pth"
 n_layer = 32
 n_embd = 2560
elif base_model_name == "RWKV-4-Pile-7B":
 base_model_file = "RWKV-4-Pile-7B-20221115-8047.pth"
 n_layer = 32
 n_embd = 4096
elif base_model_name == "RWKV-4-Pile-14B":
 base_model_file = "RWKV-4-Pile-14B-20230213-8019.pth"
 n_layer = 40
 n_embd = 5120

base_model_url = f"https://huggingface.co/BlinkDL/{base_model_name.lower()}/resolve/main/{base_model_file}"
base_model_path = f"scripts/code/base_models/{base_model_file}"

## Train

In [None]:
hyperparameters={
 'load_model': f"/tmp/{base_model_file}",
 'model_url': base_model_url,
 'proj_dir': '/opt/ml/model',
 'data_file': '/opt/ml/input/data/train/train.npy',
 'data_type': 'numpy',
 'vocab_size': 50277,
 'ctx_len': 384,
 'epoch_save': 5,
 'epoch_count': 3,
 'max_epochs': 3, # Added to stop training
 'n_layer': n_layer,
 'n_embd': n_embd,
 'epoch_steps': 1000,
 'micro_bsz': 8,
 'accelerator': "gpu",
 'devices': 1,
 'precision': 'bf16',
 'strategy': "deepspeed_stage_2",
 'enable_progress_bar': False, # Added to Reduce Log
}

In [None]:
huggingface_estimator = HuggingFace(
 base_job_name="RWKV",
 role=role,
 entry_point='train.py',
 source_dir='./scripts/code',
 instance_type='ml.g5.2xlarge',
 instance_count=1,
 volume_size=200,
 transformers_version='4.26',
 pytorch_version='1.13',
 py_version='py39',
 # use_spot_instances=True,
 hyperparameters=hyperparameters,
)
huggingface_estimator.fit({'train': input_train})

## Download and Extract Model

In [None]:
import boto3
import sagemaker

def get_latest_training_job_artifact(base_job_name):
 sagemaker_client = boto3.client('sagemaker')
 response = sagemaker_client.list_training_jobs(NameContains=base_job_name, SortBy='CreationTime', SortOrder='Descending')
 training_job_arn = response['TrainingJobSummaries'][0]['TrainingJobArn']
 training_job_description = sagemaker_client.describe_training_job(TrainingJobName=training_job_arn.split('/')[-1])
 return training_job_description['ModelArtifacts']['S3ModelArtifacts']

try:
 model_data = huggingface_estimator.model_data
except:
 # Retrieve artifact url when kernel is restarted
 model_data = get_latest_training_job_artifact('RWKV')
 
!aws s3 cp {model_data} rwkv.tar.gz

In [None]:
import tarfile
tarf = tarfile.open('rwkv.tar.gz', 'r:gz')
files = tarf.getnames()
files = sorted(tarf.getnames())
target_file = files[-2]
print(files, target_file)

In [None]:
!rm -rf scripts/model && mkdir scripts/model
!tar -xf rwkv.tar.gz -C scripts/model --no-same-owner {target_file}
!cp data/20B_tokenizer.json scripts/model
!ls scripts/model

## Package and Upload Model

In [None]:
%cd scripts
!tar --exclude base_models -czvf ../package.tar.gz *
%cd -

In [None]:
model_path = sess.upload_data('package.tar.gz', bucket=bucket, key_prefix=f"RWKV")
model_path

## Deploy Model

In [None]:
from sagemaker.async_inference import AsyncInferenceConfig
from sagemaker.serializers import JSONSerializer

endpoint_name = "RWKV"

huggingface_model = PyTorchModel(
 model_data=model_path,
 framework_version="1.13",
 py_version='py39',
 role=role,
 name=endpoint_name,
 env={
 "model_params": json.dumps({
 "model_path": "model/" + target_file,
 "tokenizer_path": "model/20B_tokenizer.json", # path relative to model package
 "strategy": "cuda bf16",
 }),
 }
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
 initial_instance_count=1,
 instance_type='ml.g5.2xlarge',
 endpoint_name=endpoint_name,
 serializer=JSONSerializer(),
 async_inference_config=AsyncInferenceConfig()
)

## Run Inference

In [None]:
import sagemaker
from sagemaker.predictor import Predictor
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import NumpyDeserializer, JSONDeserializer

endpoint_name = "RWKV"

predictor_client = AsyncPredictor(
 predictor=Predictor(
 endpoint_name=endpoint_name,
 sagemaker_session=sagemaker.Session(),
 serializer=JSONSerializer(),
 deserializer=NumpyDeserializer()
 ),
 name=endpoint_name
)
data = {
 "instruction": "When did Virgin Australia start operating?",
 "input": "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3] It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]",
}
response = predictor_client.predict(
 data=data
)
print(response)

## Delete Endpoint

In [None]:
predictor.delete_model()
predictor.delete_endpoint()