summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/caffe/layers/scale_layer.cpp14
1 files changed, 11 insertions, 3 deletions
diff --git a/src/caffe/layers/scale_layer.cpp b/src/caffe/layers/scale_layer.cpp
index ecdbb123..e652dad6 100644
--- a/src/caffe/layers/scale_layer.cpp
+++ b/src/caffe/layers/scale_layer.cpp
@@ -56,9 +56,17 @@ void ScaleLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
bias_bottom_vec_.resize(1);
bias_bottom_vec_[0] = bottom[0];
bias_layer_->SetUp(bias_bottom_vec_, top);
- bias_param_id_ = this->blobs_.size();
- this->blobs_.resize(bias_param_id_ + 1);
- this->blobs_[bias_param_id_] = bias_layer_->blobs()[0];
+ if (this->blobs_.size() + bottom.size() < 3) {
+ // case: blobs.size == 1 && bottom.size == 1
+ // or blobs.size == 0 && bottom.size == 2
+ bias_param_id_ = this->blobs_.size();
+ this->blobs_.resize(bias_param_id_ + 1);
+ this->blobs_[bias_param_id_] = bias_layer_->blobs()[0];
+ } else {
+ // bias param already initialized
+ bias_param_id_ = this->blobs_.size() - 1;
+ bias_layer_->blobs()[0] = this->blobs_[bias_param_id_];
+ }
bias_propagate_down_.resize(1, false);
}
this->param_propagate_down_.resize(this->blobs_.size(), true);