diff options
Diffstat (limited to 'runtimes/pure_arm_compute/src/internal/layers/GenericFullyConnectedLayer.h')
-rw-r--r-- | runtimes/pure_arm_compute/src/internal/layers/GenericFullyConnectedLayer.h | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/runtimes/pure_arm_compute/src/internal/layers/GenericFullyConnectedLayer.h b/runtimes/pure_arm_compute/src/internal/layers/GenericFullyConnectedLayer.h index 55d8683da..f1519f54d 100644 --- a/runtimes/pure_arm_compute/src/internal/layers/GenericFullyConnectedLayer.h +++ b/runtimes/pure_arm_compute/src/internal/layers/GenericFullyConnectedLayer.h @@ -14,23 +14,52 @@ * limitations under the License. */ +/** + * @file GenericFullyConnectedLayer.h + * @brief This file contains GenericFullyConnectedLayer class + * @ingroup COM_AI_RUNTIME + */ + #ifndef __GENERIC_FULLY_CONNECTED_LAYER_H__ #define __GENERIC_FULLY_CONNECTED_LAYER_H__ -#include <arm_compute/runtime/Tensor.h> -#include <arm_compute/runtime/CL/CLTensor.h> #include <arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h> #include <arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h> #include "internal/layers/GenericReshapeLayer.h" +/** + * @brief Class to run FullyConnected Layer with both CPU and GPU + */ class GenericFullyConnectedLayer : public ::arm_compute::IFunction { public: + GenericFullyConnectedLayer(void) + : _input(nullptr), _weights(nullptr), _biases(nullptr), _output(nullptr), _cl_buffer{}, + _neon_buffer{}, _cl_fc{}, _neon_fc{}, _generic_reshape{}, _needs_reshape(false) + { + // DO NOTHING + } + +public: + /** + * @brief Configure the layer + * @param[in] input The source tensor + * @param[in] weights The tensor that is filled with weight values + * @param[in] biases The tensor that is filled with biase values + * @param[in] output The destination tensor + * @param[in] needs_reshape Whether it needs to be reshaped or not + * @param[in] reshape The tensor shape to be reshaped. Only valid when needs_reshape is true. + * @return N/A + */ void configure(::arm_compute::ITensor *input, ::arm_compute::ITensor *weights, ::arm_compute::ITensor *biases, ::arm_compute::ITensor *output, bool needs_reshape, ::arm_compute::TensorShape reshape); public: + /** + * @brief Run the operation. Must be called after configure(). + * @return N/A + */ void run(void) override; private: |