summaryrefslogtreecommitdiff
path: root/include/caffe
diff options
context:
space:
mode:
Diffstat (limited to 'include/caffe')
-rw-r--r--include/caffe/blob.hpp73
-rw-r--r--include/caffe/caffe.hpp15
-rw-r--r--include/caffe/common.hpp112
-rw-r--r--include/caffe/filler.hpp147
-rw-r--r--include/caffe/layer.hpp136
-rw-r--r--include/caffe/net.hpp95
-rw-r--r--include/caffe/solver.hpp64
-rw-r--r--include/caffe/syncedmem.hpp35
-rw-r--r--include/caffe/util/im2col.hpp30
-rw-r--r--include/caffe/util/io.hpp46
-rw-r--r--include/caffe/util/math_functions.hpp106
-rw-r--r--include/caffe/vision_layers.hpp405
12 files changed, 1264 insertions, 0 deletions
diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp
new file mode 100644
index 00000000..f31d3b0f
--- /dev/null
+++ b/include/caffe/blob.hpp
@@ -0,0 +1,73 @@
+// Copyright 2013 Yangqing Jia
+
+#ifndef CAFFE_BLOB_HPP_
+#define CAFFE_BLOB_HPP_
+
+#include "caffe/common.hpp"
+#include "caffe/syncedmem.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+namespace caffe {
+
+template <typename Dtype>
+class Blob {
+ public:
+ Blob()
+ : num_(0), channels_(0), height_(0), width_(0), count_(0), data_(),
+ diff_() {}
+ explicit Blob(const int num, const int channels, const int height,
+ const int width);
+ virtual ~Blob() {}
+ void Reshape(const int num, const int height,
+ const int width, const int channels);
+ inline int num() const { return num_; }
+ inline int channels() const { return channels_; }
+ inline int height() const { return height_; }
+ inline int width() const { return width_; }
+ inline int count() const {return count_; }
+ inline int offset(const int n, const int c = 0, const int h = 0,
+ const int w = 0) const {
+ return ((n * channels_ + c) * height_ + h) * width_ + w;
+ }
+ // Copy from source. If copy_diff is false, we copy the data; if copy_diff
+ // is true, we copy the diff.
+ void CopyFrom(const Blob<Dtype>& source, bool copy_diff = false,
+ bool reshape = false);
+
+ inline Dtype data_at(const int n, const int c, const int h,
+ const int w) const {
+ return *(cpu_data() + offset(n, c, h, w));
+ }
+
+ inline Dtype diff_at(const int n, const int c, const int h,
+ const int w) const {
+ return *(cpu_diff() + offset(n, c, h, w));
+ }
+
+ const Dtype* cpu_data() const;
+ const Dtype* gpu_data() const;
+ const Dtype* cpu_diff() const;
+ const Dtype* gpu_diff() const;
+ Dtype* mutable_cpu_data();
+ Dtype* mutable_gpu_data();
+ Dtype* mutable_cpu_diff();
+ Dtype* mutable_gpu_diff();
+ void Update();
+ void FromProto(const BlobProto& proto);
+ void ToProto(BlobProto* proto, bool write_diff = false) const;
+
+ protected:
+ shared_ptr<SyncedMemory> data_;
+ shared_ptr<SyncedMemory> diff_;
+ int num_;
+ int channels_;
+ int height_;
+ int width_;
+ int count_;
+
+ DISABLE_COPY_AND_ASSIGN(Blob);
+}; // class Blob
+
+} // namespace caffe
+
+#endif // CAFFE_BLOB_HPP_
diff --git a/include/caffe/caffe.hpp b/include/caffe/caffe.hpp
new file mode 100644
index 00000000..5806bc02
--- /dev/null
+++ b/include/caffe/caffe.hpp
@@ -0,0 +1,15 @@
+// Copyright Yangqing Jia 2013
+
+#ifndef CAFFE_CAFFE_HPP_
+#define CAFFE_CAFFE_HPP_
+
+#include "caffe/common.hpp"
+#include "caffe/blob.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/net.hpp"
+#include "caffe/vision_layers.hpp"
+
+#include "caffe/proto/caffe.pb.h"
+
+#endif // CAFFE_CAFFE_HPP_
diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp
new file mode 100644
index 00000000..cc6e31b7
--- /dev/null
+++ b/include/caffe/common.hpp
@@ -0,0 +1,112 @@
+// Copyright 2013 Yangqing Jia
+
+#ifndef CAFFE_COMMON_HPP_
+#define CAFFE_COMMON_HPP_
+
+#include <boost/shared_ptr.hpp>
+#include <cublas_v2.h>
+#include <cuda.h>
+#include <curand.h>
+// cuda driver types
+#include <driver_types.h>
+#include <glog/logging.h>
+#include <mkl_vsl.h>
+
+// various checks for different function calls.
+#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
+#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
+#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
+#define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
+
+// After a kernel is executed, this will check the error and if there is one,
+// exit loudly.
+#define CUDA_POST_KERNEL_CHECK \
+ if (cudaSuccess != cudaPeekAtLastError()) \
+ LOG(FATAL) << "Cuda kernel failed. Error: " \
+ << cudaGetErrorString(cudaPeekAtLastError())
+
+// Disable the copy and assignment operator for a class.
+#define DISABLE_COPY_AND_ASSIGN(classname) \
+private:\
+ classname(const classname&);\
+ classname& operator=(const classname&)
+
+// Instantiate a class with float and double specifications.
+#define INSTANTIATE_CLASS(classname) \
+ template class classname<float>; \
+ template class classname<double>
+
+// A simple macro to mark codes that are not implemented, so that when the code
+// is executed we will see a fatal log.
+#define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
+
+
+namespace caffe {
+
+// We will use the boost shared_ptr instead of the new C++11 one mainly
+// because cuda does not work (at least now) well with C++11 features.
+using boost::shared_ptr;
+
+
+// We will use 1024 threads per block, which requires cuda sm_2x or above.
+const int CAFFE_CUDA_NUM_THREADS = 1024;
+
+
+inline int CAFFE_GET_BLOCKS(const int N) {
+ return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
+}
+
+
+// A singleton class to hold common caffe stuff, such as the handler that
+// caffe is going to use for cublas, curand, etc.
+class Caffe {
+ public:
+ ~Caffe();
+ inline static Caffe& Get() {
+ if (!singleton_.get()) {
+ singleton_.reset(new Caffe());
+ }
+ return *singleton_;
+ }
+ enum Brew { CPU, GPU };
+ enum Phase { TRAIN, TEST };
+
+ // The getters for the variables.
+ // Returns the cublas handle.
+ inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; }
+ // Returns the curand generator.
+ inline static curandGenerator_t curand_generator() {
+ return Get().curand_generator_;
+ }
+ // Returns the MKL random stream.
+ inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; }
+ // Returns the mode: running on CPU or GPU.
+ inline static Brew mode() { return Get().mode_; }
+ // Returns the phase: TRAIN or TEST.
+ inline static Phase phase() { return Get().phase_; }
+ // The setters for the variables
+ // Sets the mode.
+ inline static void set_mode(Brew mode) { Get().mode_ = mode; }
+ // Sets the phase.
+ inline static void set_phase(Phase phase) { Get().phase_ = phase; }
+ // Sets the random seed of both MKL and curand
+ static void set_random_seed(const unsigned int seed);
+
+ protected:
+ cublasHandle_t cublas_handle_;
+ curandGenerator_t curand_generator_;
+ VSLStreamStatePtr vsl_stream_;
+ Brew mode_;
+ Phase phase_;
+ static shared_ptr<Caffe> singleton_;
+
+ private:
+ // The private constructor to avoid duplicate instantiation.
+ Caffe();
+
+ DISABLE_COPY_AND_ASSIGN(Caffe);
+};
+
+} // namespace caffe
+
+#endif // CAFFE_COMMON_HPP_
diff --git a/include/caffe/filler.hpp b/include/caffe/filler.hpp
new file mode 100644
index 00000000..effe62ff
--- /dev/null
+++ b/include/caffe/filler.hpp
@@ -0,0 +1,147 @@
+// Copyright 2013 Yangqing Jia
+
+// Fillers are random number generators that fills a blob using the specified
+// algorithm. The expectation is that they are only going to be used during
+// initialization time and will not involve any GPUs.
+
+#ifndef CAFFE_FILLER_HPP
+#define CAFFE_FILLER_HPP
+
+#include <mkl.h>
+#include <string>
+
+#include "caffe/common.hpp"
+#include "caffe/blob.hpp"
+#include "caffe/syncedmem.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+namespace caffe {
+
+template <typename Dtype>
+class Filler {
+ public:
+ explicit Filler(const FillerParameter& param) : filler_param_(param) {}
+ virtual ~Filler() {}
+ virtual void Fill(Blob<Dtype>* blob) = 0;
+ protected:
+ FillerParameter filler_param_;
+}; // class Filler
+
+
+template <typename Dtype>
+class ConstantFiller : public Filler<Dtype> {
+ public:
+ explicit ConstantFiller(const FillerParameter& param)
+ : Filler<Dtype>(param) {}
+ virtual void Fill(Blob<Dtype>* blob) {
+ Dtype* data = blob->mutable_cpu_data();
+ const int count = blob->count();
+ const Dtype value = this->filler_param_.value();
+ CHECK(count);
+ for (int i = 0; i < count; ++i) {
+ data[i] = value;
+ }
+ };
+};
+
+template <typename Dtype>
+class UniformFiller : public Filler<Dtype> {
+ public:
+ explicit UniformFiller(const FillerParameter& param)
+ : Filler<Dtype>(param) {}
+ virtual void Fill(Blob<Dtype>* blob) {
+ CHECK(blob->count());
+ caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
+ Dtype(this->filler_param_.min()),
+ Dtype(this->filler_param_.max()));
+ }
+};
+
+template <typename Dtype>
+class GaussianFiller : public Filler<Dtype> {
+ public:
+ explicit GaussianFiller(const FillerParameter& param)
+ : Filler<Dtype>(param) {}
+ virtual void Fill(Blob<Dtype>* blob) {
+ Dtype* data = blob->mutable_cpu_data();
+ CHECK(blob->count());
+ caffe_vRngGaussian<Dtype>(blob->count(), blob->mutable_cpu_data(),
+ Dtype(this->filler_param_.mean()),
+ Dtype(this->filler_param_.std()));
+ }
+};
+
+template <typename Dtype>
+class PositiveUnitballFiller : public Filler<Dtype> {
+ public:
+ explicit PositiveUnitballFiller(const FillerParameter& param)
+ : Filler<Dtype>(param) {}
+ virtual void Fill(Blob<Dtype>* blob) {
+ Dtype* data = blob->mutable_cpu_data();
+ DCHECK(blob->count());
+ caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(), 0, 1);
+ // We expect the filler to not be called very frequently, so we will
+ // just use a simple implementation
+ int dim = blob->count() / blob->num();
+ CHECK(dim);
+ for (int i = 0; i < blob->num(); ++i) {
+ Dtype sum = 0;
+ for (int j = 0; j < dim; ++j) {
+ sum += data[i * dim + j];
+ }
+ for (int j = 0; j < dim; ++j) {
+ data[i * dim + j] /= sum;
+ }
+ }
+ }
+};
+
+// A filler based on the paper [Bengio and Glorot 2010]: Understanding
+// the difficulty of training deep feedforward neuralnetworks, but does not
+// use the fan_out value.
+//
+// It fills the incoming matrix by randomly sampling uniform data from
+// [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
+// of input nodes. You should make sure the input blob has shape (num, a, b, c)
+// where a * b * c = fan_in.
+template <typename Dtype>
+class XavierFiller : public Filler<Dtype> {
+ public:
+ explicit XavierFiller(const FillerParameter& param)
+ : Filler<Dtype>(param) {}
+ virtual void Fill(Blob<Dtype>* blob) {
+ CHECK(blob->count());
+ int fan_in = blob->count() / blob->num();
+ Dtype scale = sqrt(Dtype(3) / fan_in);
+ caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
+ -scale, scale);
+ }
+};
+
+
+// A function to get a specific filler from the specification given in
+// FillerParameter. Ideally this would be replaced by a factory pattern,
+// but we will leave it this way for now.
+template <typename Dtype>
+Filler<Dtype>* GetFiller(const FillerParameter& param) {
+ const std::string& type = param.type();
+ if (type == "constant") {
+ return new ConstantFiller<Dtype>(param);
+ } else if (type == "gaussian") {
+ return new GaussianFiller<Dtype>(param);
+ } else if (type == "positive_unitball") {
+ return new PositiveUnitballFiller<Dtype>(param);
+ } else if (type == "uniform") {
+ return new UniformFiller<Dtype>(param);
+ } else if (type == "xavier") {
+ return new XavierFiller<Dtype>(param);
+ } else {
+ CHECK(false) << "Unknown filler name: " << param.type();
+ }
+ return (Filler<Dtype>*)(NULL);
+}
+
+} // namespace caffe
+
+#endif // CAFFE_FILLER_HPP_
diff --git a/include/caffe/layer.hpp b/include/caffe/layer.hpp
new file mode 100644
index 00000000..adc63657
--- /dev/null
+++ b/include/caffe/layer.hpp
@@ -0,0 +1,136 @@
+// Copyright 2013 Yangqing Jia
+
+#ifndef CAFFE_LAYER_H_
+#define CAFFE_LAYER_H_
+
+#include <vector>
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+using std::vector;
+
+namespace caffe {
+
+template <typename Dtype>
+class Layer {
+ public:
+ // You should not implement your own constructor. Any set up code should go
+ // to SetUp(), where the dimensions of the bottom blobs are provided to the
+ // layer.
+ explicit Layer(const LayerParameter& param)
+ : layer_param_(param) {
+ // The only thing we do is to copy blobs if there are any.
+ if (layer_param_.blobs_size() > 0) {
+ blobs_.resize(layer_param_.blobs_size());
+ for (int i = 0; i < layer_param_.blobs_size(); ++i) {
+ blobs_[i].reset(new Blob<Dtype>());
+ blobs_[i]->FromProto(layer_param_.blobs(i));
+ }
+ }
+ }
+ virtual ~Layer() {}
+ // SetUp: your function should implement this.
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) = 0;
+
+ // Forward and backward wrappers. You should implement the cpu and
+ // gpu specific implementations instead, and should not change these
+ // functions.
+ inline void Forward(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ inline Dtype Backward(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down,
+ vector<Blob<Dtype>*>* bottom);
+
+ // Returns the vector of blobs.
+ vector<shared_ptr<Blob<Dtype> > >& blobs() {
+ return blobs_;
+ }
+
+ // Returns the layer parameter
+ const LayerParameter& layer_param() { return layer_param_; }
+ // Writes the layer parameter to a protocol buffer
+ virtual void ToProto(LayerParameter* param, bool write_diff = false);
+
+ protected:
+ // The protobuf that stores the layer parameters
+ LayerParameter layer_param_;
+ // The vector that stores the parameters as a set of blobs.
+ vector<shared_ptr<Blob<Dtype> > > blobs_;
+
+ // Forward functions
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) = 0;
+ // If no gpu code is provided, we will simply use cpu code.
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ // LOG(WARNING) << "Using CPU code as backup.";
+ Forward_cpu(bottom, top);
+ };
+
+ // Backward functions: the backward function will compute the gradients for
+ // any parameters and also for the bottom blobs if propagate_down is true.
+ // It will return the loss produced from this layer.
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down,
+ vector<Blob<Dtype>*>* bottom) = 0;
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down,
+ vector<Blob<Dtype>*>* bottom) {
+ // LOG(WARNING) << "Using CPU code as backup.";
+ return Backward_cpu(top, propagate_down, bottom);
+ };
+
+ DISABLE_COPY_AND_ASSIGN(Layer);
+}; // class Layer
+
+// Forward and backward wrappers. You should implement the cpu and
+// gpu specific implementations instead, and should not change these
+// functions.
+template <typename Dtype>
+inline void Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ switch (Caffe::mode()) {
+ case Caffe::CPU:
+ Forward_cpu(bottom, top);
+ break;
+ case Caffe::GPU:
+ Forward_gpu(bottom, top);
+ break;
+ default:
+ LOG(FATAL) << "Unknown caffe mode.";
+ }
+};
+
+template <typename Dtype>
+inline Dtype Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down,
+ vector<Blob<Dtype>*>* bottom) {
+ switch (Caffe::mode()) {
+ case Caffe::CPU:
+ return Backward_cpu(top, propagate_down, bottom);
+ case Caffe::GPU:
+ return Backward_gpu(top, propagate_down, bottom);
+ default:
+ LOG(FATAL) << "Unknown caffe mode.";
+ }
+};
+
+template <typename Dtype>
+void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) {
+ param->Clear();
+ param->CopyFrom(layer_param_);
+ param->clear_blobs();
+ for (int i = 0; i < blobs_.size(); ++i) {
+ blobs_[i]->ToProto(param->add_blobs(), write_diff);
+ }
+}
+
+// The layer factory function
+template <typename Dtype>
+Layer<Dtype>* GetLayer(const LayerParameter& param);
+
+} // namespace caffe
+
+#endif // CAFFE_LAYER_H_
diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp
new file mode 100644
index 00000000..c27442b8
--- /dev/null
+++ b/include/caffe/net.hpp
@@ -0,0 +1,95 @@
+// Copyright 2013 Yangqing Jia
+
+#ifndef CAFFE_NET_HPP_
+#define CAFFE_NET_HPP_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+using std::map;
+using std::vector;
+using std::string;
+
+namespace caffe {
+
+
+template <typename Dtype>
+class Net {
+ public:
+ Net(const NetParameter& param,
+ const vector<Blob<Dtype>* >& bottom);
+ ~Net() {}
+ const vector<Blob<Dtype>*>& Forward(const vector<Blob<Dtype>* > & bottom);
+ // The network backward should take no input and output, since it solely
+ // computes the gradient w.r.t the parameters, and the data has already
+ // been provided during the forward pass.
+ Dtype Backward();
+
+ Dtype ForwardBackward(const vector<Blob<Dtype>* > & bottom) {
+ Forward(bottom);
+ return Backward();
+ }
+
+ // For an already initialized net, CopyTrainedLayersFrom() copies the already
+ // trained layers from another net parameter instance.
+ void CopyTrainedLayersFrom(const NetParameter& param);
+ // Writes the net to a proto.
+ void ToProto(NetParameter* param, bool write_diff = false);
+
+ // returns the network name.
+ inline const string& name() { return name_; }
+ // returns the layer names
+ inline const vector<string>& layer_names() { return layer_names_; }
+ // returns the blob names
+ inline const vector<string>& blob_names() { return blob_names_; }
+ // returns the blobs
+ inline const vector<shared_ptr<Blob<Dtype> > >& blobs() { return blobs_; }
+ // returns the layers
+ inline const vector<shared_ptr<Layer<Dtype> > >& layers() { return layers_; }
+ // returns the bottom and top vecs for each layer - usually you won't need
+ // this unless you do per-layer checks such as gradients.
+ inline vector<vector<Blob<Dtype>*> >& bottom_vecs() { return bottom_vecs_; }
+ inline vector<vector<Blob<Dtype>*> >& top_vecs() { return top_vecs_; }
+ // returns the parameters
+ inline vector<shared_ptr<Blob<Dtype> > >& params() { return params_; }
+ // returns the parameter learning rate multipliers
+ inline vector<float>& params_lr() {return params_lr_; }
+ // Updates the network
+ void Update();
+
+ protected:
+ // Individual layers in the net
+ vector<shared_ptr<Layer<Dtype> > > layers_;
+ vector<string> layer_names_;
+ // blobs stores the blobs that store intermediate results between the
+ // layers.
+ vector<shared_ptr<Blob<Dtype> > > blobs_;
+ vector<string> blob_names_;
+ // bottom_vecs stores the vectors containing the input for each layer
+ vector<vector<Blob<Dtype>*> > bottom_vecs_;
+ vector<vector<int> > bottom_id_vecs_;
+ // top_vecs stores the vectors containing the output for each layer
+ vector<vector<Blob<Dtype>*> > top_vecs_;
+ vector<vector<int> > top_id_vecs_;
+ // blob indices for the input and the output of the net.
+ vector<int> net_input_blob_indices_;
+ vector<int> net_output_blob_indices_;
+ vector<Blob<Dtype>*> net_output_blobs_;
+ string name_;
+ // The parameters in the network.
+ vector<shared_ptr<Blob<Dtype> > > params_;
+ // the learning rate multipliers
+ vector<float> params_lr_;
+ DISABLE_COPY_AND_ASSIGN(Net);
+};
+
+
+} // namespace caffe
+
+#endif // CAFFE_NET_HPP_
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
new file mode 100644
index 00000000..98c872dc
--- /dev/null
+++ b/include/caffe/solver.hpp
@@ -0,0 +1,64 @@
+// Copyright Yangqing Jia 2013
+
+#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
+#define CAFFE_OPTIMIZATION_SOLVER_HPP_
+
+#include <vector>
+
+namespace caffe {
+
+template <typename Dtype>
+class Solver {
+ public:
+ explicit Solver(const SolverParameter& param)
+ : param_(param) {}
+ // The main entry of the solver function. In default, iter will be zero. Pass
+ // in a non-zero iter number to resume training for a pre-trained net.
+ void Solve(Net<Dtype>* net, char* state_file = NULL);
+ virtual ~Solver() {}
+
+ protected:
+ // PreSolve is run before any solving iteration starts, allowing one to
+ // put up some scaffold.
+ virtual void PreSolve() {}
+ // Get the update value for the current iteration.
+ virtual void ComputeUpdateValue() = 0;
+ // The Solver::Snapshot function implements the basic snapshotting utility
+ // that stores the learned net. You should implement the SnapshotSolverState()
+ // function that produces a SolverState protocol buffer that needs to be
+ // written to disk together with the learned net.
+ void Snapshot();
+ virtual void SnapshotSolverState(SolverState* state) = 0;
+ // The Restore function implements how one should restore the solver to a
+ // previously snapshotted state. You should implement the RestoreSolverState()
+ // function that restores the state from a SolverState protocol buffer.
+ void Restore(char* state_file);
+ virtual void RestoreSolverState(const SolverState& state) = 0;
+ SolverParameter param_;
+ int iter_;
+ Net<Dtype>* net_;
+
+ DISABLE_COPY_AND_ASSIGN(Solver);
+};
+
+
+template <typename Dtype>
+class SGDSolver : public Solver<Dtype> {
+ public:
+ explicit SGDSolver(const SolverParameter& param)
+ : Solver<Dtype>(param) {}
+
+ protected:
+ virtual void PreSolve();
+ virtual Dtype GetLearningRate();
+ virtual void ComputeUpdateValue();
+ virtual void SnapshotSolverState(SolverState * state);
+ virtual void RestoreSolverState(const SolverState& state);
+ // history maintains the historical momentum data.
+ vector<shared_ptr<Blob<Dtype> > > history_;
+};
+
+
+} // namspace caffe
+
+#endif // CAFFE_OPTIMIZATION_SOLVER_HPP_
diff --git a/include/caffe/syncedmem.hpp b/include/caffe/syncedmem.hpp
new file mode 100644
index 00000000..862512f9
--- /dev/null
+++ b/include/caffe/syncedmem.hpp
@@ -0,0 +1,35 @@
+// Copyright 2013 Yangqing Jia
+
+#ifndef CAFFE_SYNCEDMEM_HPP_
+#define CAFFE_SYNCEDMEM_HPP_
+
+namespace caffe {
+
+class SyncedMemory {
+ public:
+ SyncedMemory()
+ : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED) {}
+ explicit SyncedMemory(size_t size)
+ : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED) {}
+ ~SyncedMemory();
+ const void* cpu_data();
+ const void* gpu_data();
+ void* mutable_cpu_data();
+ void* mutable_gpu_data();
+ enum SyncedHead { UNINITIALIZED, HEAD_AT_CPU, HEAD_AT_GPU, SYNCED };
+ SyncedHead head() { return head_; }
+ size_t size() { return size_; }
+ private:
+ void to_cpu();
+ void to_gpu();
+ void* cpu_ptr_;
+ void* gpu_ptr_;
+ size_t size_;
+ SyncedHead head_;
+
+ DISABLE_COPY_AND_ASSIGN(SyncedMemory);
+}; // class SyncedMemory
+
+} // namespace caffe
+
+#endif // CAFFE_SYNCEDMEM_HPP_
diff --git a/include/caffe/util/im2col.hpp b/include/caffe/util/im2col.hpp
new file mode 100644
index 00000000..83c01dda
--- /dev/null
+++ b/include/caffe/util/im2col.hpp
@@ -0,0 +1,30 @@
+// Copyright 2013 Yangqing Jia
+
+#ifndef _CAFFE_UTIL_IM2COL_HPP_
+#define _CAFFE_UTIL_IM2COL_HPP_
+
+namespace caffe {
+
+template <typename Dtype>
+void im2col_cpu(const Dtype* data_im, const int channels,
+ const int height, const int width, const int ksize, const int stride,
+ Dtype* data_col);
+
+template <typename Dtype>
+void col2im_cpu(const Dtype* data_col, const int channels,
+ const int height, const int width, const int psize, const int stride,
+ Dtype* data_im);
+
+template <typename Dtype>
+void im2col_gpu(const Dtype* data_im, const int channels,
+ const int height, const int width, const int ksize, const int stride,
+ Dtype* data_col);
+
+template <typename Dtype>
+void col2im_gpu(const Dtype* data_col, const int channels,
+ const int height, const int width, const int psize, const int stride,
+ Dtype* data_im);
+
+} // namespace caffe
+
+#endif // CAFFE_UTIL_IM2COL_HPP_
diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp
new file mode 100644
index 00000000..0dce4e7e
--- /dev/null
+++ b/include/caffe/util/io.hpp
@@ -0,0 +1,46 @@
+// Copyright Yangqing Jia 2013
+
+#ifndef CAFFE_UTIL_IO_H_
+#define CAFFE_UTIL_IO_H_
+
+#include <google/protobuf/message.h>
+
+#include <string>
+
+#include "caffe/blob.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+using std::string;
+using ::google::protobuf::Message;
+
+namespace caffe {
+
+void ReadProtoFromTextFile(const char* filename,
+ Message* proto);
+inline void ReadProtoFromTextFile(const string& filename,
+ Message* proto) {
+ ReadProtoFromTextFile(filename.c_str(), proto);
+}
+
+void WriteProtoToTextFile(const Message& proto, const char* filename);
+inline void WriteProtoToTextFile(const Message& proto, const string& filename) {
+ WriteProtoToTextFile(proto, filename.c_str());
+}
+
+void ReadProtoFromBinaryFile(const char* filename,
+ Message* proto);
+inline void ReadProtoFromBinaryFile(const string& filename,
+ Message* proto) {
+ ReadProtoFromBinaryFile(filename.c_str(), proto);
+}
+
+void WriteProtoToBinaryFile(const Message& proto, const char* filename);
+inline void WriteProtoToBinaryFile(
+ const Message& proto, const string& filename) {
+ WriteProtoToBinaryFile(proto, filename.c_str());
+}
+
+
+} // namespace caffe
+
+#endif // CAFFE_UTIL_IO_H_
diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp
new file mode 100644
index 00000000..e9e2db8f
--- /dev/null
+++ b/include/caffe/util/math_functions.hpp
@@ -0,0 +1,106 @@
+// Copyright 2013 Yangqing Jia
+
+#ifndef CAFFE_UTIL_MATH_FUNCTIONS_H_
+#define CAFFE_UTIL_MATH_FUNCTIONS_H_
+
+#include <mkl.h>
+#include <cublas_v2.h>
+
+namespace caffe {
+
+// Decaf gemm provides a simpler interface to the gemm functions, with the
+// limitation that the data has to be contiguous in memory.
+template <typename Dtype>
+void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
+ const Dtype alpha, const Dtype* A, const Dtype* B, const Dtype beta,
+ Dtype* C);
+
+// Decaf gpu gemm provides an interface that is almost the same as the cpu
+// gemm function - following the c convention and calling the fortran-order
+// gpu code under the hood.
+template <typename Dtype>
+void caffe_gpu_gemm(const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
+ const Dtype alpha, const Dtype* A, const Dtype* B, const Dtype beta,
+ Dtype* C);
+
+template <typename Dtype>
+void caffe_cpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const Dtype alpha, const Dtype* A, const Dtype* x, const Dtype beta,
+ Dtype* y);
+
+template <typename Dtype>
+void caffe_gpu_gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
+ const Dtype alpha, const Dtype* A, const Dtype* x, const Dtype beta,
+ Dtype* y);
+
+template <typename Dtype>
+void caffe_axpy(const int N, const Dtype alpha, const Dtype* X,
+ Dtype* Y);
+
+template <typename Dtype>
+void caffe_gpu_axpy(const int N, const Dtype alpha, const Dtype* X,
+ Dtype* Y);
+
+template <typename Dtype>
+void caffe_axpby(const int N, const Dtype alpha, const Dtype* X,
+ const Dtype beta, Dtype* Y);
+
+template <typename Dtype>
+void caffe_gpu_axpby(const int N, const Dtype alpha, const Dtype* X,
+ const Dtype beta, Dtype* Y);
+
+template <typename Dtype>
+void caffe_copy(const int N, const Dtype *X, Dtype *Y);
+
+template <typename Dtype>
+void caffe_gpu_copy(const int N, const Dtype *X, Dtype *Y);
+
+template <typename Dtype>
+void caffe_scal(const int N, const Dtype alpha, Dtype *X);
+
+template <typename Dtype>
+void caffe_gpu_scal(const int N, const Dtype alpha, Dtype *X);
+
+template <typename Dtype>
+void caffe_sqr(const int N, const Dtype* a, Dtype* y);
+
+template <typename Dtype>
+void caffe_add(const int N, const Dtype* a, const Dtype* b, Dtype* y);
+
+template <typename Dtype>
+void caffe_sub(const int N, const Dtype* a, const Dtype* b, Dtype* y);
+
+template <typename Dtype>
+void caffe_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
+
+template <typename Dtype>
+void caffe_gpu_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
+
+template <typename Dtype>
+void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
+
+template <typename Dtype>
+void caffe_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);
+
+template <typename Dtype>
+void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b);
+
+template <typename Dtype>
+void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
+ const Dtype sigma);
+
+template <typename Dtype>
+void caffe_exp(const int n, const Dtype* a, Dtype* y);
+
+template <typename Dtype>
+Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y);
+
+template <typename Dtype>
+void caffe_gpu_dot(const int n, const Dtype* x, const Dtype* y, Dtype* out);
+
+} // namespace caffe
+
+
+#endif // CAFFE_UTIL_MATH_FUNCTIONS_H_
diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp
new file mode 100644
index 00000000..0dc34763
--- /dev/null
+++ b/include/caffe/vision_layers.hpp
@@ -0,0 +1,405 @@
+// Copyright 2013 Yangqing Jia
+
+#ifndef CAFFE_VISION_LAYERS_HPP_
+#define CAFFE_VISION_LAYERS_HPP_
+
+#include <leveldb/db.h>
+#include <pthread.h>
+
+#include <vector>
+
+#include "caffe/layer.hpp"
+
+namespace caffe {
+
+
+// The neuron layer is a specific type of layers that just works on single
+// celements.
+template <typename Dtype>
+class NeuronLayer : public Layer<Dtype> {
+ public:
+ explicit NeuronLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+};
+
+
+template <typename Dtype>
+class ReLULayer : public NeuronLayer<Dtype> {
+ public:
+ explicit ReLULayer(const LayerParameter& param)
+ : NeuronLayer<Dtype>(param) {}
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+};
+
+
+template <typename Dtype>
+class DropoutLayer : public NeuronLayer<Dtype> {
+ public:
+ explicit DropoutLayer(const LayerParameter& param)
+ : NeuronLayer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ shared_ptr<SyncedMemory> rand_vec_;
+ float threshold_;
+ float scale_;
+ unsigned int uint_thres_;
+};
+
+
+template <typename Dtype>
+class InnerProductLayer : public Layer<Dtype> {
+ public:
+ explicit InnerProductLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ int M_;
+ int K_;
+ int N_;
+ bool biasterm_;
+ shared_ptr<SyncedMemory> bias_multiplier_;
+};
+
+
+template <typename Dtype>
+class PaddingLayer : public Layer<Dtype> {
+ public:
+ explicit PaddingLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ unsigned int PAD_;
+ int NUM_;
+ int CHANNEL_;
+ int HEIGHT_IN_;
+ int WIDTH_IN_;
+ int HEIGHT_OUT_;
+ int WIDTH_OUT_;
+};
+
+
+template <typename Dtype>
+class LRNLayer : public Layer<Dtype> {
+ public:
+ explicit LRNLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ // scale_ stores the intermediate summing results
+ Blob<Dtype> scale_;
+ int size_;
+ int pre_pad_;
+ Dtype alpha_;
+ Dtype beta_;
+ int num_;
+ int channels_;
+ int height_;
+ int width_;
+};
+
+
+template <typename Dtype>
+class Im2colLayer : public Layer<Dtype> {
+ public:
+ explicit Im2colLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ int KSIZE_;
+ int STRIDE_;
+ int CHANNELS_;
+ int HEIGHT_;
+ int WIDTH_;
+};
+
+
+template <typename Dtype>
+class PoolingLayer : public Layer<Dtype> {
+ public:
+ explicit PoolingLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ int KSIZE_;
+ int STRIDE_;
+ int CHANNELS_;
+ int HEIGHT_;
+ int WIDTH_;
+ int POOLED_HEIGHT_;
+ int POOLED_WIDTH_;
+};
+
+
+template <typename Dtype>
+class ConvolutionLayer : public Layer<Dtype> {
+ public:
+ explicit ConvolutionLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ Blob<Dtype> col_bob_;
+
+ int KSIZE_;
+ int STRIDE_;
+ int NUM_;
+ int CHANNELS_;
+ int HEIGHT_;
+ int WIDTH_;
+ int NUM_OUTPUT_;
+ int GROUP_;
+ Blob<Dtype> col_buffer_;
+ shared_ptr<SyncedMemory> bias_multiplier_;
+ bool biasterm_;
+ int M_;
+ int K_;
+ int N_;
+};
+
+
+// This function is used to create a pthread that prefetches the data.
+template <typename Dtype>
+void* DataLayerPrefetch(void* layer_pointer);
+
+template <typename Dtype>
+class DataLayer : public Layer<Dtype> {
+ // The function used to perform prefetching.
+ friend void* DataLayerPrefetch<Dtype>(void* layer_pointer);
+
+ public:
+ explicit DataLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+ shared_ptr<leveldb::DB> db_;
+ shared_ptr<leveldb::Iterator> iter_;
+ int datum_channels_;
+ int datum_height_;
+ int datum_width_;
+ int datum_size_;
+ pthread_t thread_;
+ shared_ptr<Blob<Dtype> > prefetch_data_;
+ shared_ptr<Blob<Dtype> > prefetch_label_;
+ Blob<Dtype> data_mean_;
+};
+
+
+template <typename Dtype>
+class SoftmaxLayer : public Layer<Dtype> {
+ public:
+ explicit SoftmaxLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+ // sum_multiplier is just used to carry out sum using blas
+ Blob<Dtype> sum_multiplier_;
+ // scale is an intermediate blob to hold temporary results.
+ Blob<Dtype> scale_;
+};
+
+
+template <typename Dtype>
+class MultinomialLogisticLossLayer : public Layer<Dtype> {
+ public:
+ explicit MultinomialLogisticLossLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ // The loss layer will do nothing during forward - all computation are
+ // carried out in the backward pass.
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) { return; }
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) { return; }
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ // const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+};
+
+
+// SoftmaxWithLossLayer is a layer that implements softmax and then computes
+// the loss - it is preferred over softmax + multinomiallogisticloss in the
+// sense that during training, this will produce more numerically stable
+// gradients. During testing this layer could be replaced by a softmax layer
+// to generate probability outputs.
+template <typename Dtype>
+class SoftmaxWithLossLayer : public Layer<Dtype> {
+ public:
+ explicit SoftmaxWithLossLayer(const LayerParameter& param)
+ : Layer<Dtype>(param), softmax_layer_(new SoftmaxLayer<Dtype>(param)) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+ shared_ptr<SoftmaxLayer<Dtype> > softmax_layer_;
+ // prob stores the output probability of the layer.
+ Blob<Dtype> prob_;
+ // Vector holders to call the underlying softmax layer forward and backward.
+ vector<Blob<Dtype>*> softmax_bottom_vec_;
+ vector<Blob<Dtype>*> softmax_top_vec_;
+};
+
+
+template <typename Dtype>
+class EuclideanLossLayer : public Layer<Dtype> {
+ public:
+ explicit EuclideanLossLayer(const LayerParameter& param)
+ : Layer<Dtype>(param), difference_() {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ // The loss layer will do nothing during forward - all computation are
+ // carried out in the backward pass.
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) { return; }
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) { return; }
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ // const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ Blob<Dtype> difference_;
+};
+
+
+template <typename Dtype>
+class AccuracyLayer : public Layer<Dtype> {
+ public:
+ explicit AccuracyLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ // The accuracy layer should not be used to compute backward operations.
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+ NOT_IMPLEMENTED;
+ return Dtype(0.);
+ }
+};
+
+} // namespace caffe
+
+#endif // CAFFE_VISION_LAYERS_HPP_
+