diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2015-08-25 20:24:57 -0300 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2015-08-25 20:24:57 -0300 |
commit | 80579b8aa657dc15fd2164e97f103ce364ea77bd (patch) | |
tree | 298d1b0307a514d298390801fe1298d4c274c5d0 /src/caffe | |
parent | 4092b700207e8946729d97a666e196944d4dec1e (diff) | |
parent | ac9e29fd7b90a665a956f460715669bf05445a13 (diff) | |
download | caffeonacl-80579b8aa657dc15fd2164e97f103ce364ea77bd.tar.gz caffeonacl-80579b8aa657dc15fd2164e97f103ce364ea77bd.tar.bz2 caffeonacl-80579b8aa657dc15fd2164e97f103ce364ea77bd.zip |
Merge pull request #2032 from jeffdonahue/embed-layer
Embed layer for lookup table of one hot encodings
Diffstat (limited to 'src/caffe')
-rw-r--r-- | src/caffe/layers/embed_layer.cpp | 122 | ||||
-rw-r--r-- | src/caffe/layers/embed_layer.cu | 85 | ||||
-rw-r--r-- | src/caffe/proto/caffe.proto | 18 | ||||
-rw-r--r-- | src/caffe/test/test_embed_layer.cpp | 183 |
4 files changed, 407 insertions, 1 deletions
diff --git a/src/caffe/layers/embed_layer.cpp b/src/caffe/layers/embed_layer.cpp new file mode 100644 index 00000000..be6b2cd2 --- /dev/null +++ b/src/caffe/layers/embed_layer.cpp @@ -0,0 +1,122 @@ +#include <vector> + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/common_layers.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template <typename Dtype> +void EmbedLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top) { + N_ = this->layer_param_.embed_param().num_output(); + CHECK_GT(N_, 0) << "EmbedLayer num_output must be positive."; + K_ = this->layer_param_.embed_param().input_dim(); + CHECK_GT(K_, 0) << "EmbedLayer input_dim must be positive."; + bias_term_ = this->layer_param_.embed_param().bias_term(); + // Check if we need to set up the weights + if (this->blobs_.size() > 0) { + LOG(INFO) << "Skipping parameter initialization"; + } else { + if (bias_term_) { + this->blobs_.resize(2); + } else { + this->blobs_.resize(1); + } + // Initialize the weights -- + // transposed from InnerProductLayer for spatial locality. + vector<int> weight_shape(2); + weight_shape[0] = K_; + weight_shape[1] = N_; + this->blobs_[0].reset(new Blob<Dtype>(weight_shape)); + // fill the weights + shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>( + this->layer_param_.embed_param().weight_filler())); + weight_filler->Fill(this->blobs_[0].get()); + // If necessary, initialize and fill the bias term + if (bias_term_) { + vector<int> bias_shape(1, N_); + this->blobs_[1].reset(new Blob<Dtype>(bias_shape)); + shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>( + this->layer_param_.embed_param().bias_filler())); + bias_filler->Fill(this->blobs_[1].get()); + } + } // parameter initialization + this->param_propagate_down_.resize(this->blobs_.size(), true); +} + +template <typename Dtype> +void EmbedLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top) { + // Figure out the dimensions + M_ = bottom[0]->count(); + vector<int> top_shape = bottom[0]->shape(); + top_shape.push_back(N_); + top[0]->Reshape(top_shape); + // Set up the bias multiplier + if (bias_term_) { + vector<int> bias_shape(1, M_); + bias_multiplier_.Reshape(bias_shape); + caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data()); + } +} + +template <typename Dtype> +void EmbedLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + const Dtype* weight = this->blobs_[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + int index; + for (int n = 0; n < M_; ++n) { + index = static_cast<int>(bottom_data[n]); + DCHECK_GE(index, 0); + DCHECK_LT(index, K_); + DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n]) << "non-integer input"; + caffe_copy(N_, weight + index * N_, top_data + n * N_); + } + if (bias_term_) { + const Dtype* bias = this->blobs_[1]->cpu_data(); + caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1), + bias_multiplier_.cpu_data(), bias, Dtype(1), top_data); + } +} + +template <typename Dtype> +void EmbedLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { + CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input."; + if (this->param_propagate_down_[0]) { + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* bottom_data = bottom[0]->cpu_data(); + // Gradient with respect to weight + Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff(); + int index; + for (int n = 0; n < M_; ++n) { + index = static_cast<int>(bottom_data[n]); + DCHECK_GE(index, 0); + DCHECK_LT(index, K_); + DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n]) + << "non-integer input"; + caffe_axpy(N_, Dtype(1), top_diff + n * N_, weight_diff + index * N_); + } + } + if (bias_term_ && this->param_propagate_down_[1]) { + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff(); + caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff, + bias_multiplier_.cpu_data(), Dtype(1), bias_diff); + } +} + +#ifdef CPU_ONLY +STUB_GPU(EmbedLayer); +#endif + +INSTANTIATE_CLASS(EmbedLayer); +REGISTER_LAYER_CLASS(Embed); + +} // namespace caffe diff --git a/src/caffe/layers/embed_layer.cu b/src/caffe/layers/embed_layer.cu new file mode 100644 index 00000000..672fb9c6 --- /dev/null +++ b/src/caffe/layers/embed_layer.cu @@ -0,0 +1,85 @@ +#include <vector> + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/common_layers.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/gpu_util.cuh" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + +template <typename Dtype> +__global__ void EmbedForward(const int nthreads, const Dtype* bottom_data, + const Dtype* weight, const int M, const int N, const int K, + Dtype* top_data) { + CUDA_KERNEL_LOOP(top_index, nthreads) { + const int n = top_index / N; + const int d = top_index % N; + const int index = static_cast<int>(bottom_data[n]); + const int weight_index = index * N + d; + top_data[top_index] = weight[weight_index]; + } +} + +template <typename Dtype> +__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data, + const Dtype* top_diff, const int M, const int N, const int K, + Dtype* weight_diff); + +template <typename Dtype> +__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data, + const Dtype* top_diff, const int M, const int N, const int K, + Dtype* weight_diff) { + CUDA_KERNEL_LOOP(top_index, nthreads) { + const int n = top_index / N; + const int d = top_index % N; + const int index = static_cast<int>(bottom_data[n]); + const int weight_index = index * N + d; + caffe_gpu_atomic_add(top_diff[top_index], weight_diff + weight_index); + } +} + +template <typename Dtype> +void EmbedLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + const Dtype* weight = this->blobs_[0]->gpu_data(); + const int count = top[0]->count(); + EmbedForward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators) + <<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>( + count, bottom_data, weight, M_, N_, K_, top_data); + if (bias_term_) { + caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1), + bias_multiplier_.gpu_data(), + this->blobs_[1]->gpu_data(), Dtype(1), top_data); + } +} + +template <typename Dtype> +void EmbedLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { + CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input."; + if (this->param_propagate_down_[0]) { + const int top_count = top[0]->count(); + const int count = this->blobs_[0]->count(); + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); + EmbedBackward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators) + <<<CAFFE_GET_BLOCKS(top_count), CAFFE_CUDA_NUM_THREADS>>>( + top_count, bottom_data, top_diff, M_, N_, K_, weight_diff); + } + if (bias_term_ && this->param_propagate_down_[1]) { + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff(); + caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff, + bias_multiplier_.gpu_data(), Dtype(1), bias_diff); + } +} + +INSTANTIATE_LAYER_GPU_FUNCS(EmbedLayer); + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index d4c97d2b..35264610 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -301,7 +301,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 137 (last added: reduction_param) +// LayerParameter next available layer-specific ID: 138 (last added: embed_param) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -357,6 +357,7 @@ message LayerParameter { optional DropoutParameter dropout_param = 108; optional DummyDataParameter dummy_data_param = 109; optional EltwiseParameter eltwise_param = 110; + optional EmbedParameter embed_param = 137; optional ExpParameter exp_param = 111; optional FlattenParameter flatten_param = 135; optional HDF5DataParameter hdf5_data_param = 112; @@ -562,6 +563,21 @@ message EltwiseParameter { optional bool stable_prod_grad = 3 [default = true]; } +// Message that stores parameters used by EmbedLayer +message EmbedParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + // The input is given as integers to be interpreted as one-hot + // vector indices with dimension num_input. Hence num_input should be + // 1 greater than the maximum possible input value. + optional uint32 input_dim = 2; + + optional bool bias_term = 3 [default = true]; // Whether to use a bias term + optional FillerParameter weight_filler = 4; // The filler for the weight + optional FillerParameter bias_filler = 5; // The filler for the bias + +} + +// Message that stores parameters used by ExpLayer message ExpParameter { // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. // Or if base is set to the default (-1), base is set to e, diff --git a/src/caffe/test/test_embed_layer.cpp b/src/caffe/test/test_embed_layer.cpp new file mode 100644 index 00000000..7a4fb980 --- /dev/null +++ b/src/caffe/test/test_embed_layer.cpp @@ -0,0 +1,183 @@ +#include <cstring> +#include <vector> + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +#ifndef CPU_ONLY +extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; +#endif + +template <typename TypeParam> +class EmbedLayerTest : public MultiDeviceTest<TypeParam> { + typedef typename TypeParam::Dtype Dtype; + protected: + EmbedLayerTest() + : blob_bottom_(new Blob<Dtype>(4, 1, 1, 1)), + blob_top_(new Blob<Dtype>()) { + // fill the values + FillerParameter filler_param; + UniformFiller<Dtype> filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~EmbedLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob<Dtype>* const blob_bottom_; + Blob<Dtype>* const blob_top_; + vector<Blob<Dtype>*> blob_bottom_vec_; + vector<Blob<Dtype>*> blob_top_vec_; +}; + +TYPED_TEST_CASE(EmbedLayerTest, TestDtypesAndDevices); + +TYPED_TEST(EmbedLayerTest, TestSetUp) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EmbedParameter* embed_param = layer_param.mutable_embed_param(); + embed_param->set_num_output(10); + embed_param->set_input_dim(5); + shared_ptr<EmbedLayer<Dtype> > layer(new EmbedLayer<Dtype>(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(this->blob_top_->num_axes(), 5); + EXPECT_EQ(this->blob_top_->shape(0), 4); + EXPECT_EQ(this->blob_top_->shape(1), 1); + EXPECT_EQ(this->blob_top_->shape(2), 1); + EXPECT_EQ(this->blob_top_->shape(3), 1); + EXPECT_EQ(this->blob_top_->shape(4), 10); +} + +TYPED_TEST(EmbedLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EmbedParameter* embed_param = layer_param.mutable_embed_param(); + const int kNumOutput = 10; + const int kInputDim = 5; + embed_param->set_num_output(kNumOutput); + embed_param->set_input_dim(kInputDim); + embed_param->mutable_weight_filler()->set_type("uniform"); + embed_param->mutable_weight_filler()->set_min(-10); + embed_param->mutable_weight_filler()->set_max(10); + embed_param->set_bias_term(false); + shared_ptr<EmbedLayer<Dtype> > layer(new EmbedLayer<Dtype>(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(1, layer->blobs().size()); + vector<int> weight_shape(2); + weight_shape[0] = kInputDim; + weight_shape[1] = kNumOutput; + ASSERT_TRUE(weight_shape == layer->blobs()[0]->shape()); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + this->blob_bottom_->mutable_cpu_data()[i] = caffe_rng_rand() % kInputDim; + } + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + vector<int> weight_offset(2, 0); + vector<int> top_offset(5, 0); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + weight_offset[0] = static_cast<int>(this->blob_bottom_->cpu_data()[i]); + weight_offset[1] = 0; + top_offset[0] = i; + top_offset[4] = 0; + for (int j = 0; j < kNumOutput; ++j) { + EXPECT_EQ(layer->blobs()[0]->data_at(weight_offset), + this->blob_top_->data_at(top_offset)); + ++top_offset[4]; + ++weight_offset[1]; + } + } +} + +TYPED_TEST(EmbedLayerTest, TestForwardWithBias) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EmbedParameter* embed_param = layer_param.mutable_embed_param(); + const int kNumOutput = 10; + const int kInputDim = 5; + embed_param->set_num_output(kNumOutput); + embed_param->set_input_dim(kInputDim); + embed_param->mutable_weight_filler()->set_type("uniform"); + embed_param->mutable_weight_filler()->set_min(-10); + embed_param->mutable_weight_filler()->set_max(10); + embed_param->mutable_bias_filler()->CopyFrom(embed_param->weight_filler()); + embed_param->set_bias_term(true); + shared_ptr<EmbedLayer<Dtype> > layer(new EmbedLayer<Dtype>(layer_param)); + layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + ASSERT_EQ(2, layer->blobs().size()); + vector<int> weight_shape(2); + weight_shape[0] = kInputDim; + weight_shape[1] = kNumOutput; + ASSERT_TRUE(weight_shape == layer->blobs()[0]->shape()); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + this->blob_bottom_->mutable_cpu_data()[i] = caffe_rng_rand() % kInputDim; + } + layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_); + vector<int> bias_offset(1, 0); + vector<int> weight_offset(2, 0); + vector<int> top_offset(5, 0); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + weight_offset[0] = static_cast<int>(this->blob_bottom_->cpu_data()[i]); + weight_offset[1] = 0; + top_offset[0] = i; + top_offset[4] = 0; + bias_offset[0] = 0; + for (int j = 0; j < kNumOutput; ++j) { + EXPECT_EQ(layer->blobs()[0]->data_at(weight_offset) + + layer->blobs()[1]->data_at(bias_offset), + this->blob_top_->data_at(top_offset)); + ++top_offset[4]; + ++weight_offset[1]; + ++bias_offset[0]; + } + } +} + +TYPED_TEST(EmbedLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EmbedParameter* embed_param = layer_param.mutable_embed_param(); + embed_param->set_num_output(10); + embed_param->set_input_dim(5); + embed_param->set_bias_term(false); + embed_param->mutable_weight_filler()->set_type("uniform"); + embed_param->mutable_weight_filler()->set_min(-10); + embed_param->mutable_weight_filler()->set_max(10); + EmbedLayer<Dtype> layer(layer_param); + GradientChecker<Dtype> checker(1e-2, 1e-3); + this->blob_bottom_->mutable_cpu_data()[0] = 4; + this->blob_bottom_->mutable_cpu_data()[1] = 2; + this->blob_bottom_->mutable_cpu_data()[2] = 2; + this->blob_bottom_->mutable_cpu_data()[3] = 3; + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, -2); +} + +TYPED_TEST(EmbedLayerTest, TestGradientWithBias) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + EmbedParameter* embed_param = layer_param.mutable_embed_param(); + embed_param->set_num_output(10); + embed_param->set_input_dim(5); + embed_param->set_bias_term(true); + embed_param->mutable_weight_filler()->set_type("uniform"); + embed_param->mutable_weight_filler()->set_min(-10); + embed_param->mutable_weight_filler()->set_max(10); + embed_param->mutable_bias_filler()->CopyFrom(embed_param->weight_filler()); + EmbedLayer<Dtype> layer(layer_param); + GradientChecker<Dtype> checker(1e-2, 1e-3); + this->blob_bottom_->mutable_cpu_data()[0] = 4; + this->blob_bottom_->mutable_cpu_data()[1] = 2; + this->blob_bottom_->mutable_cpu_data()[2] = 2; + this->blob_bottom_->mutable_cpu_data()[3] = 3; + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_, -2); +} + +} // namespace caffe |