summaryrefslogtreecommitdiff
path: root/src/caffe
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2015-08-25 20:24:57 -0300
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2015-08-25 20:24:57 -0300
commit80579b8aa657dc15fd2164e97f103ce364ea77bd (patch)
tree298d1b0307a514d298390801fe1298d4c274c5d0 /src/caffe
parent4092b700207e8946729d97a666e196944d4dec1e (diff)
parentac9e29fd7b90a665a956f460715669bf05445a13 (diff)
downloadcaffeonacl-80579b8aa657dc15fd2164e97f103ce364ea77bd.tar.gz
caffeonacl-80579b8aa657dc15fd2164e97f103ce364ea77bd.tar.bz2
caffeonacl-80579b8aa657dc15fd2164e97f103ce364ea77bd.zip
Merge pull request #2032 from jeffdonahue/embed-layer
Embed layer for lookup table of one hot encodings
Diffstat (limited to 'src/caffe')
-rw-r--r--src/caffe/layers/embed_layer.cpp122
-rw-r--r--src/caffe/layers/embed_layer.cu85
-rw-r--r--src/caffe/proto/caffe.proto18
-rw-r--r--src/caffe/test/test_embed_layer.cpp183
4 files changed, 407 insertions, 1 deletions
diff --git a/src/caffe/layers/embed_layer.cpp b/src/caffe/layers/embed_layer.cpp
new file mode 100644
index 00000000..be6b2cd2
--- /dev/null
+++ b/src/caffe/layers/embed_layer.cpp
@@ -0,0 +1,122 @@
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/common_layers.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void EmbedLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ N_ = this->layer_param_.embed_param().num_output();
+ CHECK_GT(N_, 0) << "EmbedLayer num_output must be positive.";
+ K_ = this->layer_param_.embed_param().input_dim();
+ CHECK_GT(K_, 0) << "EmbedLayer input_dim must be positive.";
+ bias_term_ = this->layer_param_.embed_param().bias_term();
+ // Check if we need to set up the weights
+ if (this->blobs_.size() > 0) {
+ LOG(INFO) << "Skipping parameter initialization";
+ } else {
+ if (bias_term_) {
+ this->blobs_.resize(2);
+ } else {
+ this->blobs_.resize(1);
+ }
+ // Initialize the weights --
+ // transposed from InnerProductLayer for spatial locality.
+ vector<int> weight_shape(2);
+ weight_shape[0] = K_;
+ weight_shape[1] = N_;
+ this->blobs_[0].reset(new Blob<Dtype>(weight_shape));
+ // fill the weights
+ shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
+ this->layer_param_.embed_param().weight_filler()));
+ weight_filler->Fill(this->blobs_[0].get());
+ // If necessary, initialize and fill the bias term
+ if (bias_term_) {
+ vector<int> bias_shape(1, N_);
+ this->blobs_[1].reset(new Blob<Dtype>(bias_shape));
+ shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>(
+ this->layer_param_.embed_param().bias_filler()));
+ bias_filler->Fill(this->blobs_[1].get());
+ }
+ } // parameter initialization
+ this->param_propagate_down_.resize(this->blobs_.size(), true);
+}
+
+template <typename Dtype>
+void EmbedLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ // Figure out the dimensions
+ M_ = bottom[0]->count();
+ vector<int> top_shape = bottom[0]->shape();
+ top_shape.push_back(N_);
+ top[0]->Reshape(top_shape);
+ // Set up the bias multiplier
+ if (bias_term_) {
+ vector<int> bias_shape(1, M_);
+ bias_multiplier_.Reshape(bias_shape);
+ caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data());
+ }
+}
+
+template <typename Dtype>
+void EmbedLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ const Dtype* weight = this->blobs_[0]->cpu_data();
+ Dtype* top_data = top[0]->mutable_cpu_data();
+ int index;
+ for (int n = 0; n < M_; ++n) {
+ index = static_cast<int>(bottom_data[n]);
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, K_);
+ DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n]) << "non-integer input";
+ caffe_copy(N_, weight + index * N_, top_data + n * N_);
+ }
+ if (bias_term_) {
+ const Dtype* bias = this->blobs_[1]->cpu_data();
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1),
+ bias_multiplier_.cpu_data(), bias, Dtype(1), top_data);
+ }
+}
+
+template <typename Dtype>
+void EmbedLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+ CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input.";
+ if (this->param_propagate_down_[0]) {
+ const Dtype* top_diff = top[0]->cpu_diff();
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ // Gradient with respect to weight
+ Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
+ int index;
+ for (int n = 0; n < M_; ++n) {
+ index = static_cast<int>(bottom_data[n]);
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, K_);
+ DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n])
+ << "non-integer input";
+ caffe_axpy(N_, Dtype(1), top_diff + n * N_, weight_diff + index * N_);
+ }
+ }
+ if (bias_term_ && this->param_propagate_down_[1]) {
+ const Dtype* top_diff = top[0]->cpu_diff();
+ Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff();
+ caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff,
+ bias_multiplier_.cpu_data(), Dtype(1), bias_diff);
+ }
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(EmbedLayer);
+#endif
+
+INSTANTIATE_CLASS(EmbedLayer);
+REGISTER_LAYER_CLASS(Embed);
+
+} // namespace caffe
diff --git a/src/caffe/layers/embed_layer.cu b/src/caffe/layers/embed_layer.cu
new file mode 100644
index 00000000..672fb9c6
--- /dev/null
+++ b/src/caffe/layers/embed_layer.cu
@@ -0,0 +1,85 @@
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/common_layers.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/gpu_util.cuh"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+__global__ void EmbedForward(const int nthreads, const Dtype* bottom_data,
+ const Dtype* weight, const int M, const int N, const int K,
+ Dtype* top_data) {
+ CUDA_KERNEL_LOOP(top_index, nthreads) {
+ const int n = top_index / N;
+ const int d = top_index % N;
+ const int index = static_cast<int>(bottom_data[n]);
+ const int weight_index = index * N + d;
+ top_data[top_index] = weight[weight_index];
+ }
+}
+
+template <typename Dtype>
+__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data,
+ const Dtype* top_diff, const int M, const int N, const int K,
+ Dtype* weight_diff);
+
+template <typename Dtype>
+__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data,
+ const Dtype* top_diff, const int M, const int N, const int K,
+ Dtype* weight_diff) {
+ CUDA_KERNEL_LOOP(top_index, nthreads) {
+ const int n = top_index / N;
+ const int d = top_index % N;
+ const int index = static_cast<int>(bottom_data[n]);
+ const int weight_index = index * N + d;
+ caffe_gpu_atomic_add(top_diff[top_index], weight_diff + weight_index);
+ }
+}
+
+template <typename Dtype>
+void EmbedLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* top_data = top[0]->mutable_gpu_data();
+ const Dtype* weight = this->blobs_[0]->gpu_data();
+ const int count = top[0]->count();
+ EmbedForward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
+ <<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, bottom_data, weight, M_, N_, K_, top_data);
+ if (bias_term_) {
+ caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1),
+ bias_multiplier_.gpu_data(),
+ this->blobs_[1]->gpu_data(), Dtype(1), top_data);
+ }
+}
+
+template <typename Dtype>
+void EmbedLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+ CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input.";
+ if (this->param_propagate_down_[0]) {
+ const int top_count = top[0]->count();
+ const int count = this->blobs_[0]->count();
+ const Dtype* top_diff = top[0]->gpu_diff();
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
+ EmbedBackward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
+ <<<CAFFE_GET_BLOCKS(top_count), CAFFE_CUDA_NUM_THREADS>>>(
+ top_count, bottom_data, top_diff, M_, N_, K_, weight_diff);
+ }
+ if (bias_term_ && this->param_propagate_down_[1]) {
+ const Dtype* top_diff = top[0]->gpu_diff();
+ Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff();
+ caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff,
+ bias_multiplier_.gpu_data(), Dtype(1), bias_diff);
+ }
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(EmbedLayer);
+
+} // namespace caffe
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index d4c97d2b..35264610 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -301,7 +301,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
-// LayerParameter next available layer-specific ID: 137 (last added: reduction_param)
+// LayerParameter next available layer-specific ID: 138 (last added: embed_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
@@ -357,6 +357,7 @@ message LayerParameter {
optional DropoutParameter dropout_param = 108;
optional DummyDataParameter dummy_data_param = 109;
optional EltwiseParameter eltwise_param = 110;
+ optional EmbedParameter embed_param = 137;
optional ExpParameter exp_param = 111;
optional FlattenParameter flatten_param = 135;
optional HDF5DataParameter hdf5_data_param = 112;
@@ -562,6 +563,21 @@ message EltwiseParameter {
optional bool stable_prod_grad = 3 [default = true];
}
+// Message that stores parameters used by EmbedLayer
+message EmbedParameter {
+ optional uint32 num_output = 1; // The number of outputs for the layer
+ // The input is given as integers to be interpreted as one-hot
+ // vector indices with dimension num_input. Hence num_input should be
+ // 1 greater than the maximum possible input value.
+ optional uint32 input_dim = 2;
+
+ optional bool bias_term = 3 [default = true]; // Whether to use a bias term
+ optional FillerParameter weight_filler = 4; // The filler for the weight
+ optional FillerParameter bias_filler = 5; // The filler for the bias
+
+}
+
+// Message that stores parameters used by ExpLayer
message ExpParameter {
// ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0.
// Or if base is set to the default (-1), base is set to e,
diff --git a/src/caffe/test/test_embed_layer.cpp b/src/caffe/test/test_embed_layer.cpp
new file mode 100644
index 00000000..7a4fb980
--- /dev/null
+++ b/src/caffe/test/test_embed_layer.cpp
@@ -0,0 +1,183 @@
+#include <cstring>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/vision_layers.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+namespace caffe {
+
+#ifndef CPU_ONLY
+extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+#endif
+
+template <typename TypeParam>
+class EmbedLayerTest : public MultiDeviceTest<TypeParam> {
+ typedef typename TypeParam::Dtype Dtype;
+ protected:
+ EmbedLayerTest()
+ : blob_bottom_(new Blob<Dtype>(4, 1, 1, 1)),
+ blob_top_(new Blob<Dtype>()) {
+ // fill the values
+ FillerParameter filler_param;
+ UniformFiller<Dtype> filler(filler_param);
+ filler.Fill(this->blob_bottom_);
+ blob_bottom_vec_.push_back(blob_bottom_);
+ blob_top_vec_.push_back(blob_top_);
+ }
+ virtual ~EmbedLayerTest() { delete blob_bottom_; delete blob_top_; }
+ Blob<Dtype>* const blob_bottom_;
+ Blob<Dtype>* const blob_top_;
+ vector<Blob<Dtype>*> blob_bottom_vec_;
+ vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+TYPED_TEST_CASE(EmbedLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(EmbedLayerTest, TestSetUp) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ EmbedParameter* embed_param = layer_param.mutable_embed_param();
+ embed_param->set_num_output(10);
+ embed_param->set_input_dim(5);
+ shared_ptr<EmbedLayer<Dtype> > layer(new EmbedLayer<Dtype>(layer_param));
+ layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ ASSERT_EQ(this->blob_top_->num_axes(), 5);
+ EXPECT_EQ(this->blob_top_->shape(0), 4);
+ EXPECT_EQ(this->blob_top_->shape(1), 1);
+ EXPECT_EQ(this->blob_top_->shape(2), 1);
+ EXPECT_EQ(this->blob_top_->shape(3), 1);
+ EXPECT_EQ(this->blob_top_->shape(4), 10);
+}
+
+TYPED_TEST(EmbedLayerTest, TestForward) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ EmbedParameter* embed_param = layer_param.mutable_embed_param();
+ const int kNumOutput = 10;
+ const int kInputDim = 5;
+ embed_param->set_num_output(kNumOutput);
+ embed_param->set_input_dim(kInputDim);
+ embed_param->mutable_weight_filler()->set_type("uniform");
+ embed_param->mutable_weight_filler()->set_min(-10);
+ embed_param->mutable_weight_filler()->set_max(10);
+ embed_param->set_bias_term(false);
+ shared_ptr<EmbedLayer<Dtype> > layer(new EmbedLayer<Dtype>(layer_param));
+ layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ ASSERT_EQ(1, layer->blobs().size());
+ vector<int> weight_shape(2);
+ weight_shape[0] = kInputDim;
+ weight_shape[1] = kNumOutput;
+ ASSERT_TRUE(weight_shape == layer->blobs()[0]->shape());
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ this->blob_bottom_->mutable_cpu_data()[i] = caffe_rng_rand() % kInputDim;
+ }
+ layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ vector<int> weight_offset(2, 0);
+ vector<int> top_offset(5, 0);
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ weight_offset[0] = static_cast<int>(this->blob_bottom_->cpu_data()[i]);
+ weight_offset[1] = 0;
+ top_offset[0] = i;
+ top_offset[4] = 0;
+ for (int j = 0; j < kNumOutput; ++j) {
+ EXPECT_EQ(layer->blobs()[0]->data_at(weight_offset),
+ this->blob_top_->data_at(top_offset));
+ ++top_offset[4];
+ ++weight_offset[1];
+ }
+ }
+}
+
+TYPED_TEST(EmbedLayerTest, TestForwardWithBias) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ EmbedParameter* embed_param = layer_param.mutable_embed_param();
+ const int kNumOutput = 10;
+ const int kInputDim = 5;
+ embed_param->set_num_output(kNumOutput);
+ embed_param->set_input_dim(kInputDim);
+ embed_param->mutable_weight_filler()->set_type("uniform");
+ embed_param->mutable_weight_filler()->set_min(-10);
+ embed_param->mutable_weight_filler()->set_max(10);
+ embed_param->mutable_bias_filler()->CopyFrom(embed_param->weight_filler());
+ embed_param->set_bias_term(true);
+ shared_ptr<EmbedLayer<Dtype> > layer(new EmbedLayer<Dtype>(layer_param));
+ layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ ASSERT_EQ(2, layer->blobs().size());
+ vector<int> weight_shape(2);
+ weight_shape[0] = kInputDim;
+ weight_shape[1] = kNumOutput;
+ ASSERT_TRUE(weight_shape == layer->blobs()[0]->shape());
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ this->blob_bottom_->mutable_cpu_data()[i] = caffe_rng_rand() % kInputDim;
+ }
+ layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ vector<int> bias_offset(1, 0);
+ vector<int> weight_offset(2, 0);
+ vector<int> top_offset(5, 0);
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ weight_offset[0] = static_cast<int>(this->blob_bottom_->cpu_data()[i]);
+ weight_offset[1] = 0;
+ top_offset[0] = i;
+ top_offset[4] = 0;
+ bias_offset[0] = 0;
+ for (int j = 0; j < kNumOutput; ++j) {
+ EXPECT_EQ(layer->blobs()[0]->data_at(weight_offset) +
+ layer->blobs()[1]->data_at(bias_offset),
+ this->blob_top_->data_at(top_offset));
+ ++top_offset[4];
+ ++weight_offset[1];
+ ++bias_offset[0];
+ }
+ }
+}
+
+TYPED_TEST(EmbedLayerTest, TestGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ EmbedParameter* embed_param = layer_param.mutable_embed_param();
+ embed_param->set_num_output(10);
+ embed_param->set_input_dim(5);
+ embed_param->set_bias_term(false);
+ embed_param->mutable_weight_filler()->set_type("uniform");
+ embed_param->mutable_weight_filler()->set_min(-10);
+ embed_param->mutable_weight_filler()->set_max(10);
+ EmbedLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-3);
+ this->blob_bottom_->mutable_cpu_data()[0] = 4;
+ this->blob_bottom_->mutable_cpu_data()[1] = 2;
+ this->blob_bottom_->mutable_cpu_data()[2] = 2;
+ this->blob_bottom_->mutable_cpu_data()[3] = 3;
+ checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_, -2);
+}
+
+TYPED_TEST(EmbedLayerTest, TestGradientWithBias) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ EmbedParameter* embed_param = layer_param.mutable_embed_param();
+ embed_param->set_num_output(10);
+ embed_param->set_input_dim(5);
+ embed_param->set_bias_term(true);
+ embed_param->mutable_weight_filler()->set_type("uniform");
+ embed_param->mutable_weight_filler()->set_min(-10);
+ embed_param->mutable_weight_filler()->set_max(10);
+ embed_param->mutable_bias_filler()->CopyFrom(embed_param->weight_filler());
+ EmbedLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-3);
+ this->blob_bottom_->mutable_cpu_data()[0] = 4;
+ this->blob_bottom_->mutable_cpu_data()[1] = 2;
+ this->blob_bottom_->mutable_cpu_data()[2] = 2;
+ this->blob_bottom_->mutable_cpu_data()[3] = 3;
+ checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_, -2);
+}
+
+} // namespace caffe