summaryrefslogtreecommitdiff
path: root/torch/_thnn
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2017-02-12 02:09:33 +0100
committerAdam Paszke <adam.paszke@gmail.com>2017-02-14 21:28:50 +0100
commit8c8dc791ef3da241bf8c7122372a26abf3cbe60f (patch)
tree276fb46ff3e127997609f31212c52e2b5f957dae /torch/_thnn
parent63edca44f28c04343e32c58439f0f422ab758fd1 (diff)
downloadpytorch-8c8dc791ef3da241bf8c7122372a26abf3cbe60f.tar.gz
pytorch-8c8dc791ef3da241bf8c7122372a26abf3cbe60f.tar.bz2
pytorch-8c8dc791ef3da241bf8c7122372a26abf3cbe60f.zip
Load half and double THCUNN backends
Diffstat (limited to 'torch/_thnn')
-rw-r--r--torch/_thnn/__init__.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/torch/_thnn/__init__.py b/torch/_thnn/__init__.py
index dd41e47a27..f84c6fbf25 100644
--- a/torch/_thnn/__init__.py
+++ b/torch/_thnn/__init__.py
@@ -58,7 +58,10 @@ for t in ['Float', 'Double']:
type2backend.backends['torch.{}Tensor'.format(t)] = backend
type2backend.backends[getattr(torch, '{}Tensor'.format(t))] = backend
-backend = Backend('Cuda', 'torch._thnn._THCUNN', _thcunn_headers, (THNNCudaBackendStateMixin,))
-type2backend.backends['THNNCudaBackend'] = backend
-type2backend.backends['torch.cuda.FloatTensor'] = backend
-type2backend.backends[torch.cuda.FloatTensor] = backend
+
+for t in ['Half', '', 'Double']:
+ backend = Backend('Cuda' + t, 'torch._thnn._THCUNN', _thcunn_headers, (THNNCudaBackendStateMixin,))
+ type2backend.backends['THNNCuda{}Backend'.format(t)] = backend
+ py_name = 'Float' if t == '' else t
+ type2backend.backends['torch.cuda.{}Tensor'.format(py_name)] = backend
+ type2backend.backends[getattr(torch.cuda, '{}Tensor'.format(py_name))] = backend