summaryrefslogtreecommitdiff
path: root/src/caffe
diff options
context:
space:
mode:
Diffstat (limited to 'src/caffe')
-rw-r--r--src/caffe/layer_factory.cpp1
-rw-r--r--src/caffe/layers/clip_layer.cpp51
-rw-r--r--src/caffe/layers/clip_layer.cu67
-rw-r--r--src/caffe/layers/embed_layer.cu5
-rw-r--r--src/caffe/layers/hdf5_data_layer.cpp2
-rw-r--r--src/caffe/layers/hdf5_data_layer.cu2
-rw-r--r--src/caffe/layers/hdf5_output_layer.cpp2
-rw-r--r--src/caffe/layers/hdf5_output_layer.cu2
-rw-r--r--src/caffe/layers/inner_product_layer.cpp2
-rw-r--r--src/caffe/layers/pooling_layer.cpp23
-rw-r--r--src/caffe/layers/swish_layer.cpp68
-rw-r--r--src/caffe/layers/swish_layer.cu54
-rw-r--r--src/caffe/net.cpp21
-rw-r--r--src/caffe/proto/caffe.proto31
-rw-r--r--src/caffe/solver.cpp10
-rw-r--r--src/caffe/solvers/sgd_solver.cpp16
-rw-r--r--src/caffe/test/test_filler.cpp447
-rw-r--r--src/caffe/test/test_hdf5_output_layer.cpp2
-rw-r--r--src/caffe/test/test_hdf5data_layer.cpp2
-rw-r--r--src/caffe/test/test_neuron_layer.cpp140
-rw-r--r--src/caffe/test/test_syncedmem.cpp4
-rw-r--r--src/caffe/util/hdf5.cpp2
-rw-r--r--src/caffe/util/signal_handler.cpp2
23 files changed, 845 insertions, 111 deletions
diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp
index 9f9026b1..d9984431 100644
--- a/src/caffe/layer_factory.cpp
+++ b/src/caffe/layer_factory.cpp
@@ -7,6 +7,7 @@
#include "caffe/layer.hpp"
#include "caffe/layer_factory.hpp"
+#include "caffe/layers/clip_layer.hpp"
#include "caffe/layers/conv_layer.hpp"
#include "caffe/layers/deconv_layer.hpp"
#include "caffe/layers/lrn_layer.hpp"
diff --git a/src/caffe/layers/clip_layer.cpp b/src/caffe/layers/clip_layer.cpp
new file mode 100644
index 00000000..9d9a5967
--- /dev/null
+++ b/src/caffe/layers/clip_layer.cpp
@@ -0,0 +1,51 @@
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layers/clip_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* top_data = top[0]->mutable_cpu_data();
+ const int count = bottom[0]->count();
+
+ Dtype min = this->layer_param_.clip_param().min();
+ Dtype max = this->layer_param_.clip_param().max();
+
+ for (int i = 0; i < count; ++i) {
+ top_data[i] = std::max(min, std::min(bottom_data[i], max));
+ }
+}
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ if (propagate_down[0]) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ const Dtype* top_diff = top[0]->cpu_diff();
+ Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+ const int count = bottom[0]->count();
+
+ Dtype min = this->layer_param_.clip_param().min();
+ Dtype max = this->layer_param_.clip_param().max();
+
+ for (int i = 0; i < count; ++i) {
+ bottom_diff[i] = top_diff[i] * (
+ bottom_data[i] >= min && bottom_data[i] <= max);
+ }
+ }
+}
+
+
+#ifdef CPU_ONLY
+STUB_GPU(ClipLayer);
+#endif
+
+INSTANTIATE_CLASS(ClipLayer);
+REGISTER_LAYER_CLASS(Clip);
+
+} // namespace caffe
diff --git a/src/caffe/layers/clip_layer.cu b/src/caffe/layers/clip_layer.cu
new file mode 100644
index 00000000..56f3be32
--- /dev/null
+++ b/src/caffe/layers/clip_layer.cu
@@ -0,0 +1,67 @@
+#include <vector>
+
+#include "caffe/layers/clip_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+__global__ void ClipForward(const int n, const float* in, float* out,
+ float p_min, float p_max) {
+ CUDA_KERNEL_LOOP(index, n) {
+ out[index] = fmaxf(p_min, fminf(in[index], p_max));
+ }
+}
+
+__global__ void ClipForward(const int n, const double* in, double* out,
+ double p_min, double p_max) {
+ CUDA_KERNEL_LOOP(index, n) {
+ out[index] = fmax(p_min, fmin(in[index], p_max));
+ }
+}
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* top_data = top[0]->mutable_gpu_data();
+ const int count = bottom[0]->count();
+ Dtype p_min = this->layer_param_.clip_param().min();
+ Dtype p_max = this->layer_param_.clip_param().max();
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ ClipForward<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, bottom_data, top_data, p_min, p_max);
+ CUDA_POST_KERNEL_CHECK;
+}
+
+template <typename Dtype>
+__global__ void ClipBackward(const int n, const Dtype* in_diff,
+ const Dtype* in_data, Dtype* out_diff, Dtype p_min, Dtype p_max) {
+ CUDA_KERNEL_LOOP(index, n) {
+ out_diff[index] = in_diff[index] * (
+ in_data[index] >= p_min && in_data[index] <= p_max);
+ }
+}
+
+template <typename Dtype>
+void ClipLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ if (propagate_down[0]) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ const Dtype* top_diff = top[0]->gpu_diff();
+ Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+ const int count = bottom[0]->count();
+ Dtype p_min = this->layer_param_.clip_param().min();
+ Dtype p_max = this->layer_param_.clip_param().max();
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ ClipBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, top_diff, bottom_data, bottom_diff, p_min, p_max);
+ CUDA_POST_KERNEL_CHECK;
+ }
+}
+
+
+INSTANTIATE_LAYER_GPU_FUNCS(ClipLayer);
+
+
+} // namespace caffe
diff --git a/src/caffe/layers/embed_layer.cu b/src/caffe/layers/embed_layer.cu
index 6324a3a8..3cf39fd9 100644
--- a/src/caffe/layers/embed_layer.cu
+++ b/src/caffe/layers/embed_layer.cu
@@ -15,6 +15,11 @@ __global__ void EmbedForward(const int nthreads, const Dtype* bottom_data,
const int n = top_index / N;
const int d = top_index % N;
const int index = static_cast<int>(bottom_data[n]);
+ #ifdef DEBUG
+ assert(index >= 0);
+ assert(index < K);
+ assert(static_cast<Dtype>(index) == bottom_data[n]);
+ #endif
const int weight_index = index * N + d;
top_data[top_index] = weight[weight_index];
}
diff --git a/src/caffe/layers/hdf5_data_layer.cpp b/src/caffe/layers/hdf5_data_layer.cpp
index 00716a92..7668854c 100644
--- a/src/caffe/layers/hdf5_data_layer.cpp
+++ b/src/caffe/layers/hdf5_data_layer.cpp
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
/*
TODO:
- load file in a separate thread ("prefetch")
@@ -184,3 +185,4 @@ INSTANTIATE_CLASS(HDF5DataLayer);
REGISTER_LAYER_CLASS(HDF5Data);
} // namespace caffe
+#endif // USE_HDF5
diff --git a/src/caffe/layers/hdf5_data_layer.cu b/src/caffe/layers/hdf5_data_layer.cu
index 33eebd41..70cd9f32 100644
--- a/src/caffe/layers/hdf5_data_layer.cu
+++ b/src/caffe/layers/hdf5_data_layer.cu
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
/*
TODO:
- only load parts of the file, in accordance with a prototxt param "max_mem"
@@ -34,3 +35,4 @@ void HDF5DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
INSTANTIATE_LAYER_GPU_FUNCS(HDF5DataLayer);
} // namespace caffe
+#endif // USE_HDF5
diff --git a/src/caffe/layers/hdf5_output_layer.cpp b/src/caffe/layers/hdf5_output_layer.cpp
index f8f1edcd..28c453a2 100644
--- a/src/caffe/layers/hdf5_output_layer.cpp
+++ b/src/caffe/layers/hdf5_output_layer.cpp
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
#include <vector>
#include "hdf5.h"
@@ -72,3 +73,4 @@ INSTANTIATE_CLASS(HDF5OutputLayer);
REGISTER_LAYER_CLASS(HDF5Output);
} // namespace caffe
+#endif // USE_HDF5
diff --git a/src/caffe/layers/hdf5_output_layer.cu b/src/caffe/layers/hdf5_output_layer.cu
index c1685cd3..891aea03 100644
--- a/src/caffe/layers/hdf5_output_layer.cu
+++ b/src/caffe/layers/hdf5_output_layer.cu
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
#include <vector>
#include "hdf5.h"
@@ -37,3 +38,4 @@ void HDF5OutputLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
INSTANTIATE_LAYER_GPU_FUNCS(HDF5OutputLayer);
} // namespace caffe
+#endif // USE_HDF5
diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp
index e65349f0..57fdbe1f 100644
--- a/src/caffe/layers/inner_product_layer.cpp
+++ b/src/caffe/layers/inner_product_layer.cpp
@@ -42,7 +42,7 @@ void InnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
this->layer_param_.inner_product_param().weight_filler()));
weight_filler->Fill(this->blobs_[0].get());
- // If necessary, intiialize and fill the bias term
+ // If necessary, initialize and fill the bias term
if (bias_term_) {
vector<int> bias_shape(1, N_);
this->blobs_[1].reset(new Blob<Dtype>(bias_shape));
diff --git a/src/caffe/layers/pooling_layer.cpp b/src/caffe/layers/pooling_layer.cpp
index 90897db0..32dc0482 100644
--- a/src/caffe/layers/pooling_layer.cpp
+++ b/src/caffe/layers/pooling_layer.cpp
@@ -35,6 +35,7 @@ void PoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
|| (!pool_param.has_stride_h() && !pool_param.has_stride_w()))
<< "Stride is stride OR stride_h and stride_w are required.";
global_pooling_ = pool_param.global_pooling();
+ round_mode_ = pool_param.round_mode();
if (global_pooling_) {
kernel_h_ = bottom[0]->height();
kernel_w_ = bottom[0]->width();
@@ -87,10 +88,22 @@ void PoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
kernel_h_ = bottom[0]->height();
kernel_w_ = bottom[0]->width();
}
- pooled_height_ = static_cast<int>(ceil(static_cast<float>(
- height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
- pooled_width_ = static_cast<int>(ceil(static_cast<float>(
- width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
+ switch (round_mode_) {
+ case PoolingParameter_RoundMode_CEIL:
+ pooled_height_ = static_cast<int>(ceil(static_cast<float>(
+ height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
+ pooled_width_ = static_cast<int>(ceil(static_cast<float>(
+ width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
+ break;
+ case PoolingParameter_RoundMode_FLOOR:
+ pooled_height_ = static_cast<int>(floor(static_cast<float>(
+ height_ + 2 * pad_h_ - kernel_h_) / stride_h_)) + 1;
+ pooled_width_ = static_cast<int>(floor(static_cast<float>(
+ width_ + 2 * pad_w_ - kernel_w_) / stride_w_)) + 1;
+ break;
+ default:
+ LOG(FATAL) << "Unknown rounding mode.";
+ }
if (pad_h_ || pad_w_) {
// If we have padding, ensure that the last pooling starts strictly
// inside the image (instead of at the padding); otherwise clip the last.
@@ -132,7 +145,7 @@ void PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const int top_count = top[0]->count();
// We'll output the mask to top[1] if it's of size >1.
const bool use_top_mask = top.size() > 1;
- int* mask = NULL; // suppress warnings about uninitalized variables
+ int* mask = NULL; // suppress warnings about uninitialized variables
Dtype* top_mask = NULL;
// Different pooling methods. We explicitly do the switch outside the for
// loop to save time, although this results in more code.
diff --git a/src/caffe/layers/swish_layer.cpp b/src/caffe/layers/swish_layer.cpp
new file mode 100644
index 00000000..28935679
--- /dev/null
+++ b/src/caffe/layers/swish_layer.cpp
@@ -0,0 +1,68 @@
+#include <cmath>
+#include <vector>
+
+#include "caffe/layers/swish_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void SwishLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ NeuronLayer<Dtype>::LayerSetUp(bottom, top);
+ sigmoid_bottom_vec_.clear();
+ sigmoid_bottom_vec_.push_back(sigmoid_input_.get());
+ sigmoid_top_vec_.clear();
+ sigmoid_top_vec_.push_back(sigmoid_output_.get());
+ sigmoid_layer_->SetUp(sigmoid_bottom_vec_, sigmoid_top_vec_);
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ NeuronLayer<Dtype>::Reshape(bottom, top);
+ sigmoid_input_->ReshapeLike(*bottom[0]);
+ sigmoid_layer_->Reshape(sigmoid_bottom_vec_, sigmoid_top_vec_);
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* sigmoid_input_data = sigmoid_input_->mutable_cpu_data();
+ Dtype* top_data = top[0]->mutable_cpu_data();
+ const int count = bottom[0]->count();
+ Dtype beta = this->layer_param_.swish_param().beta();
+ caffe_copy(count, bottom_data, sigmoid_input_data);
+ caffe_scal(count, beta, sigmoid_input_data);
+ sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
+ caffe_mul(count, bottom_data, sigmoid_output_->cpu_data(), top_data);
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ if (propagate_down[0]) {
+ const Dtype* top_data = top[0]->cpu_data();
+ const Dtype* top_diff = top[0]->cpu_diff();
+ const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data();
+ Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+ const int count = bottom[0]->count();
+ Dtype beta = this->layer_param_.swish_param().beta();
+ for (int i = 0; i < count; ++i) {
+ const Dtype swish_x = top_data[i];
+ bottom_diff[i] = top_diff[i] * (beta * swish_x + sigmoid_output_data[i]
+ * (1. - beta * swish_x));
+ }
+ }
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(SwishLayer);
+#endif
+
+INSTANTIATE_CLASS(SwishLayer);
+REGISTER_LAYER_CLASS(Swish);
+
+} // namespace caffe
diff --git a/src/caffe/layers/swish_layer.cu b/src/caffe/layers/swish_layer.cu
new file mode 100644
index 00000000..c4fef53b
--- /dev/null
+++ b/src/caffe/layers/swish_layer.cu
@@ -0,0 +1,54 @@
+#include <cmath>
+#include <vector>
+
+#include "caffe/layers/swish_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* sigmoid_input_data = sigmoid_input_->mutable_gpu_data();
+ Dtype* top_data = top[0]->mutable_gpu_data();
+ const int count = bottom[0]->count();
+ Dtype beta = this->layer_param_.swish_param().beta();
+ caffe_copy(count, bottom_data, sigmoid_input_data);
+ caffe_gpu_scal(count, beta, sigmoid_input_data);
+ sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
+ caffe_gpu_mul(count, bottom_data, sigmoid_output_->gpu_data(), top_data);
+}
+
+template <typename Dtype>
+__global__ void SwishBackward(const int n, const Dtype* in_diff,
+ const Dtype* out_data, const Dtype* sigmoid_output_data, Dtype* out_diff,
+ const Dtype beta) {
+ CUDA_KERNEL_LOOP(index, n) {
+ const Dtype swish_x = out_data[index];
+ out_diff[index] = in_diff[index] * (beta * swish_x
+ + sigmoid_output_data[index] * (1 - beta * swish_x));
+ }
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ if (propagate_down[0]) {
+ const Dtype* top_data = top[0]->gpu_data();
+ const Dtype* top_diff = top[0]->gpu_diff();
+ const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data();
+ Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+ const int count = bottom[0]->count();
+ Dtype beta = this->layer_param_.swish_param().beta();
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ SwishBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, top_diff, top_data, sigmoid_output_data, bottom_diff, beta);
+ CUDA_POST_KERNEL_CHECK;
+ }
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(SwishLayer);
+
+} // namespace caffe
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index 353c2f95..5e844b03 100644
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
@@ -5,7 +5,9 @@
#include <utility>
#include <vector>
+#ifdef USE_HDF5
#include "hdf5.h"
+#endif // USE_HDF5
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
@@ -164,7 +166,7 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
// loss. We can skip backward computation for blobs that don't contribute
// to the loss.
// Also checks if all bottom blobs don't need backward computation (possible
- // because the skip_propagate_down param) and so we can skip bacward
+ // because the skip_propagate_down param) and so we can skip backward
// computation for the entire layer
set<string> blobs_under_loss;
set<string> blobs_skip_backp;
@@ -768,7 +770,7 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
}
template <typename Dtype>
-void Net<Dtype>::CopyTrainedLayersFrom(const string trained_filename) {
+void Net<Dtype>::CopyTrainedLayersFrom(const string& trained_filename) {
if (H5Fis_hdf5(trained_filename.c_str())) {
CopyTrainedLayersFromHDF5(trained_filename);
} else {
@@ -778,14 +780,15 @@ void Net<Dtype>::CopyTrainedLayersFrom(const string trained_filename) {
template <typename Dtype>
void Net<Dtype>::CopyTrainedLayersFromBinaryProto(
- const string trained_filename) {
+ const string& trained_filename) {
NetParameter param;
ReadNetParamsFromBinaryFileOrDie(trained_filename, &param);
CopyTrainedLayersFrom(param);
}
template <typename Dtype>
-void Net<Dtype>::CopyTrainedLayersFromHDF5(const string trained_filename) {
+void Net<Dtype>::CopyTrainedLayersFromHDF5(const string& trained_filename) {
+#ifdef USE_HDF5
hid_t file_hid = H5Fopen(trained_filename.c_str(), H5F_ACC_RDONLY,
H5P_DEFAULT);
CHECK_GE(file_hid, 0) << "Couldn't open " << trained_filename;
@@ -832,6 +835,10 @@ void Net<Dtype>::CopyTrainedLayersFromHDF5(const string trained_filename) {
}
H5Gclose(data_hid);
H5Fclose(file_hid);
+#else
+ LOG(FATAL) << "CopyTrainedLayersFromHDF5 requires hdf5;"
+ << " compile with USE_HDF5.";
+#endif // USE_HDF5
}
template <typename Dtype>
@@ -848,6 +855,8 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {
template <typename Dtype>
void Net<Dtype>::ToHDF5(const string& filename, bool write_diff) const {
+// This code is taken from https://github.com/sh1r0/caffe-android-lib
+#ifdef USE_HDF5
hid_t file_hid = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
H5P_DEFAULT);
CHECK_GE(file_hid, 0)
@@ -901,6 +910,10 @@ void Net<Dtype>::ToHDF5(const string& filename, bool write_diff) const {
H5Gclose(diff_hid);
}
H5Fclose(file_hid);
+// This code is taken from https://github.com/sh1r0/caffe-android-lib
+#else
+ LOG(FATAL) << "ToHDF5 requires hdf5; compile with USE_HDF5.";
+#endif // USE_HDF5
}
template <typename Dtype>
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index 22764abc..3dcad697 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -187,7 +187,7 @@ message SolverParameter {
optional int32 snapshot = 14 [default = 0]; // The snapshot interval
// The prefix for the snapshot.
- // If not set then is replaced by prototxt file path without extention.
+ // If not set then is replaced by prototxt file path without extension.
// If is set to directory then is augmented by prototxt file name
// without extention.
optional string snapshot_prefix = 15;
@@ -248,8 +248,8 @@ message SolverParameter {
// Path to caffemodel file(s) with pretrained weights to initialize finetuning.
// Tha same as command line --weights parameter for caffe train command.
- // If command line --weights parameter if specified, it has higher priority
- // and owerwrites this one(s).
+ // If command line --weights parameter is specified, it has higher priority
+ // and overwrites this one(s).
// If --snapshot command line parameter is specified, this one(s) are ignored.
// If several model files are expected, they can be listed in a one
// weights parameter separated by ',' (like in a command string) or
@@ -322,7 +322,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
-// LayerParameter next available layer-specific ID: 147 (last added: recurrent_param)
+// LayerParameter next available layer-specific ID: 149 (last added: clip_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
@@ -378,6 +378,7 @@ message LayerParameter {
optional ArgMaxParameter argmax_param = 103;
optional BatchNormParameter batch_norm_param = 139;
optional BiasParameter bias_param = 141;
+ optional ClipParameter clip_param = 148;
optional ConcatParameter concat_param = 104;
optional ContrastiveLossParameter contrastive_loss_param = 105;
optional ConvolutionParameter convolution_param = 106;
@@ -415,6 +416,7 @@ message LayerParameter {
optional SoftmaxParameter softmax_param = 125;
optional SPPParameter spp_param = 132;
optional SliceParameter slice_param = 126;
+ optional SwishParameter swish_param = 147;
optional TanHParameter tanh_param = 127;
optional ThresholdParameter threshold_param = 128;
optional TileParameter tile_param = 138;
@@ -504,6 +506,12 @@ message ArgMaxParameter {
optional int32 axis = 3;
}
+// Message that stores parameters used by ClipLayer
+message ClipParameter {
+ required float min = 1;
+ required float max = 2;
+}
+
message ConcatParameter {
// The axis along which to concatenate -- may be negative to index from the
// end (e.g., -1 for the last axis). Other axes must have the
@@ -935,6 +943,12 @@ message PoolingParameter {
// If global_pooling then it will pool over the size of the bottom by doing
// kernel_h = bottom->height and kernel_w = bottom->width
optional bool global_pooling = 12 [default = false];
+ // How to calculate the output size - using ceil (default) or floor rounding.
+ enum RoundMode {
+ CEIL = 0;
+ FLOOR = 1;
+ }
+ optional RoundMode round_mode = 13 [default = CEIL];
}
message PowerParameter {
@@ -1156,6 +1170,15 @@ message SoftmaxParameter {
optional int32 axis = 2 [default = 1];
}
+// Message that stores parameters used by SwishLayer
+message SwishParameter {
+ // Beta parameter for the Swish activation function
+ // Described in:
+ // Prajit Ramachandran, Barret Zoph, Quoc V. Le. (2017). Searching for
+ // Activation Functions. https://arxiv.org/abs/1710.05941v2
+ optional float beta = 1 [default = 1];
+}
+
message TanHParameter {
enum Engine {
DEFAULT = 0;
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index d229acff..842312e0 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -78,7 +78,7 @@ template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
const int num_train_nets = param_.has_net() + param_.has_net_param() +
param_.has_train_net() + param_.has_train_net_param();
- const string& field_names = "net, net_param, train_net, train_net_param";
+ const string field_names = "net, net_param, train_net, train_net_param";
CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
<< "using one of these fields: " << field_names;
CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
@@ -266,10 +266,6 @@ void Solver<Dtype>::Step(int iters) {
}
ApplyUpdate();
- // Increment the internal iter_ counter -- its value should always indicate
- // the number of times the weights have been updated.
- ++iter_;
-
SolverAction::Enum request = GetRequestedAction();
// Save a snapshot if needed.
@@ -451,13 +447,13 @@ void Solver<Dtype>::CheckSnapshotWritePermissions() {
} else {
LOG(FATAL) << "Cannot write to snapshot prefix '"
<< param_.snapshot_prefix() << "'. Make sure "
- << "that the directory exists and is writeable.";
+ << "that the directory exists and is writable.";
}
}
}
template <typename Dtype>
-string Solver<Dtype>::SnapshotFilename(const string extension) {
+string Solver<Dtype>::SnapshotFilename(const string& extension) {
return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
+ extension;
}
diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp
index 1d52beb0..081c47eb 100644
--- a/src/caffe/solvers/sgd_solver.cpp
+++ b/src/caffe/solvers/sgd_solver.cpp
@@ -120,6 +120,10 @@ void SGDSolver<Dtype>::ApplyUpdate() {
ComputeUpdateValue(param_id, rate);
}
this->net_->Update();
+
+ // Increment the internal iter_ counter -- its value should always indicate
+ // the number of times the weights have been updated.
+ ++this->iter_;
}
template <typename Dtype>
@@ -285,6 +289,8 @@ void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(
template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverStateToHDF5(
const string& model_filename) {
+// This code is taken from https://github.com/sh1r0/caffe-android-lib
+#ifdef USE_HDF5
string snapshot_filename =
Solver<Dtype>::SnapshotFilename(".solverstate.h5");
LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename;
@@ -306,6 +312,11 @@ void SGDSolver<Dtype>::SnapshotSolverStateToHDF5(
}
H5Gclose(history_hid);
H5Fclose(file_hid);
+// This code is taken from https://github.com/sh1r0/caffe-android-lib
+#else
+ LOG(FATAL) << "SnapshotSolverStateToHDF5 requires hdf5;"
+ << " compile with USE_HDF5.";
+#endif // USE_HDF5
}
template <typename Dtype>
@@ -330,6 +341,7 @@ void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
template <typename Dtype>
void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
+#ifdef USE_HDF5
hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
this->iter_ = hdf5_load_int(file_hid, "iter");
@@ -351,6 +363,10 @@ void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
}
H5Gclose(history_hid);
H5Fclose(file_hid);
+#else
+ LOG(FATAL) << "RestoreSolverStateFromHDF5 requires hdf5;"
+ << " compile with USE_HDF5.";
+#endif // USE_HDF5
}
INSTANTIATE_CLASS(SGDSolver);
diff --git a/src/caffe/test/test_filler.cpp b/src/caffe/test/test_filler.cpp
index f84d707b..34f7007d 100644
--- a/src/caffe/test/test_filler.cpp
+++ b/src/caffe/test/test_filler.cpp
@@ -1,3 +1,5 @@
+#include <vector>
+
#include "gtest/gtest.h"
#include "caffe/filler.hpp"
@@ -10,11 +12,20 @@ template <typename Dtype>
class ConstantFillerTest : public ::testing::Test {
protected:
ConstantFillerTest()
- : blob_(new Blob<Dtype>(2, 3, 4, 5)),
+ : blob_(new Blob<Dtype>()),
filler_param_() {
filler_param_.set_value(10.);
filler_.reset(new ConstantFiller<Dtype>(filler_param_));
+ }
+ virtual void test_params(const vector<int>& shape) {
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
filler_->Fill(blob_);
+ const int count = blob_->count();
+ const Dtype* data = blob_->cpu_data();
+ for (int i = 0; i < count; ++i) {
+ EXPECT_EQ(data[i], filler_param_.value());
+ }
}
virtual ~ConstantFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
@@ -25,12 +36,34 @@ class ConstantFillerTest : public ::testing::Test {
TYPED_TEST_CASE(ConstantFillerTest, TestDtypes);
TYPED_TEST(ConstantFillerTest, TestFill) {
- EXPECT_TRUE(this->blob_);
- const int count = this->blob_->count();
- const TypeParam* data = this->blob_->cpu_data();
- for (int i = 0; i < count; ++i) {
- EXPECT_EQ(data[i], this->filler_param_.value());
- }
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(ConstantFillerTest, TestFill1D) {
+ vector<int> blob_shape(1, 15);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(ConstantFillerTest, TestFill2D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(8);
+ blob_shape.push_back(3);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(ConstantFillerTest, TestFill5D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ blob_shape.push_back(2);
+ this->test_params(blob_shape);
}
@@ -38,12 +71,22 @@ template <typename Dtype>
class UniformFillerTest : public ::testing::Test {
protected:
UniformFillerTest()
- : blob_(new Blob<Dtype>(2, 3, 4, 5)),
+ : blob_(new Blob<Dtype>()),
filler_param_() {
filler_param_.set_min(1.);
filler_param_.set_max(2.);
filler_.reset(new UniformFiller<Dtype>(filler_param_));
+ }
+ virtual void test_params(const vector<int>& shape) {
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
filler_->Fill(blob_);
+ const int count = blob_->count();
+ const Dtype* data = blob_->cpu_data();
+ for (int i = 0; i < count; ++i) {
+ EXPECT_GE(data[i], filler_param_.min());
+ EXPECT_LE(data[i], filler_param_.max());
+ }
}
virtual ~UniformFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
@@ -54,23 +97,64 @@ class UniformFillerTest : public ::testing::Test {
TYPED_TEST_CASE(UniformFillerTest, TestDtypes);
TYPED_TEST(UniformFillerTest, TestFill) {
- EXPECT_TRUE(this->blob_);
- const int count = this->blob_->count();
- const TypeParam* data = this->blob_->cpu_data();
- for (int i = 0; i < count; ++i) {
- EXPECT_GE(data[i], this->filler_param_.min());
- EXPECT_LE(data[i], this->filler_param_.max());
- }
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(UniformFillerTest, TestFill1D) {
+ vector<int> blob_shape(1, 15);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(UniformFillerTest, TestFill2D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(8);
+ blob_shape.push_back(3);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(UniformFillerTest, TestFill5D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ blob_shape.push_back(2);
+ this->test_params(blob_shape);
}
template <typename Dtype>
class PositiveUnitballFillerTest : public ::testing::Test {
protected:
PositiveUnitballFillerTest()
- : blob_(new Blob<Dtype>(2, 3, 4, 5)),
+ : blob_(new Blob<Dtype>()),
filler_param_() {
filler_.reset(new PositiveUnitballFiller<Dtype>(filler_param_));
+ }
+ virtual void test_params(const vector<int>& shape) {
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
filler_->Fill(blob_);
+ const int num = blob_->shape(0);
+ const int count = blob_->count();
+ const int dim = count / num;
+ const Dtype* data = blob_->cpu_data();
+ for (int i = 0; i < count; ++i) {
+ EXPECT_GE(data[i], 0);
+ EXPECT_LE(data[i], 1);
+ }
+ for (int i = 0; i < num; ++i) {
+ Dtype sum = Dtype(0);
+ for (int j = 0; j < dim; ++j) {
+ sum += data[i * dim + j];
+ }
+ EXPECT_GE(sum, 0.999);
+ EXPECT_LE(sum, 1.001);
+ }
}
virtual ~PositiveUnitballFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
@@ -81,35 +165,78 @@ class PositiveUnitballFillerTest : public ::testing::Test {
TYPED_TEST_CASE(PositiveUnitballFillerTest, TestDtypes);
TYPED_TEST(PositiveUnitballFillerTest, TestFill) {
- EXPECT_TRUE(this->blob_);
- const int num = this->blob_->num();
- const int count = this->blob_->count();
- const int dim = count / num;
- const TypeParam* data = this->blob_->cpu_data();
- for (int i = 0; i < count; ++i) {
- EXPECT_GE(data[i], 0);
- EXPECT_LE(data[i], 1);
- }
- for (int i = 0; i < num; ++i) {
- TypeParam sum = 0;
- for (int j = 0; j < dim; ++j) {
- sum += data[i * dim + j];
- }
- EXPECT_GE(sum, 0.999);
- EXPECT_LE(sum, 1.001);
- }
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(PositiveUnitballFillerTest, TestFill1D) {
+ vector<int> blob_shape(1, 15);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(PositiveUnitballFillerTest, TestFill2D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(8);
+ blob_shape.push_back(3);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(PositiveUnitballFillerTest, TestFill5D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ blob_shape.push_back(2);
+ this->test_params(blob_shape);
}
template <typename Dtype>
class GaussianFillerTest : public ::testing::Test {
protected:
GaussianFillerTest()
- : blob_(new Blob<Dtype>(2, 3, 4, 5)),
+ : blob_(new Blob<Dtype>()),
filler_param_() {
filler_param_.set_mean(10.);
filler_param_.set_std(0.1);
filler_.reset(new GaussianFiller<Dtype>(filler_param_));
+ }
+ virtual void test_params(const vector<int>& shape,
+ const Dtype tolerance = Dtype(5), const int repetitions = 100) {
+ // Tests for statistical properties should be ran multiple times.
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
+ for (int i = 0; i < repetitions; ++i) {
+ test_params_iter(shape, tolerance);
+ }
+ }
+ virtual void test_params_iter(const vector<int>& shape,
+ const Dtype tolerance) {
+ // This test has a configurable tolerance parameter - by default it was
+ // equal to 5.0 which is very loose - allowing some tuning (e.g. for tests
+ // on smaller blobs the actual variance will be larger than desired, so the
+ // tolerance can be increased to account for that).
filler_->Fill(blob_);
+ const int count = blob_->count();
+ const Dtype* data = blob_->cpu_data();
+ Dtype mean = Dtype(0);
+ Dtype var = Dtype(0);
+ for (int i = 0; i < count; ++i) {
+ mean += data[i];
+ var += data[i] * data[i];
+ }
+ mean /= count;
+ var /= count;
+ var -= mean*mean;
+ EXPECT_GE(mean, filler_param_.mean() - filler_param_.std() * tolerance);
+ EXPECT_LE(mean, filler_param_.mean() + filler_param_.std() * tolerance);
+ Dtype target_var = filler_param_.std() * filler_param_.std();
+ EXPECT_GE(var, target_var / tolerance);
+ EXPECT_LE(var, target_var * tolerance);
}
virtual ~GaussianFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
@@ -120,41 +247,62 @@ class GaussianFillerTest : public ::testing::Test {
TYPED_TEST_CASE(GaussianFillerTest, TestDtypes);
TYPED_TEST(GaussianFillerTest, TestFill) {
- EXPECT_TRUE(this->blob_);
- const int count = this->blob_->count();
- const TypeParam* data = this->blob_->cpu_data();
- TypeParam mean = 0.;
- TypeParam var = 0.;
- for (int i = 0; i < count; ++i) {
- mean += data[i];
- var += (data[i] - this->filler_param_.mean()) *
- (data[i] - this->filler_param_.mean());
- }
- mean /= count;
- var /= count;
- // Very loose test.
- EXPECT_GE(mean, this->filler_param_.mean() - this->filler_param_.std() * 5);
- EXPECT_LE(mean, this->filler_param_.mean() + this->filler_param_.std() * 5);
- TypeParam target_var = this->filler_param_.std() * this->filler_param_.std();
- EXPECT_GE(var, target_var / 5.);
- EXPECT_LE(var, target_var * 5.);
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ const TypeParam tolerance = TypeParam(3); // enough for a 120-element blob
+ this->test_params(blob_shape, tolerance);
+}
+
+TYPED_TEST(GaussianFillerTest, TestFill1D) {
+ vector<int> blob_shape(1, 125);
+ const TypeParam tolerance = TypeParam(3);
+ this->test_params(blob_shape, tolerance);
+}
+
+TYPED_TEST(GaussianFillerTest, TestFill2D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(8);
+ blob_shape.push_back(15);
+ const TypeParam tolerance = TypeParam(3);
+ this->test_params(blob_shape, tolerance);
+}
+
+TYPED_TEST(GaussianFillerTest, TestFill5D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ blob_shape.push_back(2);
+ const TypeParam tolerance = TypeParam(2);
+ this->test_params(blob_shape, tolerance);
}
template <typename Dtype>
class XavierFillerTest : public ::testing::Test {
protected:
XavierFillerTest()
- : blob_(new Blob<Dtype>(1000, 2, 4, 5)),
+ : blob_(new Blob<Dtype>()),
filler_param_() {
}
virtual void test_params(FillerParameter_VarianceNorm variance_norm,
+ Dtype n, const vector<int>& shape, const int repetitions = 100) {
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
+ for (int i = 0; i < repetitions; ++i) {
+ test_params_iter(variance_norm, n);
+ }
+ }
+ virtual void test_params_iter(FillerParameter_VarianceNorm variance_norm,
Dtype n) {
- this->filler_param_.set_variance_norm(variance_norm);
- this->filler_.reset(new XavierFiller<Dtype>(this->filler_param_));
- this->filler_->Fill(blob_);
- EXPECT_TRUE(this->blob_);
- const int count = this->blob_->count();
- const Dtype* data = this->blob_->cpu_data();
+ filler_param_.set_variance_norm(variance_norm);
+ filler_.reset(new XavierFiller<Dtype>(filler_param_));
+ filler_->Fill(blob_);
+ const int count = blob_->count();
+ const Dtype* data = blob_->cpu_data();
Dtype mean = 0.;
Dtype ex2 = 0.;
for (int i = 0; i < count; ++i) {
@@ -177,33 +325,92 @@ class XavierFillerTest : public ::testing::Test {
TYPED_TEST_CASE(XavierFillerTest, TestDtypes);
TYPED_TEST(XavierFillerTest, TestFillFanIn) {
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
TypeParam n = 2*4*5;
- this->test_params(FillerParameter_VarianceNorm_FAN_IN, n);
+ this->test_params(FillerParameter_VarianceNorm_FAN_IN, n, blob_shape);
}
+
TYPED_TEST(XavierFillerTest, TestFillFanOut) {
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
TypeParam n = 1000*4*5;
- this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n);
+ this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n, blob_shape);
}
+
TYPED_TEST(XavierFillerTest, TestFillAverage) {
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
TypeParam n = (2*4*5 + 1000*4*5) / 2.0;
- this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
+ this->test_params(FillerParameter_VarianceNorm_AVERAGE, n, blob_shape);
+}
+
+TYPED_TEST(XavierFillerTest, TestFill1D) {
+ // This makes little sense but at least we will know that we can fill it
+ EXPECT_TRUE(this->blob_);
+ vector<int> blob_shape(1, 25);
+ this->blob_->Reshape(blob_shape);
+ this->filler_param_.set_variance_norm(FillerParameter_VarianceNorm_AVERAGE);
+ this->filler_.reset(new XavierFiller<TypeParam>(this->filler_param_));
+ this->filler_->Fill(this->blob_);
+}
+
+TYPED_TEST(XavierFillerTest, TestFill2D) {
+ EXPECT_TRUE(this->blob_);
+ vector<int> blob_shape;
+ blob_shape.push_back(8);
+ blob_shape.push_back(3);
+ this->blob_->Reshape(blob_shape);
+ this->filler_param_.set_variance_norm(FillerParameter_VarianceNorm_AVERAGE);
+ this->filler_.reset(new XavierFiller<TypeParam>(this->filler_param_));
+ this->filler_->Fill(this->blob_);
+}
+
+TYPED_TEST(XavierFillerTest, TestFill5D) {
+ EXPECT_TRUE(this->blob_);
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ blob_shape.push_back(2);
+ this->blob_->Reshape(blob_shape);
+ this->filler_param_.set_variance_norm(FillerParameter_VarianceNorm_AVERAGE);
+ this->filler_.reset(new XavierFiller<TypeParam>(this->filler_param_));
+ this->filler_->Fill(this->blob_);
}
template <typename Dtype>
class MSRAFillerTest : public ::testing::Test {
protected:
MSRAFillerTest()
- : blob_(new Blob<Dtype>(1000, 2, 4, 5)),
+ : blob_(new Blob<Dtype>()),
filler_param_() {
}
virtual void test_params(FillerParameter_VarianceNorm variance_norm,
+ Dtype n, const vector<int>& shape, const int repetitions = 100) {
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
+ for (int i = 0; i < repetitions; ++i) {
+ test_params_iter(variance_norm, n);
+ }
+ }
+ virtual void test_params_iter(FillerParameter_VarianceNorm variance_norm,
Dtype n) {
- this->filler_param_.set_variance_norm(variance_norm);
- this->filler_.reset(new MSRAFiller<Dtype>(this->filler_param_));
- this->filler_->Fill(blob_);
- EXPECT_TRUE(this->blob_);
- const int count = this->blob_->count();
- const Dtype* data = this->blob_->cpu_data();
+ filler_param_.set_variance_norm(variance_norm);
+ filler_.reset(new MSRAFiller<Dtype>(filler_param_));
+ filler_->Fill(blob_);
+ const int count = blob_->count();
+ const Dtype* data = blob_->cpu_data();
Dtype mean = 0.;
Dtype ex2 = 0.;
for (int i = 0; i < count; ++i) {
@@ -226,36 +433,92 @@ class MSRAFillerTest : public ::testing::Test {
TYPED_TEST_CASE(MSRAFillerTest, TestDtypes);
TYPED_TEST(MSRAFillerTest, TestFillFanIn) {
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
TypeParam n = 2*4*5;
- this->test_params(FillerParameter_VarianceNorm_FAN_IN, n);
+ this->test_params(FillerParameter_VarianceNorm_FAN_IN, n, blob_shape);
}
+
TYPED_TEST(MSRAFillerTest, TestFillFanOut) {
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
TypeParam n = 1000*4*5;
- this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n);
+ this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n, blob_shape);
}
+
TYPED_TEST(MSRAFillerTest, TestFillAverage) {
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
TypeParam n = (2*4*5 + 1000*4*5) / 2.0;
- this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
+ this->test_params(FillerParameter_VarianceNorm_AVERAGE, n, blob_shape);
+}
+
+TYPED_TEST(MSRAFillerTest, TestFill1D) {
+ // Like with Xavier - no checking for correctness, just if it can be filled.
+ EXPECT_TRUE(this->blob_);
+ vector<int> blob_shape(1, 25);
+ this->blob_->Reshape(blob_shape);
+ this->filler_param_.set_variance_norm(FillerParameter_VarianceNorm_AVERAGE);
+ this->filler_.reset(new MSRAFiller<TypeParam>(this->filler_param_));
+ this->filler_->Fill(this->blob_);
+}
+
+TYPED_TEST(MSRAFillerTest, TestFill2D) {
+ EXPECT_TRUE(this->blob_);
+ vector<int> blob_shape;
+ blob_shape.push_back(8);
+ blob_shape.push_back(3);
+ this->blob_->Reshape(blob_shape);
+ this->filler_param_.set_variance_norm(FillerParameter_VarianceNorm_AVERAGE);
+ this->filler_.reset(new MSRAFiller<TypeParam>(this->filler_param_));
+ this->filler_->Fill(this->blob_);
+}
+
+TYPED_TEST(MSRAFillerTest, TestFill5D) {
+ EXPECT_TRUE(this->blob_);
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ blob_shape.push_back(2);
+ this->blob_->Reshape(blob_shape);
+ this->filler_param_.set_variance_norm(FillerParameter_VarianceNorm_AVERAGE);
+ this->filler_.reset(new MSRAFiller<TypeParam>(this->filler_param_));
+ this->filler_->Fill(this->blob_);
}
template <typename Dtype>
class BilinearFillerTest : public ::testing::Test {
protected:
- BilinearFillerTest() : filler_param_() {}
- virtual void test_params(const int n) {
- this->blob_ = new Blob<Dtype>(1000, 2, n, n);
- this->filler_.reset(new BilinearFiller<Dtype>(this->filler_param_));
- this->filler_->Fill(blob_);
- EXPECT_TRUE(this->blob_);
- const int outer_num = this->blob_->count(0, 2);
- const int inner_num = this->blob_->count(2, 4);
- const Dtype* data = this->blob_->cpu_data();
- int f = ceil(this->blob_->width() / 2.);
- Dtype c = (this->blob_->width() - 1) / (2. * f);
+ BilinearFillerTest()
+ : blob_(new Blob<Dtype>()),
+ filler_param_() {
+ }
+ virtual void test_params(const vector<int>& shape) {
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
+ filler_.reset(new BilinearFiller<Dtype>(filler_param_));
+ filler_->Fill(blob_);
+ CHECK_EQ(blob_->num_axes(), 4);
+ const int outer_num = blob_->count(0, 2);
+ const int inner_num = blob_->count(2, 4);
+ const Dtype* data = blob_->cpu_data();
+ int f = ceil(blob_->shape(3) / 2.);
+ Dtype c = (blob_->shape(3) - 1) / (2. * f);
for (int i = 0; i < outer_num; ++i) {
for (int j = 0; j < inner_num; ++j) {
- Dtype x = j % this->blob_->width();
- Dtype y = (j / this->blob_->width()) % this->blob_->height();
+ Dtype x = j % blob_->shape(3);
+ Dtype y = (j / blob_->shape(3)) % blob_->shape(2);
Dtype expected_value = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
const Dtype actual_value = data[i * inner_num + j];
EXPECT_NEAR(expected_value, actual_value, 0.01);
@@ -272,11 +535,21 @@ TYPED_TEST_CASE(BilinearFillerTest, TestDtypes);
TYPED_TEST(BilinearFillerTest, TestFillOdd) {
const int n = 7;
- this->test_params(n);
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(n);
+ blob_shape.push_back(n);
+ this->test_params(blob_shape);
}
TYPED_TEST(BilinearFillerTest, TestFillEven) {
const int n = 6;
- this->test_params(n);
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(n);
+ blob_shape.push_back(n);
+ this->test_params(blob_shape);
}
} // namespace caffe
diff --git a/src/caffe/test/test_hdf5_output_layer.cpp b/src/caffe/test/test_hdf5_output_layer.cpp
index f94dd57e..11d52310 100644
--- a/src/caffe/test/test_hdf5_output_layer.cpp
+++ b/src/caffe/test/test_hdf5_output_layer.cpp
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
#include <string>
#include <vector>
@@ -120,3 +121,4 @@ TYPED_TEST(HDF5OutputLayerTest, TestForward) {
}
} // namespace caffe
+#endif // USE_HDF5
diff --git a/src/caffe/test/test_hdf5data_layer.cpp b/src/caffe/test/test_hdf5data_layer.cpp
index 3977c486..0e5c398f 100644
--- a/src/caffe/test/test_hdf5data_layer.cpp
+++ b/src/caffe/test/test_hdf5data_layer.cpp
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
#include <string>
#include <vector>
@@ -163,3 +164,4 @@ TYPED_TEST(HDF5DataLayerTest, TestSkip) {
}
} // namespace caffe
+#endif // USE_HDF5
diff --git a/src/caffe/test/test_neuron_layer.cpp b/src/caffe/test/test_neuron_layer.cpp
index 180871a2..d1ecc37b 100644
--- a/src/caffe/test/test_neuron_layer.cpp
+++ b/src/caffe/test/test_neuron_layer.cpp
@@ -10,6 +10,7 @@
#include "caffe/layers/absval_layer.hpp"
#include "caffe/layers/bnll_layer.hpp"
+#include "caffe/layers/clip_layer.hpp"
#include "caffe/layers/dropout_layer.hpp"
#include "caffe/layers/elu_layer.hpp"
#include "caffe/layers/exp_layer.hpp"
@@ -19,6 +20,7 @@
#include "caffe/layers/prelu_layer.hpp"
#include "caffe/layers/relu_layer.hpp"
#include "caffe/layers/sigmoid_layer.hpp"
+#include "caffe/layers/swish_layer.hpp"
#include "caffe/layers/tanh_layer.hpp"
#include "caffe/layers/threshold_layer.hpp"
@@ -205,6 +207,66 @@ TYPED_TEST(NeuronLayerTest, TestAbsGradient) {
this->blob_top_vec_);
}
+TYPED_TEST(NeuronLayerTest, TestClip) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "clip_param { min: -1, max: 2 }", &layer_param));
+ ClipLayer<Dtype> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+ const Dtype* top_data = this->blob_top_->cpu_data();
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ EXPECT_GE(top_data[i], -1);
+ EXPECT_LE(top_data[i], 2);
+ EXPECT_TRUE(bottom_data[i] > -1 || top_data[i] == -1);
+ EXPECT_TRUE(bottom_data[i] < 2 || top_data[i] == 2);
+ EXPECT_TRUE(!(bottom_data[i] >= -1 && bottom_data[i] <= 2)
+ || top_data[i] == bottom_data[i]);
+ }
+}
+
+TYPED_TEST(NeuronLayerTest, TestClipGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "clip_param { min: -1, max: 2 }", &layer_param));
+ ClipLayer<Dtype> layer(layer_param);
+ // Unfortunately, it might happen that an input value lands exactly within
+ // the discontinuity region of the Clip function. In this case the numeric
+ // gradient is likely to differ significantly (i.e. by a value larger than
+ // checker tolerance) from the computed gradient. To handle such cases, we
+ // eliminate such values from the input blob before the gradient check.
+ const Dtype epsilon = 1e-2;
+ const Dtype min_range_start = layer_param.clip_param().min() - epsilon;
+ const Dtype min_range_end = layer_param.clip_param().min() + epsilon;
+ const Dtype max_range_start = layer_param.clip_param().max() - epsilon;
+ const Dtype max_range_end = layer_param.clip_param().max() + epsilon;
+ // The input blob is owned by the NeuronLayerTest object, so we begin with
+ // creating a temporary blob and copying the input data there.
+ Blob<Dtype> temp_bottom;
+ temp_bottom.ReshapeLike(*this->blob_bottom_);
+ const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+ Dtype* temp_data_mutable = temp_bottom.mutable_cpu_data();
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ if (bottom_data[i] >= min_range_start &&
+ bottom_data[i] <= min_range_end) {
+ temp_data_mutable[i] = bottom_data[i] - epsilon;
+ } else if (bottom_data[i] >= max_range_start &&
+ bottom_data[i] <= max_range_end) {
+ temp_data_mutable[i] = bottom_data[i] + epsilon;
+ } else {
+ temp_data_mutable[i] = bottom_data[i];
+ }
+ }
+ vector<Blob<Dtype>*> temp_bottom_vec;
+ temp_bottom_vec.push_back(&temp_bottom);
+ GradientChecker<Dtype> checker(epsilon, 1e-3);
+ checker.CheckGradientEltwise(&layer, temp_bottom_vec, this->blob_top_vec_);
+}
+
TYPED_TEST(NeuronLayerTest, TestReLU) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
@@ -344,6 +406,84 @@ TYPED_TEST(NeuronLayerTest, TestSigmoidGradient) {
this->blob_top_vec_);
}
+TYPED_TEST(NeuronLayerTest, TestSwish) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ SwishLayer<Dtype> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+ const Dtype* top_data = this->blob_top_->cpu_data();
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] / (1. + exp(-bottom_data[i])));
+ }
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishWithBeta) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "swish_param { beta: 1.5 }", &layer_param));
+ SwishLayer<Dtype> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+ const Dtype* top_data = this->blob_top_->cpu_data();
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] / (1. + exp(-1.5 *
+ bottom_data[i])));
+ }
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishAsLinear) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "swish_param { beta: 0.0 }", &layer_param));
+ SwishLayer<Dtype> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+ const Dtype* top_data = this->blob_top_->cpu_data();
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] / 2.0);
+ }
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ SwishLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
+ checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishWithBetaGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "swish_param { beta: 1.5 }", &layer_param));
+ SwishLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
+ checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishAsLinearGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "swish_param { beta: 0.0 }", &layer_param));
+ SwishLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
+ checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+}
+
TYPED_TEST(NeuronLayerTest, TestTanH) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
diff --git a/src/caffe/test/test_syncedmem.cpp b/src/caffe/test/test_syncedmem.cpp
index 16dfb582..2ca9ca2f 100644
--- a/src/caffe/test/test_syncedmem.cpp
+++ b/src/caffe/test/test_syncedmem.cpp
@@ -80,7 +80,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
char* recovered_value = new char[10];
caffe_gpu_memcpy(10, gpu_data, recovered_value);
for (int i = 0; i < mem.size(); ++i) {
- EXPECT_EQ((static_cast<char*>(recovered_value))[i], 1);
+ EXPECT_EQ(recovered_value[i], 1);
}
// do another round
cpu_data = mem.mutable_cpu_data();
@@ -94,7 +94,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
// check if values are the same
caffe_gpu_memcpy(10, gpu_data, recovered_value);
for (int i = 0; i < mem.size(); ++i) {
- EXPECT_EQ((static_cast<char*>(recovered_value))[i], 2);
+ EXPECT_EQ(recovered_value[i], 2);
}
delete[] recovered_value;
}
diff --git a/src/caffe/util/hdf5.cpp b/src/caffe/util/hdf5.cpp
index ed737429..cefd853d 100644
--- a/src/caffe/util/hdf5.cpp
+++ b/src/caffe/util/hdf5.cpp
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
#include "caffe/util/hdf5.hpp"
#include <string>
@@ -207,3 +208,4 @@ string hdf5_get_name_by_idx(hid_t loc_id, int idx) {
}
} // namespace caffe
+#endif // USE_HDF5
diff --git a/src/caffe/util/signal_handler.cpp b/src/caffe/util/signal_handler.cpp
index 5d764ec5..9658fb39 100644
--- a/src/caffe/util/signal_handler.cpp
+++ b/src/caffe/util/signal_handler.cpp
@@ -48,7 +48,7 @@ namespace {
void UnhookHandler() {
if (already_hooked_up) {
struct sigaction sa;
- // Setup the sighub handler
+ // Setup the sighup handler
sa.sa_handler = SIG_DFL;
// Restart the system call, if at all possible
sa.sa_flags = SA_RESTART;