summaryrefslogtreecommitdiff
path: root/runtimes/pure_arm_compute/src/internal/layers/GenericFullyConnectedLayer.h
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/pure_arm_compute/src/internal/layers/GenericFullyConnectedLayer.h')
-rw-r--r--runtimes/pure_arm_compute/src/internal/layers/GenericFullyConnectedLayer.h33
1 files changed, 31 insertions, 2 deletions
diff --git a/runtimes/pure_arm_compute/src/internal/layers/GenericFullyConnectedLayer.h b/runtimes/pure_arm_compute/src/internal/layers/GenericFullyConnectedLayer.h
index 55d8683da..f1519f54d 100644
--- a/runtimes/pure_arm_compute/src/internal/layers/GenericFullyConnectedLayer.h
+++ b/runtimes/pure_arm_compute/src/internal/layers/GenericFullyConnectedLayer.h
@@ -14,23 +14,52 @@
* limitations under the License.
*/
+/**
+ * @file        GenericFullyConnectedLayer.h
+ * @brief       This file contains GenericFullyConnectedLayer class
+ * @ingroup     COM_AI_RUNTIME
+ */
+
#ifndef __GENERIC_FULLY_CONNECTED_LAYER_H__
#define __GENERIC_FULLY_CONNECTED_LAYER_H__
-#include <arm_compute/runtime/Tensor.h>
-#include <arm_compute/runtime/CL/CLTensor.h>
#include <arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h>
#include <arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h>
#include "internal/layers/GenericReshapeLayer.h"
+/**
+ * @brief Class to run FullyConnected Layer with both CPU and GPU
+ */
class GenericFullyConnectedLayer : public ::arm_compute::IFunction
{
public:
+ GenericFullyConnectedLayer(void)
+ : _input(nullptr), _weights(nullptr), _biases(nullptr), _output(nullptr), _cl_buffer{},
+ _neon_buffer{}, _cl_fc{}, _neon_fc{}, _generic_reshape{}, _needs_reshape(false)
+ {
+ // DO NOTHING
+ }
+
+public:
+ /**
+ * @brief Configure the layer
+ * @param[in] input The source tensor
+ * @param[in] weights The tensor that is filled with weight values
+ * @param[in] biases The tensor that is filled with biase values
+ * @param[in] output The destination tensor
+ * @param[in] needs_reshape Whether it needs to be reshaped or not
+ * @param[in] reshape The tensor shape to be reshaped. Only valid when needs_reshape is true.
+ * @return N/A
+ */
void configure(::arm_compute::ITensor *input, ::arm_compute::ITensor *weights,
::arm_compute::ITensor *biases, ::arm_compute::ITensor *output, bool needs_reshape,
::arm_compute::TensorShape reshape);
public:
+ /**
+ * @brief Run the operation. Must be called after configure().
+ * @return N/A
+ */
void run(void) override;
private: