diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2016-08-19 14:22:47 -0700 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2016-08-19 14:56:55 -0700 |
commit | e055ffbdc7a09abd58802a116b0d885cd6e31b95 (patch) | |
tree | c235f9e58e87ea20ca945044e3e8cec53432c765 /torch/_thnn | |
parent | 53f00ae429aa1bd18b407ffd17d06c9e85578edf (diff) | |
download | pytorch-e055ffbdc7a09abd58802a116b0d885cd6e31b95.tar.gz pytorch-e055ffbdc7a09abd58802a116b0d885cd6e31b95.tar.bz2 pytorch-e055ffbdc7a09abd58802a116b0d885cd6e31b95.zip |
Add nn
Diffstat (limited to 'torch/_thnn')
-rw-r--r-- | torch/_thnn/thcunn.py | 1 | ||||
-rw-r--r-- | torch/_thnn/thnn.py | 1 |
2 files changed, 2 insertions, 0 deletions
diff --git a/torch/_thnn/thcunn.py b/torch/_thnn/thcunn.py index 6883b6198b..44d871af77 100644 --- a/torch/_thnn/thcunn.py +++ b/torch/_thnn/thcunn.py @@ -14,3 +14,4 @@ for function in generic_functions: backend = load_backend('Cuda', torch._thnn._THCUNN, generic_functions, (THNNCudaBackendStateMixin,)) type2backend['torch.cuda.FloatTensor'] = backend +type2backend[torch.cuda.FloatTensor] = backend diff --git a/torch/_thnn/thnn.py b/torch/_thnn/thnn.py index 6cf8e6cf0d..d96d4432ed 100644 --- a/torch/_thnn/thnn.py +++ b/torch/_thnn/thnn.py @@ -6,4 +6,5 @@ generic_functions = parse_header(THNN_H_PATH) for t in ['Float', 'Double']: backend = load_backend(t, torch._thnn._THNN, generic_functions) type2backend['torch.' + t + 'Tensor'] = backend + type2backend[getattr(torch, t + 'Tensor')] = backend |