summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2016-09-09 16:49:31 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2016-09-12 23:14:37 -0700
commita8ec123c00723df0d0ad897e1eea32a29201c81b (patch)
tree18a0759dc3cf5ee70afe49155766aa480a5d6a6d
parentc8f446f640b12b0577063eca8fab004e73c0aefc (diff)
downloadcaffeonacl-a8ec123c00723df0d0ad897e1eea32a29201c81b.tar.gz
caffeonacl-a8ec123c00723df0d0ad897e1eea32a29201c81b.tar.bz2
caffeonacl-a8ec123c00723df0d0ad897e1eea32a29201c81b.zip
batch norm: auto-upgrade old layer definitions w/ param messages
automatically strip old batch norm layer definitions including `param` messages. the batch norm layer used to require manually masking its state from the solver by setting `param { lr_mult: 0 }` messages for each of its statistics. this is now handled automatically by the layer.
-rw-r--r--include/caffe/util/upgrade_proto.hpp6
-rw-r--r--src/caffe/util/upgrade_proto.cpp34
2 files changed, 39 insertions, 1 deletions
diff --git a/include/caffe/util/upgrade_proto.hpp b/include/caffe/util/upgrade_proto.hpp
index 14e1936a..b145822a 100644
--- a/include/caffe/util/upgrade_proto.hpp
+++ b/include/caffe/util/upgrade_proto.hpp
@@ -65,6 +65,12 @@ bool NetNeedsInputUpgrade(const NetParameter& net_param);
// Perform all necessary transformations to upgrade input fields into layers.
void UpgradeNetInput(NetParameter* net_param);
+// Return true iff the Net contains batch norm layers with manual local LRs.
+bool NetNeedsBatchNormUpgrade(const NetParameter& net_param);
+
+// Perform all necessary transformations to upgrade batch norm layers.
+void UpgradeNetBatchNorm(NetParameter* net_param);
+
// Return true iff the solver contains any old solver_type specified as enums
bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param);
diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp
index 9e186915..a0aacbe9 100644
--- a/src/caffe/util/upgrade_proto.cpp
+++ b/src/caffe/util/upgrade_proto.cpp
@@ -14,7 +14,8 @@ namespace caffe {
bool NetNeedsUpgrade(const NetParameter& net_param) {
return NetNeedsV0ToV1Upgrade(net_param) || NetNeedsV1ToV2Upgrade(net_param)
- || NetNeedsDataUpgrade(net_param) || NetNeedsInputUpgrade(net_param);
+ || NetNeedsDataUpgrade(net_param) || NetNeedsInputUpgrade(net_param)
+ || NetNeedsBatchNormUpgrade(net_param);
}
bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) {
@@ -71,6 +72,14 @@ bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) {
LOG(WARNING) << "Note that future Caffe releases will only support "
<< "input layers and not input fields.";
}
+ // NetParameter uses old style batch norm layers; try to upgrade it.
+ if (NetNeedsBatchNormUpgrade(*param)) {
+ LOG(INFO) << "Attempting to upgrade batch norm layers using deprecated "
+ << "params: " << param_file;
+ UpgradeNetBatchNorm(param);
+ LOG(INFO) << "Successfully upgraded batch norm layers using deprecated "
+ << "params.";
+ }
return success;
}
@@ -991,6 +1000,29 @@ void UpgradeNetInput(NetParameter* net_param) {
net_param->clear_input_dim();
}
+bool NetNeedsBatchNormUpgrade(const NetParameter& net_param) {
+ for (int i = 0; i < net_param.layer_size(); ++i) {
+ // Check if BatchNorm layers declare three parameters, as required by
+ // the previous BatchNorm layer definition.
+ if (net_param.layer(i).type() == "BatchNorm"
+ && net_param.layer(i).param_size() == 3) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void UpgradeNetBatchNorm(NetParameter* net_param) {
+ for (int i = 0; i < net_param->layer_size(); ++i) {
+ // Check if BatchNorm layers declare three parameters, as required by
+ // the previous BatchNorm layer definition.
+ if (net_param->layer(i).type() == "BatchNorm"
+ && net_param->layer(i).param_size() == 3) {
+ net_param->mutable_layer(i)->clear_param();
+ }
+ }
+}
+
// Return true iff the solver contains any old solver_type specified as enums
bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param) {
if (solver_param.has_solver_type()) {