ARG PYTHON=python3
ARG PYTHON_VERSION=3.8.13
ARG MAMBA_VERSION=4.12.0-2

# TODO: Declare and assign build-time variables for AWS PyTorch and other packages here
# TODO: Declare and assign build-time variable for S3 plugin

FROM ubuntu:20.04 AS common

LABEL maintainer="Amazon AI"
LABEL dlc_major_version="1"

ARG PYTHON
ARG PYTHON_VERSION
ARG OPEN_MPI_VERSION=4.0.1
ARG MAMBA_VERSION

# This arg required to stop docker build waiting for region configuration while installing tz data from ubuntu 20
ARG DEBIAN_FRONTEND=noninteractive

# Python won’t try to write .pyc or .pyo files on the import of source modules
# Force stdin, stdout and stderr to be totally unbuffered. Good for logging
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
ENV PYTHONIOENCODING=UTF-8
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/lib"
ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/opt/conda/lib"
ENV PATH=/opt/conda/bin:$PATH
ENV DGLBACKEND=pytorch
ENV DLC_CONTAINER_TYPE=training

WORKDIR /

RUN apt-get update \
 # TODO: Remove systemd upgrade once it is updated in base image
 && apt-get -y upgrade --only-upgrade systemd \
 && apt-get install -y --no-install-recommends \
    build-essential \
    ca-certificates \
    cmake \
    curl \
    emacs \
    git \
    jq \
    libcurl4-openssl-dev \
    libglib2.0-0 \
    libgl1-mesa-glx \
    libsm6 \
    libssl-dev \
    libxext6 \
    libxrender-dev \
    software-properties-common \
    unzip \
    vim \
    wget \
    zlib1g-dev \
 && rm -rf /var/lib/apt/lists/* \
 && apt-get clean

# Install Open MPI
RUN wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-$OPEN_MPI_VERSION.tar.gz \
 && gunzip -c openmpi-$OPEN_MPI_VERSION.tar.gz | tar xf - \
 && cd openmpi-$OPEN_MPI_VERSION \
 && ./configure --prefix=/home/.openmpi \
 && make all install \
 && cd .. \
 && rm openmpi-$OPEN_MPI_VERSION.tar.gz \
 && rm -rf openmpi-$OPEN_MPI_VERSION

# The ENV variables declared below are changed in the previous section
# Grouping these ENV variables in the first section causes
# ompi_info to fail. This is only observed in CPU containers
ENV PATH="$PATH:/home/.openmpi/bin"
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/.openmpi/lib/"
RUN ompi_info --parsable --all | grep mpi_built_with_cuda_support:value

# Install OpenSSH for MPI to communicate between containers, allow OpenSSH to talk to containers without asking for confirmation
RUN apt-get update \
 && apt-get install -y --no-install-recommends openssh-client openssh-server \
 && mkdir -p /var/run/sshd \
 && cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new \
 && echo "    StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new \
 && mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config \
 && rm -rf /var/lib/apt/lists/* \
 && apt-get clean

# Configure OpenSSH so that nodes can communicate with each other
RUN mkdir -p /var/run/sshd \
 && sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd

RUN rm -rf /root/.ssh/ \
 && mkdir -p /root/.ssh/ \
 && ssh-keygen -q -t rsa -N '' -f /root/.ssh/id_rsa \
 && cp /root/.ssh/id_rsa.pub /root/.ssh/authorized_keys \
 && printf "Host *\n StrictHostKeyChecking no\n" >> /root/.ssh/config

RUN curl -L -o ~/mambaforge.sh https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-x86_64.sh \
 && chmod +x ~/mambaforge.sh \
 && ~/mambaforge.sh -b -p /opt/conda \
 && rm ~/mambaforge.sh \
 && /opt/conda/bin/conda install -c conda-forge \
    python=$PYTHON_VERSION \
    cython \
    mkl \
    mkl-include \
    parso \
    typing \
    h5py \
    requests \
    "pyopenssl>=17.5.0" \
    # Below 2 are included in miniconda base, but not mamba so need to install
    conda-content-trust \
    charset-normalizer \
 && /opt/conda/bin/conda install -c dglteam -y dgl \
 # Upstream conda looks to have moved to 4.13 which is incompatible with mamba 0.22.1 and will fail the conda-forge installs.
 # having "conda update conda" before the "conda -c conda-forge" commands will automatically update conda to 4.13.
 # Moving conda update conda" after the "conda -c conda-forge" commands keep conda at 4.12 but will update other packages using
 # the current conda 4.12
 && /opt/conda/bin/conda update -y conda \
 && /opt/conda/bin/conda clean -ya

# Conda installs links for libtinfo.so.6 and libtinfo.so.6.2 both
# Which causes "/opt/conda/lib/libtinfo.so.6: no version information available" warning
# Removing link for libtinfo.so.6. This change is needed only for ubuntu 20.04-conda, and can be reverted
# once conda fixes the issue
RUN rm -rf /opt/conda/lib/libtinfo.so.6

RUN pip install --upgrade pip --trusted-host pypi.org --trusted-host files.pythonhosted.org \
 && ln -s /opt/conda/bin/pip /usr/local/bin/pip3 \
 && pip install --no-cache-dir -U \
    "awscli<2" \
    boto3 \
    click \
    "cryptography>3.2" \
    ipython==8.1.0 \
    numpy==1.22.2 \
    "opencv-python>=4.6,<5" \
    packaging \
    "Pillow>=9.0.0" \
    psutil==5.9.0 \
    "pyyaml>=5.4,<5.5" \
    scipy

COPY deep_learning_container.py /usr/local/bin/deep_learning_container.py

RUN chmod +x /usr/local/bin/deep_learning_container.py

RUN curl -o /license.txt  https://aws-dlc-licenses.s3.amazonaws.com/pytorch-1.12/license.txt

########################################################
#  _____ ____ ____    ___
# | ____/ ___|___ \  |_ _|_ __ ___   __ _  __ _  ___
# |  _|| |     __) |  | || '_ ` _ \ / _` |/ _` |/ _ \
# | |__| |___ / __/   | || | | | | | (_| | (_| |  __/
# |_____\____|_____| |___|_| |_| |_|\__,_|\__, |\___|
#                                         |___/
#  ____           _
# |  _ \ ___  ___(_)_ __   ___ 
# | |_) / _ \/ __| | '_ \ / _ \
# |  _ <  __/ (__| | |_) |  __/
# |_| \_\___|\___|_| .__/ \___|
#                  |_|
########################################################

FROM common AS ec2

ARG PYTHON

# TODO: Redeclare build-time variables for AWS PyTorch and other Torch packages here
# TODO: Redeclare build-time variable for S3 plugin

# uninstall if already present
RUN pip uninstall -y torch torchvision torchaudio torchdata

# TODO: Install AWS-PyTorch and other torch packages here

# PyTorch and other packages currently installed from pytorch-nightly pipeline
RUN /opt/conda/bin/conda install -c pytorch-nightly \
    pytorch=1.13.* \
    torchvision \
    torchaudio \
    torchdata

# TODO: Install S3 plugin

RUN mkdir -p /etc/pki/tls/certs && cp /etc/ssl/certs/ca-certificates.crt /etc/pki/tls/certs/ca-bundle.crt

RUN HOME_DIR=/root \
 && curl -o ${HOME_DIR}/oss_compliance.zip https://aws-dlinfra-utilities.s3.amazonaws.com/oss_compliance.zip \
 && unzip ${HOME_DIR}/oss_compliance.zip -d ${HOME_DIR}/ \
 && cp ${HOME_DIR}/oss_compliance/test/testOSSCompliance /usr/local/bin/testOSSCompliance \
 && chmod +x /usr/local/bin/testOSSCompliance \
 && chmod +x ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh \
 && ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh ${HOME_DIR} ${PYTHON} \
 && rm -rf ${HOME_DIR}/oss_compliance* \
 && rm -rf /tmp/tmp*

# Starts framework
CMD ["/bin/bash"]

#################################################################
#  ____                   __  __       _
# / ___|  __ _  __ _  ___|  \/  | __ _| | _____ _ __
# \___ \ / _` |/ _` |/ _ \ |\/| |/ _` | |/ / _ \ '__|
#  ___) | (_| | (_| |  __/ |  | | (_| |   <  __/ |
# |____/ \__,_|\__, |\___|_|  |_|\__,_|_|\_\___|_|
#              |___/
#  ___                              ____           _
# |_ _|_ __ ___   __ _  __ _  ___  |  _ \ ___  ___(_)_ __   ___
#  | || '_ ` _ \ / _` |/ _` |/ _ \ | |_) / _ \/ __| | '_ \ / _ \
#  | || | | | | | (_| | (_| |  __/ |  _ <  __/ (__| | |_) |  __/
# |___|_| |_| |_|\__,_|\__, |\___| |_| \_\___|\___|_| .__/ \___|
#                      |___/                        |_|
#################################################################

FROM common AS sagemaker

LABEL maintainer="Amazon AI"
LABEL dlc_major_version="1"

ARG PYTHON

# TODO: Redeclare build-time variable for S3 plugin

# TODO: Install SMDEBUG VERSION here

ENV SAGEMAKER_TRAINING_MODULE=sagemaker_pytorch_container.training:main

# TODO: Redeclare ARGs for PyTorch and other Torch packages here

# uninstall if already present
RUN pip uninstall -y torch torchvision torchaudio torchdata

# TODO: Install AWS-PyTorch and other torch packages here

RUN /opt/conda/bin/conda install -c pytorch-nightly \
    pytorch=1.13.* \
    torchvision \
    torchaudio \
    torchdata

# TODO: Install S3 plugin

RUN mkdir -p /etc/pki/tls/certs && cp /etc/ssl/certs/ca-certificates.crt /etc/pki/tls/certs/ca-bundle.crt

WORKDIR /

# Copy workaround script for incorrect hostname
COPY changehostname.c /
COPY start_with_right_hostname.sh /usr/local/bin/start_with_right_hostname.sh

RUN chmod +x /usr/local/bin/start_with_right_hostname.sh

# Install scikit-learn and pandas
RUN conda install -y -c conda-forge \
    scikit-learn \
    pandas

RUN pip install --upgrade pip --trusted-host pypi.org --trusted-host files.pythonhosted.org \
 && pip install --no-cache-dir -U \
    # disable smdebug pip install until available stable smdebug releases
    # smdebug==${SMDEBUG_VERSION} \
    smclarify \
    "sagemaker>=2,<3" \
    sagemaker-experiments==0.* \
    "sagemaker-pytorch-training<3"

# TODO: Install SMDEBUG from source here

# Install extra packages
# numba 0.54 only works with numpy>=1.20. See https://github.com/numba/numba/issues/7339
RUN pip install --no-cache-dir -U \
    "bokeh>=2.3,<3" \
    "imageio>=2.9,<3" \
    "opencv-python>=4.3,<5" \
    "plotly>=5.1,<6" \
    "seaborn>=0.11,<1" \
    "numba<0.54" \
    "shap>=0.39,<1"

RUN HOME_DIR=/root \
 && curl -o ${HOME_DIR}/oss_compliance.zip https://aws-dlinfra-utilities.s3.amazonaws.com/oss_compliance.zip \
 && unzip ${HOME_DIR}/oss_compliance.zip -d ${HOME_DIR}/ \
 && cp ${HOME_DIR}/oss_compliance/test/testOSSCompliance /usr/local/bin/testOSSCompliance \
 && chmod +x /usr/local/bin/testOSSCompliance \
 && chmod +x ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh \
 && ${HOME_DIR}/oss_compliance/generate_oss_compliance.sh ${HOME_DIR} ${PYTHON} \
 && rm -rf ${HOME_DIR}/oss_compliance* \
 && rm -rf /tmp/tmp*

ENTRYPOINT ["bash", "-m", "start_with_right_hostname.sh"]
CMD ["/bin/bash"]