ARG SAGEMAKER_XGBOOST_VERSION=1.7-1
ARG PYTHON_VERSION=3.8

FROM xgboost-container-base:${SAGEMAKER_XGBOOST_VERSION}-cpu-py3

ARG SAGEMAKER_XGBOOST_VERSION

########################
# Install dependencies #
########################
COPY requirements.txt /requirements.txt
RUN python3 -m pip install -r /requirements.txt && rm /requirements.txt

# Install smdebug from source
RUN python3 -m pip install git+https://github.com/awslabs/sagemaker-debugger.git@1.0.29


###########################
# Copy wheel to container #
###########################
COPY dist/sagemaker_xgboost_container-2.0-py2.py3-none-any.whl /sagemaker_xgboost_container-1.0-py2.py3-none-any.whl
RUN rm -rf /miniconda3/lib/python3.8/site-packages/numpy-1.21.2.dist-info && \
    python3 -m pip install --no-cache /sagemaker_xgboost_container-1.0-py2.py3-none-any.whl && \
    python3 -m pip uninstall -y typing && \
    rm /sagemaker_xgboost_container-1.0-py2.py3-none-any.whl

##############
# DMLC PATCH #
##############
# TODO: remove after making contributions back to xgboost for tracker.py
COPY src/sagemaker_xgboost_container/dmlc_patch/tracker.py \
   /miniconda3/lib/python${PYTHON_VERSION}/site-packages/xgboost/dmlc-core/tracker/dmlc_tracker/tracker.py

# Include DMLC python code in PYTHONPATH to use RabitTracker
ENV PYTHONPATH=$PYTHONPATH:/miniconda3/lib/python${PYTHON_VERSION}/site-packages/xgboost/dmlc-core/tracker

#######
# MMS #
#######
# Create MMS user directory
RUN useradd -m model-server
RUN mkdir -p /home/model-server/tmp && chown -R model-server /home/model-server

# Copy MMS configs
COPY docker/${SAGEMAKER_XGBOOST_VERSION}/resources/mms/config.properties.tmp /home/model-server
ENV XGBOOST_MMS_CONFIG=/home/model-server/config.properties

# Copy execution parameters endpoint plugin for MMS
RUN mkdir -p /tmp/plugins
COPY docker/${SAGEMAKER_XGBOOST_VERSION}/resources/mms/endpoints-1.0.jar /tmp/plugins
RUN chmod +x /tmp/plugins/endpoints-1.0.jar

# Create directory for models
RUN mkdir -p /opt/ml/models
RUN chmod +rwx /opt/ml/models

# Copy Dask configs
RUN mkdir /etc/dask
COPY docker/configs/dask_configs.yaml /etc/dask/

# Required label for multi-model loading
LABEL com.amazonaws.sagemaker.capabilities.multi-models=true

#####################
# Required ENV vars #
#####################
# Set SageMaker training environment variables
ENV SM_INPUT /opt/ml/input
ENV SM_INPUT_TRAINING_CONFIG_FILE $SM_INPUT/config/hyperparameters.json
ENV SM_INPUT_DATA_CONFIG_FILE $SM_INPUT/config/inputdataconfig.json
ENV SM_CHECKPOINT_CONFIG_FILE $SM_INPUT/config/checkpointconfig.json
# See: https://github.com/dmlc/xgboost/issues/7982#issuecomment-1379390906 https://github.com/dmlc/xgboost/pull/8257
ENV NCCL_SOCKET_IFNAME eth 
                            

# Set SageMaker serving environment variables
ENV SM_MODEL_DIR /opt/ml/model

# Set SageMaker entrypoints
ENV SAGEMAKER_TRAINING_MODULE sagemaker_xgboost_container.training:main
ENV SAGEMAKER_SERVING_MODULE sagemaker_xgboost_container.serving:main

EXPOSE 8080
ENV TEMP=/home/model-server/tmp
LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true