summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJeff Donahue <jeff.donahue@gmail.com>2016-02-25 11:17:19 -0800
committerJeff Donahue <jeff.donahue@gmail.com>2016-02-25 11:17:19 -0800
commitfe0f44112a153377ff4c418adefc8c690b872c37 (patch)
tree2ca424883ec440cb10567c54bfe3d95e9bb0ae15 /src
parent3b2e733f38d2b8ba59ecfb9461d89449bd5a0be7 (diff)
parent8f847fa8fae0460c6bf8e8d7a9bcf96a44305033 (diff)
downloadcaffeonacl-fe0f44112a153377ff4c418adefc8c690b872c37.tar.gz
caffeonacl-fe0f44112a153377ff4c418adefc8c690b872c37.tar.bz2
caffeonacl-fe0f44112a153377ff4c418adefc8c690b872c37.zip
Merge pull request #3612 from kashefy/tied_weights_ip_transpose
Tied weights with transpose flag for InnerProduct layer
Diffstat (limited to 'src')
-rw-r--r--src/caffe/layers/inner_product_layer.cpp42
-rw-r--r--src/caffe/layers/inner_product_layer.cu31
-rw-r--r--src/caffe/proto/caffe.proto5
-rw-r--r--src/caffe/test/test_inner_product_layer.cpp240
4 files changed, 303 insertions, 15 deletions
diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp
index d9088805..e65349f0 100644
--- a/src/caffe/layers/inner_product_layer.cpp
+++ b/src/caffe/layers/inner_product_layer.cpp
@@ -11,6 +11,7 @@ void InnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const int num_output = this->layer_param_.inner_product_param().num_output();
bias_term_ = this->layer_param_.inner_product_param().bias_term();
+ transpose_ = this->layer_param_.inner_product_param().transpose();
N_ = num_output;
const int axis = bottom[0]->CanonicalAxisIndex(
this->layer_param_.inner_product_param().axis());
@@ -27,10 +28,15 @@ void InnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
} else {
this->blobs_.resize(1);
}
- // Intialize the weight
+ // Initialize the weights
vector<int> weight_shape(2);
- weight_shape[0] = N_;
- weight_shape[1] = K_;
+ if (transpose_) {
+ weight_shape[0] = K_;
+ weight_shape[1] = N_;
+ } else {
+ weight_shape[0] = N_;
+ weight_shape[1] = K_;
+ }
this->blobs_[0].reset(new Blob<Dtype>(weight_shape));
// fill the weights
shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
@@ -80,7 +86,8 @@ void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
const Dtype* weight = this->blobs_[0]->cpu_data();
- caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1.,
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, transpose_ ? CblasNoTrans : CblasTrans,
+ M_, N_, K_, (Dtype)1.,
bottom_data, weight, (Dtype)0., top_data);
if (bias_term_) {
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
@@ -97,8 +104,17 @@ void InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* bottom_data = bottom[0]->cpu_data();
// Gradient with respect to weight
- caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
- top_diff, bottom_data, (Dtype)1., this->blobs_[0]->mutable_cpu_diff());
+ if (transpose_) {
+ caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans,
+ K_, N_, M_,
+ (Dtype)1., bottom_data, top_diff,
+ (Dtype)1., this->blobs_[0]->mutable_cpu_diff());
+ } else {
+ caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans,
+ N_, K_, M_,
+ (Dtype)1., top_diff, bottom_data,
+ (Dtype)1., this->blobs_[0]->mutable_cpu_diff());
+ }
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->cpu_diff();
@@ -110,9 +126,17 @@ void InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
if (propagate_down[0]) {
const Dtype* top_diff = top[0]->cpu_diff();
// Gradient with respect to bottom data
- caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
- top_diff, this->blobs_[0]->cpu_data(), (Dtype)0.,
- bottom[0]->mutable_cpu_diff());
+ if (transpose_) {
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans,
+ M_, K_, N_,
+ (Dtype)1., top_diff, this->blobs_[0]->cpu_data(),
+ (Dtype)0., bottom[0]->mutable_cpu_diff());
+ } else {
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans,
+ M_, K_, N_,
+ (Dtype)1., top_diff, this->blobs_[0]->cpu_data(),
+ (Dtype)0., bottom[0]->mutable_cpu_diff());
+ }
}
}
diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu
index dc25aa33..a58b56e3 100644
--- a/src/caffe/layers/inner_product_layer.cu
+++ b/src/caffe/layers/inner_product_layer.cu
@@ -19,7 +19,9 @@ void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
caffe_gpu_axpy<Dtype>(N_, bias_multiplier_.cpu_data()[0],
this->blobs_[1]->gpu_data(), top_data);
} else {
- caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1.,
+ caffe_gpu_gemm<Dtype>(CblasNoTrans,
+ transpose_ ? CblasNoTrans : CblasTrans,
+ M_, N_, K_, (Dtype)1.,
bottom_data, weight, (Dtype)0., top_data);
if (bias_term_)
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
@@ -36,8 +38,17 @@ void InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
// Gradient with respect to weight
- caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
- top_diff, bottom_data, (Dtype)1., this->blobs_[0]->mutable_gpu_diff());
+ if (transpose_) {
+ caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans,
+ K_, N_, M_,
+ (Dtype)1., bottom_data, top_diff,
+ (Dtype)1., this->blobs_[0]->mutable_gpu_diff());
+ } else {
+ caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans,
+ N_, K_, M_,
+ (Dtype)1., top_diff, bottom_data,
+ (Dtype)1., this->blobs_[0]->mutable_gpu_diff());
+ }
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->gpu_diff();
@@ -49,9 +60,17 @@ void InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
if (propagate_down[0]) {
const Dtype* top_diff = top[0]->gpu_diff();
// Gradient with respect to bottom data
- caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
- top_diff, this->blobs_[0]->gpu_data(), (Dtype)0.,
- bottom[0]->mutable_gpu_diff());
+ if (transpose_) {
+ caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans,
+ M_, K_, N_,
+ (Dtype)1., top_diff, this->blobs_[0]->gpu_data(),
+ (Dtype)0., bottom[0]->mutable_gpu_diff());
+ } else {
+ caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans,
+ M_, K_, N_,
+ (Dtype)1., top_diff, this->blobs_[0]->gpu_data(),
+ (Dtype)0., bottom[0]->mutable_gpu_diff());
+ }
}
}
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index 6493a72d..7edb6ae8 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -786,6 +786,11 @@ message InnerProductParameter {
// all preceding axes are retained in the output.
// May be negative to index from the end (e.g., -1 for the last axis).
optional int32 axis = 5 [default = 1];
+ // Specify whether to transpose the weight matrix or not.
+ // If transpose == true, any operations will be performed on the transpose
+ // of the weight matrix. The weight matrix itself is not going to be transposed
+ // but rather the transfer flag of operations will be toggled accordingly.
+ optional bool transpose = 6 [default = false];
}
// Message that stores parameters used by LogLayer
diff --git a/src/caffe/test/test_inner_product_layer.cpp b/src/caffe/test/test_inner_product_layer.cpp
index b888b510..f1ec2333 100644
--- a/src/caffe/test/test_inner_product_layer.cpp
+++ b/src/caffe/test/test_inner_product_layer.cpp
@@ -60,6 +60,50 @@ TYPED_TEST(InnerProductLayerTest, TestSetUp) {
EXPECT_EQ(this->blob_top_->channels(), 10);
}
+/** @brief TestSetUp while toggling tranpose flag
+ */
+TYPED_TEST(InnerProductLayerTest, TestSetUpTranposeFalse) {
+ typedef typename TypeParam::Dtype Dtype;
+ this->blob_bottom_vec_.push_back(this->blob_bottom_);
+ LayerParameter layer_param;
+ InnerProductParameter* inner_product_param =
+ layer_param.mutable_inner_product_param();
+ inner_product_param->set_num_output(10);
+ inner_product_param->set_transpose(false);
+ shared_ptr<InnerProductLayer<Dtype> > layer(
+ new InnerProductLayer<Dtype>(layer_param));
+ layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ EXPECT_EQ(2, this->blob_top_->num());
+ EXPECT_EQ(1, this->blob_top_->height());
+ EXPECT_EQ(1, this->blob_top_->width());
+ EXPECT_EQ(10, this->blob_top_->channels());
+ EXPECT_EQ(2, layer->blobs()[0]->num_axes());
+ EXPECT_EQ(10, layer->blobs()[0]->shape(0));
+ EXPECT_EQ(60, layer->blobs()[0]->shape(1));
+}
+
+/** @brief TestSetUp while toggling tranpose flag
+ */
+TYPED_TEST(InnerProductLayerTest, TestSetUpTranposeTrue) {
+ typedef typename TypeParam::Dtype Dtype;
+ this->blob_bottom_vec_.push_back(this->blob_bottom_);
+ LayerParameter layer_param;
+ InnerProductParameter* inner_product_param =
+ layer_param.mutable_inner_product_param();
+ inner_product_param->set_num_output(10);
+ inner_product_param->set_transpose(true);
+ shared_ptr<InnerProductLayer<Dtype> > layer(
+ new InnerProductLayer<Dtype>(layer_param));
+ layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ EXPECT_EQ(2, this->blob_top_->num());
+ EXPECT_EQ(1, this->blob_top_->height());
+ EXPECT_EQ(1, this->blob_top_->width());
+ EXPECT_EQ(10, this->blob_top_->channels());
+ EXPECT_EQ(2, layer->blobs()[0]->num_axes());
+ EXPECT_EQ(60, layer->blobs()[0]->shape(0));
+ EXPECT_EQ(10, layer->blobs()[0]->shape(1));
+}
+
TYPED_TEST(InnerProductLayerTest, TestForward) {
typedef typename TypeParam::Dtype Dtype;
this->blob_bottom_vec_.push_back(this->blob_bottom_);
@@ -91,6 +135,79 @@ TYPED_TEST(InnerProductLayerTest, TestForward) {
}
}
+/**
+ * @brief Init. an IP layer without transpose + random weights,
+ * run Forward, save the result.
+ * Init. another IP layer with transpose.
+ * manually copy and transpose the weights from the first IP layer,
+ * then run Forward on the same input and check that the result is the same
+ */
+TYPED_TEST(InnerProductLayerTest, TestForwardTranspose) {
+ typedef typename TypeParam::Dtype Dtype;
+ this->blob_bottom_vec_.push_back(this->blob_bottom_);
+ bool IS_VALID_CUDA = false;
+#ifndef CPU_ONLY
+ IS_VALID_CUDA = CAFFE_TEST_CUDA_PROP.major >= 2;
+#endif
+ if (Caffe::mode() == Caffe::CPU ||
+ sizeof(Dtype) == 4 || IS_VALID_CUDA) {
+ LayerParameter layer_param;
+ InnerProductParameter* inner_product_param =
+ layer_param.mutable_inner_product_param();
+ inner_product_param->set_num_output(10);
+ inner_product_param->mutable_weight_filler()->set_type("uniform");
+ inner_product_param->mutable_bias_filler()->set_type("uniform");
+ inner_product_param->mutable_bias_filler()->set_min(1);
+ inner_product_param->mutable_bias_filler()->set_max(2);
+ inner_product_param->set_transpose(false);
+ shared_ptr<InnerProductLayer<Dtype> > layer(
+ new InnerProductLayer<Dtype>(layer_param));
+ layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ const int count = this->blob_top_->count();
+ Blob<Dtype>* const top = new Blob<Dtype>();
+ top->ReshapeLike(*this->blob_top_);
+ caffe_copy(count, this->blob_top_->cpu_data(), top->mutable_cpu_data());
+ this->blob_top_vec_.clear();
+ this->blob_top_vec_.push_back(new Blob<Dtype>());
+ inner_product_param->set_transpose(true);
+ shared_ptr<InnerProductLayer<Dtype> > ip_t(
+ new InnerProductLayer<Dtype>(layer_param));
+ ip_t->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ const int count_w = layer->blobs()[0]->count();
+ EXPECT_EQ(count_w, ip_t->blobs()[0]->count());
+ // manually copy and transpose the weights from 1st IP layer into 2nd
+ const Dtype* w = layer->blobs()[0]->cpu_data();
+ Dtype* w_t = ip_t->blobs()[0]->mutable_cpu_data();
+ const int width = layer->blobs()[0]->shape(1);
+ const int width_t = ip_t->blobs()[0]->shape(1);
+ for (int i = 0; i < count_w; ++i) {
+ int r = i / width;
+ int c = i % width;
+ w_t[c*width_t+r] = w[r*width+c]; // copy while transposing
+ }
+ // copy bias from 1st IP layer to 2nd IP layer
+ ASSERT_EQ(layer->blobs()[1]->count(), ip_t->blobs()[1]->count());
+ caffe_copy(layer->blobs()[1]->count(), layer->blobs()[1]->cpu_data(),
+ ip_t->blobs()[1]->mutable_cpu_data());
+ ip_t->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ EXPECT_EQ(count, this->blob_top_->count())
+ << "Invalid count for top blob for IP with transpose.";
+ Blob<Dtype>* const top_t = new Blob<Dtype>();\
+ top_t->ReshapeLike(*this->blob_top_vec_[0]);
+ caffe_copy(count,
+ this->blob_top_vec_[0]->cpu_data(),
+ top_t->mutable_cpu_data());
+ const Dtype* data = top->cpu_data();
+ const Dtype* data_t = top_t->cpu_data();
+ for (int i = 0; i < count; ++i) {
+ EXPECT_FLOAT_EQ(data[i], data_t[i]);
+ }
+ } else {
+ LOG(ERROR) << "Skipping test due to old architecture.";
+ }
+}
+
TYPED_TEST(InnerProductLayerTest, TestForwardNoBatch) {
typedef typename TypeParam::Dtype Dtype;
this->blob_bottom_vec_.push_back(this->blob_bottom_nobatch_);
@@ -148,4 +265,127 @@ TYPED_TEST(InnerProductLayerTest, TestGradient) {
}
}
+TYPED_TEST(InnerProductLayerTest, TestGradientTranspose) {
+ typedef typename TypeParam::Dtype Dtype;
+ this->blob_bottom_vec_.push_back(this->blob_bottom_);
+ bool IS_VALID_CUDA = false;
+#ifndef CPU_ONLY
+ IS_VALID_CUDA = CAFFE_TEST_CUDA_PROP.major >= 2;
+#endif
+ if (Caffe::mode() == Caffe::CPU ||
+ sizeof(Dtype) == 4 || IS_VALID_CUDA) {
+ LayerParameter layer_param;
+ InnerProductParameter* inner_product_param =
+ layer_param.mutable_inner_product_param();
+ inner_product_param->set_num_output(11);
+ inner_product_param->mutable_weight_filler()->set_type("gaussian");
+ inner_product_param->mutable_bias_filler()->set_type("gaussian");
+ inner_product_param->mutable_bias_filler()->set_min(1);
+ inner_product_param->mutable_bias_filler()->set_max(2);
+ inner_product_param->set_transpose(true);
+ InnerProductLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-3);
+ checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+ } else {
+ LOG(ERROR) << "Skipping test due to old architecture.";
+ }
+}
+
+TYPED_TEST(InnerProductLayerTest, TestBackwardTranspose) {
+ typedef typename TypeParam::Dtype Dtype;
+ this->blob_bottom_vec_.push_back(this->blob_bottom_);
+ bool IS_VALID_CUDA = false;
+#ifndef CPU_ONLY
+ IS_VALID_CUDA = CAFFE_TEST_CUDA_PROP.major >= 2;
+#endif
+ if (Caffe::mode() == Caffe::CPU ||
+ sizeof(Dtype) == 4 || IS_VALID_CUDA) {
+ LayerParameter layer_param;
+ InnerProductParameter* inner_product_param =
+ layer_param.mutable_inner_product_param();
+ inner_product_param->set_num_output(10);
+ inner_product_param->mutable_weight_filler()->set_type("uniform");
+ inner_product_param->mutable_bias_filler()->set_type("uniform");
+ inner_product_param->mutable_bias_filler()->set_min(1);
+ inner_product_param->mutable_bias_filler()->set_max(2);
+ inner_product_param->set_transpose(false);
+ shared_ptr<InnerProductLayer<Dtype> > layer(
+ new InnerProductLayer<Dtype>(layer_param));
+ layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // copy top blob
+ Blob<Dtype>* const top = new Blob<Dtype>();
+ top->CopyFrom(*this->blob_top_, false, true);
+ // fake top diff
+ Blob<Dtype>* const diff = new Blob<Dtype>();
+ diff->ReshapeLike(*this->blob_top_);
+ {
+ FillerParameter filler_param;
+ UniformFiller<Dtype> filler(filler_param);
+ filler.Fill(diff);
+ }
+ caffe_copy(this->blob_top_vec_[0]->count(),
+ diff->cpu_data(),
+ this->blob_top_vec_[0]->mutable_cpu_diff());
+ vector<bool> propagate_down(1, true);
+ layer->Backward(this->blob_top_vec_,
+ propagate_down,
+ this->blob_bottom_vec_);
+ // copy first ip's weights and their diffs
+ Blob<Dtype>* const w = new Blob<Dtype>();
+ w->CopyFrom(*layer->blobs()[0], false, true);
+ w->CopyFrom(*layer->blobs()[0], true, true);
+ // copy bottom diffs
+ Blob<Dtype>* const bottom_diff = new Blob<Dtype>();
+ bottom_diff->CopyFrom(*this->blob_bottom_vec_[0], true, true);
+ // repeat original top with tranposed ip
+ this->blob_top_vec_.clear();
+ this->blob_top_vec_.push_back(new Blob<Dtype>());
+ inner_product_param->set_transpose(true);
+ shared_ptr<InnerProductLayer<Dtype> > ip_t(
+ new InnerProductLayer<Dtype>(layer_param));
+ ip_t->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ // manually copy and transpose the weights from 1st IP layer into 2nd
+ {
+ const Dtype* w_src = w->cpu_data();
+ Dtype* w_t = ip_t->blobs()[0]->mutable_cpu_data();
+ const int width = layer->blobs()[0]->shape(1);
+ const int width_t = ip_t->blobs()[0]->shape(1);
+ for (int i = 0; i < layer->blobs()[0]->count(); ++i) {
+ int r = i / width;
+ int c = i % width;
+ w_t[c*width_t+r] = w_src[r*width+c]; // copy while transposing
+ }
+ // copy bias from 1st IP layer to 2nd IP layer
+ ASSERT_EQ(layer->blobs()[1]->count(), ip_t->blobs()[1]->count());
+ caffe_copy(layer->blobs()[1]->count(), layer->blobs()[1]->cpu_data(),
+ ip_t->blobs()[1]->mutable_cpu_data());
+ }
+ ip_t->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ caffe_copy(this->blob_top_vec_[0]->count(),
+ diff->cpu_data(),
+ this->blob_top_vec_[0]->mutable_cpu_diff());
+ ip_t->Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_);
+ const Dtype* data = w->cpu_diff();
+ const Dtype* data_t = ip_t->blobs()[0]->cpu_diff();
+ const int WIDTH = layer->blobs()[0]->shape(1);
+ const int WIDTH_T = ip_t->blobs()[0]->shape(1);
+ for (int i = 0; i < layer->blobs()[0]->count(); ++i) {
+ int r = i / WIDTH;
+ int c = i % WIDTH;
+ EXPECT_NE(Dtype(0.), data[r*WIDTH+c]);
+ EXPECT_FLOAT_EQ(data[r*WIDTH+c], data_t[c*WIDTH_T+r]);
+ }
+ data = bottom_diff->cpu_diff();
+ data_t = this->blob_bottom_vec_[0]->cpu_diff();
+ for (int i = 0; i < this->blob_bottom_vec_[0]->count(); ++i) {
+ EXPECT_NE(Dtype(0.), data[i]);
+ EXPECT_FLOAT_EQ(data[i], data_t[i]);
+ }
+ } else {
+ LOG(ERROR) << "Skipping test due to old architecture.";
+ }
+}
+
} // namespace caffe