25 lines
941 B
Bash
25 lines
941 B
Bash
#!/bin/bash
|
|
|
|
if [[ -n "${CUDNN_VERSION}" ]]; then
|
|
# cuDNN license: https://developer.nvidia.com/cudnn/license_agreement
|
|
mkdir tmp_cudnn
|
|
pushd tmp_cudnn
|
|
if [[ ${CUDA_VERSION:0:4} == "12.9" || ${CUDA_VERSION:0:4} == "12.8" ]]; then
|
|
CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive"
|
|
elif [[ ${CUDA_VERSION:0:4} == "12.6" ]]; then
|
|
CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive"
|
|
elif [[ ${CUDA_VERSION:0:2} == "11" ]]; then
|
|
CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda11-archive"
|
|
else
|
|
print "Unsupported CUDA version ${CUDA_VERSION}"
|
|
exit 1
|
|
fi
|
|
curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/${CUDNN_NAME}.tar.xz
|
|
tar xf ${CUDNN_NAME}.tar.xz
|
|
cp -a ${CUDNN_NAME}/include/* /usr/local/cuda/include/
|
|
cp -a ${CUDNN_NAME}/lib/* /usr/local/cuda/lib64/
|
|
popd
|
|
rm -rf tmp_cudnn
|
|
ldconfig
|
|
fi
|