summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorWill Feng <willfeng@fb.com>2019-02-22 07:54:47 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-22 08:00:25 -0800
commitbe6ad7ddde64cebeec4cd9537ddc3bca9cbe18e5 (patch)
tree21bc96f596acc775281eb36457176bd6b1d3515d /torch
parent562fa55f3dfc2d3ce1cac19e3ec9dab202e06b3b (diff)
downloadpytorch-be6ad7ddde64cebeec4cd9537ddc3bca9cbe18e5.tar.gz
pytorch-be6ad7ddde64cebeec4cd9537ddc3bca9cbe18e5.tar.bz2
pytorch-be6ad7ddde64cebeec4cd9537ddc3bca9cbe18e5.zip
Rename BatchNorm running_variance to running_var (#17371)
Summary: Currently there is a mismatch in naming between Python BatchNorm `running_var` and C++ BatchNorm `running_variance`, which causes JIT model parameters loading to fail (https://github.com/pytorch/vision/pull/728#issuecomment-466067138): ``` terminate called after throwing an instance of 'c10::Error' what(): No such serialized tensor 'running_variance' (read at /home/shahriar/Build/pytorch/torch/csrc/api/src/serialize/input-archive.cpp:27) frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x85 (0x7f2d92d32f95 in /usr/local/lib/libc10.so) frame #1: torch::serialize::InputArchive::read(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, at::Tensor&, bool) + 0xdeb (0x7f2d938551ab in /usr/local/lib/libtorch.so.1) frame #2: torch::nn::Module::load(torch::serialize::InputArchive&) + 0x98 (0x7f2d9381cd08 in /usr/local/lib/libtorch.so.1) frame #3: torch::nn::Module::load(torch::serialize::InputArchive&) + 0xf9 (0x7f2d9381cd69 in /usr/local/lib/libtorch.so.1) frame #4: torch::nn::Module::load(torch::serialize::InputArchive&) + 0xf9 (0x7f2d9381cd69 in /usr/local/lib/libtorch.so.1) frame #5: torch::nn::operator>>(torch::serialize::InputArchive&, std::shared_ptr<torch::nn::Module> const&) + 0x32 (0x7f2d9381c7b2 in /usr/local/lib/libtorch.so.1) frame #6: <unknown function> + 0x2b16c (0x5645f4d1916c in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest) frame #7: <unknown function> + 0x27a3c (0x5645f4d15a3c in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest) frame #8: <unknown function> + 0x2165c (0x5645f4d0f65c in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest) frame #9: <unknown function> + 0x1540b (0x5645f4d0340b in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest) frame #10: __libc_start_main + 0xf3 (0x7f2d051dd223 in /usr/lib/libc.so.6) frame #11: <unknown function> + 0x1381e (0x5645f4d0181e in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest) ``` Renaming C++ BatchNorm `running_variance` to `running_var` should fix this problem. This is a BC-breaking change, but it should be easy for end user to rename `running_variance` to `running_var` in their call sites. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17371 Reviewed By: goldsborough Differential Revision: D14172775 Pulled By: yf225 fbshipit-source-id: b9d3729ec79272a8084269756f28a8f7c4dd16b6
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/api/include/torch/nn/modules/batchnorm.h2
-rw-r--r--torch/csrc/api/src/nn/modules/batchnorm.cpp6
2 files changed, 4 insertions, 4 deletions
diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
index 2f7dd3be1a..782a5d950d 100644
--- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h
+++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
@@ -88,7 +88,7 @@ class TORCH_API BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> {
/// The running variance.
/// Only defined if the `stateful` option was `true` upon construction.
- Tensor running_variance;
+ Tensor running_var;
};
/// A `ModuleHolder` subclass for `BatchNormImpl`.
diff --git a/torch/csrc/api/src/nn/modules/batchnorm.cpp b/torch/csrc/api/src/nn/modules/batchnorm.cpp
index fa2789c73d..8a542e5cef 100644
--- a/torch/csrc/api/src/nn/modules/batchnorm.cpp
+++ b/torch/csrc/api/src/nn/modules/batchnorm.cpp
@@ -28,8 +28,8 @@ void BatchNormImpl::reset() {
if (options.stateful_) {
running_mean =
register_buffer("running_mean", torch::zeros({options.features_}));
- running_variance =
- register_buffer("running_variance", torch::ones({options.features_}));
+ running_var =
+ register_buffer("running_var", torch::ones({options.features_}));
}
}
@@ -47,7 +47,7 @@ Tensor BatchNormImpl::forward(const Tensor& input) {
"Calling BatchNorm::forward is only permitted when "
"the 'stateful' option is true (was false). "
"Use BatchNorm::pure_forward instead.");
- return pure_forward(input, running_mean, running_variance);
+ return pure_forward(input, running_mean, running_var);
}
Tensor BatchNormImpl::pure_forward(