diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2017-02-12 02:09:33 +0100 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2017-02-14 21:28:50 +0100 |
commit | 8c8dc791ef3da241bf8c7122372a26abf3cbe60f (patch) | |
tree | 276fb46ff3e127997609f31212c52e2b5f957dae /torch/_thnn | |
parent | 63edca44f28c04343e32c58439f0f422ab758fd1 (diff) | |
download | pytorch-8c8dc791ef3da241bf8c7122372a26abf3cbe60f.tar.gz pytorch-8c8dc791ef3da241bf8c7122372a26abf3cbe60f.tar.bz2 pytorch-8c8dc791ef3da241bf8c7122372a26abf3cbe60f.zip |
Load half and double THCUNN backends
Diffstat (limited to 'torch/_thnn')
-rw-r--r-- | torch/_thnn/__init__.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/torch/_thnn/__init__.py b/torch/_thnn/__init__.py index dd41e47a27..f84c6fbf25 100644 --- a/torch/_thnn/__init__.py +++ b/torch/_thnn/__init__.py @@ -58,7 +58,10 @@ for t in ['Float', 'Double']: type2backend.backends['torch.{}Tensor'.format(t)] = backend type2backend.backends[getattr(torch, '{}Tensor'.format(t))] = backend -backend = Backend('Cuda', 'torch._thnn._THCUNN', _thcunn_headers, (THNNCudaBackendStateMixin,)) -type2backend.backends['THNNCudaBackend'] = backend -type2backend.backends['torch.cuda.FloatTensor'] = backend -type2backend.backends[torch.cuda.FloatTensor] = backend + +for t in ['Half', '', 'Double']: + backend = Backend('Cuda' + t, 'torch._thnn._THCUNN', _thcunn_headers, (THNNCudaBackendStateMixin,)) + type2backend.backends['THNNCuda{}Backend'.format(t)] = backend + py_name = 'Float' if t == '' else t + type2backend.backends['torch.cuda.{}Tensor'.format(py_name)] = backend + type2backend.backends[getattr(torch.cuda, '{}Tensor'.format(py_name))] = backend |