diff options
author | Will Feng <willfeng@fb.com> | 2019-02-22 07:54:47 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-22 08:00:25 -0800 |
commit | be6ad7ddde64cebeec4cd9537ddc3bca9cbe18e5 (patch) | |
tree | 21bc96f596acc775281eb36457176bd6b1d3515d /torch | |
parent | 562fa55f3dfc2d3ce1cac19e3ec9dab202e06b3b (diff) | |
download | pytorch-be6ad7ddde64cebeec4cd9537ddc3bca9cbe18e5.tar.gz pytorch-be6ad7ddde64cebeec4cd9537ddc3bca9cbe18e5.tar.bz2 pytorch-be6ad7ddde64cebeec4cd9537ddc3bca9cbe18e5.zip |
Rename BatchNorm running_variance to running_var (#17371)
Summary:
Currently there is a mismatch in naming between Python BatchNorm `running_var` and C++ BatchNorm `running_variance`, which causes JIT model parameters loading to fail (https://github.com/pytorch/vision/pull/728#issuecomment-466067138):
```
terminate called after throwing an instance of 'c10::Error'
what(): No such serialized tensor 'running_variance' (read at /home/shahriar/Build/pytorch/torch/csrc/api/src/serialize/input-archive.cpp:27)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x85 (0x7f2d92d32f95 in /usr/local/lib/libc10.so)
frame #1: torch::serialize::InputArchive::read(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, at::Tensor&, bool) + 0xdeb (0x7f2d938551ab in /usr/local/lib/libtorch.so.1)
frame #2: torch::nn::Module::load(torch::serialize::InputArchive&) + 0x98 (0x7f2d9381cd08 in /usr/local/lib/libtorch.so.1)
frame #3: torch::nn::Module::load(torch::serialize::InputArchive&) + 0xf9 (0x7f2d9381cd69 in /usr/local/lib/libtorch.so.1)
frame #4: torch::nn::Module::load(torch::serialize::InputArchive&) + 0xf9 (0x7f2d9381cd69 in /usr/local/lib/libtorch.so.1)
frame #5: torch::nn::operator>>(torch::serialize::InputArchive&, std::shared_ptr<torch::nn::Module> const&) + 0x32 (0x7f2d9381c7b2 in /usr/local/lib/libtorch.so.1)
frame #6: <unknown function> + 0x2b16c (0x5645f4d1916c in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest)
frame #7: <unknown function> + 0x27a3c (0x5645f4d15a3c in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest)
frame #8: <unknown function> + 0x2165c (0x5645f4d0f65c in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest)
frame #9: <unknown function> + 0x1540b (0x5645f4d0340b in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest)
frame #10: __libc_start_main + 0xf3 (0x7f2d051dd223 in /usr/lib/libc.so.6)
frame #11: <unknown function> + 0x1381e (0x5645f4d0181e in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest)
```
Renaming C++ BatchNorm `running_variance` to `running_var` should fix this problem.
This is a BC-breaking change, but it should be easy for end user to rename `running_variance` to `running_var` in their call sites.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17371
Reviewed By: goldsborough
Differential Revision: D14172775
Pulled By: yf225
fbshipit-source-id: b9d3729ec79272a8084269756f28a8f7c4dd16b6
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/api/include/torch/nn/modules/batchnorm.h | 2 | ||||
-rw-r--r-- | torch/csrc/api/src/nn/modules/batchnorm.cpp | 6 |
2 files changed, 4 insertions, 4 deletions
diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h index 2f7dd3be1a..782a5d950d 100644 --- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h +++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h @@ -88,7 +88,7 @@ class TORCH_API BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> { /// The running variance. /// Only defined if the `stateful` option was `true` upon construction. - Tensor running_variance; + Tensor running_var; }; /// A `ModuleHolder` subclass for `BatchNormImpl`. diff --git a/torch/csrc/api/src/nn/modules/batchnorm.cpp b/torch/csrc/api/src/nn/modules/batchnorm.cpp index fa2789c73d..8a542e5cef 100644 --- a/torch/csrc/api/src/nn/modules/batchnorm.cpp +++ b/torch/csrc/api/src/nn/modules/batchnorm.cpp @@ -28,8 +28,8 @@ void BatchNormImpl::reset() { if (options.stateful_) { running_mean = register_buffer("running_mean", torch::zeros({options.features_})); - running_variance = - register_buffer("running_variance", torch::ones({options.features_})); + running_var = + register_buffer("running_var", torch::ones({options.features_})); } } @@ -47,7 +47,7 @@ Tensor BatchNormImpl::forward(const Tensor& input) { "Calling BatchNorm::forward is only permitted when " "the 'stateful' option is true (was false). " "Use BatchNorm::pure_forward instead."); - return pure_forward(input, running_mean, running_variance); + return pure_forward(input, running_mean, running_var); } Tensor BatchNormImpl::pure_forward( |