summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/installation.md2
-rw-r--r--include/caffe/filler.hpp18
-rw-r--r--include/caffe/layers/swish_layer.hpp96
-rw-r--r--python/caffe/draw.py144
-rwxr-xr-xpython/draw_net.py6
-rw-r--r--src/caffe/layers/embed_layer.cu5
-rw-r--r--src/caffe/layers/swish_layer.cpp68
-rw-r--r--src/caffe/layers/swish_layer.cu54
-rw-r--r--src/caffe/proto/caffe.proto12
-rw-r--r--src/caffe/solvers/sgd_solver.cpp7
-rw-r--r--src/caffe/test/test_filler.cpp447
-rw-r--r--src/caffe/test/test_neuron_layer.cpp79
-rw-r--r--src/caffe/util/io.cpp2
-rw-r--r--tools/convert_imageset.cpp2
-rw-r--r--tools/device_query.cpp7
-rw-r--r--tools/finetune_net.cpp7
-rw-r--r--tools/net_speed_benchmark.cpp7
-rw-r--r--tools/test_net.cpp7
-rw-r--r--tools/train_net.cpp7
19 files changed, 807 insertions, 170 deletions
diff --git a/docs/installation.md b/docs/installation.md
index 2c4b30d3..c4822853 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -80,7 +80,7 @@ The main requirements are `numpy` and `boost.python` (provided by boost). `panda
You can install the dependencies with
- for req in $(cat requirements.txt); do pip install $req; done
+ pip install -r requirements.txt
but we suggest first installing the [Anaconda](https://store.continuum.io/cshop/anaconda/) Python distribution, which provides most of the necessary packages, as well as the `hdf5` library dependency.
diff --git a/include/caffe/filler.hpp b/include/caffe/filler.hpp
index bb92ded7..a4477361 100644
--- a/include/caffe/filler.hpp
+++ b/include/caffe/filler.hpp
@@ -108,9 +108,9 @@ class PositiveUnitballFiller : public Filler<Dtype> {
caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());
// We expect the filler to not be called very frequently, so we will
// just use a simple implementation
- int dim = blob->count() / blob->num();
+ int dim = blob->count() / blob->shape(0);
CHECK(dim);
- for (int i = 0; i < blob->num(); ++i) {
+ for (int i = 0; i < blob->shape(0); ++i) {
Dtype sum = 0;
for (int j = 0; j < dim; ++j) {
sum += data[i * dim + j];
@@ -147,8 +147,11 @@ class XavierFiller : public Filler<Dtype> {
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
- int fan_in = blob->count() / blob->num();
- int fan_out = blob->count() / blob->channels();
+ int fan_in = blob->count() / blob->shape(0);
+ // Compatibility with ND blobs
+ int fan_out = blob->num_axes() > 1 ?
+ blob->count() / blob->shape(1) :
+ blob->count();
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {
@@ -189,8 +192,11 @@ class MSRAFiller : public Filler<Dtype> {
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
- int fan_in = blob->count() / blob->num();
- int fan_out = blob->count() / blob->channels();
+ int fan_in = blob->count() / blob->shape(0);
+ // Compatibility with ND blobs
+ int fan_out = blob->num_axes() > 1 ?
+ blob->count() / blob->shape(1) :
+ blob->count();
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {
diff --git a/include/caffe/layers/swish_layer.hpp b/include/caffe/layers/swish_layer.hpp
new file mode 100644
index 00000000..d538ff6d
--- /dev/null
+++ b/include/caffe/layers/swish_layer.hpp
@@ -0,0 +1,96 @@
+#ifndef CAFFE_SWISH_LAYER_HPP_
+#define CAFFE_SWISH_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/neuron_layer.hpp"
+#include "caffe/layers/sigmoid_layer.hpp"
+
+namespace caffe {
+
+/**
+ * @brief Swish non-linearity @f$ y = x \sigma (\beta x) @f$.
+ * A novel activation function that tends to work better than ReLU [1].
+ *
+ * [1] Prajit Ramachandran, Barret Zoph, Quoc V. Le. "Searching for
+ * Activation Functions". arXiv preprint arXiv:1710.05941v2 (2017).
+ */
+template <typename Dtype>
+class SwishLayer : public NeuronLayer<Dtype> {
+ public:
+ /**
+ * @param param provides SwishParameter swish_param,
+ * with SwishLayer options:
+ * - beta (\b optional, default 1).
+ * the value @f$ \beta @f$ in the @f$ y = x \sigma (\beta x) @f$.
+ */
+ explicit SwishLayer(const LayerParameter& param)
+ : NeuronLayer<Dtype>(param),
+ sigmoid_layer_(new SigmoidLayer<Dtype>(param)),
+ sigmoid_input_(new Blob<Dtype>()),
+ sigmoid_output_(new Blob<Dtype>()) {}
+ virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+
+ virtual inline const char* type() const { return "Swish"; }
+
+ protected:
+ /**
+ * @param bottom input Blob vector (length 1)
+ * -# @f$ (N \times C \times H \times W) @f$
+ * the inputs @f$ x @f$
+ * @param top output Blob vector (length 1)
+ * -# @f$ (N \times C \times H \times W) @f$
+ * the computed outputs @f$
+ * y = x \sigma (\beta x)
+ * @f$.
+ */
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+
+ /**
+ * @brief Computes the error gradient w.r.t. the sigmoid inputs.
+ *
+ * @param top output Blob vector (length 1), providing the error gradient with
+ * respect to the outputs
+ * -# @f$ (N \times C \times H \times W) @f$
+ * containing error gradients @f$ \frac{\partial E}{\partial y} @f$
+ * with respect to computed outputs @f$ y @f$
+ * @param propagate_down see Layer::Backward.
+ * @param bottom input Blob vector (length 1)
+ * -# @f$ (N \times C \times H \times W) @f$
+ * the inputs @f$ x @f$; Backward fills their diff with
+ * gradients @f$
+ * \frac{\partial E}{\partial x}
+ * = \frac{\partial E}{\partial y}(\beta y +
+ * \sigma (\beta x)(1 - \beta y))
+ * @f$ if propagate_down[0]
+ */
+ virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+ virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+ /// The internal SigmoidLayer
+ shared_ptr<SigmoidLayer<Dtype> > sigmoid_layer_;
+ /// sigmoid_input_ stores the input of the SigmoidLayer.
+ shared_ptr<Blob<Dtype> > sigmoid_input_;
+ /// sigmoid_output_ stores the output of the SigmoidLayer.
+ shared_ptr<Blob<Dtype> > sigmoid_output_;
+ /// bottom vector holder to call the underlying SigmoidLayer::Forward
+ vector<Blob<Dtype>*> sigmoid_bottom_vec_;
+ /// top vector holder to call the underlying SigmoidLayer::Forward
+ vector<Blob<Dtype>*> sigmoid_top_vec_;
+};
+
+} // namespace caffe
+
+#endif // CAFFE_SWISH_LAYER_HPP_
diff --git a/python/caffe/draw.py b/python/caffe/draw.py
index 8411a41d..0061f490 100644
--- a/python/caffe/draw.py
+++ b/python/caffe/draw.py
@@ -59,18 +59,60 @@ def get_edge_label(layer):
return edge_label
-def get_layer_label(layer, rankdir):
+def get_layer_lr_mult(layer):
+ """Get the learning rate multipliers.
+
+ Get the learning rate multipliers for the given layer. Assumes a
+ Convolution/Deconvolution/InnerProduct layer.
+
+ Parameters
+ ----------
+ layer : caffe_pb2.LayerParameter
+ A Convolution, Deconvolution, or InnerProduct layer.
+
+ Returns
+ -------
+ learning_rates : tuple of floats
+ the learning rate multipliers for the weights and biases.
+ """
+ if layer.type not in ['Convolution', 'Deconvolution', 'InnerProduct']:
+ raise ValueError("%s layers do not have a "
+ "learning rate multiplier" % layer.type)
+
+ if not hasattr(layer, 'param'):
+ return (1.0, 1.0)
+
+ params = getattr(layer, 'param')
+
+ if len(params) == 0:
+ return (1.0, 1.0)
+
+ if len(params) == 1:
+ lrm0 = getattr(params[0],'lr_mult', 1.0)
+ return (lrm0, 1.0)
+
+ if len(params) == 2:
+ lrm0, lrm1 = [getattr(p,'lr_mult', 1.0) for p in params]
+ return (lrm0, lrm1)
+
+ raise ValueError("Could not parse the learning rate multiplier")
+
+
+def get_layer_label(layer, rankdir, display_lrm=False):
"""Define node label based on layer type.
Parameters
----------
- layer : ?
+ layer : caffe_pb2.LayerParameter
rankdir : {'LR', 'TB', 'BT'}
Direction of graph layout.
+ display_lrm : boolean, optional
+ If True include the learning rate multipliers in the label (default is
+ False).
Returns
-------
- string :
+ node_label : string
A label for the current layer
"""
@@ -81,36 +123,54 @@ def get_layer_label(layer, rankdir):
else:
# If graph orientation is horizontal, vertical space is free and
# horizontal space is not; separate words with newlines
- separator = '\\n'
-
- if layer.type == 'Convolution' or layer.type == 'Deconvolution':
- # Outer double quotes needed or else colon characters don't parse
- # properly
- node_label = '"%s%s(%s)%skernel size: %d%sstride: %d%spad: %d"' %\
- (layer.name,
- separator,
- layer.type,
- separator,
- layer.convolution_param.kernel_size[0] if len(layer.convolution_param.kernel_size) else 1,
- separator,
- layer.convolution_param.stride[0] if len(layer.convolution_param.stride) else 1,
- separator,
- layer.convolution_param.pad[0] if len(layer.convolution_param.pad) else 0)
- elif layer.type == 'Pooling':
+ separator = r'\n'
+
+ # Initializes a list of descriptors that will be concatenated into the
+ # `node_label`
+ descriptors_list = []
+ # Add the layer's name
+ descriptors_list.append(layer.name)
+ # Add layer's type
+ if layer.type == 'Pooling':
pooling_types_dict = get_pooling_types_dict()
- node_label = '"%s%s(%s %s)%skernel size: %d%sstride: %d%spad: %d"' %\
- (layer.name,
- separator,
- pooling_types_dict[layer.pooling_param.pool],
- layer.type,
- separator,
- layer.pooling_param.kernel_size,
- separator,
- layer.pooling_param.stride,
- separator,
- layer.pooling_param.pad)
+ layer_type = '(%s %s)' % (layer.type,
+ pooling_types_dict[layer.pooling_param.pool])
else:
- node_label = '"%s%s(%s)"' % (layer.name, separator, layer.type)
+ layer_type = '(%s)' % layer.type
+ descriptors_list.append(layer_type)
+
+ # Describe parameters for spatial operation layers
+ if layer.type in ['Convolution', 'Deconvolution', 'Pooling']:
+ if layer.type == 'Pooling':
+ kernel_size = layer.pooling_param.kernel_size
+ stride = layer.pooling_param.stride
+ padding = layer.pooling_param.pad
+ else:
+ kernel_size = layer.convolution_param.kernel_size[0] if \
+ len(layer.convolution_param.kernel_size) else 1
+ stride = layer.convolution_param.stride[0] if \
+ len(layer.convolution_param.stride) else 1
+ padding = layer.convolution_param.pad[0] if \
+ len(layer.convolution_param.pad) else 0
+ spatial_descriptor = separator.join([
+ "kernel size: %d" % kernel_size,
+ "stride: %d" % stride,
+ "pad: %d" % padding,
+ ])
+ descriptors_list.append(spatial_descriptor)
+
+ # Add LR multiplier for learning layers
+ if display_lrm and layer.type in ['Convolution', 'Deconvolution', 'InnerProduct']:
+ lrm0, lrm1 = get_layer_lr_mult(layer)
+ if any([lrm0, lrm1]):
+ lr_mult = "lr mult: %.1f, %.1f" % (lrm0, lrm1)
+ descriptors_list.append(lr_mult)
+
+ # Concatenate the descriptors into one label
+ node_label = separator.join(descriptors_list)
+ # Outer double quotes needed or else colon characters don't parse
+ # properly
+ node_label = '"%s"' % node_label
return node_label
@@ -127,7 +187,7 @@ def choose_color_by_layertype(layertype):
return color
-def get_pydot_graph(caffe_net, rankdir, label_edges=True, phase=None):
+def get_pydot_graph(caffe_net, rankdir, label_edges=True, phase=None, display_lrm=False):
"""Create a data structure which represents the `caffe_net`.
Parameters
@@ -140,6 +200,9 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True, phase=None):
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
Include layers from this network phase. If None, include all layers.
(the default is None)
+ display_lrm : boolean, optional
+ If True display the learning rate multipliers when relevant (default is
+ False).
Returns
-------
@@ -164,7 +227,7 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True, phase=None):
included = included and not layer_phase.phase == phase
if not included:
continue
- node_label = get_layer_label(layer, rankdir)
+ node_label = get_layer_label(layer, rankdir, display_lrm=display_lrm)
node_name = "%s_%s" % (layer.name, layer.type)
if (len(layer.bottom) == 1 and len(layer.top) == 1 and
layer.bottom[0] == layer.top[0]):
@@ -202,7 +265,7 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True, phase=None):
return pydot_graph
-def draw_net(caffe_net, rankdir, ext='png', phase=None):
+def draw_net(caffe_net, rankdir, ext='png', phase=None, display_lrm=False):
"""Draws a caffe net and returns the image string encoded using the given
extension.
@@ -214,16 +277,20 @@ def draw_net(caffe_net, rankdir, ext='png', phase=None):
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
Include layers from this network phase. If None, include all layers.
(the default is None)
+ display_lrm : boolean, optional
+ If True display the learning rate multipliers for the learning layers
+ (default is False).
Returns
-------
string :
Postscript representation of the graph.
"""
- return get_pydot_graph(caffe_net, rankdir, phase=phase).create(format=ext)
+ return get_pydot_graph(caffe_net, rankdir, phase=phase,
+ display_lrm=display_lrm).create(format=ext)
-def draw_net_to_file(caffe_net, filename, rankdir='LR', phase=None):
+def draw_net_to_file(caffe_net, filename, rankdir='LR', phase=None, display_lrm=False):
"""Draws a caffe net, and saves it to file using the format given as the
file extension. Use '.raw' to output raw text that you can manually feed
to graphviz to draw graphs.
@@ -238,7 +305,10 @@ def draw_net_to_file(caffe_net, filename, rankdir='LR', phase=None):
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
Include layers from this network phase. If None, include all layers.
(the default is None)
+ display_lrm : boolean, optional
+ If True display the learning rate multipliers for the learning layers
+ (default is False).
"""
ext = filename[filename.rfind('.')+1:]
with open(filename, 'wb') as fid:
- fid.write(draw_net(caffe_net, rankdir, ext, phase))
+ fid.write(draw_net(caffe_net, rankdir, ext, phase, display_lrm))
diff --git a/python/draw_net.py b/python/draw_net.py
index dfe70d26..23cae30a 100755
--- a/python/draw_net.py
+++ b/python/draw_net.py
@@ -33,6 +33,10 @@ def parse_args():
'TEST, or ALL. If ALL, then all layers are drawn '
'regardless of phase.'),
default="ALL")
+ parser.add_argument('--display_lrm', action='store_true',
+ help=('Use this flag to visualize the learning rate '
+ 'multiplier, when non-zero, for the learning '
+ 'layers (Convolution, Deconvolution, InnerProduct).'))
args = parser.parse_args()
return args
@@ -51,7 +55,7 @@ def main():
elif args.phase != "ALL":
raise ValueError("Unknown phase: " + args.phase)
caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir,
- phase)
+ phase, args.display_lrm)
if __name__ == '__main__':
diff --git a/src/caffe/layers/embed_layer.cu b/src/caffe/layers/embed_layer.cu
index 6324a3a8..3cf39fd9 100644
--- a/src/caffe/layers/embed_layer.cu
+++ b/src/caffe/layers/embed_layer.cu
@@ -15,6 +15,11 @@ __global__ void EmbedForward(const int nthreads, const Dtype* bottom_data,
const int n = top_index / N;
const int d = top_index % N;
const int index = static_cast<int>(bottom_data[n]);
+ #ifdef DEBUG
+ assert(index >= 0);
+ assert(index < K);
+ assert(static_cast<Dtype>(index) == bottom_data[n]);
+ #endif
const int weight_index = index * N + d;
top_data[top_index] = weight[weight_index];
}
diff --git a/src/caffe/layers/swish_layer.cpp b/src/caffe/layers/swish_layer.cpp
new file mode 100644
index 00000000..28935679
--- /dev/null
+++ b/src/caffe/layers/swish_layer.cpp
@@ -0,0 +1,68 @@
+#include <cmath>
+#include <vector>
+
+#include "caffe/layers/swish_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void SwishLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ NeuronLayer<Dtype>::LayerSetUp(bottom, top);
+ sigmoid_bottom_vec_.clear();
+ sigmoid_bottom_vec_.push_back(sigmoid_input_.get());
+ sigmoid_top_vec_.clear();
+ sigmoid_top_vec_.push_back(sigmoid_output_.get());
+ sigmoid_layer_->SetUp(sigmoid_bottom_vec_, sigmoid_top_vec_);
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ NeuronLayer<Dtype>::Reshape(bottom, top);
+ sigmoid_input_->ReshapeLike(*bottom[0]);
+ sigmoid_layer_->Reshape(sigmoid_bottom_vec_, sigmoid_top_vec_);
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* sigmoid_input_data = sigmoid_input_->mutable_cpu_data();
+ Dtype* top_data = top[0]->mutable_cpu_data();
+ const int count = bottom[0]->count();
+ Dtype beta = this->layer_param_.swish_param().beta();
+ caffe_copy(count, bottom_data, sigmoid_input_data);
+ caffe_scal(count, beta, sigmoid_input_data);
+ sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
+ caffe_mul(count, bottom_data, sigmoid_output_->cpu_data(), top_data);
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ if (propagate_down[0]) {
+ const Dtype* top_data = top[0]->cpu_data();
+ const Dtype* top_diff = top[0]->cpu_diff();
+ const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data();
+ Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+ const int count = bottom[0]->count();
+ Dtype beta = this->layer_param_.swish_param().beta();
+ for (int i = 0; i < count; ++i) {
+ const Dtype swish_x = top_data[i];
+ bottom_diff[i] = top_diff[i] * (beta * swish_x + sigmoid_output_data[i]
+ * (1. - beta * swish_x));
+ }
+ }
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(SwishLayer);
+#endif
+
+INSTANTIATE_CLASS(SwishLayer);
+REGISTER_LAYER_CLASS(Swish);
+
+} // namespace caffe
diff --git a/src/caffe/layers/swish_layer.cu b/src/caffe/layers/swish_layer.cu
new file mode 100644
index 00000000..c4fef53b
--- /dev/null
+++ b/src/caffe/layers/swish_layer.cu
@@ -0,0 +1,54 @@
+#include <cmath>
+#include <vector>
+
+#include "caffe/layers/swish_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* sigmoid_input_data = sigmoid_input_->mutable_gpu_data();
+ Dtype* top_data = top[0]->mutable_gpu_data();
+ const int count = bottom[0]->count();
+ Dtype beta = this->layer_param_.swish_param().beta();
+ caffe_copy(count, bottom_data, sigmoid_input_data);
+ caffe_gpu_scal(count, beta, sigmoid_input_data);
+ sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
+ caffe_gpu_mul(count, bottom_data, sigmoid_output_->gpu_data(), top_data);
+}
+
+template <typename Dtype>
+__global__ void SwishBackward(const int n, const Dtype* in_diff,
+ const Dtype* out_data, const Dtype* sigmoid_output_data, Dtype* out_diff,
+ const Dtype beta) {
+ CUDA_KERNEL_LOOP(index, n) {
+ const Dtype swish_x = out_data[index];
+ out_diff[index] = in_diff[index] * (beta * swish_x
+ + sigmoid_output_data[index] * (1 - beta * swish_x));
+ }
+}
+
+template <typename Dtype>
+void SwishLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ if (propagate_down[0]) {
+ const Dtype* top_data = top[0]->gpu_data();
+ const Dtype* top_diff = top[0]->gpu_diff();
+ const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data();
+ Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+ const int count = bottom[0]->count();
+ Dtype beta = this->layer_param_.swish_param().beta();
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ SwishBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, top_diff, top_data, sigmoid_output_data, bottom_diff, beta);
+ CUDA_POST_KERNEL_CHECK;
+ }
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(SwishLayer);
+
+} // namespace caffe
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index 22764abc..b9bb3f4d 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -322,7 +322,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
-// LayerParameter next available layer-specific ID: 147 (last added: recurrent_param)
+// LayerParameter next available layer-specific ID: 148 (last added: swish_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
@@ -415,6 +415,7 @@ message LayerParameter {
optional SoftmaxParameter softmax_param = 125;
optional SPPParameter spp_param = 132;
optional SliceParameter slice_param = 126;
+ optional SwishParameter swish_param = 147;
optional TanHParameter tanh_param = 127;
optional ThresholdParameter threshold_param = 128;
optional TileParameter tile_param = 138;
@@ -1156,6 +1157,15 @@ message SoftmaxParameter {
optional int32 axis = 2 [default = 1];
}
+// Message that stores parameters used by SwishLayer
+message SwishParameter {
+ // Beta parameter for the Swish activation function
+ // Described in:
+ // Prajit Ramachandran, Barret Zoph, Quoc V. Le. (2017). Searching for
+ // Activation Functions. https://arxiv.org/abs/1710.05941v2
+ optional float beta = 1 [default = 1];
+}
+
message TanHParameter {
enum Engine {
DEFAULT = 0;
diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp
index ad6abe54..1d52beb0 100644
--- a/src/caffe/solvers/sgd_solver.cpp
+++ b/src/caffe/solvers/sgd_solver.cpp
@@ -30,12 +30,16 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
if (lr_policy == "fixed") {
rate = this->param_.base_lr();
} else if (lr_policy == "step") {
+ CHECK_GT(this->param_.stepsize(), 0);
this->current_step_ = this->iter_ / this->param_.stepsize();
+ CHECK_GE(this->param_.gamma(), 0);
rate = this->param_.base_lr() *
pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "exp") {
+ CHECK_GE(this->param_.gamma(), 0);
rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
} else if (lr_policy == "inv") {
+ CHECK_GE(this->param_.gamma(), 0);
rate = this->param_.base_lr() *
pow(Dtype(1) + this->param_.gamma() * this->iter_,
- this->param_.power());
@@ -46,6 +50,7 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
LOG(INFO) << "MultiStep Status: Iteration " <<
this->iter_ << ", step = " << this->current_step_;
}
+ CHECK_GE(this->param_.gamma(), 0);
rate = this->param_.base_lr() *
pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "poly") {
@@ -53,6 +58,8 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
(Dtype(this->iter_) / Dtype(this->param_.max_iter())),
this->param_.power());
} else if (lr_policy == "sigmoid") {
+ CHECK_GE(this->param_.gamma(), 0);
+ CHECK_GT(this->param_.stepsize(), 0);
rate = this->param_.base_lr() * (Dtype(1.) /
(Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
Dtype(this->param_.stepsize())))));
diff --git a/src/caffe/test/test_filler.cpp b/src/caffe/test/test_filler.cpp
index f84d707b..1e6b5c21 100644
--- a/src/caffe/test/test_filler.cpp
+++ b/src/caffe/test/test_filler.cpp
@@ -1,3 +1,5 @@
+#include <vector>
+
#include "gtest/gtest.h"
#include "caffe/filler.hpp"
@@ -10,11 +12,20 @@ template <typename Dtype>
class ConstantFillerTest : public ::testing::Test {
protected:
ConstantFillerTest()
- : blob_(new Blob<Dtype>(2, 3, 4, 5)),
+ : blob_(new Blob<Dtype>()),
filler_param_() {
filler_param_.set_value(10.);
filler_.reset(new ConstantFiller<Dtype>(filler_param_));
+ }
+ virtual void test_params(const vector<int>& shape) {
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
filler_->Fill(blob_);
+ const int count = blob_->count();
+ const Dtype* data = blob_->cpu_data();
+ for (int i = 0; i < count; ++i) {
+ EXPECT_EQ(data[i], filler_param_.value());
+ }
}
virtual ~ConstantFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
@@ -25,12 +36,34 @@ class ConstantFillerTest : public ::testing::Test {
TYPED_TEST_CASE(ConstantFillerTest, TestDtypes);
TYPED_TEST(ConstantFillerTest, TestFill) {
- EXPECT_TRUE(this->blob_);
- const int count = this->blob_->count();
- const TypeParam* data = this->blob_->cpu_data();
- for (int i = 0; i < count; ++i) {
- EXPECT_EQ(data[i], this->filler_param_.value());
- }
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(ConstantFillerTest, TestFill1D) {
+ vector<int> blob_shape(1, 15);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(ConstantFillerTest, TestFill2D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(8);
+ blob_shape.push_back(3);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(ConstantFillerTest, TestFill5D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ blob_shape.push_back(2);
+ this->test_params(blob_shape);
}
@@ -38,12 +71,22 @@ template <typename Dtype>
class UniformFillerTest : public ::testing::Test {
protected:
UniformFillerTest()
- : blob_(new Blob<Dtype>(2, 3, 4, 5)),
+ : blob_(new Blob<Dtype>()),
filler_param_() {
filler_param_.set_min(1.);
filler_param_.set_max(2.);
filler_.reset(new UniformFiller<Dtype>(filler_param_));
+ }
+ virtual void test_params(const vector<int>& shape) {
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
filler_->Fill(blob_);
+ const int count = blob_->count();
+ const Dtype* data = blob_->cpu_data();
+ for (int i = 0; i < count; ++i) {
+ EXPECT_GE(data[i], filler_param_.min());
+ EXPECT_LE(data[i], filler_param_.max());
+ }
}
virtual ~UniformFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
@@ -54,23 +97,64 @@ class UniformFillerTest : public ::testing::Test {
TYPED_TEST_CASE(UniformFillerTest, TestDtypes);
TYPED_TEST(UniformFillerTest, TestFill) {
- EXPECT_TRUE(this->blob_);
- const int count = this->blob_->count();
- const TypeParam* data = this->blob_->cpu_data();
- for (int i = 0; i < count; ++i) {
- EXPECT_GE(data[i], this->filler_param_.min());
- EXPECT_LE(data[i], this->filler_param_.max());
- }
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(UniformFillerTest, TestFill1D) {
+ vector<int> blob_shape(1, 15);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(UniformFillerTest, TestFill2D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(8);
+ blob_shape.push_back(3);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(UniformFillerTest, TestFill5D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ blob_shape.push_back(2);
+ this->test_params(blob_shape);
}
template <typename Dtype>
class PositiveUnitballFillerTest : public ::testing::Test {
protected:
PositiveUnitballFillerTest()
- : blob_(new Blob<Dtype>(2, 3, 4, 5)),
+ : blob_(new Blob<Dtype>()),
filler_param_() {
filler_.reset(new PositiveUnitballFiller<Dtype>(filler_param_));
+ }
+ virtual void test_params(const vector<int>& shape) {
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
filler_->Fill(blob_);
+ const int num = blob_->shape(0);
+ const int count = blob_->count();
+ const int dim = count / num;
+ const Dtype* data = blob_->cpu_data();
+ for (int i = 0; i < count; ++i) {
+ EXPECT_GE(data[i], 0);
+ EXPECT_LE(data[i], 1);
+ }
+ for (int i = 0; i < num; ++i) {
+ Dtype sum = Dtype(0);
+ for (int j = 0; j < dim; ++j) {
+ sum += data[i * dim + j];
+ }
+ EXPECT_GE(sum, 0.999);
+ EXPECT_LE(sum, 1.001);
+ }
}
virtual ~PositiveUnitballFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
@@ -81,35 +165,78 @@ class PositiveUnitballFillerTest : public ::testing::Test {
TYPED_TEST_CASE(PositiveUnitballFillerTest, TestDtypes);
TYPED_TEST(PositiveUnitballFillerTest, TestFill) {
- EXPECT_TRUE(this->blob_);
- const int num = this->blob_->num();
- const int count = this->blob_->count();
- const int dim = count / num;
- const TypeParam* data = this->blob_->cpu_data();
- for (int i = 0; i < count; ++i) {
- EXPECT_GE(data[i], 0);
- EXPECT_LE(data[i], 1);
- }
- for (int i = 0; i < num; ++i) {
- TypeParam sum = 0;
- for (int j = 0; j < dim; ++j) {
- sum += data[i * dim + j];
- }
- EXPECT_GE(sum, 0.999);
- EXPECT_LE(sum, 1.001);
- }
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(PositiveUnitballFillerTest, TestFill1D) {
+ vector<int> blob_shape(1, 15);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(PositiveUnitballFillerTest, TestFill2D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(8);
+ blob_shape.push_back(3);
+ this->test_params(blob_shape);
+}
+
+TYPED_TEST(PositiveUnitballFillerTest, TestFill5D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ blob_shape.push_back(2);
+ this->test_params(blob_shape);
}
template <typename Dtype>
class GaussianFillerTest : public ::testing::Test {
protected:
GaussianFillerTest()
- : blob_(new Blob<Dtype>(2, 3, 4, 5)),
+ : blob_(new Blob<Dtype>()),
filler_param_() {
filler_param_.set_mean(10.);
filler_param_.set_std(0.1);
filler_.reset(new GaussianFiller<Dtype>(filler_param_));
+ }
+ virtual void test_params(const vector<int>& shape,
+ const Dtype tolerance = Dtype(5), const int repetitions = 100) {
+ // Tests for statistical properties should be ran multiple times.
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
+ for (int i = 0; i < repetitions; ++i) {
+ test_params_iter(shape, tolerance);
+ }
+ }
+ virtual void test_params_iter(const vector<int>& shape,
+ const Dtype tolerance) {
+ // This test has a configurable tolerance parameter - by default it was
+ // equal to 5.0 which is very loose - allowing some tuning (e.g. for tests
+ // on smaller blobs the actual variance will be larger than desired, so the
+ // tolerance can be increased to account for that).
filler_->Fill(blob_);
+ const int count = blob_->count();
+ const Dtype* data = blob_->cpu_data();
+ Dtype mean = Dtype(0);
+ Dtype var = Dtype(0);
+ for (int i = 0; i < count; ++i) {
+ mean += data[i];
+ var += data[i] * data[i];
+ }
+ mean /= count;
+ var /= count;
+ var -= mean*mean;
+ EXPECT_GE(mean, filler_param_.mean() - filler_param_.std() * tolerance);
+ EXPECT_LE(mean, filler_param_.mean() + filler_param_.std() * tolerance);
+ Dtype target_var = filler_param_.std() * filler_param_.std();
+ EXPECT_GE(var, target_var / tolerance);
+ EXPECT_LE(var, target_var * tolerance);
}
virtual ~GaussianFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
@@ -120,41 +247,62 @@ class GaussianFillerTest : public ::testing::Test {
TYPED_TEST_CASE(GaussianFillerTest, TestDtypes);
TYPED_TEST(GaussianFillerTest, TestFill) {
- EXPECT_TRUE(this->blob_);
- const int count = this->blob_->count();
- const TypeParam* data = this->blob_->cpu_data();
- TypeParam mean = 0.;
- TypeParam var = 0.;
- for (int i = 0; i < count; ++i) {
- mean += data[i];
- var += (data[i] - this->filler_param_.mean()) *
- (data[i] - this->filler_param_.mean());
- }
- mean /= count;
- var /= count;
- // Very loose test.
- EXPECT_GE(mean, this->filler_param_.mean() - this->filler_param_.std() * 5);
- EXPECT_LE(mean, this->filler_param_.mean() + this->filler_param_.std() * 5);
- TypeParam target_var = this->filler_param_.std() * this->filler_param_.std();
- EXPECT_GE(var, target_var / 5.);
- EXPECT_LE(var, target_var * 5.);
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ const TypeParam tolerance = TypeParam(3); // enough for a 120-element blob
+ this->test_params(blob_shape, tolerance);
+}
+
+TYPED_TEST(GaussianFillerTest, TestFill1D) {
+ vector<int> blob_shape(1, 25);
+ const TypeParam tolerance = TypeParam(5);
+ this->test_params(blob_shape, tolerance);
+}
+
+TYPED_TEST(GaussianFillerTest, TestFill2D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(8);
+ blob_shape.push_back(3);
+ const TypeParam tolerance = TypeParam(5);
+ this->test_params(blob_shape, tolerance);
+}
+
+TYPED_TEST(GaussianFillerTest, TestFill5D) {
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ blob_shape.push_back(2);
+ const TypeParam tolerance = TypeParam(2);
+ this->test_params(blob_shape, tolerance);
}
template <typename Dtype>
class XavierFillerTest : public ::testing::Test {
protected:
XavierFillerTest()
- : blob_(new Blob<Dtype>(1000, 2, 4, 5)),
+ : blob_(new Blob<Dtype>()),
filler_param_() {
}
virtual void test_params(FillerParameter_VarianceNorm variance_norm,
+ Dtype n, const vector<int>& shape, const int repetitions = 100) {
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
+ for (int i = 0; i < repetitions; ++i) {
+ test_params_iter(variance_norm, n);
+ }
+ }
+ virtual void test_params_iter(FillerParameter_VarianceNorm variance_norm,
Dtype n) {
- this->filler_param_.set_variance_norm(variance_norm);
- this->filler_.reset(new XavierFiller<Dtype>(this->filler_param_));
- this->filler_->Fill(blob_);
- EXPECT_TRUE(this->blob_);
- const int count = this->blob_->count();
- const Dtype* data = this->blob_->cpu_data();
+ filler_param_.set_variance_norm(variance_norm);
+ filler_.reset(new XavierFiller<Dtype>(filler_param_));
+ filler_->Fill(blob_);
+ const int count = blob_->count();
+ const Dtype* data = blob_->cpu_data();
Dtype mean = 0.;
Dtype ex2 = 0.;
for (int i = 0; i < count; ++i) {
@@ -177,33 +325,92 @@ class XavierFillerTest : public ::testing::Test {
TYPED_TEST_CASE(XavierFillerTest, TestDtypes);
TYPED_TEST(XavierFillerTest, TestFillFanIn) {
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
TypeParam n = 2*4*5;
- this->test_params(FillerParameter_VarianceNorm_FAN_IN, n);
+ this->test_params(FillerParameter_VarianceNorm_FAN_IN, n, blob_shape);
}
+
TYPED_TEST(XavierFillerTest, TestFillFanOut) {
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
TypeParam n = 1000*4*5;
- this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n);
+ this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n, blob_shape);
}
+
TYPED_TEST(XavierFillerTest, TestFillAverage) {
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
TypeParam n = (2*4*5 + 1000*4*5) / 2.0;
- this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
+ this->test_params(FillerParameter_VarianceNorm_AVERAGE, n, blob_shape);
+}
+
+TYPED_TEST(XavierFillerTest, TestFill1D) {
+ // This makes little sense but at least we will know that we can fill it
+ EXPECT_TRUE(this->blob_);
+ vector<int> blob_shape(1, 25);
+ this->blob_->Reshape(blob_shape);
+ this->filler_param_.set_variance_norm(FillerParameter_VarianceNorm_AVERAGE);
+ this->filler_.reset(new XavierFiller<TypeParam>(this->filler_param_));
+ this->filler_->Fill(this->blob_);
+}
+
+TYPED_TEST(XavierFillerTest, TestFill2D) {
+ EXPECT_TRUE(this->blob_);
+ vector<int> blob_shape;
+ blob_shape.push_back(8);
+ blob_shape.push_back(3);
+ this->blob_->Reshape(blob_shape);
+ this->filler_param_.set_variance_norm(FillerParameter_VarianceNorm_AVERAGE);
+ this->filler_.reset(new XavierFiller<TypeParam>(this->filler_param_));
+ this->filler_->Fill(this->blob_);
+}
+
+TYPED_TEST(XavierFillerTest, TestFill5D) {
+ EXPECT_TRUE(this->blob_);
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ blob_shape.push_back(2);
+ this->blob_->Reshape(blob_shape);
+ this->filler_param_.set_variance_norm(FillerParameter_VarianceNorm_AVERAGE);
+ this->filler_.reset(new XavierFiller<TypeParam>(this->filler_param_));
+ this->filler_->Fill(this->blob_);
}
template <typename Dtype>
class MSRAFillerTest : public ::testing::Test {
protected:
MSRAFillerTest()
- : blob_(new Blob<Dtype>(1000, 2, 4, 5)),
+ : blob_(new Blob<Dtype>()),
filler_param_() {
}
virtual void test_params(FillerParameter_VarianceNorm variance_norm,
+ Dtype n, const vector<int>& shape, const int repetitions = 100) {
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
+ for (int i = 0; i < repetitions; ++i) {
+ test_params_iter(variance_norm, n);
+ }
+ }
+ virtual void test_params_iter(FillerParameter_VarianceNorm variance_norm,
Dtype n) {
- this->filler_param_.set_variance_norm(variance_norm);
- this->filler_.reset(new MSRAFiller<Dtype>(this->filler_param_));
- this->filler_->Fill(blob_);
- EXPECT_TRUE(this->blob_);
- const int count = this->blob_->count();
- const Dtype* data = this->blob_->cpu_data();
+ filler_param_.set_variance_norm(variance_norm);
+ filler_.reset(new MSRAFiller<Dtype>(filler_param_));
+ filler_->Fill(blob_);
+ const int count = blob_->count();
+ const Dtype* data = blob_->cpu_data();
Dtype mean = 0.;
Dtype ex2 = 0.;
for (int i = 0; i < count; ++i) {
@@ -226,36 +433,92 @@ class MSRAFillerTest : public ::testing::Test {
TYPED_TEST_CASE(MSRAFillerTest, TestDtypes);
TYPED_TEST(MSRAFillerTest, TestFillFanIn) {
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
TypeParam n = 2*4*5;
- this->test_params(FillerParameter_VarianceNorm_FAN_IN, n);
+ this->test_params(FillerParameter_VarianceNorm_FAN_IN, n, blob_shape);
}
+
TYPED_TEST(MSRAFillerTest, TestFillFanOut) {
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
TypeParam n = 1000*4*5;
- this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n);
+ this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n, blob_shape);
}
+
TYPED_TEST(MSRAFillerTest, TestFillAverage) {
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
TypeParam n = (2*4*5 + 1000*4*5) / 2.0;
- this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
+ this->test_params(FillerParameter_VarianceNorm_AVERAGE, n, blob_shape);
+}
+
+TYPED_TEST(MSRAFillerTest, TestFill1D) {
+ // Like with Xavier - no checking for correctness, just if it can be filled.
+ EXPECT_TRUE(this->blob_);
+ vector<int> blob_shape(1, 25);
+ this->blob_->Reshape(blob_shape);
+ this->filler_param_.set_variance_norm(FillerParameter_VarianceNorm_AVERAGE);
+ this->filler_.reset(new MSRAFiller<TypeParam>(this->filler_param_));
+ this->filler_->Fill(this->blob_);
+}
+
+TYPED_TEST(MSRAFillerTest, TestFill2D) {
+ EXPECT_TRUE(this->blob_);
+ vector<int> blob_shape;
+ blob_shape.push_back(8);
+ blob_shape.push_back(3);
+ this->blob_->Reshape(blob_shape);
+ this->filler_param_.set_variance_norm(FillerParameter_VarianceNorm_AVERAGE);
+ this->filler_.reset(new MSRAFiller<TypeParam>(this->filler_param_));
+ this->filler_->Fill(this->blob_);
+}
+
+TYPED_TEST(MSRAFillerTest, TestFill5D) {
+ EXPECT_TRUE(this->blob_);
+ vector<int> blob_shape;
+ blob_shape.push_back(2);
+ blob_shape.push_back(3);
+ blob_shape.push_back(4);
+ blob_shape.push_back(5);
+ blob_shape.push_back(2);
+ this->blob_->Reshape(blob_shape);
+ this->filler_param_.set_variance_norm(FillerParameter_VarianceNorm_AVERAGE);
+ this->filler_.reset(new MSRAFiller<TypeParam>(this->filler_param_));
+ this->filler_->Fill(this->blob_);
}
template <typename Dtype>
class BilinearFillerTest : public ::testing::Test {
protected:
- BilinearFillerTest() : filler_param_() {}
- virtual void test_params(const int n) {
- this->blob_ = new Blob<Dtype>(1000, 2, n, n);
- this->filler_.reset(new BilinearFiller<Dtype>(this->filler_param_));
- this->filler_->Fill(blob_);
- EXPECT_TRUE(this->blob_);
- const int outer_num = this->blob_->count(0, 2);
- const int inner_num = this->blob_->count(2, 4);
- const Dtype* data = this->blob_->cpu_data();
- int f = ceil(this->blob_->width() / 2.);
- Dtype c = (this->blob_->width() - 1) / (2. * f);
+ BilinearFillerTest()
+ : blob_(new Blob<Dtype>()),
+ filler_param_() {
+ }
+ virtual void test_params(const vector<int>& shape) {
+ EXPECT_TRUE(blob_);
+ blob_->Reshape(shape);
+ filler_.reset(new BilinearFiller<Dtype>(filler_param_));
+ filler_->Fill(blob_);
+ CHECK_EQ(blob_->num_axes(), 4);
+ const int outer_num = blob_->count(0, 2);
+ const int inner_num = blob_->count(2, 4);
+ const Dtype* data = blob_->cpu_data();
+ int f = ceil(blob_->shape(3) / 2.);
+ Dtype c = (blob_->shape(3) - 1) / (2. * f);
for (int i = 0; i < outer_num; ++i) {
for (int j = 0; j < inner_num; ++j) {
- Dtype x = j % this->blob_->width();
- Dtype y = (j / this->blob_->width()) % this->blob_->height();
+ Dtype x = j % blob_->shape(3);
+ Dtype y = (j / blob_->shape(3)) % blob_->shape(2);
Dtype expected_value = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
const Dtype actual_value = data[i * inner_num + j];
EXPECT_NEAR(expected_value, actual_value, 0.01);
@@ -272,11 +535,21 @@ TYPED_TEST_CASE(BilinearFillerTest, TestDtypes);
TYPED_TEST(BilinearFillerTest, TestFillOdd) {
const int n = 7;
- this->test_params(n);
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(n);
+ blob_shape.push_back(n);
+ this->test_params(blob_shape);
}
TYPED_TEST(BilinearFillerTest, TestFillEven) {
const int n = 6;
- this->test_params(n);
+ vector<int> blob_shape;
+ blob_shape.push_back(1000);
+ blob_shape.push_back(2);
+ blob_shape.push_back(n);
+ blob_shape.push_back(n);
+ this->test_params(blob_shape);
}
} // namespace caffe
diff --git a/src/caffe/test/test_neuron_layer.cpp b/src/caffe/test/test_neuron_layer.cpp
index 180871a2..83d80fcd 100644
--- a/src/caffe/test/test_neuron_layer.cpp
+++ b/src/caffe/test/test_neuron_layer.cpp
@@ -19,6 +19,7 @@
#include "caffe/layers/prelu_layer.hpp"
#include "caffe/layers/relu_layer.hpp"
#include "caffe/layers/sigmoid_layer.hpp"
+#include "caffe/layers/swish_layer.hpp"
#include "caffe/layers/tanh_layer.hpp"
#include "caffe/layers/threshold_layer.hpp"
@@ -344,6 +345,84 @@ TYPED_TEST(NeuronLayerTest, TestSigmoidGradient) {
this->blob_top_vec_);
}
+TYPED_TEST(NeuronLayerTest, TestSwish) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ SwishLayer<Dtype> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+ const Dtype* top_data = this->blob_top_->cpu_data();
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] / (1. + exp(-bottom_data[i])));
+ }
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishWithBeta) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "swish_param { beta: 1.5 }", &layer_param));
+ SwishLayer<Dtype> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+ const Dtype* top_data = this->blob_top_->cpu_data();
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] / (1. + exp(-1.5 *
+ bottom_data[i])));
+ }
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishAsLinear) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "swish_param { beta: 0.0 }", &layer_param));
+ SwishLayer<Dtype> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // Now, check values
+ const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+ const Dtype* top_data = this->blob_top_->cpu_data();
+ for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+ EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] / 2.0);
+ }
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ SwishLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
+ checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishWithBetaGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "swish_param { beta: 1.5 }", &layer_param));
+ SwishLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
+ checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+}
+
+TYPED_TEST(NeuronLayerTest, TestSwishAsLinearGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ CHECK(google::protobuf::TextFormat::ParseFromString(
+ "swish_param { beta: 0.0 }", &layer_param));
+ SwishLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 0., 0.01);
+ checker.CheckGradientEltwise(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+}
+
TYPED_TEST(NeuronLayerTest, TestTanH) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp
index 835d2d4e..5295d9dd 100644
--- a/src/caffe/util/io.cpp
+++ b/src/caffe/util/io.cpp
@@ -106,7 +106,7 @@ cv::Mat ReadImageToCVMat(const string& filename) {
static bool matchExt(const std::string & fn,
std::string en) {
size_t p = fn.rfind('.');
- std::string ext = p != fn.npos ? fn.substr(p) : fn;
+ std::string ext = p != fn.npos ? fn.substr(p+1) : fn;
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
std::transform(en.begin(), en.end(), en.begin(), ::tolower);
if ( ext == en )
diff --git a/tools/convert_imageset.cpp b/tools/convert_imageset.cpp
index 90cdb15d..9c5d09f9 100644
--- a/tools/convert_imageset.cpp
+++ b/tools/convert_imageset.cpp
@@ -115,7 +115,7 @@ int main(int argc, char** argv) {
size_t p = fn.rfind('.');
if ( p == fn.npos )
LOG(WARNING) << "Failed to guess the encoding of '" << fn << "'";
- enc = fn.substr(p);
+ enc = fn.substr(p+1);
std::transform(enc.begin(), enc.end(), enc.begin(), ::tolower);
}
status = ReadImageToDatum(root_folder + lines[line_id].first,
diff --git a/tools/device_query.cpp b/tools/device_query.cpp
deleted file mode 100644
index 03799e52..00000000
--- a/tools/device_query.cpp
+++ /dev/null
@@ -1,7 +0,0 @@
-#include "caffe/common.hpp"
-
-int main(int argc, char** argv) {
- LOG(FATAL) << "Deprecated. Use caffe device_query "
- "[--device_id=0] instead.";
- return 0;
-}
diff --git a/tools/finetune_net.cpp b/tools/finetune_net.cpp
deleted file mode 100644
index 81c0c354..00000000
--- a/tools/finetune_net.cpp
+++ /dev/null
@@ -1,7 +0,0 @@
-#include "caffe/caffe.hpp"
-
-int main(int argc, char** argv) {
- LOG(FATAL) << "Deprecated. Use caffe train --solver=... "
- "[--weights=...] instead.";
- return 0;
-}
diff --git a/tools/net_speed_benchmark.cpp b/tools/net_speed_benchmark.cpp
deleted file mode 100644
index cd16e8d0..00000000
--- a/tools/net_speed_benchmark.cpp
+++ /dev/null
@@ -1,7 +0,0 @@
-#include "caffe/caffe.hpp"
-
-int main(int argc, char** argv) {
- LOG(FATAL) << "Deprecated. Use caffe time --model=... "
- "[--iterations=50] [--gpu] [--device_id=0]";
- return 0;
-}
diff --git a/tools/test_net.cpp b/tools/test_net.cpp
deleted file mode 100644
index 92e14eee..00000000
--- a/tools/test_net.cpp
+++ /dev/null
@@ -1,7 +0,0 @@
-#include "caffe/caffe.hpp"
-
-int main(int argc, char** argv) {
- LOG(FATAL) << "Deprecated. Use caffe test --model=... "
- "--weights=... instead.";
- return 0;
-}
diff --git a/tools/train_net.cpp b/tools/train_net.cpp
deleted file mode 100644
index 622bca31..00000000
--- a/tools/train_net.cpp
+++ /dev/null
@@ -1,7 +0,0 @@
-#include "caffe/caffe.hpp"
-
-int main(int argc, char** argv) {
- LOG(FATAL) << "Deprecated. Use caffe train --solver=... "
- "[--snapshot=...] instead.";
- return 0;
-}