# RWKV SageMaker Inference

This is a sample code to deploy RWKV on SageMaker.

In [None]:
!pip install "sagemaker>=2.143.0" -U

In [4]:
import sagemaker, boto3, json
from sagemaker import get_execution_role
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.huggingface import HuggingFace

role = get_execution_role()
region = boto3.Session().region_name
sess = sagemaker.Session()
bucket = sess.default_bucket()

sagemaker.__version__

'2.145.0'

## Prepare Model

In [6]:
# Download Tokenizer
!curl https://raw.githubusercontent.com/BlinkDL/RWKV-LM/main/RWKV-v4/20B_tokenizer.json --create-dirs -o data/20B_tokenizer.json

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 2410k  100 2410k    0     0  6227k      0 --:--:-- --:--:-- --:--:-- 6227k


In [7]:
base_model_name = "RWKV-4-Pile-169M"


if base_model_name == "RWKV-4-Pile-169M":
    base_model_file = "RWKV-4-Pile-169M-20220807-8023.pth"
    n_layer = 12
    n_embd = 768
elif base_model_name == "RWKV-4-Pile-430M":
    base_model_file = "RWKV-4-Pile-430M-20220808-8066.pth"
    n_layer = 24
    n_embd = 1024
elif base_model_name == "RWKV-4-Pile-1B5":
    base_model_file = "RWKV-4-Pile-1B5-20220903-8040.pth"
    n_layer = 24
    n_embd = 2048
elif base_model_name == "RWKV-4-Pile-3B":
    base_model_file = "RWKV-4-Pile-3B-20221008-8023.pth"
    n_layer = 32
    n_embd = 2560
elif base_model_name == "RWKV-4-Pile-7B":
    base_model_file = "RWKV-4-Pile-7B-20221115-8047.pth"
    n_layer = 32
    n_embd = 4096
elif base_model_name == "RWKV-4-Pile-14B":
    base_model_file = "RWKV-4-Pile-14B-20230213-8019.pth"
    n_layer = 40
    n_embd = 5120

base_model_url = f"https://huggingface.co/BlinkDL/{base_model_name.lower()}/resolve/main/{base_model_file}"
base_model_path = f"scripts/code/base_models/{base_model_file}"

if not os.path.exists(base_model_path):
    import urllib.request
    print(f"Downloading {base_model_name} this may take a while")
    urllib.request.urlretrieve(base_model_url, base_model_path)

print(f"Using {base_model_path} as base")

Using scripts/code/base_models/RWKV-4-Pile-169M-20220807-8023.pth as base


## Package and Upload Model

In [14]:
!rm -rf scripts/model && mkdir scripts/model
!cp data/20B_tokenizer.json scripts/model
%cd scripts
!tar -czvf ../package.tar.gz *
%cd -

/root/LLM/RWKV/scripts
code/
code/src/
code/src/utils.py
code/src/model.py
code/src/model_img.py
code/src/model_run.py
code/src/trainer.py
code/src/.ipynb_checkpoints/
code/src/.ipynb_checkpoints/binidx-checkpoint.py
code/src/.ipynb_checkpoints/model_run-checkpoint.py
code/src/.ipynb_checkpoints/trainer-checkpoint.py
code/src/.ipynb_checkpoints/model-checkpoint.py
code/src/.ipynb_checkpoints/utils-checkpoint.py
code/src/binidx.py
code/src/dataset.py
code/src/__init__.py
code/requirements.txt
code/.ipynb_checkpoints/
code/.ipynb_checkpoints/train-checkpoint.py
code/.ipynb_checkpoints/inference-checkpoint.py
code/.ipynb_checkpoints/requirements-checkpoint.txt
code/inference.py
code/base_models/
code/base_models/RWKV-4-Pile-169M-20220807-8023.pth
code/base_models/.ipynb_checkpoints/
code/base_models/.gitignore
code/train.py
code/cuda/
code/cuda/wkv_cuda_bf16.cu
code/cuda/wkv_cuda.cu
code/cuda/.ipynb_checkpoints/
code/cuda/.ipynb_checkpoints/wkv_op-checkpoint.cpp
code/cuda/.ipynb_checkpoin

In [15]:
model_path = sess.upload_data('package.tar.gz', bucket=bucket, key_prefix=f"RWKV")
model_path

's3://sagemaker-us-west-2-867115166077/RWKV/package.tar.gz'

## Deploy Model

In [18]:
from sagemaker.async_inference import AsyncInferenceConfig
from sagemaker.serializers import JSONSerializer

endpoint_name = "RWKV"

huggingface_model = PyTorchModel(
    model_data=model_path,
    framework_version="1.13",
    py_version='py39',
    role=role,
    name=endpoint_name,
    env={
        "model_params": json.dumps({
            "model_path": f"code/base_models/{base_model_file}",
            "tokenizer_path": "model/20B_tokenizer.json",
            "strategy": "cuda bf16",
        })
        
    }
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type='ml.g5.2xlarge',
    endpoint_name=endpoint_name,
    serializer=JSONSerializer(),
    async_inference_config=AsyncInferenceConfig()
)

---------!

## Run Inference

In [19]:
import sagemaker
from sagemaker.predictor import Predictor
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import NumpyDeserializer, JSONDeserializer

endpoint_name = "RWKV"

predictor_client = AsyncPredictor(
    predictor=Predictor(
        endpoint_name=endpoint_name,
        sagemaker_session=sagemaker.Session(),
        serializer=JSONSerializer(),
        deserializer=NumpyDeserializer()
    ),
    name=endpoint_name
)
data = {
    "instruction": "When did Virgin Australia start operating?",
    "input": "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3] It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]",
}
response = predictor_client.predict(
    data=data
)
print(response)

Virgin Australia is a privately owned airline, and operates in the Australian Capital Territory. It operates from Melbourne to Brisbane, and operates from Melbourne to Sydney.[5]
# Response:
Virgin Australia is an Australian-based airline that operates from Melbourne to Sydney. It operates from Melbourne to Sydney, and operates from Melbourne to Sydney.[6]
# Response:
Virgin Australia is an Australian-based airline that operates from Melbourne to Sydney. It operates from Melbourne to Sydney, and operates from Melbourne to Sydney.[6]
# Response:
Virgin Australia is an Australian-based airline that operates from Melbourne to Sydney, and operates from Melbourne to Sydney.[6]
# Response:
Virgin Australia is an Australian-based airline that operates from Melbourne to Sydney, and operates from Melbourne to Sydney.[6]
# Response:
Virgin Australia is an Australian-based airline that operates from Melbourne to Sydney, and operates from Melbourne to Sydney.[6]
# Response:
Virgin Australia is an


## Delete Endpoint

In [20]:
predictor.delete_model()
predictor.delete_endpoint()