# RWKV SageMaker Inference

This is a sample code to deploy RWKV on SageMaker.

In [None]:
!pip install "sagemaker>=2.143.0" -U

In [None]:
import sagemaker, boto3, json
from sagemaker import get_execution_role
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.huggingface import HuggingFace

role = get_execution_role()
region = boto3.Session().region_name
sess = sagemaker.Session()
bucket = sess.default_bucket()

sagemaker.__version__

## Prepare Model

## Package and Upload Model

In [None]:
!rm -rf scripts/model && mkdir scripts/model
%cd scripts
!tar -czvf ../package.tar.gz *
%cd -

In [None]:
model_path = sess.upload_data('package.tar.gz', bucket=bucket, key_prefix=f"RWKV")
model_path

## Deploy Model

In [None]:
from sagemaker.async_inference import AsyncInferenceConfig
from sagemaker.serializers import JSONSerializer

endpoint_name = "RWKV"

huggingface_model = PyTorchModel(
    model_data=model_path,
    framework_version="1.13",
    py_version='py39',
    role=role,
    name=endpoint_name,
    env={
        "model_params": json.dumps({
            "base_model": "RWKV/rwkv-4-169m-pile",
            "peft": False,
            "prompt_template": "alpaca",
        })
    }
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type='ml.g5.2xlarge',
    endpoint_name=endpoint_name,
    serializer=JSONSerializer(),
    # async_inference_config=AsyncInferenceConfig()
)

## Run Inference

In [None]:
import sagemaker
from sagemaker.predictor import Predictor
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

endpoint_name = "RWKV"

predictor_client = Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)
# predictor_client = AsyncPredictor(
#     predictor=predictor_client,
#     name=endpoint_name
# )
data = {
    "instruction": "When did Virgin Australia start operating?",
    "input": "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3] It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]",
    "max_new_tokens": 256,
    "temperature": 0.3,
    "do_sample": True,
    "pad_token_id": 1,
    "bos_token_id": 0,
    "eos_token_is": 0,
}
response = predictor_client.predict(
    data=data
)
print(response)

In [None]:
# With Boto3

import boto3
import json

endpoint_name = "RWKV"
sagemaker_client = boto3.client('sagemaker-runtime')

data = {
    "instruction": "When did Virgin Australia start operating?",
    "input": "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3] It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]",
    "max_new_tokens": 256,
    "temperature": 0.3,
    "do_sample": True,
    "pad_token_id": 1,
    "bos_token_id": 0,
    "eos_token_is": 0,
}

response = sagemaker_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType='application/json',
    Accept='application/json',
    Body=json.dumps(data)
)

result = json.loads(response['Body'].read())
print(result)

## Benchmark Speed

In [None]:
%timeit response = predictor_client.predict(data=data)

## Delete Endpoint

In [None]:
predictor.delete_model()
predictor.delete_endpoint()