diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2016-09-09 16:49:31 -0700 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2016-09-12 23:14:37 -0700 |
commit | a8ec123c00723df0d0ad897e1eea32a29201c81b (patch) | |
tree | 18a0759dc3cf5ee70afe49155766aa480a5d6a6d | |
parent | c8f446f640b12b0577063eca8fab004e73c0aefc (diff) | |
download | caffeonacl-a8ec123c00723df0d0ad897e1eea32a29201c81b.tar.gz caffeonacl-a8ec123c00723df0d0ad897e1eea32a29201c81b.tar.bz2 caffeonacl-a8ec123c00723df0d0ad897e1eea32a29201c81b.zip |
batch norm: auto-upgrade old layer definitions w/ param messages
automatically strip old batch norm layer definitions including `param`
messages. the batch norm layer used to require manually masking its
state from the solver by setting `param { lr_mult: 0 }` messages for
each of its statistics. this is now handled automatically by the layer.
-rw-r--r-- | include/caffe/util/upgrade_proto.hpp | 6 | ||||
-rw-r--r-- | src/caffe/util/upgrade_proto.cpp | 34 |
2 files changed, 39 insertions, 1 deletions
diff --git a/include/caffe/util/upgrade_proto.hpp b/include/caffe/util/upgrade_proto.hpp index 14e1936a..b145822a 100644 --- a/include/caffe/util/upgrade_proto.hpp +++ b/include/caffe/util/upgrade_proto.hpp @@ -65,6 +65,12 @@ bool NetNeedsInputUpgrade(const NetParameter& net_param); // Perform all necessary transformations to upgrade input fields into layers. void UpgradeNetInput(NetParameter* net_param); +// Return true iff the Net contains batch norm layers with manual local LRs. +bool NetNeedsBatchNormUpgrade(const NetParameter& net_param); + +// Perform all necessary transformations to upgrade batch norm layers. +void UpgradeNetBatchNorm(NetParameter* net_param); + // Return true iff the solver contains any old solver_type specified as enums bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param); diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index 9e186915..a0aacbe9 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -14,7 +14,8 @@ namespace caffe { bool NetNeedsUpgrade(const NetParameter& net_param) { return NetNeedsV0ToV1Upgrade(net_param) || NetNeedsV1ToV2Upgrade(net_param) - || NetNeedsDataUpgrade(net_param) || NetNeedsInputUpgrade(net_param); + || NetNeedsDataUpgrade(net_param) || NetNeedsInputUpgrade(net_param) + || NetNeedsBatchNormUpgrade(net_param); } bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) { @@ -71,6 +72,14 @@ bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) { LOG(WARNING) << "Note that future Caffe releases will only support " << "input layers and not input fields."; } + // NetParameter uses old style batch norm layers; try to upgrade it. + if (NetNeedsBatchNormUpgrade(*param)) { + LOG(INFO) << "Attempting to upgrade batch norm layers using deprecated " + << "params: " << param_file; + UpgradeNetBatchNorm(param); + LOG(INFO) << "Successfully upgraded batch norm layers using deprecated " + << "params."; + } return success; } @@ -991,6 +1000,29 @@ void UpgradeNetInput(NetParameter* net_param) { net_param->clear_input_dim(); } +bool NetNeedsBatchNormUpgrade(const NetParameter& net_param) { + for (int i = 0; i < net_param.layer_size(); ++i) { + // Check if BatchNorm layers declare three parameters, as required by + // the previous BatchNorm layer definition. + if (net_param.layer(i).type() == "BatchNorm" + && net_param.layer(i).param_size() == 3) { + return true; + } + } + return false; +} + +void UpgradeNetBatchNorm(NetParameter* net_param) { + for (int i = 0; i < net_param->layer_size(); ++i) { + // Check if BatchNorm layers declare three parameters, as required by + // the previous BatchNorm layer definition. + if (net_param->layer(i).type() == "BatchNorm" + && net_param->layer(i).param_size() == 3) { + net_param->mutable_layer(i)->clear_param(); + } + } +} + // Return true iff the solver contains any old solver_type specified as enums bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param) { if (solver_param.has_solver_type()) { |