summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/caffe/common_layers.hpp67
-rw-r--r--include/caffe/test/test_gradient_check_util.hpp11
-rw-r--r--include/caffe/util/gpu_util.cuh35
-rw-r--r--python/caffe/draw.py6
-rw-r--r--src/caffe/layers/concat_layer.cpp13
-rw-r--r--src/caffe/layers/concat_layer.cu17
-rw-r--r--src/caffe/layers/embed_layer.cpp122
-rw-r--r--src/caffe/layers/embed_layer.cu85
-rw-r--r--src/caffe/layers/tile_layer.cpp62
-rw-r--r--src/caffe/layers/tile_layer.cu67
-rw-r--r--src/caffe/net.cpp43
-rw-r--r--src/caffe/proto/caffe.proto29
-rw-r--r--src/caffe/test/test_concat_layer.cpp9
-rw-r--r--src/caffe/test/test_embed_layer.cpp183
-rw-r--r--src/caffe/test/test_tile_layer.cpp162
-rw-r--r--src/caffe/util/insert_splits.cpp3
16 files changed, 880 insertions, 34 deletions
diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp
index d2c0ce6d..8e64b3e5 100644
--- a/include/caffe/common_layers.hpp
+++ b/include/caffe/common_layers.hpp
@@ -181,6 +181,44 @@ class EltwiseLayer : public Layer<Dtype> {
};
/**
+ * @brief A layer for learning "embeddings" of one-hot vector input.
+ * Equivalent to an InnerProductLayer with one-hot vectors as input, but
+ * for efficiency the input is the "hot" index of each column itself.
+ *
+ * TODO(dox): thorough documentation for Forward, Backward, and proto params.
+ */
+template <typename Dtype>
+class EmbedLayer : public Layer<Dtype> {
+ public:
+ explicit EmbedLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+
+ virtual inline const char* type() const { return "Embed"; }
+ virtual inline int ExactNumBottomBlobs() const { return 1; }
+ virtual inline int ExactNumTopBlobs() const { return 1; }
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+ virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+ int M_;
+ int K_;
+ int N_;
+ bool bias_term_;
+ Blob<Dtype> bias_multiplier_;
+};
+
+/**
* @brief Takes two+ Blobs, interprets last Blob as a selector and
* filter remaining Blobs accordingly with selector data (0 means that
* the corresponding item has to be filtered, non-zero means that corresponding
@@ -606,6 +644,35 @@ class SliceLayer : public Layer<Dtype> {
vector<int> slice_point_;
};
+/**
+ * @brief Copy a Blob along specified dimensions.
+ */
+template <typename Dtype>
+class TileLayer : public Layer<Dtype> {
+ public:
+ explicit TileLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+
+ virtual inline const char* type() const { return "Tile"; }
+ virtual inline int ExactNumBottomBlobs() const { return 1; }
+ virtual inline int ExactNumTopBlobs() const { return 1; }
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+
+ virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+ virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+ unsigned int axis_, tiles_, outer_dim_, inner_dim_;
+};
+
} // namespace caffe
#endif // CAFFE_COMMON_LAYERS_HPP_
diff --git a/include/caffe/test/test_gradient_check_util.hpp b/include/caffe/test/test_gradient_check_util.hpp
index cc5dcbad..25f35d15 100644
--- a/include/caffe/test/test_gradient_check_util.hpp
+++ b/include/caffe/test/test_gradient_check_util.hpp
@@ -45,6 +45,10 @@ class GradientChecker {
void CheckGradientEltwise(Layer<Dtype>* layer,
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
+ // Checks the gradient of a single output with respect to particular input
+ // blob(s). If check_bottom = i >= 0, check only the ith bottom Blob.
+ // If check_bottom == -1, check everything -- all bottom Blobs and all
+ // param Blobs. Otherwise (if check_bottom < -1), check only param Blobs.
void CheckGradientSingle(Layer<Dtype>* layer,
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top,
int check_bottom, int top_id, int top_data_id, bool element_wise = false);
@@ -83,21 +87,22 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>* layer,
// First, figure out what blobs we need to check against, and zero init
// parameter blobs.
vector<Blob<Dtype>*> blobs_to_check;
- vector<bool> propagate_down(bottom.size(), check_bottom < 0);
+ vector<bool> propagate_down(bottom.size(), check_bottom == -1);
for (int i = 0; i < layer->blobs().size(); ++i) {
Blob<Dtype>* blob = layer->blobs()[i].get();
caffe_set(blob->count(), static_cast<Dtype>(0), blob->mutable_cpu_diff());
blobs_to_check.push_back(blob);
}
- if (check_bottom < 0) {
+ if (check_bottom == -1) {
for (int i = 0; i < bottom.size(); ++i) {
blobs_to_check.push_back(bottom[i]);
}
- } else {
+ } else if (check_bottom >= 0) {
CHECK_LT(check_bottom, bottom.size());
blobs_to_check.push_back(bottom[check_bottom]);
propagate_down[check_bottom] = true;
}
+ CHECK_GT(blobs_to_check.size(), 0) << "No blobs to check.";
// Compute the gradient analytically using Backward
Caffe::set_random_seed(seed_);
// Ignore the loss from the layer (it's just the weighted sum of the losses
diff --git a/include/caffe/util/gpu_util.cuh b/include/caffe/util/gpu_util.cuh
new file mode 100644
index 00000000..994202f2
--- /dev/null
+++ b/include/caffe/util/gpu_util.cuh
@@ -0,0 +1,35 @@
+#ifndef CAFFE_UTIL_GPU_UTIL_H_
+#define CAFFE_UTIL_GPU_UTIL_H_
+
+namespace caffe {
+
+template <typename Dtype>
+inline __device__ Dtype caffe_gpu_atomic_add(const Dtype val, Dtype* address);
+
+template <>
+inline __device__
+float caffe_gpu_atomic_add(const float val, float* address) {
+ return atomicAdd(address, val);
+}
+
+// double atomicAdd implementation taken from:
+// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#axzz3PVCpVsEG
+template <>
+inline __device__
+double caffe_gpu_atomic_add(const double val, double* address) {
+ unsigned long long int* address_as_ull = // NOLINT(runtime/int)
+ // NOLINT_NEXT_LINE(runtime/int)
+ reinterpret_cast<unsigned long long int*>(address);
+ unsigned long long int old = *address_as_ull; // NOLINT(runtime/int)
+ unsigned long long int assumed; // NOLINT(runtime/int)
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_ull, assumed,
+ __double_as_longlong(val + __longlong_as_double(assumed)));
+ } while (assumed != old);
+ return __longlong_as_double(old);
+}
+
+} // namespace caffe
+
+#endif // CAFFE_UTIL_GPU_UTIL_H_
diff --git a/python/caffe/draw.py b/python/caffe/draw.py
index 324929de..a002b60b 100644
--- a/python/caffe/draw.py
+++ b/python/caffe/draw.py
@@ -40,7 +40,7 @@ def get_edge_label(layer):
if layer.type == 'Data':
edge_label = 'Batch ' + str(layer.data_param.batch_size)
- elif layer.type == 'Convolution':
+ elif layer.type == 'Convolution' or layer.type == 'Deconvolution':
edge_label = str(layer.convolution_param.num_output)
elif layer.type == 'InnerProduct':
edge_label = str(layer.inner_product_param.num_output)
@@ -74,7 +74,7 @@ def get_layer_label(layer, rankdir):
# horizontal space is not; separate words with newlines
separator = '\\n'
- if layer.type == 'Convolution':
+ if layer.type == 'Convolution' or layer.type == 'Deconvolution':
# Outer double quotes needed or else colon characters don't parse
# properly
node_label = '"%s%s(%s)%skernel size: %d%sstride: %d%spad: %d"' %\
@@ -109,7 +109,7 @@ def choose_color_by_layertype(layertype):
"""Define colors for nodes based on the layer type.
"""
color = '#6495ED' # Default
- if layertype == 'Convolution':
+ if layertype == 'Convolution' or layertype == 'Deconvolution':
color = '#FF5050'
elif layertype == 'Pooling':
color = '#FF9900'
diff --git a/src/caffe/layers/concat_layer.cpp b/src/caffe/layers/concat_layer.cpp
index 1cac8fc3..95fba105 100644
--- a/src/caffe/layers/concat_layer.cpp
+++ b/src/caffe/layers/concat_layer.cpp
@@ -76,13 +76,14 @@ void ConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
int offset_concat_axis = 0;
const int top_concat_axis = top[0]->shape(concat_axis_);
for (int i = 0; i < bottom.size(); ++i) {
- if (!propagate_down[i]) { continue; }
- Dtype* bottom_diff = bottom[i]->mutable_cpu_diff();
const int bottom_concat_axis = bottom[i]->shape(concat_axis_);
- for (int n = 0; n < num_concats_; ++n) {
- caffe_copy(bottom_concat_axis * concat_input_size_, top_diff +
- (n * top_concat_axis + offset_concat_axis) * concat_input_size_,
- bottom_diff + n * bottom_concat_axis * concat_input_size_);
+ if (propagate_down[i]) {
+ Dtype* bottom_diff = bottom[i]->mutable_cpu_diff();
+ for (int n = 0; n < num_concats_; ++n) {
+ caffe_copy(bottom_concat_axis * concat_input_size_, top_diff +
+ (n * top_concat_axis + offset_concat_axis) * concat_input_size_,
+ bottom_diff + n * bottom_concat_axis * concat_input_size_);
+ }
}
offset_concat_axis += bottom_concat_axis;
}
diff --git a/src/caffe/layers/concat_layer.cu b/src/caffe/layers/concat_layer.cu
index 8f2e85d8..3c64c7ef 100644
--- a/src/caffe/layers/concat_layer.cu
+++ b/src/caffe/layers/concat_layer.cu
@@ -53,15 +53,16 @@ void ConcatLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const int top_concat_axis = top[0]->shape(concat_axis_);
const bool kForward = false;
for (int i = 0; i < bottom.size(); ++i) {
- if (!propagate_down[i]) { continue; }
- Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
const int bottom_concat_axis = bottom[i]->shape(concat_axis_);
- const int bottom_concat_size = bottom_concat_axis * concat_input_size_;
- const int nthreads = bottom_concat_size * num_concats_;
- Concat<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
- <<<CAFFE_GET_BLOCKS(nthreads), CAFFE_CUDA_NUM_THREADS>>>(
- nthreads, top_diff, kForward, num_concats_, concat_input_size_,
- top_concat_axis, bottom_concat_axis, offset_concat_axis, bottom_diff);
+ if (propagate_down[i]) {
+ Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
+ const int bottom_concat_size = bottom_concat_axis * concat_input_size_;
+ const int nthreads = bottom_concat_size * num_concats_;
+ Concat<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
+ <<<CAFFE_GET_BLOCKS(nthreads), CAFFE_CUDA_NUM_THREADS>>>(
+ nthreads, top_diff, kForward, num_concats_, concat_input_size_,
+ top_concat_axis, bottom_concat_axis, offset_concat_axis, bottom_diff);
+ }
offset_concat_axis += bottom_concat_axis;
}
}
diff --git a/src/caffe/layers/embed_layer.cpp b/src/caffe/layers/embed_layer.cpp
new file mode 100644
index 00000000..be6b2cd2
--- /dev/null
+++ b/src/caffe/layers/embed_layer.cpp
@@ -0,0 +1,122 @@
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/common_layers.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void EmbedLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ N_ = this->layer_param_.embed_param().num_output();
+ CHECK_GT(N_, 0) << "EmbedLayer num_output must be positive.";
+ K_ = this->layer_param_.embed_param().input_dim();
+ CHECK_GT(K_, 0) << "EmbedLayer input_dim must be positive.";
+ bias_term_ = this->layer_param_.embed_param().bias_term();
+ // Check if we need to set up the weights
+ if (this->blobs_.size() > 0) {
+ LOG(INFO) << "Skipping parameter initialization";
+ } else {
+ if (bias_term_) {
+ this->blobs_.resize(2);
+ } else {
+ this->blobs_.resize(1);
+ }
+ // Initialize the weights --
+ // transposed from InnerProductLayer for spatial locality.
+ vector<int> weight_shape(2);
+ weight_shape[0] = K_;
+ weight_shape[1] = N_;
+ this->blobs_[0].reset(new Blob<Dtype>(weight_shape));
+ // fill the weights
+ shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
+ this->layer_param_.embed_param().weight_filler()));
+ weight_filler->Fill(this->blobs_[0].get());
+ // If necessary, initialize and fill the bias term
+ if (bias_term_) {
+ vector<int> bias_shape(1, N_);
+ this->blobs_[1].reset(new Blob<Dtype>(bias_shape));
+ shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>(
+ this->layer_param_.embed_param().bias_filler()));
+ bias_filler->Fill(this->blobs_[1].get());
+ }
+ } // parameter initialization
+ this->param_propagate_down_.resize(this->blobs_.size(), true);
+}
+
+template <typename Dtype>
+void EmbedLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ // Figure out the dimensions
+ M_ = bottom[0]->count();
+ vector<int> top_shape = bottom[0]->shape();
+ top_shape.push_back(N_);
+ top[0]->Reshape(top_shape);
+ // Set up the bias multiplier
+ if (bias_term_) {
+ vector<int> bias_shape(1, M_);
+ bias_multiplier_.Reshape(bias_shape);
+ caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data());
+ }
+}
+
+template <typename Dtype>
+void EmbedLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ const Dtype* weight = this->blobs_[0]->cpu_data();
+ Dtype* top_data = top[0]->mutable_cpu_data();
+ int index;
+ for (int n = 0; n < M_; ++n) {
+ index = static_cast<int>(bottom_data[n]);
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, K_);
+ DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n]) << "non-integer input";
+ caffe_copy(N_, weight + index * N_, top_data + n * N_);
+ }
+ if (bias_term_) {
+ const Dtype* bias = this->blobs_[1]->cpu_data();
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1),
+ bias_multiplier_.cpu_data(), bias, Dtype(1), top_data);
+ }
+}
+
+template <typename Dtype>
+void EmbedLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+ CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input.";
+ if (this->param_propagate_down_[0]) {
+ const Dtype* top_diff = top[0]->cpu_diff();
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ // Gradient with respect to weight
+ Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
+ int index;
+ for (int n = 0; n < M_; ++n) {
+ index = static_cast<int>(bottom_data[n]);
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, K_);
+ DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n])
+ << "non-integer input";
+ caffe_axpy(N_, Dtype(1), top_diff + n * N_, weight_diff + index * N_);
+ }
+ }
+ if (bias_term_ && this->param_propagate_down_[1]) {
+ const Dtype* top_diff = top[0]->cpu_diff();
+ Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff();
+ caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff,
+ bias_multiplier_.cpu_data(), Dtype(1), bias_diff);
+ }
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(EmbedLayer);
+#endif
+
+INSTANTIATE_CLASS(EmbedLayer);
+REGISTER_LAYER_CLASS(Embed);
+
+} // namespace caffe
diff --git a/src/caffe/layers/embed_layer.cu b/src/caffe/layers/embed_layer.cu
new file mode 100644
index 00000000..672fb9c6
--- /dev/null
+++ b/src/caffe/layers/embed_layer.cu
@@ -0,0 +1,85 @@
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/common_layers.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/gpu_util.cuh"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+__global__ void EmbedForward(const int nthreads, const Dtype* bottom_data,
+ const Dtype* weight, const int M, const int N, const int K,
+ Dtype* top_data) {
+ CUDA_KERNEL_LOOP(top_index, nthreads) {
+ const int n = top_index / N;
+ const int d = top_index % N;
+ const int index = static_cast<int>(bottom_data[n]);
+ const int weight_index = index * N + d;
+ top_data[top_index] = weight[weight_index];
+ }
+}
+
+template <typename Dtype>
+__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data,
+ const Dtype* top_diff, const int M, const int N, const int K,
+ Dtype* weight_diff);
+
+template <typename Dtype>
+__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data,
+ const Dtype* top_diff, const int M, const int N, const int K,
+ Dtype* weight_diff) {
+ CUDA_KERNEL_LOOP(top_index, nthreads) {
+ const int n = top_index / N;
+ const int d = top_index % N;
+ const int index = static_cast<int>(bottom_data[n]);
+ const int weight_index = index * N + d;
+ caffe_gpu_atomic_add(top_diff[top_index], weight_diff + weight_index);
+ }
+}
+
+template <typename Dtype>
+void EmbedLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* top_data = top[0]->mutable_gpu_data();
+ const Dtype* weight = this->blobs_[0]->gpu_data();
+ const int count = top[0]->count();
+ EmbedForward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
+ <<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, bottom_data, weight, M_, N_, K_, top_data);
+ if (bias_term_) {
+ caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1),
+ bias_multiplier_.gpu_data(),
+ this->blobs_[1]->gpu_data(), Dtype(1), top_data);
+ }
+}
+
+template <typename Dtype>
+void EmbedLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+ CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input.";
+ if (this->param_propagate_down_[0]) {
+ const int top_count = top[0]->count();
+ const int count = this->blobs_[0]->count();
+ const Dtype* top_diff = top[0]->gpu_diff();
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
+ EmbedBackward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
+ <<<CAFFE_GET_BLOCKS(top_count), CAFFE_CUDA_NUM_THREADS>>>(
+ top_count, bottom_data, top_diff, M_, N_, K_, weight_diff);
+ }
+ if (bias_term_ && this->param_propagate_down_[1]) {
+ const Dtype* top_diff = top[0]->gpu_diff();
+ Dtype* bias_diff = this->blobs_[1]->mutable_gpu_diff();
+ caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff,
+ bias_multiplier_.gpu_data(), Dtype(1), bias_diff);
+ }
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(EmbedLayer);
+
+} // namespace caffe
diff --git a/src/caffe/layers/tile_layer.cpp b/src/caffe/layers/tile_layer.cpp
new file mode 100644
index 00000000..f55008cc
--- /dev/null
+++ b/src/caffe/layers/tile_layer.cpp
@@ -0,0 +1,62 @@
+#include <vector>
+
+#include "caffe/common_layers.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void TileLayer<Dtype>::Reshape(
+ const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+ const TileParameter& tile_param = this->layer_param_.tile_param();
+ axis_ = bottom[0]->CanonicalAxisIndex(tile_param.axis());
+ CHECK(tile_param.has_tiles()) << "Number of tiles must be specified";
+ tiles_ = tile_param.tiles();
+ CHECK_GT(tiles_, 0) << "Number of tiles must be positive.";
+ vector<int> top_shape = bottom[0]->shape();
+ top_shape[axis_] = bottom[0]->shape(axis_) * tiles_;
+ top[0]->Reshape(top_shape);
+ outer_dim_ = bottom[0]->count(0, axis_);
+ inner_dim_ = bottom[0]->count(axis_);
+}
+
+template <typename Dtype>
+void TileLayer<Dtype>::Forward_cpu(
+ const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* top_data = top[0]->mutable_cpu_data();
+ for (int i = 0; i < outer_dim_; ++i) {
+ for (int t = 0; t < tiles_; ++t) {
+ caffe_copy(inner_dim_, bottom_data, top_data);
+ top_data += inner_dim_;
+ }
+ bottom_data += inner_dim_;
+ }
+}
+
+template <typename Dtype>
+void TileLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+ if (!propagate_down[0]) { return; }
+ const Dtype* top_diff = top[0]->cpu_diff();
+ Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+ for (int i = 0; i < outer_dim_; ++i) {
+ caffe_copy(inner_dim_, top_diff, bottom_diff);
+ top_diff += inner_dim_;
+ for (int t = 1; t < tiles_; ++t) {
+ caffe_axpy(inner_dim_, Dtype(1), top_diff, bottom_diff);
+ top_diff += inner_dim_;
+ }
+ bottom_diff += inner_dim_;
+ }
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(TileLayer);
+#endif
+
+INSTANTIATE_CLASS(TileLayer);
+REGISTER_LAYER_CLASS(Tile);
+
+} // namespace caffe
diff --git a/src/caffe/layers/tile_layer.cu b/src/caffe/layers/tile_layer.cu
new file mode 100644
index 00000000..7fd3bc47
--- /dev/null
+++ b/src/caffe/layers/tile_layer.cu
@@ -0,0 +1,67 @@
+#include <vector>
+
+#include "caffe/common_layers.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+__global__ void Tile(const int nthreads, const Dtype* bottom_data,
+ const int tile_size, const int num_tiles, const int bottom_tile_axis,
+ Dtype* top_data) {
+ CUDA_KERNEL_LOOP(index, nthreads) {
+ const int d = index % tile_size;
+ const int b = (index / tile_size / num_tiles) % bottom_tile_axis;
+ const int n = index / tile_size / num_tiles / bottom_tile_axis;
+ const int bottom_index = (n * bottom_tile_axis + b) * tile_size + d;
+ top_data[index] = bottom_data[bottom_index];
+ }
+}
+
+template <typename Dtype>
+void TileLayer<Dtype>::Forward_gpu(
+ const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* top_data = top[0]->mutable_gpu_data();
+ const int bottom_tile_axis = bottom[0]->shape(axis_);
+ const int nthreads = top[0]->count();
+ Tile<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
+ <<<CAFFE_GET_BLOCKS(nthreads), CAFFE_CUDA_NUM_THREADS>>>(
+ nthreads, bottom_data, inner_dim_, tiles_, bottom_tile_axis, top_data);
+}
+
+template <typename Dtype>
+__global__ void TileBackward(const int nthreads, const Dtype* top_diff,
+ const int tile_size, const int num_tiles, const int bottom_tile_axis,
+ Dtype* bottom_diff) {
+ CUDA_KERNEL_LOOP(index, nthreads) {
+ const int d = index % tile_size;
+ const int b = (index / tile_size) % bottom_tile_axis;
+ const int n = index / tile_size / bottom_tile_axis;
+ bottom_diff[index] = 0;
+ int top_index = (n * num_tiles * bottom_tile_axis + b) * tile_size + d;
+ for (int t = 0; t < num_tiles; ++t) {
+ bottom_diff[index] += top_diff[top_index];
+ top_index += bottom_tile_axis * tile_size;
+ }
+ }
+}
+
+template <typename Dtype>
+void TileLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+ if (!propagate_down[0]) { return; }
+ const Dtype* top_diff = top[0]->gpu_diff();
+ Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+ const int bottom_tile_axis = bottom[0]->shape(axis_);
+ const int tile_size = inner_dim_ / bottom_tile_axis;
+ const int nthreads = bottom[0]->count();
+ TileBackward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
+ <<<CAFFE_GET_BLOCKS(nthreads), CAFFE_CUDA_NUM_THREADS>>>(
+ nthreads, top_diff, tile_size, tiles_, bottom_tile_axis, bottom_diff);
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(TileLayer);
+
+} // namespace caffe
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index 7875285f..f1fc63ab 100644
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
@@ -424,7 +424,8 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
blob_name_to_idx->find(blob_name) != blob_name_to_idx->end()) {
// If we are not doing in-place computation but have duplicated blobs,
// raise an error.
- LOG(FATAL) << "Duplicate blobs produced by multiple sources.";
+ LOG(FATAL) << "Top blob '" << blob_name
+ << "' produced by multiple sources.";
} else {
// Normal output.
if (Caffe::root_solver()) {
@@ -468,8 +469,8 @@ int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
const LayerParameter& layer_param = param.layer(layer_id);
const string& blob_name = layer_param.bottom(bottom_id);
if (available_blobs->find(blob_name) == available_blobs->end()) {
- LOG(FATAL) << "Unknown blob input " << blob_name
- << " (at index " << bottom_id << ") to layer " << layer_id;
+ LOG(FATAL) << "Unknown bottom blob '" << blob_name << "' (layer '"
+ << layer_param.name() << "', bottom index " << bottom_id << ")";
}
const int blob_id = (*blob_name_to_idx)[blob_name];
if (Caffe::root_solver()) {
@@ -545,10 +546,19 @@ void Net<Dtype>::AppendParam(const NetParameter& param, const int layer_id,
ParamSpec_DimCheckMode_PERMISSIVE)) {
// Permissive dimension checking -- only check counts are the same.
CHECK_EQ(this_blob->count(), owner_blob->count())
- << "Shared parameter blobs must have the same count.";
+ << "Cannot share param '" << param_name << "' owned by layer '"
+ << layer_names_[owner_layer_id] << "' with layer '"
+ << layer_names_[layer_id] << "'; count mismatch. Owner layer param "
+ << "shape is " << owner_blob->shape_string() << "; sharing layer "
+ << "shape is " << this_blob->shape_string();
} else {
// Strict dimension checking -- all dims must be the same.
- CHECK(this_blob->shape() == owner_blob->shape());
+ CHECK(this_blob->shape() == owner_blob->shape())
+ << "Cannot share param '" << param_name << "' owned by layer '"
+ << layer_names_[owner_layer_id] << "' with layer '"
+ << layer_names_[layer_id] << "'; shape mismatch. Owner layer param "
+ << "shape is " << owner_blob->shape_string() << "; sharing layer "
+ << "expects shape " << this_blob->shape_string();
}
const int learnable_param_id = learnable_param_ids_[owner_net_param_id];
learnable_param_ids_.push_back(learnable_param_id);
@@ -775,7 +785,11 @@ void Net<Dtype>::ShareTrainedLayersWith(const Net* other) {
<< "Incompatible number of blobs for layer " << source_layer_name;
for (int j = 0; j < target_blobs.size(); ++j) {
Blob<Dtype>* source_blob = source_layer->blobs()[j].get();
- CHECK(target_blobs[j]->shape() == source_blob->shape());
+ CHECK(target_blobs[j]->shape() == source_blob->shape())
+ << "Cannot share param " << j << " weights from layer '"
+ << source_layer_name << "'; shape mismatch. Source param shape is "
+ << source_blob->shape_string() << "; target param shape is "
+ << target_blobs[j]->shape_string();
target_blobs[j]->ShareData(*source_blob);
}
}
@@ -839,6 +853,17 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
<< "Incompatible number of blobs for layer " << source_layer_name;
for (int j = 0; j < target_blobs.size(); ++j) {
+ if (!target_blobs[j]->ShapeEquals(source_layer.blobs(j))) {
+ Blob<Dtype> source_blob;
+ const bool kReshape = true;
+ source_blob.FromProto(source_layer.blobs(j), kReshape);
+ LOG(FATAL) << "Cannot copy param " << j << " weights from layer '"
+ << source_layer_name << "'; shape mismatch. Source param shape is "
+ << source_blob.shape_string() << "; target param shape is "
+ << target_blobs[j]->shape_string() << ". "
+ << "To learn this layer's parameters from scratch rather than "
+ << "copying from a saved net, rename the layer.";
+ }
const bool kReshape = false;
target_blobs[j]->FromProto(source_layer.blobs(j), kReshape);
}
@@ -924,12 +949,6 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {
DLOG(INFO) << "Serializing " << layers_.size() << " layers";
for (int i = 0; i < layers_.size(); ++i) {
LayerParameter* layer_param = param->add_layer();
- for (int j = 0; j < bottom_id_vecs_[i].size(); ++j) {
- layer_param->add_bottom(blob_names_[bottom_id_vecs_[i][j]]);
- }
- for (int j = 0; j < top_id_vecs_[i].size(); ++j) {
- layer_param->add_top(blob_names_[top_id_vecs_[i][j]]);
- }
layers_[i]->ToProto(layer_param, write_diff);
}
}
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index d4c97d2b..aa299f86 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -301,7 +301,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
-// LayerParameter next available layer-specific ID: 137 (last added: reduction_param)
+// LayerParameter next available layer-specific ID: 139 (last added: tile_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
@@ -357,6 +357,7 @@ message LayerParameter {
optional DropoutParameter dropout_param = 108;
optional DummyDataParameter dummy_data_param = 109;
optional EltwiseParameter eltwise_param = 110;
+ optional EmbedParameter embed_param = 137;
optional ExpParameter exp_param = 111;
optional FlattenParameter flatten_param = 135;
optional HDF5DataParameter hdf5_data_param = 112;
@@ -382,6 +383,7 @@ message LayerParameter {
optional SliceParameter slice_param = 126;
optional TanHParameter tanh_param = 127;
optional ThresholdParameter threshold_param = 128;
+ optional TileParameter tile_param = 138;
optional WindowDataParameter window_data_param = 129;
}
@@ -562,6 +564,21 @@ message EltwiseParameter {
optional bool stable_prod_grad = 3 [default = true];
}
+// Message that stores parameters used by EmbedLayer
+message EmbedParameter {
+ optional uint32 num_output = 1; // The number of outputs for the layer
+ // The input is given as integers to be interpreted as one-hot
+ // vector indices with dimension num_input. Hence num_input should be
+ // 1 greater than the maximum possible input value.
+ optional uint32 input_dim = 2;
+
+ optional bool bias_term = 3 [default = true]; // Whether to use a bias term
+ optional FillerParameter weight_filler = 4; // The filler for the weight
+ optional FillerParameter bias_filler = 5; // The filler for the bias
+
+}
+
+// Message that stores parameters used by ExpLayer
message ExpParameter {
// ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0.
// Or if base is set to the default (-1), base is set to e,
@@ -903,6 +920,16 @@ message TanHParameter {
optional Engine engine = 1 [default = DEFAULT];
}
+// Message that stores parameters used by TileLayer
+message TileParameter {
+ // The index of the axis to tile.
+ optional int32 axis = 1 [default = 1];
+
+ // The number of copies (tiles) of the blob to output.
+ optional int32 tiles = 2;
+}
+
+// Message that stores parameters used by ThresholdLayer
message ThresholdParameter {
optional float threshold = 1 [default = 0]; // Strictly positive values
}
diff --git a/src/caffe/test/test_concat_layer.cpp b/src/caffe/test/test_concat_layer.cpp
index 662a50fa..088e0a41 100644
--- a/src/caffe/test/test_concat_layer.cpp
+++ b/src/caffe/test/test_concat_layer.cpp
@@ -173,4 +173,13 @@ TYPED_TEST(ConcatLayerTest, TestGradientChannels) {
this->blob_top_vec_);
}
+TYPED_TEST(ConcatLayerTest, TestGradientChannelsBottomOneOnly) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ ConcatLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-2);
+ checker.CheckGradient(&layer, this->blob_bottom_vec_0_,
+ this->blob_top_vec_, 1);
+}
+
} // namespace caffe
diff --git a/src/caffe/test/test_embed_layer.cpp b/src/caffe/test/test_embed_layer.cpp
new file mode 100644
index 00000000..7a4fb980
--- /dev/null
+++ b/src/caffe/test/test_embed_layer.cpp
@@ -0,0 +1,183 @@
+#include <cstring>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/vision_layers.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+namespace caffe {
+
+#ifndef CPU_ONLY
+extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+#endif
+
+template <typename TypeParam>
+class EmbedLayerTest : public MultiDeviceTest<TypeParam> {
+ typedef typename TypeParam::Dtype Dtype;
+ protected:
+ EmbedLayerTest()
+ : blob_bottom_(new Blob<Dtype>(4, 1, 1, 1)),
+ blob_top_(new Blob<Dtype>()) {
+ // fill the values
+ FillerParameter filler_param;
+ UniformFiller<Dtype> filler(filler_param);
+ filler.Fill(this->blob_bottom_);
+ blob_bottom_vec_.push_back(blob_bottom_);
+ blob_top_vec_.push_back(blob_top_);
+ }
+ virtual ~EmbedLayerTest() { delete blob_bottom_; delete blob_top_; }
+ Blob<Dtype>* const blob_bottom_;
+ Blob<Dtype>* const blob_top_;
+ vector<Blob<Dtype>*> blob_bottom_vec_;
+ vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+TYPED_TEST_CASE(EmbedLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(EmbedLayerTest, TestSetUp) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ EmbedParameter* embed_param = layer_param.mutable_embed_param();
+ embed_param->set_num_output(10);
+ embed_param->set_input_dim(5);
+ shared_ptr<EmbedLayer<Dtype> > layer(new EmbedLayer<Dtype>(layer_param));
+ layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ ASSERT_EQ(this->blob_top_->num_axes(), 5);
+ EXPECT_EQ(this->blob_top_->shape(0), 4);
+ EXPECT_EQ(this->blob_top_->shape(1), 1);
+ EXPECT_EQ(this->blob_top_->shape(2), 1);
+ EXPECT_EQ(this->blob_top_->shape(3), 1);
+ EXPECT_EQ(this->blob_top_->shape(4), 10);
+}
+
+TYPED_TEST(EmbedLayerTest, TestForward) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ EmbedParameter* embed_param = layer_param.mutable_embed_param();
+ const int kNumOutput = 10;
+ const int kInputDim = 5;
+ embed_param->set_num_output(kNumOutput);
+ embed_param->set_input_dim(kInputDim);
+ embed_param->mutable_weight_filler()->set_type("uniform");
+ embed_param->mutable_weight_filler()->set_min(-10);
+ embed_param->mutable_weight_filler()->set_max(10);
+ embed_param->set_bias_term(false);
+ shared_ptr<EmbedLayer<Dtype> > layer(new EmbedLayer<Dtype>(layer_param));
+ layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ ASSERT_EQ(1, layer->blobs().size());
+ vector<int> weight_shape(2);
+ weight_shape[0] = kInputDim;
+ weight_shape[1] = kNumOutput;
+ ASSERT_TRUE(weight_shape == layer->blobs()[0]->shape());
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ this->blob_bottom_->mutable_cpu_data()[i] = caffe_rng_rand() % kInputDim;
+ }
+ layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ vector<int> weight_offset(2, 0);
+ vector<int> top_offset(5, 0);
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ weight_offset[0] = static_cast<int>(this->blob_bottom_->cpu_data()[i]);
+ weight_offset[1] = 0;
+ top_offset[0] = i;
+ top_offset[4] = 0;
+ for (int j = 0; j < kNumOutput; ++j) {
+ EXPECT_EQ(layer->blobs()[0]->data_at(weight_offset),
+ this->blob_top_->data_at(top_offset));
+ ++top_offset[4];
+ ++weight_offset[1];
+ }
+ }
+}
+
+TYPED_TEST(EmbedLayerTest, TestForwardWithBias) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ EmbedParameter* embed_param = layer_param.mutable_embed_param();
+ const int kNumOutput = 10;
+ const int kInputDim = 5;
+ embed_param->set_num_output(kNumOutput);
+ embed_param->set_input_dim(kInputDim);
+ embed_param->mutable_weight_filler()->set_type("uniform");
+ embed_param->mutable_weight_filler()->set_min(-10);
+ embed_param->mutable_weight_filler()->set_max(10);
+ embed_param->mutable_bias_filler()->CopyFrom(embed_param->weight_filler());
+ embed_param->set_bias_term(true);
+ shared_ptr<EmbedLayer<Dtype> > layer(new EmbedLayer<Dtype>(layer_param));
+ layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ ASSERT_EQ(2, layer->blobs().size());
+ vector<int> weight_shape(2);
+ weight_shape[0] = kInputDim;
+ weight_shape[1] = kNumOutput;
+ ASSERT_TRUE(weight_shape == layer->blobs()[0]->shape());
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ this->blob_bottom_->mutable_cpu_data()[i] = caffe_rng_rand() % kInputDim;
+ }
+ layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ vector<int> bias_offset(1, 0);
+ vector<int> weight_offset(2, 0);
+ vector<int> top_offset(5, 0);
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ weight_offset[0] = static_cast<int>(this->blob_bottom_->cpu_data()[i]);
+ weight_offset[1] = 0;
+ top_offset[0] = i;
+ top_offset[4] = 0;
+ bias_offset[0] = 0;
+ for (int j = 0; j < kNumOutput; ++j) {
+ EXPECT_EQ(layer->blobs()[0]->data_at(weight_offset) +
+ layer->blobs()[1]->data_at(bias_offset),
+ this->blob_top_->data_at(top_offset));
+ ++top_offset[4];
+ ++weight_offset[1];
+ ++bias_offset[0];
+ }
+ }
+}
+
+TYPED_TEST(EmbedLayerTest, TestGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ EmbedParameter* embed_param = layer_param.mutable_embed_param();
+ embed_param->set_num_output(10);
+ embed_param->set_input_dim(5);
+ embed_param->set_bias_term(false);
+ embed_param->mutable_weight_filler()->set_type("uniform");
+ embed_param->mutable_weight_filler()->set_min(-10);
+ embed_param->mutable_weight_filler()->set_max(10);
+ EmbedLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-3);
+ this->blob_bottom_->mutable_cpu_data()[0] = 4;
+ this->blob_bottom_->mutable_cpu_data()[1] = 2;
+ this->blob_bottom_->mutable_cpu_data()[2] = 2;
+ this->blob_bottom_->mutable_cpu_data()[3] = 3;
+ checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_, -2);
+}
+
+TYPED_TEST(EmbedLayerTest, TestGradientWithBias) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ EmbedParameter* embed_param = layer_param.mutable_embed_param();
+ embed_param->set_num_output(10);
+ embed_param->set_input_dim(5);
+ embed_param->set_bias_term(true);
+ embed_param->mutable_weight_filler()->set_type("uniform");
+ embed_param->mutable_weight_filler()->set_min(-10);
+ embed_param->mutable_weight_filler()->set_max(10);
+ embed_param->mutable_bias_filler()->CopyFrom(embed_param->weight_filler());
+ EmbedLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-3);
+ this->blob_bottom_->mutable_cpu_data()[0] = 4;
+ this->blob_bottom_->mutable_cpu_data()[1] = 2;
+ this->blob_bottom_->mutable_cpu_data()[2] = 2;
+ this->blob_bottom_->mutable_cpu_data()[3] = 3;
+ checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_, -2);
+}
+
+} // namespace caffe
diff --git a/src/caffe/test/test_tile_layer.cpp b/src/caffe/test/test_tile_layer.cpp
new file mode 100644
index 00000000..540aac3c
--- /dev/null
+++ b/src/caffe/test/test_tile_layer.cpp
@@ -0,0 +1,162 @@
+#include <cstring>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/vision_layers.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+namespace caffe {
+
+template <typename TypeParam>
+class TileLayerTest : public MultiDeviceTest<TypeParam> {
+ typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+ TileLayerTest()
+ : blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
+ blob_top_(new Blob<Dtype>()) {}
+ virtual void SetUp() {
+ blob_bottom_vec_.push_back(blob_bottom_);
+ blob_top_vec_.push_back(blob_top_);
+ FillerParameter filler_param;
+ filler_param.set_mean(0.0);
+ filler_param.set_std(1.0);
+ GaussianFiller<Dtype> filler(filler_param);
+ filler.Fill(blob_bottom_);
+ }
+
+ virtual ~TileLayerTest() {
+ delete blob_bottom_;
+ delete blob_top_;
+ }
+
+ Blob<Dtype>* const blob_bottom_;
+ Blob<Dtype>* const blob_top_;
+ vector<Blob<Dtype>*> blob_bottom_vec_;
+ vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+TYPED_TEST_CASE(TileLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(TileLayerTest, TestTrivialSetup) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ const int kNumTiles = 1;
+ layer_param.mutable_tile_param()->set_tiles(kNumTiles);
+ for (int i = 0; i < this->blob_bottom_->num_axes(); ++i) {
+ layer_param.mutable_tile_param()->set_axis(i);
+ TileLayer<Dtype> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ ASSERT_EQ(this->blob_top_->num_axes(), this->blob_bottom_->num_axes());
+ for (int j = 0; j < this->blob_bottom_->num_axes(); ++j) {
+ EXPECT_EQ(this->blob_top_->shape(j), this->blob_bottom_->shape(j));
+ }
+ }
+}
+
+TYPED_TEST(TileLayerTest, TestSetup) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ const int kNumTiles = 3;
+ layer_param.mutable_tile_param()->set_tiles(kNumTiles);
+ for (int i = 0; i < this->blob_bottom_->num_axes(); ++i) {
+ layer_param.mutable_tile_param()->set_axis(i);
+ TileLayer<Dtype> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ ASSERT_EQ(this->blob_top_->num_axes(), this->blob_bottom_->num_axes());
+ for (int j = 0; j < this->blob_bottom_->num_axes(); ++j) {
+ const int top_dim =
+ ((i == j) ? kNumTiles : 1) * this->blob_bottom_->shape(j);
+ EXPECT_EQ(top_dim, this->blob_top_->shape(j));
+ }
+ }
+}
+
+TYPED_TEST(TileLayerTest, TestForwardNum) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ const int kTileAxis = 0;
+ const int kNumTiles = 3;
+ layer_param.mutable_tile_param()->set_axis(kTileAxis);
+ layer_param.mutable_tile_param()->set_tiles(kNumTiles);
+ TileLayer<Dtype> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ for (int n = 0; n < this->blob_top_->num(); ++n) {
+ for (int c = 0; c < this->blob_top_->channels(); ++c) {
+ for (int h = 0; h < this->blob_top_->height(); ++h) {
+ for (int w = 0; w < this->blob_top_->width(); ++w) {
+ const int bottom_n = n % this->blob_bottom_->num();
+ EXPECT_EQ(this->blob_bottom_->data_at(bottom_n, c, h, w),
+ this->blob_top_->data_at(n, c, h, w));
+ }
+ }
+ }
+ }
+}
+
+TYPED_TEST(TileLayerTest, TestForwardChannels) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ const int kNumTiles = 3;
+ layer_param.mutable_tile_param()->set_tiles(kNumTiles);
+ TileLayer<Dtype> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ for (int n = 0; n < this->blob_top_->num(); ++n) {
+ for (int c = 0; c < this->blob_top_->channels(); ++c) {
+ for (int h = 0; h < this->blob_top_->height(); ++h) {
+ for (int w = 0; w < this->blob_top_->width(); ++w) {
+ const int bottom_c = c % this->blob_bottom_->channels();
+ EXPECT_EQ(this->blob_bottom_->data_at(n, bottom_c, h, w),
+ this->blob_top_->data_at(n, c, h, w));
+ }
+ }
+ }
+ }
+}
+
+TYPED_TEST(TileLayerTest, TestTrivialGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ const int kNumTiles = 1;
+ layer_param.mutable_tile_param()->set_tiles(kNumTiles);
+ TileLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-2);
+ checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+}
+
+TYPED_TEST(TileLayerTest, TestGradientNum) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ const int kTileAxis = 0;
+ const int kNumTiles = 3;
+ layer_param.mutable_tile_param()->set_axis(kTileAxis);
+ layer_param.mutable_tile_param()->set_tiles(kNumTiles);
+ TileLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-2);
+ checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+}
+
+TYPED_TEST(TileLayerTest, TestGradientChannels) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ const int kTileAxis = 1;
+ const int kNumTiles = 3;
+ layer_param.mutable_tile_param()->set_axis(kTileAxis);
+ layer_param.mutable_tile_param()->set_tiles(kNumTiles);
+ TileLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-2);
+ checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+}
+
+} // namespace caffe
diff --git a/src/caffe/util/insert_splits.cpp b/src/caffe/util/insert_splits.cpp
index 416f80ab..475a2a9f 100644
--- a/src/caffe/util/insert_splits.cpp
+++ b/src/caffe/util/insert_splits.cpp
@@ -32,7 +32,8 @@ void InsertSplits(const NetParameter& param, NetParameter* param_split) {
const string& blob_name = layer_param.bottom(j);
if (blob_name_to_last_top_idx.find(blob_name) ==
blob_name_to_last_top_idx.end()) {
- LOG(FATAL) << "Unknown blob input " << blob_name << " to layer " << j;
+ LOG(FATAL) << "Unknown bottom blob '" << blob_name << "' (layer '"
+ << layer_param.name() << "', bottom index " << j << ")";
}
const pair<int, int>& bottom_idx = make_pair(i, j);
const pair<int, int>& top_idx = blob_name_to_last_top_idx[blob_name];