diff options
-rw-r--r-- | src/caffe/layers/scale_layer.cpp | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/src/caffe/layers/scale_layer.cpp b/src/caffe/layers/scale_layer.cpp index ecdbb123..e652dad6 100644 --- a/src/caffe/layers/scale_layer.cpp +++ b/src/caffe/layers/scale_layer.cpp @@ -56,9 +56,17 @@ void ScaleLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, bias_bottom_vec_.resize(1); bias_bottom_vec_[0] = bottom[0]; bias_layer_->SetUp(bias_bottom_vec_, top); - bias_param_id_ = this->blobs_.size(); - this->blobs_.resize(bias_param_id_ + 1); - this->blobs_[bias_param_id_] = bias_layer_->blobs()[0]; + if (this->blobs_.size() + bottom.size() < 3) { + // case: blobs.size == 1 && bottom.size == 1 + // or blobs.size == 0 && bottom.size == 2 + bias_param_id_ = this->blobs_.size(); + this->blobs_.resize(bias_param_id_ + 1); + this->blobs_[bias_param_id_] = bias_layer_->blobs()[0]; + } else { + // bias param already initialized + bias_param_id_ = this->blobs_.size() - 1; + bias_layer_->blobs()[0] = this->blobs_[bias_param_id_]; + } bias_propagate_down_.resize(1, false); } this->param_propagate_down_.resize(this->blobs_.size(), true); |