summaryrefslogtreecommitdiff
path: root/torch/_thnn
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2016-08-19 14:22:47 -0700
committerAdam Paszke <adam.paszke@gmail.com>2016-08-19 14:56:55 -0700
commite055ffbdc7a09abd58802a116b0d885cd6e31b95 (patch)
treec235f9e58e87ea20ca945044e3e8cec53432c765 /torch/_thnn
parent53f00ae429aa1bd18b407ffd17d06c9e85578edf (diff)
downloadpytorch-e055ffbdc7a09abd58802a116b0d885cd6e31b95.tar.gz
pytorch-e055ffbdc7a09abd58802a116b0d885cd6e31b95.tar.bz2
pytorch-e055ffbdc7a09abd58802a116b0d885cd6e31b95.zip
Add nn
Diffstat (limited to 'torch/_thnn')
-rw-r--r--torch/_thnn/thcunn.py1
-rw-r--r--torch/_thnn/thnn.py1
2 files changed, 2 insertions, 0 deletions
diff --git a/torch/_thnn/thcunn.py b/torch/_thnn/thcunn.py
index 6883b6198b..44d871af77 100644
--- a/torch/_thnn/thcunn.py
+++ b/torch/_thnn/thcunn.py
@@ -14,3 +14,4 @@ for function in generic_functions:
backend = load_backend('Cuda', torch._thnn._THCUNN, generic_functions, (THNNCudaBackendStateMixin,))
type2backend['torch.cuda.FloatTensor'] = backend
+type2backend[torch.cuda.FloatTensor] = backend
diff --git a/torch/_thnn/thnn.py b/torch/_thnn/thnn.py
index 6cf8e6cf0d..d96d4432ed 100644
--- a/torch/_thnn/thnn.py
+++ b/torch/_thnn/thnn.py
@@ -6,4 +6,5 @@ generic_functions = parse_header(THNN_H_PATH)
for t in ['Float', 'Double']:
backend = load_backend(t, torch._thnn._THNN, generic_functions)
type2backend['torch.' + t + 'Tensor'] = backend
+ type2backend[getattr(torch, t + 'Tensor')] = backend