apiVersion: kubeflow.org/v2beta1
kind: MPIJob
metadata:
  name: fairscale-1b-pretraining-efa
spec:
  slotsPerWorker: 1
  runPolicy:
    cleanPodPolicy: Running
  mpiReplicaSpecs:
    Launcher:
      replicas: 1
      template:
         spec:
          imagePullPolicy: Always
          restartPolicy: OnFailure
          containers:
          - image: 231748552833.dkr.ecr.us-east-1.amazonaws.com/aws/pt21.06-fairscale:latest
            name: fairscale-training-launcher
            env:
             - name: GPUS_PER_NODE
               value: "8"
             - name: MASTER_ADDR
               value: fairscale-1b-pretraining-efa-worker-0
             - name: MASTER_PORT
               value: "6002"
             - name: NNODES
               value: $OMPI_MCA_orte_num_nodes
             - name: NODE_RANK
               value: $OMPI_COMM_WORLD_RANK
             - name: WORLD_SIZE
               value: $(($GPUS_PER_NODE*$NNODES))
             - name: LD_LIBRARY_PATH
               value: /opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/aws-ofi-nccl/install/lib:/usr/local/nvidia/lib:$LD_LIBRARY_PATH
             - name: PATH
               value: $PATH:/opt/amazon/efa/bin:/usr/local/bin:/usr/bin:/opt/conda/bin
             - name: XLA_FLAGS
               value: "--xla_gpu_cuda_data_dir=/usr/local/cuda"
             - name: TF_XLA_FLAGS
               value: "--tf_xla_cpu_global_jit"
             - name: DISTRIBUTED_ARGS
               value: "--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
    
            command:
            - /opt/amazon/openmpi/bin/mpirun
            - --allow-run-as-root
            - --tag-output
            - -np
            - "16"
            - -bind-to
            - none
            - -map-by
            - slot
            - -x 
            - PATH
            - -x 
            - LD_LIBRARY_PATH
            - -x
            - MASTER_ADDR
            - -x
            - MASTER_PORT
            - /fairdata/train.sh

    Worker:
      replicas: 2
      template:
        spec:
          imagePullPolicy: Always
          containers:
          - image: 231748552833.dkr.ecr.us-east-1.amazonaws.com/aws/pt21.06-fairscale:latest
            name: fairscale-worker
            resources:
              limits:
                nvidia.com/gpu: 8
                hugepages-2Mi: 15120Mi
                memory: 784000Mi
                vpc.amazonaws.com/efa: 4
              requests:
                nvidia.com/gpu: 8
                hugepages-2Mi: 15120Mi
                memory: 784000Mi
                vpc.amazonaws.com/efa: 4