summaryrefslogtreecommitdiff
path: root/runtimes/contrib/tflite_classify/src
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/contrib/tflite_classify/src')
-rw-r--r--runtimes/contrib/tflite_classify/src/ImageClassifier.cc107
-rw-r--r--runtimes/contrib/tflite_classify/src/ImageClassifier.h99
-rw-r--r--runtimes/contrib/tflite_classify/src/InferenceInterface.cc114
-rw-r--r--runtimes/contrib/tflite_classify/src/InferenceInterface.h93
-rw-r--r--runtimes/contrib/tflite_classify/src/tflite_classify.cc132
5 files changed, 545 insertions, 0 deletions
diff --git a/runtimes/contrib/tflite_classify/src/ImageClassifier.cc b/runtimes/contrib/tflite_classify/src/ImageClassifier.cc
new file mode 100644
index 000000000..fae4f066c
--- /dev/null
+++ b/runtimes/contrib/tflite_classify/src/ImageClassifier.cc
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ImageClassifier.h"
+
+#include <fstream>
+#include <queue>
+#include <algorithm>
+
+ImageClassifier::ImageClassifier(const std::string &model_file, const std::string &label_file,
+ const int input_size, const int image_mean, const int image_std,
+ const std::string &input_name, const std::string &output_name,
+ const bool use_nnapi)
+ : _inference(new InferenceInterface(model_file, use_nnapi)), _input_size(input_size),
+ _image_mean(image_mean), _image_std(image_std), _input_name(input_name),
+ _output_name(output_name)
+{
+ // Load label
+ std::ifstream label_stream(label_file.c_str());
+ assert(label_stream);
+
+ std::string line;
+ while (std::getline(label_stream, line))
+ {
+ _labels.push_back(line);
+ }
+ _num_classes = _inference->getTensorSize(_output_name);
+ std::cout << "Output tensor size is " << _num_classes << ", label size is " << _labels.size()
+ << std::endl;
+
+ // Pre-allocate buffers
+ _fdata.reserve(_input_size * _input_size * 3);
+ _outputs.reserve(_num_classes);
+}
+
+std::vector<Recognition> ImageClassifier::recognizeImage(const cv::Mat &image)
+{
+ // Resize image
+ cv::Mat cropped;
+ cv::resize(image, cropped, cv::Size(_input_size, _input_size), 0, 0, cv::INTER_AREA);
+
+ // Preprocess the image data from 0~255 int to normalized float based
+ // on the provided parameters
+ _fdata.clear();
+ for (int y = 0; y < cropped.rows; ++y)
+ {
+ for (int x = 0; x < cropped.cols; ++x)
+ {
+ cv::Vec3b color = cropped.at<cv::Vec3b>(y, x);
+ color[0] = color[0] - (float)_image_mean / _image_std;
+ color[1] = color[1] - (float)_image_mean / _image_std;
+ color[2] = color[2] - (float)_image_mean / _image_std;
+
+ _fdata.push_back(color[0]);
+ _fdata.push_back(color[1]);
+ _fdata.push_back(color[2]);
+
+ cropped.at<cv::Vec3b>(y, x) = color;
+ }
+ }
+
+ // Copy the input data into model
+ _inference->feed(_input_name, _fdata, 1, _input_size, _input_size, 3);
+
+ // Run the inference call
+ _inference->run(_output_name);
+
+ // Copy the output tensor back into the output array
+ _inference->fetch(_output_name, _outputs);
+
+ // Find the best classifications
+ auto compare = [](const Recognition &lhs, const Recognition &rhs) {
+ return lhs.confidence < rhs.confidence;
+ };
+
+ std::priority_queue<Recognition, std::vector<Recognition>, decltype(compare)> pq(compare);
+ for (int i = 0; i < _num_classes; ++i)
+ {
+ if (_outputs[i] > _threshold)
+ {
+ pq.push(Recognition(_outputs[i], _labels[i]));
+ }
+ }
+
+ std::vector<Recognition> results;
+ int min = std::min(pq.size(), _max_results);
+ for (int i = 0; i < min; ++i)
+ {
+ results.push_back(pq.top());
+ pq.pop();
+ }
+
+ return results;
+}
diff --git a/runtimes/contrib/tflite_classify/src/ImageClassifier.h b/runtimes/contrib/tflite_classify/src/ImageClassifier.h
new file mode 100644
index 000000000..1ba19afb0
--- /dev/null
+++ b/runtimes/contrib/tflite_classify/src/ImageClassifier.h
@@ -0,0 +1,99 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * @file ImageClassifier.h
+ * @brief This file contains ImageClassifier class and Recognition structure
+ * @ingroup COM_AI_RUNTIME
+ */
+
+#ifndef __TFLITE_CLASSIFY_IMAGE_CLASSIFIER_H__
+#define __TFLITE_CLASSIFY_IMAGE_CLASSIFIER_H__
+
+#include "InferenceInterface.h"
+
+#include <iostream>
+#include <string>
+#include <vector>
+
+#include <opencv2/opencv.hpp>
+
+/**
+ * @brief struct to define an immutable result returned by a Classifier
+ */
+struct Recognition
+{
+public:
+ /**
+ * @brief Construct a new Recognition object with confidence and title
+ * @param[in] _confidence A sortable score for how good the recognition is relative to others.
+ * Higher should be better.
+ * @param[in] _title Display name for the recognition
+ */
+ Recognition(float _confidence, std::string _title) : confidence(_confidence), title(_title) {}
+
+ float confidence; /** A sortable score for how good the recognition is relative to others. Higher
+ should be better. */
+ std::string title; /** Display name for the recognition */
+};
+
+/**
+ * @brief Class to define a classifier specialized to label images
+ */
+class ImageClassifier
+{
+public:
+ /**
+ * @brief Construct a new ImageClassifier object with parameters
+ * @param[in] model_file The filepath of the model FlatBuffer protocol buffer
+ * @param[in] label_file The filepath of label file for classes
+ * @param[in] input_size The input size. A square image of input_size x input_size is assumed
+ * @param[in] image_mean The assumed mean of the image values
+ * @param[in] image_std The assumed std of the image values
+ * @param[in] input_name The label of the image input node
+ * @param[in] output_name The label of the output node
+ * @param[in] use_nnapi The flag to distinguish between TfLite interpreter and NNFW runtime
+ */
+ ImageClassifier(const std::string &model_file, const std::string &label_file,
+ const int input_size, const int image_mean, const int image_std,
+ const std::string &input_name, const std::string &output_name,
+ const bool use_nnapi);
+
+ /**
+ * @brief Recognize the given image data
+ * @param[in] image The image data to recognize
+ * @return An immutable result vector array
+ */
+ std::vector<Recognition> recognizeImage(const cv::Mat &image);
+
+private:
+ const float _threshold = 0.1f;
+ const unsigned int _max_results = 3;
+
+ std::unique_ptr<InferenceInterface> _inference;
+ int _input_size;
+ int _image_mean;
+ int _image_std;
+ std::string _input_name;
+ std::string _output_name;
+
+ std::vector<std::string> _labels;
+ std::vector<float> _fdata;
+ std::vector<float> _outputs;
+ int _num_classes;
+};
+
+#endif // __TFLITE_CLASSIFY_IMAGE_CLASSIFIER_H__
diff --git a/runtimes/contrib/tflite_classify/src/InferenceInterface.cc b/runtimes/contrib/tflite_classify/src/InferenceInterface.cc
new file mode 100644
index 000000000..160943477
--- /dev/null
+++ b/runtimes/contrib/tflite_classify/src/InferenceInterface.cc
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "InferenceInterface.h"
+
+using namespace tflite;
+using namespace tflite::ops::builtin;
+
+InferenceInterface::InferenceInterface(const std::string &model_file, const bool use_nnapi)
+ : _interpreter(nullptr), _model(nullptr), _sess(nullptr)
+{
+ // Load model
+ StderrReporter error_reporter;
+ _model = FlatBufferModel::BuildFromFile(model_file.c_str(), &error_reporter);
+ BuiltinOpResolver resolver;
+ InterpreterBuilder builder(*_model, resolver);
+ builder(&_interpreter);
+
+ if (use_nnapi)
+ {
+ _sess = std::make_shared<nnfw::tflite::NNAPISession>(_interpreter.get());
+ }
+ else
+ {
+ _sess = std::make_shared<nnfw::tflite::InterpreterSession>(_interpreter.get());
+ }
+
+ _sess->prepare();
+}
+
+InferenceInterface::~InferenceInterface() { _sess->teardown(); }
+
+void InferenceInterface::feed(const std::string &input_name, const std::vector<float> &data,
+ const int batch, const int height, const int width, const int channel)
+{
+ // Set input tensor
+ for (const auto &id : _interpreter->inputs())
+ {
+ if (_interpreter->tensor(id)->name == input_name)
+ {
+ assert(_interpreter->tensor(id)->type == kTfLiteFloat32);
+ float *p = _interpreter->tensor(id)->data.f;
+
+ // TODO consider batch
+ for (int y = 0; y < height; ++y)
+ {
+ for (int x = 0; x < width; ++x)
+ {
+ for (int c = 0; c < channel; ++c)
+ {
+ *p++ = data[y * width * channel + x * channel + c];
+ }
+ }
+ }
+ }
+ }
+}
+
+void InferenceInterface::run(const std::string &output_name)
+{
+ // Run model
+ _sess->run();
+}
+
+void InferenceInterface::fetch(const std::string &output_name, std::vector<float> &outputs)
+{
+ // Get output tensor
+ for (const auto &id : _interpreter->outputs())
+ {
+ if (_interpreter->tensor(id)->name == output_name)
+ {
+ assert(_interpreter->tensor(id)->type == kTfLiteFloat32);
+ assert(getTensorSize(output_name) == outputs.capacity());
+ float *p = _interpreter->tensor(id)->data.f;
+
+ outputs.clear();
+ for (int i = 0; i < outputs.capacity(); ++i)
+ {
+ outputs.push_back(p[i]);
+ }
+ }
+ }
+}
+
+int InferenceInterface::getTensorSize(const std::string &name)
+{
+ for (const auto &id : _interpreter->outputs())
+ {
+ if (_interpreter->tensor(id)->name == name)
+ {
+ TfLiteTensor *t = _interpreter->tensor(id);
+ int v = 1;
+ for (int i = 0; i < t->dims->size; ++i)
+ {
+ v *= t->dims->data[i];
+ }
+ return v;
+ }
+ }
+ return -1;
+}
diff --git a/runtimes/contrib/tflite_classify/src/InferenceInterface.h b/runtimes/contrib/tflite_classify/src/InferenceInterface.h
new file mode 100644
index 000000000..fe2c1aa6c
--- /dev/null
+++ b/runtimes/contrib/tflite_classify/src/InferenceInterface.h
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * @file InferenceInterface.h
+ * @brief This file contains class for running the actual inference model
+ * @ingroup COM_AI_RUNTIME
+ */
+
+#ifndef __TFLITE_CLASSIFY_INFERENCE_INTERFACE_H__
+#define __TFLITE_CLASSIFY_INFERENCE_INTERFACE_H__
+
+#include "tflite/ext/kernels/register.h"
+#include "tensorflow/lite/model.h"
+
+#include "tflite/InterpreterSession.h"
+#include "tflite/NNAPISession.h"
+
+#include <iostream>
+#include <string>
+
+/**
+ * @brief Class to define a inference interface for recognizing data
+ */
+class InferenceInterface
+{
+public:
+ /**
+ * @brief Construct a new InferenceInterface object with parameters
+ * @param[in] model_file The filepath of the model FlatBuffer protocol buffer
+ * @param[in] use_nnapi The flag to distinguish between TfLite interpreter and NNFW runtime
+ */
+ InferenceInterface(const std::string &model_file, const bool use_nnapi);
+
+ /**
+ * @brief Destructor an InferenceInterface object
+ */
+ ~InferenceInterface();
+
+ /**
+ * @brief Copy the input data into model
+ * @param[in] input_name The label of the image input node
+ * @param[in] data The actual data to be copied into input tensor
+ * @param[in] batch The number of batch size
+ * @param[in] height The number of height size
+ * @param[in] width The number of width size
+ * @param[in] channel The number of channel size
+ * @return N/A
+ */
+ void feed(const std::string &input_name, const std::vector<float> &data, const int batch,
+ const int height, const int width, const int channel);
+ /**
+ * @brief Run the inference call
+ * @param[in] output_name The label of the output node
+ * @return N/A
+ */
+ void run(const std::string &output_name);
+
+ /**
+ * @brief Copy the output tensor back into the output array
+ * @param[in] output_node The label of the output node
+ * @param[in] outputs The output data array
+ * @return N/A
+ */
+ void fetch(const std::string &output_name, std::vector<float> &outputs);
+
+ /**
+ * @brief Get tensor size
+ * @param[in] name The label of the node
+ * @result The size of tensor
+ */
+ int getTensorSize(const std::string &name);
+
+private:
+ std::unique_ptr<tflite::Interpreter> _interpreter;
+ std::unique_ptr<tflite::FlatBufferModel> _model;
+ std::shared_ptr<nnfw::tflite::Session> _sess;
+};
+
+#endif // __TFLITE_CLASSIFY_INFERENCE_INTERFACE_H__
diff --git a/runtimes/contrib/tflite_classify/src/tflite_classify.cc b/runtimes/contrib/tflite_classify/src/tflite_classify.cc
new file mode 100644
index 000000000..40c15f331
--- /dev/null
+++ b/runtimes/contrib/tflite_classify/src/tflite_classify.cc
@@ -0,0 +1,132 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ImageClassifier.h"
+
+#include <iostream>
+
+#include <boost/filesystem.hpp>
+#include <opencv2/opencv.hpp>
+
+namespace fs = boost::filesystem;
+
+int main(const int argc, char **argv)
+{
+ const std::string MODEL_FILE = "tensorflow_inception_graph.tflite";
+ const std::string LABEL_FILE = "imagenet_comp_graph_label_strings.txt";
+
+ const std::string INPUT_NAME = "input";
+ const std::string OUTPUT_NAME = "output";
+ const int INPUT_SIZE = 224;
+ const int IMAGE_MEAN = 117;
+ const int IMAGE_STD = 1;
+ const int OUTPUT_SIZE = 1008;
+
+ const int FRAME_WIDTH = 640;
+ const int FRAME_HEIGHT = 480;
+
+ bool use_nnapi = false;
+ bool debug_mode = false;
+
+ if (std::getenv("USE_NNAPI") != nullptr)
+ {
+ use_nnapi = true;
+ }
+
+ if (std::getenv("DEBUG_MODE") != nullptr)
+ {
+ debug_mode = true;
+ }
+
+ std::cout << "USE_NNAPI : " << use_nnapi << std::endl;
+ std::cout << "DEBUG_MODE : " << debug_mode << std::endl;
+
+ std::cout << "Model : " << MODEL_FILE << std::endl;
+ std::cout << "Label : " << LABEL_FILE << std::endl;
+
+ if (!fs::exists(MODEL_FILE))
+ {
+ std::cerr << "model file not found: " << MODEL_FILE << std::endl;
+ exit(1);
+ }
+
+ if (!fs::exists(LABEL_FILE))
+ {
+ std::cerr << "label file not found: " << LABEL_FILE << std::endl;
+ exit(1);
+ }
+
+ // Create ImageClassifier
+ std::unique_ptr<ImageClassifier> classifier(
+ new ImageClassifier(MODEL_FILE, LABEL_FILE, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD, INPUT_NAME,
+ OUTPUT_NAME, use_nnapi));
+
+ // Cam setting
+ cv::VideoCapture cap(0);
+ cv::Mat frame;
+
+ // Initialize camera
+ cap.set(CV_CAP_PROP_FRAME_WIDTH, FRAME_WIDTH);
+ cap.set(CV_CAP_PROP_FRAME_HEIGHT, FRAME_HEIGHT);
+ cap.set(CV_CAP_PROP_FPS, 5);
+
+ std::vector<Recognition> results;
+ clock_t begin, end;
+ while (cap.isOpened())
+ {
+ // Get image data
+ if (!cap.read(frame))
+ {
+ std::cout << "Frame is null..." << std::endl;
+ break;
+ }
+
+ if (debug_mode)
+ {
+ begin = clock();
+ }
+ // Recognize image
+ results = classifier->recognizeImage(frame);
+ if (debug_mode)
+ {
+ end = clock();
+ }
+
+ // Show result data
+ std::cout << std::endl;
+ if (results.size() > 0)
+ {
+ for (int i = 0; i < results.size(); ++i)
+ {
+ std::cout << results[i].title << ": " << results[i].confidence << std::endl;
+ }
+ }
+ else
+ {
+ std::cout << "." << std::endl;
+ }
+ if (debug_mode)
+ {
+ std::cout << "Frame: " << FRAME_WIDTH << "x" << FRAME_HEIGHT << std::endl;
+ std::cout << "Crop: " << INPUT_SIZE << "x" << INPUT_SIZE << std::endl;
+ std::cout << "Inference time(ms): " << ((end - begin) / (CLOCKS_PER_SEC / 1000)) << std::endl;
+ }
+ }
+
+ cap.release();
+
+ return 0;
+}