#!/bin/bash set -e # get version info from AWS-PyTorch PT_VERSION=$(python -c "import torch; print(torch.__version__.split('+')[0]);") HAS_CUDA=$(python -c "import torch; print(torch.version.cuda is not None)") # install OSS PyTorch conda remove -y --force pytorch torchvision torchaudio if [ -n "$HAS_CUDA" ]; then mamba install -y pytorch=$PT_VERSION torchvision torchaudio pytorch-cuda -c pytorch -c nvidia else mamba install -y pytorch=$PT_VERSION torchvision torchaudio -c pytorch fi