FROM nvidia/cuda:11.7.1-base-ubuntu20.04 AS base_image

ENV DEBIAN_FRONTEND=noninteractive \
    LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/lib"

FROM base_image AS ec2

LABEL maintainer="Amazon AI"
LABEL dlc_major_version="1"

# Specify accept-bind-to-port LABEL for inference pipelines to use SAGEMAKER_BIND_TO_PORT
# https://docs.aws.amazon.com/sagemaker/latest/dg/inference-pipeline-real-time.html
LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true

ARG MMS_VERSION=1.1.11
ARG PYTHON=python3
ARG PYTHON_VERSION=3.9.13
ARG MAMBA_VERSION=22.9.0-1
ARG OPEN_MPI_VERSION=4.1.4

# Nvidia software versions
ARG CUBLAS_VERSION=11.10.3.66
ARG CUDNN_VERSION=8.5.0.96

# PyTorch Binaries and versions.
ARG TORCH_URL=https://aws-pytorch-unified-cicd-binaries.s3.us-west-2.amazonaws.com/r1.13.1_ec2/20221219-193736/54406b8eed7fbd61be629cb06229dfb7b6b2954e/torch-1.13.1%2Bcu117-cp39-cp39-linux_x86_64.whl
ARG TORCHVISION_VERSION=0.14.1+cu117
ARG TORCHAUDIO_VERSION=0.13.1+cu117

# HF ARGS
ARG TRANSFORMERS_VERSION
ARG DIFFUSERS_VERSION=0.16.1

# Set Debian interaction
ARG DEBIAN_FRONTEND=noninteractive

# See http://bugs.python.org/issue19846
ENV LANG C.UTF-8
ENV LD_LIBRARY_PATH="/opt/conda/lib:${LD_LIBRARY_PATH}"
ENV PATH /opt/conda/bin:$PATH
ENV TEMP=/home/model-server/tmp
# Set MKL_THREADING_LAYER=GNU to prevent issues between torch and numpy/mkl
ENV MKL_THREADING_LAYER=GNU
ENV DLC_CONTAINER_TYPE=inference

ENV TORCH_CUDA_ARCH_LIST="5.0 7.0+PTX 7.5+PTX 8.0"
ENV NCCL_VERSION=2.14.3
ENV NVML_VERSION=11.7.91

RUN apt-get update \
 # TODO: Remove systemd upgrade once it is updated in base image
 && apt-get -y upgrade --only-upgrade systemd openssl cryptsetup \
 && apt-get install -y --no-install-recommends software-properties-common \
# && add-apt-repository ppa:openjdk-r/ppa \
 && apt-get update \
 && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends \
    build-essential \
    ca-certificates \
    cmake \
    cuda-command-line-tools-11-7 \
    cuda-cudart-11-7 \
    cuda-libraries-11-7 \
    cuda-libraries-dev-11-7 \
    curl \
    emacs \
    git \
    jq \
    libcublas-11-7=${CUBLAS_VERSION}-1 \
    libcublas-dev-11-7=${CUBLAS_VERSION}-1 \
    libcudnn8=${CUDNN_VERSION}-1+cuda11.7 \
    libcufft-dev-11-7 \
    libcurand-dev-11-7 \
    libcurl4-openssl-dev \
    libcusolver-dev-11-7 \
    libcusparse-dev-11-7 \
    cuda-nvml-dev-11-7=${NVML_VERSION}-1 \
    libgl1-mesa-glx \
    libglib2.0-0 \
    libgomp1 \
    libibverbs-dev \
    libnuma1 \
    libnuma-dev \
    libsm6 \
    libssl1.1 \
    libssl-dev \
    libxext6 \
    libxrender-dev \
    openjdk-11-jdk \
    openssl \
    vim \
    wget \
    unzip \
    zlib1g-dev \
    libsndfile1-dev \
    ffmpeg \
    openssh-client \
    openssh-server \
 && rm -rf /var/lib/apt/lists/* \
 && apt-get clean

# Install aws-sdk-cpp s3;transfer modules for torchdata
RUN git clone --recurse-submodules https://github.com/aws/aws-sdk-cpp \
 && cd aws-sdk-cpp/ \
 && mkdir sdk-build \
 && cd sdk-build \
 && cmake .. -DCMAKE_BUILD_TYPE=Release -DBUILD_ONLY="s3;transfer" -DAUTORUN_UNIT_TESTS=OFF \
 && make \
 && make install \
 && cd ../.. \
 && rm -rf aws-sdk-cpp

# Install NCCL
RUN cd /tmp \
 && git clone https://github.com/NVIDIA/nccl.git -b v${NCCL_VERSION}-1 \
 && cd nccl \
 && make -j64 src.build BUILDDIR=/usr/local \
 && rm -rf /tmp/nccl

# https://github.com/docker-library/openjdk/issues/261 https://github.com/docker-library/openjdk/pull/263/files
RUN keytool -importkeystore -srckeystore /etc/ssl/certs/java/cacerts -destkeystore /etc/ssl/certs/java/cacerts.jks -deststoretype JKS -srcstorepass changeit -deststorepass changeit -noprompt; \
    mv /etc/ssl/certs/java/cacerts.jks /etc/ssl/certs/java/cacerts; \
    /var/lib/dpkg/info/ca-certificates-java.postinst configure;

RUN wget --quiet https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-${OPEN_MPI_VERSION}.tar.gz \
 && gunzip -c openmpi-${OPEN_MPI_VERSION}.tar.gz | tar xf - \
 && cd openmpi-${OPEN_MPI_VERSION} \
 && ./configure --prefix=/home/.openmpi --with-cuda \
 && make all install \
 && cd .. \
 && rm openmpi-${OPEN_MPI_VERSION}.tar.gz \
 && rm -rf openmpi-${OPEN_MPI_VERSION}

ENV PATH="$PATH:/home/.openmpi/bin"
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/.openmpi/lib/"

# Install CondaForge miniconda
RUN curl -L -o ~/mambaforge.sh https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-x86_64.sh \
 && chmod +x ~/mambaforge.sh \
 && ~/mambaforge.sh -b -p /opt/conda \
 && rm ~/mambaforge.sh \
 && /opt/conda/bin/conda config --set ssl_verify False \
 && /opt/conda/bin/conda update -y conda \
 && /opt/conda/bin/conda install -c conda-forge python=${PYTHON_VERSION} \
    cython \
    mkl \
    mkl-include \
    parso \
    scipy \
    typing \
    h5py \
    requests \
    libgcc \
    # Below 2 are included in miniconda base, but not mamba so need to install
    conda-content-trust \
    charset-normalizer \
 && /opt/conda/bin/conda install -c pytorch magma-cuda117 \
 && /opt/conda/bin/conda clean -ya

# symlink pip for OS use
RUN pip install --upgrade pip --trusted-host pypi.org --trusted-host files.pythonhosted.org \
 && ln -s /opt/conda/bin/pip /usr/local/bin/pip3

WORKDIR /root

# Ensure PyTorch did not get installed from Conda due to magma-cuda117
RUN pip uninstall -y torch torchvision torchaudio model-archiver multi-model-server

# Install AWS-PyTorch, and other torch packages
RUN pip install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cu117 -U \
    "awscli<2" \
    boto3 \
    "cryptography>3.2" \
    enum-compat==0.0.3 \
    numpy==1.22.2 \
    opencv-python \
    packaging \
    "Pillow>=9.0.0" \
    "pyyaml>=5.4,<5.5" \
    captum \
    ${TORCH_URL} \
    torchaudio==${TORCHAUDIO_VERSION} \
    torchvision==${TORCHVISION_VERSION}

# Conda installs links for libtinfo.so.6 and libtinfo.so.6.2 both
# Which causes "/opt/conda/lib/libtinfo.so.6: no version information available" warning
# Removing link for libtinfo.so.6. This change is needed only for ubuntu 20.04-conda, and can be reverted
# once conda fixes the issue: https://github.com/conda/conda/issues/9680
RUN rm -rf /opt/conda/lib/libtinfo.so.6

# add necessary certificate for aws sdk cpp download
RUN mkdir -p /etc/pki/tls/certs && cp /etc/ssl/certs/ca-certificates.crt /etc/pki/tls/certs/ca-bundle.crt

WORKDIR /

RUN pip install --no-cache-dir \
    multi-model-server==$MMS_VERSION \
    sagemaker-inference

RUN cd tmp/ \
 && rm -rf tmp*

RUN useradd -m model-server \
 && mkdir -p /home/model-server/tmp /opt/ml/model \
 && chown -R model-server /home/model-server /opt/ml/model

COPY mms-entrypoint.py /usr/local/bin/dockerd-entrypoint.py
COPY config.properties /etc/sagemaker-mms.properties

RUN chmod +x /usr/local/bin/dockerd-entrypoint.py

COPY deep_learning_container.py /usr/local/bin/deep_learning_container.py

RUN chmod +x /usr/local/bin/deep_learning_container.py

#################################
# Hugging Face specific section #
#################################

RUN curl -o /license.txt  https://aws-dlc-licenses.s3.amazonaws.com/pytorch-1.13/license.txt

# install Hugging Face libraries and its dependencies
RUN pip install --no-cache-dir \
	transformers[sentencepiece,audio,vision]==${TRANSFORMERS_VERSION} \
    diffusers==${DIFFUSERS_VERSION} \
    "accelerate>=0.11.0" \
	"protobuf>=3.19.5,<=3.20.2" \
	numpy>=1.21.5 \
	"sagemaker-huggingface-inference-toolkit<3"


RUN HOME_DIR=/root \
 && curl -o ${HOME_DIR}/oss_compliance.zip https://aws-dlinfra-utilities.s3.amazonaws.com/oss_compliance.zip \
 && unzip ${HOME_DIR}/oss_compliance.zip -d ${HOME_DIR}/ \
 && cp ${HOME_DIR}/oss_compliance/test/testOSSCompliance /usr/local/bin/testOSSCompliance \
 && chmod +x /usr/local/bin/testOSSCompliance \
 && chmod +x ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh \
 && ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh ${HOME_DIR} ${PYTHON} \
 && rm -rf ${HOME_DIR}/oss_compliance*

# Removing the cache as it is needed for security verification
RUN rm -iRf /root/.cache

EXPOSE 8080 8081
ENTRYPOINT ["python", "/usr/local/bin/dockerd-entrypoint.py"]
CMD ["serve"]