summaryrefslogtreecommitdiff
path: root/torch/cuda
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2018-02-08 00:07:53 -0500
committerGitHub <noreply@github.com>2018-02-08 00:07:53 -0500
commit2d84cb4b04e79cdda295eb601331f8f4f1f12825 (patch)
tree9378adf70acbcf9af628eb41792b781a69c6a995 /torch/cuda
parent2c27bae80256022c20e6b7c783c6e0adcb32c3bc (diff)
downloadpytorch-2d84cb4b04e79cdda295eb601331f8f4f1f12825.tar.gz
pytorch-2d84cb4b04e79cdda295eb601331f8f4f1f12825.tar.bz2
pytorch-2d84cb4b04e79cdda295eb601331f8f4f1f12825.zip
warn that CUDA capability 3.0 and 5.0 is no longer supported (#5125)
Diffstat (limited to 'torch/cuda')
-rw-r--r--torch/cuda/__init__.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index 60e3ab8915..fce8309473 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -91,21 +91,29 @@ of the CUDA driver.""".format(str(torch._C._cuda_getDriverVersion())))
def _check_capability():
- error_str = """
+ incorrect_binary_warn = """
Found GPU%d %s which requires CUDA_VERSION >= %d for
optimal performance and fast startup time, but your PyTorch was compiled
with CUDA_VERSION %d. Please install the correct PyTorch binary
using instructions from http://pytorch.org
"""
+ old_gpu_warn = """
+ Found GPU%d %s which is of cuda capability %d.%d.
+ PyTorch no longer supports this GPU because it is too old.
+ """
+
CUDA_VERSION = torch._C._cuda_getCompiledVersion()
for d in range(device_count()):
- major = get_device_capability(d)[0]
+ capability = get_device_capability(d)
+ major = capability[0]
name = get_device_name(d)
if CUDA_VERSION < 8000 and major >= 6:
- warnings.warn(error_str % (d, name, 8000, CUDA_VERSION))
+ warnings.warn(incorrect_binary_warn % (d, name, 8000, CUDA_VERSION))
elif CUDA_VERSION < 9000 and major >= 7:
- warnings.warn(error_str % (d, name, 9000, CUDA_VERSION))
+ warnings.warn(incorrect_binary_warn % (d, name, 9000, CUDA_VERSION))
+ elif capability == (3, 0) or capability == (5, 0) or major < 3:
+ warnings.warn(old_gpu_warn % (d, name, major, capability[1]))
def _lazy_call(callable):