summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorknsong <sunskn@163.com>2018-02-17 15:56:32 +0800
committerknsong <sunskn@163.com>2018-02-17 15:56:32 +0800
commitced55b009ae4fd6c0685543a013b1439da5879ba (patch)
treefd38d848b74660437109d066b11aa34e43860641
parenta44c444ee4ae0e7c0aa77118213d34bb26e9f2e6 (diff)
downloadcaffe-ced55b009ae4fd6c0685543a013b1439da5879ba.tar.gz
caffe-ced55b009ae4fd6c0685543a013b1439da5879ba.tar.bz2
caffe-ced55b009ae4fd6c0685543a013b1439da5879ba.zip
Fix compatibility for ND convolution
-rw-r--r--include/caffe/filler.hpp14
1 files changed, 8 insertions, 6 deletions
diff --git a/include/caffe/filler.hpp b/include/caffe/filler.hpp
index bb92ded7..e3e86a52 100644
--- a/include/caffe/filler.hpp
+++ b/include/caffe/filler.hpp
@@ -108,9 +108,9 @@ class PositiveUnitballFiller : public Filler<Dtype> {
caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
// We expect the filler to not be called very frequently, so we will
// just use a simple implementation
- int dim = blob->count() / blob->num();
+ int dim = blob->count() / blob->shape(0);
CHECK(dim);
- for (int i = 0; i < blob->num(); ++i) {
+ for (int i = 0; i < blob->shape(0); ++i) {
Dtype sum = 0;
for (int j = 0; j < dim; ++j) {
sum += data[i * dim + j];
@@ -147,8 +147,9 @@ class XavierFiller : public Filler<Dtype> {
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
- int fan_in = blob->count() / blob->num();
- int fan_out = blob->count() / blob->channels();
+ int fan_in = blob->count() / blob->shape(0);
+ // Compatible for ND Convolution
+ int fan_out = blob->count() / blob->shape(1);
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {
@@ -189,8 +190,9 @@ class MSRAFiller : public Filler<Dtype> {
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
- int fan_in = blob->count() / blob->num();
- int fan_out = blob->count() / blob->channels();
+ int fan_in = blob->count() / blob->shape(0);
+ // Compatible for ND Convolution
+ int fan_out = blob->count() / blob->shape(1);
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {