diff options
author | honggui <hongguiyao@msn.com> | 2017-08-26 12:06:12 +0800 |
---|---|---|
committer | honggui <hongguiyao@msn.com> | 2017-08-26 12:06:12 +0800 |
commit | 25e0cec114d641fd15356edb9e18a4403a0c7309 (patch) | |
tree | e860d5baf6ecd4ea15a0a9116f25fcd3498c004c /include | |
parent | a9d213eb6936d92a493a45a0f9640da723996b03 (diff) | |
download | caffeonacl-25e0cec114d641fd15356edb9e18a4403a0c7309.tar.gz caffeonacl-25e0cec114d641fd15356edb9e18a4403a0c7309.tar.bz2 caffeonacl-25e0cec114d641fd15356edb9e18a4403a0c7309.zip |
add support acl batch normal,direct conv, local connect, concat layers
Diffstat (limited to 'include')
-rw-r--r-- | include/caffe/acl_layer.hpp | 56 | ||||
-rw-r--r-- | include/caffe/layers/acl_base_conv_layer.hpp | 61 | ||||
-rw-r--r-- | include/caffe/layers/acl_batch_norm_layer.hpp | 54 | ||||
-rw-r--r-- | include/caffe/layers/acl_concat_layer.hpp | 57 | ||||
-rw-r--r-- | include/caffe/layers/acl_conv_layer.hpp | 70 | ||||
-rw-r--r-- | include/caffe/layers/acl_local_connect_layer.hpp | 56 | ||||
-rw-r--r-- | include/caffe/layers/local_connect_layer.hpp | 59 |
7 files changed, 362 insertions, 51 deletions
diff --git a/include/caffe/acl_layer.hpp b/include/caffe/acl_layer.hpp index db9fee5c..b188bb8c 100644 --- a/include/caffe/acl_layer.hpp +++ b/include/caffe/acl_layer.hpp @@ -3,6 +3,7 @@ #ifdef USE_ACL #include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h" +#include "arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h" #include "arm_compute/runtime/CL/functions/CLConvolutionLayer.h" #include "arm_compute/runtime/NEON/functions/NEActivationLayer.h" #include "arm_compute/runtime/CL/functions/CLActivationLayer.h" @@ -14,6 +15,14 @@ #include "arm_compute/runtime/CL/functions/CLSoftmaxLayer.h" #include "arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h" #include "arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h" +#include "arm_compute/runtime/NEON/functions/NELocallyConnectedLayer.h" +#include "arm_compute/runtime/CL/functions/CLLocallyConnectedLayer.h" +#include "arm_compute/runtime/NEON/functions/NEBatchNormalizationLayer.h" +#include "arm_compute/runtime/CL/functions/CLBatchNormalizationLayer.h" +#include "arm_compute/core/NEON/kernels/NEDepthConcatenateKernel.h" +#include "arm_compute/runtime/NEON/functions/NEDepthConcatenate.h" +#include "arm_compute/core/CL/kernels/CLDepthConcatenateKernel.h" +#include "arm_compute/runtime/CL/functions/CLDepthConcatenate.h" #include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/CL/CLScheduler.h" @@ -28,6 +37,9 @@ using namespace arm_compute; #define FLAGS_ENABLE_ACL_SIGMOID 0x00000080 #define FLAGS_ENABLE_ACL_SOFTMAX 0x00000100 #define FLAGS_ENABLE_ACL_TANH 0x00000200 +#define FLAGS_ENABLE_ACL_LC 0x00000400 +#define FLAGS_ENABLE_ACL_BN 0x00000800 +#define FLAGS_ENABLE_ACL_CONCAT 0x00001000 extern unsigned int bypass_acl_class_layer; #endif #ifdef USE_PROFILING @@ -48,6 +60,9 @@ extern unsigned int bypass_acl_class_layer; #define MASK_LOG_SIGMOID 0x00001000 #define MASK_LOG_SOFTMAX 0x00002000 #define MASK_LOG_TANH 0x00004000 +#define MASK_LOG_LC 0x00008000 +#define MASK_LOG_BN 0x00010000 +#define MASK_LOG_CONCAT 0x00020000 #define APP_TIME_INFO MASK_LOG_APP_TIME,"time: \t" #define ACL_ALLOCATE_INFO MASK_LOG_ALLOCATE,"allocate: \t\t" #define ACL_RUN_INFO MASK_LOG_RUN, "run: \t\t\t" @@ -63,6 +78,9 @@ extern unsigned int bypass_acl_class_layer; #define ACL_SIGMOID_INFO MASK_LOG_SIGMOID, "ACL_SIGMOID:\t\t\t\t\t\t\t\t\t\t\t\t\t" #define ACL_SOFTMAX_INFO MASK_LOG_SOFTMAX, "ACL_SOFTMAX:\t\t\t\t\t\t\t\t\t\t\t\t\t\t" #define ACL_TANH_INFO MASK_LOG_TANH, "ACL_TANH :\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t" +#define ACL_LC_INFO MASK_LOG_LC, "ACL_LC :\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t" +#define ACL_BN_INFO MASK_LOG_BN, "ACL_BN :\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t" +#define ACL_CONCAT_INFO MASK_LOG_CONCAT, "ACL_CONCAT :\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t" extern unsigned int acl_log_flags; #endif //USE_PROFILING namespace caffe { @@ -72,6 +90,7 @@ enum TensorType{ tensor_output, tensor_weights, tensor_biases, + tensor_data, }; template <typename ACLTensor> class BaseTensor:public ACLTensor{ @@ -88,7 +107,7 @@ public: }; virtual void map(bool blocking = true){} virtual void unmap(){} - virtual void commit(); + virtual void commit(TensorType type=tensor_data); int tensor_copy(void * mem, bool toTensor=true); protected: void* mem_; @@ -129,20 +148,16 @@ class ACLXPUBaseLayer{ public: virtual void commit(){ if (input) { - input->settensortype(tensor_input); - input->commit(); + input->commit(tensor_input); } if (output){ - output->settensortype(tensor_output); - output->commit(); + output->commit(tensor_output); } if (weights){ - weights->settensortype(tensor_weights); - weights->commit(); + weights->commit(tensor_weights); } if (biases){ - biases->settensortype(tensor_biases); - biases->commit(); + biases->commit(tensor_biases); } } virtual void run(bool gpu){ @@ -163,6 +178,10 @@ public: output=nullptr; weights=nullptr; biases=nullptr; + mean=nullptr; + var=nullptr; + beta=nullptr; + gamma=nullptr; #ifdef USE_CONV_CACHE for(int i = 0; i < 16; ++i){ cache.layer[i] = nullptr; @@ -180,12 +199,20 @@ public: if (output) delete output; if (weights) delete weights; if (biases) delete biases; + if (mean) delete mean; + if (var) delete var; + if (beta) delete beta; + if (gamma) delete gamma; #endif //USE_CONV_CACHE layer=nullptr; input=nullptr; output=nullptr; weights=nullptr; biases=nullptr; + mean=nullptr; + var=nullptr; + beta=nullptr; + gamma=nullptr; } virtual ~ACLXPUBaseLayer(){ freelayer(); @@ -195,6 +222,11 @@ public: ACLTensor *output; ACLTensor *weights; ACLTensor *biases; + //for BN + ACLTensor *mean; + ACLTensor *var; + ACLTensor *beta; + ACLTensor *gamma; #ifdef USE_CONV_CACHE struct{ ACLLayer *layer[16]; @@ -223,7 +255,7 @@ public: bool checkreshape(TensorShape shape,bool gpu=false, TensorType type=tensor_input); template <typename ACLTensor> bool tensor_mem(ACLTensor *tensor,void *mem,bool share=false); template <typename ACLTensor> bool tensor_mem(void *mem,ACLTensor *tensor,bool share=false); - template <typename ACLTensor> ACLTensor * new_tensor(TensorShape shape,void *mem=nullptr,bool share=false); + template <typename ACLTensor> bool new_tensor(ACLTensor *&tensor,TensorShape shape,void *mem=nullptr,bool share=false); protected: ACLXPUBaseLayer<GPULayer,GPUTensor> gpu_; ACLXPUBaseLayer<CPULayer,CPUTensor> cpu_; @@ -238,9 +270,9 @@ protected: template class ACLBaseLayer<GPULayer,CPULayer>; #define INSTANTIATE_ACLBASE_FUNCTION(GPULayer,CPULayer,ACLTensor) \ - template bool ACLBaseLayer<GPULayer,CPULayer>::tensor_mem<ACLTensor>(ACLTensor *tensor,void *mem,bool share); \ + template bool ACLBaseLayer<GPULayer,CPULayer>::tensor_mem(ACLTensor *tensor,void *mem,bool share); \ template bool ACLBaseLayer<GPULayer,CPULayer>::tensor_mem(void *mem,ACLTensor *tensor,bool share); \ - template ACLTensor * ACLBaseLayer<GPULayer,CPULayer>::new_tensor(TensorShape shape,void *mem,bool share); \ + template bool ACLBaseLayer<GPULayer,CPULayer>::new_tensor(ACLTensor *&tensor,TensorShape shape,void *mem,bool share); \ #endif diff --git a/include/caffe/layers/acl_base_conv_layer.hpp b/include/caffe/layers/acl_base_conv_layer.hpp new file mode 100644 index 00000000..6b38eb28 --- /dev/null +++ b/include/caffe/layers/acl_base_conv_layer.hpp @@ -0,0 +1,61 @@ +#ifndef CAFFE_ACL_BASE_CONV_LAYER_HPP_ +#define CAFFE_ACL_BASE_CONV_LAYER_HPP_ + +#include <vector> + +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +#include "caffe/layers/conv_layer.hpp" + +#ifdef USE_ACL +#include "caffe/acl_layer.hpp" +#endif + +namespace caffe { + +#ifdef USE_ACL +/* + * @brief ACL implementation of ConvolutionLayer. + * Fallback to ConvolutionLayer for some corner cases. + * +*/ +template <typename Dtype,typename GPUConvLayer,typename CPUConvLayer> +class ACLConvolutionLayer : public ACLBaseLayer<GPUConvLayer,CPUConvLayer>,public ConvolutionLayer<Dtype> { + public: + explicit ACLConvolutionLayer(const LayerParameter& param) + : ConvolutionLayer<Dtype>(param) {} + virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Reshape(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual ~ACLConvolutionLayer(); + + protected: + virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){ + NOT_IMPLEMENTED; + } + virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){ + NOT_IMPLEMENTED; + } + virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + +}; +#endif + +} // namespace caffe + +// Instantiate a class with float and double specifications. +#define INSTANTIATE_CONV_CLASS(classname,GPUConvLayer,CPUConvLayer) \ + template class classname<float,GPUConvLayer,CPUConvLayer>; \ + template class classname<double,GPUConvLayer,CPUConvLayer> + +#endif // CAFFE_ACL_BASE_CONV_LAYER_HPP_ diff --git a/include/caffe/layers/acl_batch_norm_layer.hpp b/include/caffe/layers/acl_batch_norm_layer.hpp new file mode 100644 index 00000000..e899804f --- /dev/null +++ b/include/caffe/layers/acl_batch_norm_layer.hpp @@ -0,0 +1,54 @@ +#ifndef CAFFE_ACL_BATCH_NORMAL_HPP_ +#define CAFFE_ACL_BATCH_NORMAL_HPP_ + +#include <vector> + +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +#include "caffe/layers/batch_norm_layer.hpp" + +#ifdef USE_ACL +#include "caffe/acl_layer.hpp" +#endif + +namespace caffe { + +#ifdef USE_ACL +/* + * @brief ACL implementation of BatchNormLayer. + * Fallback to BatchNormLayer for some corner cases. +*/ +template <typename Dtype> +class ACLBatchNormLayer : public ACLBaseLayer<CLBatchNormalizationLayer,NEBatchNormalizationLayer>,public BatchNormLayer<Dtype> { + public: + explicit ACLBatchNormLayer(const LayerParameter& param) + : BatchNormLayer<Dtype>(param) {} + virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Reshape(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual ~ACLBatchNormLayer(); + + protected: + virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){ + NOT_IMPLEMENTED; + } + virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){ + NOT_IMPLEMENTED; + } + virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); +}; +#endif + +} // namespace caffe + +#endif // CAFFE_ACL_BATCH_NORMAL_HPP_ diff --git a/include/caffe/layers/acl_concat_layer.hpp b/include/caffe/layers/acl_concat_layer.hpp new file mode 100644 index 00000000..90212192 --- /dev/null +++ b/include/caffe/layers/acl_concat_layer.hpp @@ -0,0 +1,57 @@ +#ifndef CAFFE_ACL_CONCAT_LAYER_HPP_ +#define CAFFE_ACL_CONCAT_LAYER_HPP_ + +#include <vector> + +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +#include "caffe/layers/concat_layer.hpp" + +#ifdef USE_ACL +#include "caffe/acl_layer.hpp" +#endif + +namespace caffe { + +#ifdef USE_ACL +/* + * @brief ACL implementation of ConcatLayer. + * Fallback to ConcatLayer for some corner cases. +*/ +template <typename Dtype> +class ACLConcatLayer : public ACLBaseLayer<CLDepthConcatenate,NEDepthConcatenate>,public ConcatLayer<Dtype> { + public: + explicit ACLConcatLayer(const LayerParameter& param) + : ConcatLayer<Dtype>(param) {} + virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Reshape(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual ~ACLConcatLayer(); + + protected: + virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){ + NOT_IMPLEMENTED; + } + virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){ + NOT_IMPLEMENTED; + } + virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + private: + std::vector<ITensor *> cpu_vectors; + std::vector<ICLTensor *> gpu_vectors; +}; +#endif + +} // namespace caffe + +#endif // CAFFE_ACL_CONCAT_LAYER_HPP_ diff --git a/include/caffe/layers/acl_conv_layer.hpp b/include/caffe/layers/acl_conv_layer.hpp index b4a75848..2fd795c9 100644 --- a/include/caffe/layers/acl_conv_layer.hpp +++ b/include/caffe/layers/acl_conv_layer.hpp @@ -1,54 +1,46 @@ #ifndef CAFFE_ACL_CONV_LAYER_HPP_ #define CAFFE_ACL_CONV_LAYER_HPP_ -#include <vector> - -#include "caffe/blob.hpp" -#include "caffe/layer.hpp" -#include "caffe/proto/caffe.pb.h" - -#include "caffe/layers/conv_layer.hpp" - #ifdef USE_ACL -#include "caffe/acl_layer.hpp" +#include "caffe/layers/acl_base_conv_layer.hpp" #endif namespace caffe { +extern bool use_direct_conv_; #ifdef USE_ACL -/* - * @brief ACL implementation of ConvolutionLayer. - * Fallback to ConvolutionLayer for some corner cases. - * -*/ template <typename Dtype> -class ACLConvolutionLayer : public ACLBaseLayer<CLConvolutionLayer,NEConvolutionLayer>,public ConvolutionLayer<Dtype> { - public: - explicit ACLConvolutionLayer(const LayerParameter& param) - : ConvolutionLayer<Dtype>(param) {} - virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, - const vector<Blob<Dtype>*>& top); - virtual void Reshape(const vector<Blob<Dtype>*>& bottom, - const vector<Blob<Dtype>*>& top); - virtual ~ACLConvolutionLayer(); - - protected: - virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, - const vector<Blob<Dtype>*>& top); - virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, - const vector<Blob<Dtype>*>& top); - virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, - const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){ - NOT_IMPLEMENTED; +inline shared_ptr<Layer<Dtype> > GetACLConvolutionLayer( + const LayerParameter& param) { + ConvolutionParameter conv_param = param.convolution_param(); + const char* pDirectConv; + pDirectConv = getenv ("DIRECTCONV"); + if (pDirectConv){ + unsigned int bdirectconv; + sscanf(pDirectConv,"%i", &bdirectconv); + if(bdirectconv != use_direct_conv_){ + use_direct_conv_ = bdirectconv; + printf("DIRECTCONV<%s>\n", pDirectConv); + printf("DIRECTCONV: %x\n", use_direct_conv_); } - virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, - const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){ - NOT_IMPLEMENTED; + } + int pad_data[3]; + if (conv_param.has_pad_h() || conv_param.has_pad_w()) { + pad_data[0] = conv_param.pad_h(); + pad_data[1] = conv_param.pad_w(); + } else { + const int kDefaultPad = 0; + const int num_pad_dims = conv_param.pad_size(); + for (int i = 0; i < 2; ++i) { + pad_data[i] = (num_pad_dims == 0) ? kDefaultPad : + conv_param.pad((num_pad_dims == 1) ? 0 : i); } - virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom, - const vector<Blob<Dtype>*>& top); - -}; + } + if (use_direct_conv_ && ( (conv_param.kernel_size(0)==1 &&pad_data[0]==0 && pad_data[1]==0) || (conv_param.kernel_size(0)==3 && pad_data[0]<=1 && pad_data[1] <=1 ) )) { + return shared_ptr<Layer<Dtype> >(new ACLConvolutionLayer<Dtype, CLConvolutionLayer, NEDirectConvolutionLayer>(param)); //NEDirectConvolutionLayer only for 1x1 and 3x3 + } + return shared_ptr<Layer<Dtype> >(new ACLConvolutionLayer<Dtype, CLConvolutionLayer, NEConvolutionLayer>(param)); +} #endif } // namespace caffe diff --git a/include/caffe/layers/acl_local_connect_layer.hpp b/include/caffe/layers/acl_local_connect_layer.hpp new file mode 100644 index 00000000..fdb30757 --- /dev/null +++ b/include/caffe/layers/acl_local_connect_layer.hpp @@ -0,0 +1,56 @@ +#ifndef CAFFE_ACL_LOCALCONNECT_LAYER_HPP_ +#define CAFFE_ACL_LOCALCONNECT_LAYER_HPP_ + +#include <vector> + +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +#include "caffe/layers/local_connect_layer.hpp" + +#ifdef USE_ACL +#include "caffe/acl_layer.hpp" +#endif + +namespace caffe { + +#ifdef USE_ACL +/* + * @brief ACL implementation of LocalConnectLayer. + * Fallback to LocalConnectLayer for some corner cases. + * +*/ +template <typename Dtype> +class ACLLocalConnectLayer : public ACLBaseLayer<CLLocallyConnectedLayer,NELocallyConnectedLayer>,public LocalConnectLayer<Dtype> { + public: + explicit ACLLocalConnectLayer(const LayerParameter& param) + : LocalConnectLayer<Dtype>(param) {} + virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Reshape(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual ~ACLLocalConnectLayer(); + + protected: + virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){ + NOT_IMPLEMENTED; + } + virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){ + NOT_IMPLEMENTED; + } + virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + +}; +#endif + +} // namespace caffe + +#endif // CAFFE_ACL_LOCALCONNECT_LAYER_HPP_ diff --git a/include/caffe/layers/local_connect_layer.hpp b/include/caffe/layers/local_connect_layer.hpp new file mode 100644 index 00000000..4dda9780 --- /dev/null +++ b/include/caffe/layers/local_connect_layer.hpp @@ -0,0 +1,59 @@ +#ifndef CAFFE_LOCALCONNECT_LAYERS_HPP_ +#define CAFFE_LOCALCONNECT_LAYERS_HPP_ + +#include <string> +#include <utility> +#include <vector> + +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/util/im2col.hpp" + +namespace caffe { + +template <typename Dtype> +class LocalConnectLayer : public Layer<Dtype> { + public: + explicit LocalConnectLayer(const LayerParameter& param) + : Layer<Dtype>(param) {} + virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Reshape(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + + virtual inline int MinBottomBlobs() const { return 1; } + virtual inline int MinTopBlobs() const { return 1; } + virtual inline bool EqualNumBottomTopBlobs() const { return true; } + + protected: + virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, + const vector<Blob<Dtype>*>& top); + virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom); + virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, + const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom); + + + int kernel_size_; + int stride_; + int dilation_; + int num_; + int channels_; + int pad_; + int height_, width_; + int height_out_, width_out_; + int num_output_; + bool bias_term_; + + int M_; + int K_; + int N_; + + Blob<Dtype> col_buffer_; +}; +} // namespace caffe + +#endif // CAFFE_LOCALCONNECT_LAYERS_HPP_ |