diff --git a/src/nvidia-cuda/install.sh b/src/nvidia-cuda/install.sh index ca309fb25..6de935540 100644 --- a/src/nvidia-cuda/install.sh +++ b/src/nvidia-cuda/install.sh @@ -44,10 +44,24 @@ export DEBIAN_FRONTEND=noninteractive check_packages wget ca-certificates +# Determine system architecture and set NVIDIA repository URL accordingly +ARCH=$(uname -m) +case $ARCH in + x86_64) + NVIDIA_ARCH="x86_64" + ;; + aarch64 | arm64) + NVIDIA_ARCH="arm64" + ;; + *) + echo "Unsupported architecture: $ARCH" + exit 1 + ;; +esac + # Add NVIDIA's package repository to apt so that we can download packages -# Always use the ubuntu2004 repo because the other repos (e.g., debian11) are missing packages # Updating the repo to ubuntu2204 as ubuntu 20.04 is going out of support. -NVIDIA_REPO_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64" +NVIDIA_REPO_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/$NVIDIA_ARCH" KEYRING_PACKAGE="cuda-keyring_1.0-1_all.deb" KEYRING_PACKAGE_URL="$NVIDIA_REPO_URL/$KEYRING_PACKAGE" KEYRING_PACKAGE_PATH="$(mktemp -d)" @@ -62,6 +76,10 @@ nvtx_pkg="cuda-nvtx-${CUDA_VERSION/./-}" toolkit_pkg="cuda-toolkit-${CUDA_VERSION/./-}" if ! apt-cache show "$cuda_pkg"; then echo "The requested version of CUDA is not available: CUDA $CUDA_VERSION" + if [ "$NVIDIA_ARCH" = "arm64" ]; then + echo "Note: arm64 supports limited CUDA versions. Please check available versions:" + echo "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/arm64" + fi exit 1 fi @@ -93,6 +111,9 @@ if [ "$INSTALL_CUDNN" = "true" ]; then if ! apt-cache show "$cudnn_pkg_version"; then echo "The requested version of cuDNN is not available: cuDNN $CUDNN_VERSION for CUDA $CUDA_VERSION" + if [ "$NVIDIA_ARCH" = "arm64" ]; then + echo "Note: arm64 has limited cuDNN package availability" + fi exit 1 fi @@ -112,6 +133,9 @@ if [ "$INSTALL_CUDNNDEV" = "true" ]; then fi if ! apt-cache show "$cudnn_dev_pkg_version"; then echo "The requested version of cuDNN development package is not available: cuDNN $CUDNN_VERSION for CUDA $CUDA_VERSION" + if [ "$NVIDIA_ARCH" = "arm64" ]; then + echo "Note: arm64 has limited cuDNN development package availability" + fi exit 1 fi