diff options
Diffstat (limited to 'include/support/tflite')
-rw-r--r-- | include/support/tflite/Diff.h | 119 | ||||
-rw-r--r-- | include/support/tflite/FeatureView.h | 69 | ||||
-rw-r--r-- | include/support/tflite/InputIndex.h | 46 | ||||
-rw-r--r-- | include/support/tflite/OutputIndex.h | 46 | ||||
-rw-r--r-- | include/support/tflite/TensorUtils.h | 43 | ||||
-rw-r--r-- | include/support/tflite/TensorView.h | 64 | ||||
-rw-r--r-- | include/support/tflite/interp/Builder.h | 43 | ||||
-rw-r--r-- | include/support/tflite/interp/FlatBufferBuilder.h | 53 | ||||
-rw-r--r-- | include/support/tflite/interp/FunctionBuilder.h | 56 |
9 files changed, 539 insertions, 0 deletions
diff --git a/include/support/tflite/Diff.h b/include/support/tflite/Diff.h new file mode 100644 index 000000000..b17c9313c --- /dev/null +++ b/include/support/tflite/Diff.h @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SUPPORT_TFLITE_COMPARE_H__ +#define __NNFW_SUPPORT_TFLITE_COMPARE_H__ + +#include "tensorflow/contrib/lite/interpreter.h" + +#include "util/tensor/Index.h" + +#include "support/tflite/TensorView.h" + +#include <functional> +#include <vector> + +// NOTE The code below is subject to change. +// TODO Introduce namespaces +struct TfLiteTensorDiff +{ + nnfw::util::tensor::Index index; + float expected; + float obtained; + + TfLiteTensorDiff(const nnfw::util::tensor::Index &i) : index(i) + { + // DO NOTHING + } +}; + +class TfLiteTensorComparator +{ +public: + TfLiteTensorComparator(const std::function<bool (float lhs, float rhs)> &fn) : _compare_fn{fn} + { + // DO NOTHING + } + +public: + struct Observer + { + virtual void notify(const nnfw::util::tensor::Index &index, float expected, float obtained) = 0; + }; + +public: + // NOTE Observer should live longer than comparator + std::vector<TfLiteTensorDiff> compare(const nnfw::support::tflite::TensorView<float> &expected, + const nnfw::support::tflite::TensorView<float> &obtained, + Observer *observer = nullptr) const; + +private: + std::function<bool (float lhs, float rhs)> _compare_fn; +}; + +class TfLiteInterpMatchApp +{ +public: + TfLiteInterpMatchApp(const TfLiteTensorComparator &comparator) + : _verbose{false}, _comparator(comparator) + { + // DO NOTHING + } + +public: + int &verbose(void) { return _verbose; } + +private: + int _verbose; + +public: + bool run(::tflite::Interpreter &pure, ::tflite::Interpreter &nnapi) const; + +private: + const TfLiteTensorComparator &_comparator; +}; + +#include "support/tflite/interp/Builder.h" + +#include <random> + +// For NNAPI testing +struct RandomTestParam +{ + int verbose; + int tolerance; +}; + +class RandomTestRunner +{ +public: + RandomTestRunner(int seed, const RandomTestParam ¶m) + : _rand{seed}, _param{param} + { + // DO NOTHING + } + +public: + // NOTE this method updates '_rand' + // Return 0 if test succeeds + int run(const nnfw::support::tflite::interp::Builder &builder); + +private: + std::minstd_rand _rand; + const RandomTestParam _param; +}; + +#endif // __NNFW_SUPPORT_TFLITE_COMPARE_H__ diff --git a/include/support/tflite/FeatureView.h b/include/support/tflite/FeatureView.h new file mode 100644 index 000000000..3a7d75ec4 --- /dev/null +++ b/include/support/tflite/FeatureView.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SUPPORT_TFLITE_FEATURE_VIEW_H__ +#define __NNFW_SUPPORT_TFLITE_FEATURE_VIEW_H__ + +#include "tensorflow/contrib/lite/interpreter.h" + +#include "support/tflite/InputIndex.h" +#include "support/tflite/OutputIndex.h" + +#include "util/feature/Shape.h" +#include "util/feature/Reader.h" + +namespace nnfw +{ +namespace support +{ +namespace tflite +{ + +template<typename T> class FeatureView; + +template<> class FeatureView<float> : public nnfw::util::feature::Reader<float> +{ +public: + FeatureView(::tflite::Interpreter &interp, const InputIndex &index); + FeatureView(::tflite::Interpreter &interp, const OutputIndex &index); + +public: + float at(uint32_t ch, uint32_t row, uint32_t col) const; + float &at(uint32_t ch, uint32_t row, uint32_t col); + +private: + uint32_t getElementOffset(uint32_t ch, uint32_t row, uint32_t col) const + { + uint32_t res = 0; + + // TensorFlow Lite assumes that NHWC ordering for tessor + res += row * _shape.W * _shape.C; + res += col * _shape.C; + res += ch; + + return res; + } + +private: + nnfw::util::feature::Shape _shape; + float *_base; +}; + +} // namespace tflite +} // namespace support +} // namespace nnfw + +#endif // __NNFW_SUPPORT_TFLITE_FEATURE_VIEW_H__ diff --git a/include/support/tflite/InputIndex.h b/include/support/tflite/InputIndex.h new file mode 100644 index 000000000..c3ed891fe --- /dev/null +++ b/include/support/tflite/InputIndex.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SUPPORT_TFLITE_INPUT_INDEX_H__ +#define __NNFW_SUPPORT_TFLITE_INPUT_INDEX_H__ + +namespace nnfw +{ +namespace support +{ +namespace tflite +{ + +class InputIndex +{ +public: + InputIndex(int index) : _index(index) + { + // DO NOTHING + } + +public: + int asInt(void) const { return _index; } + +private: + int _index; +}; + +} // namespace tflite +} // namespace support +} // namespace nnfw + +#endif // __NNFW_SUPPORT_TFLITE_INPUT_INDEX_H__ diff --git a/include/support/tflite/OutputIndex.h b/include/support/tflite/OutputIndex.h new file mode 100644 index 000000000..be6556ce7 --- /dev/null +++ b/include/support/tflite/OutputIndex.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SUPPORT_TFLITE_OUTPUT_INDEX_H__ +#define __NNFW_SUPPORT_TFLITE_OUTPUT_INDEX_H__ + +namespace nnfw +{ +namespace support +{ +namespace tflite +{ + +class OutputIndex +{ +public: + OutputIndex(int index) : _index(index) + { + // DO NOTHING + } + +public: + int asInt(void) const { return _index; } + +private: + int _index; +}; + +} // namespace tflite +} // namespace support +} // namespace nnfw + +#endif // __NNFW_SUPPORT_TFLITE_OUTPUT_INDEX_H__ diff --git a/include/support/tflite/TensorUtils.h b/include/support/tflite/TensorUtils.h new file mode 100644 index 000000000..815cfcd29 --- /dev/null +++ b/include/support/tflite/TensorUtils.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SUPPORT_TFLITE_TENSOR_UTILS_H__ +#define __NNFW_SUPPORT_TFLITE_TENSOR_UTILS_H__ + +#include <tensorflow/contrib/lite/context.h> + +namespace nnfw +{ +namespace support +{ +namespace tflite +{ + +inline bool isFloatTensor(const TfLiteTensor *tensor) +{ + return tensor->type == kTfLiteFloat32; +} + +inline bool isFeatureTensor(const TfLiteTensor *tensor) +{ + return (tensor->dims->size == 4) && (tensor->dims->data[0] == 1); +} + +} // namespace tflite +} // namespace support +} // namespace nnfw + +#endif // __NNFW_SUPPORT_TFLITE_TENSOR_UTILS_H__ diff --git a/include/support/tflite/TensorView.h b/include/support/tflite/TensorView.h new file mode 100644 index 000000000..35c90a372 --- /dev/null +++ b/include/support/tflite/TensorView.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SUPPORT_TFLITE_TENSOR_VIEW_H__ +#define __NNFW_SUPPORT_TFLITE_TENSOR_VIEW_H__ + +#include "tensorflow/contrib/lite/interpreter.h" + +#include "util/tensor/Shape.h" +#include "util/tensor/Index.h" +#include "util/tensor/Reader.h" +#include "util/tensor/NonIncreasingStride.h" + +namespace nnfw +{ +namespace support +{ +namespace tflite +{ + +template<typename T> class TensorView; + +template<> class TensorView<float> final : public nnfw::util::tensor::Reader<float> +{ +public: + TensorView(const nnfw::util::tensor::Shape &shape, float *base); + +public: + const nnfw::util::tensor::Shape &shape(void) const { return _shape; } + +public: + float at(const nnfw::util::tensor::Index &index) const override; + float &at(const nnfw::util::tensor::Index &index); + +private: + nnfw::util::tensor::Shape _shape; + +public: + float *_base; + nnfw::util::tensor::NonIncreasingStride _stride; + +public: + // TODO Introduce Operand ID class + static TensorView<float> make(::tflite::Interpreter &interp, int operand_id); +}; + +} // namespace tflite +} // namespace support +} // namespace nnfw + +#endif // __NNFW_SUPPORT_TFLITE_TENSOR_VIEW_H__ diff --git a/include/support/tflite/interp/Builder.h b/include/support/tflite/interp/Builder.h new file mode 100644 index 000000000..4a5a2f26f --- /dev/null +++ b/include/support/tflite/interp/Builder.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SUPPORT_TFLITE_INTERP_BUILDER_H__ +#define __NNFW_SUPPORT_TFLITE_INTERP_BUILDER_H__ + +#include <tensorflow/contrib/lite/interpreter.h> + +namespace nnfw +{ +namespace support +{ +namespace tflite +{ +namespace interp +{ + +struct Builder +{ + virtual ~Builder() = default; + + virtual std::unique_ptr<::tflite::Interpreter> build(void) const = 0; +}; + +} // namespace interp +} // namespace tflite +} // namespace support +} // namespace nnfw + +#endif // __NNFW_SUPPORT_TFLITE_INTERP_BUILDER_H__ diff --git a/include/support/tflite/interp/FlatBufferBuilder.h b/include/support/tflite/interp/FlatBufferBuilder.h new file mode 100644 index 000000000..dab151dcf --- /dev/null +++ b/include/support/tflite/interp/FlatBufferBuilder.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SUPPORT_TFLITE_INTERP_FLAT_BUFFER_BUILDER_H__ +#define __NNFW_SUPPORT_TFLITE_INTERP_FLAT_BUFFER_BUILDER_H__ + +#include <tensorflow/contrib/lite/model.h> + +#include "support/tflite/interp/Builder.h" + +namespace nnfw +{ +namespace support +{ +namespace tflite +{ +namespace interp +{ + +class FlatBufferBuilder final : public Builder +{ +public: + FlatBufferBuilder(const ::tflite::FlatBufferModel &model) : _model{model} + { + // DO NOTHING + } + +public: + std::unique_ptr<::tflite::Interpreter> build(void) const override; + +private: + const ::tflite::FlatBufferModel &_model; +}; + +} // namespace interp +} // namespace tflite +} // namespace support +} // namespace nnfw + +#endif // __NNFW_SUPPORT_TFLITE_INTERP_FLAT_BUFFER_BUILDER_H__ diff --git a/include/support/tflite/interp/FunctionBuilder.h b/include/support/tflite/interp/FunctionBuilder.h new file mode 100644 index 000000000..1ac5918e8 --- /dev/null +++ b/include/support/tflite/interp/FunctionBuilder.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SUPPORT_TFLITE_INTERP_FUNCTION_BUILDER_H__ +#define __NNFW_SUPPORT_TFLITE_INTERP_FUNCTION_BUILDER_H__ + +#include <tensorflow/contrib/lite/model.h> + +#include "support/tflite/interp/Builder.h" + +namespace nnfw +{ +namespace support +{ +namespace tflite +{ +namespace interp +{ + +class FunctionBuilder final : public Builder +{ +public: + using SetupFunc = std::function<void (::tflite::Interpreter &)>; + +public: + FunctionBuilder(const SetupFunc &fn) : _fn{fn} + { + // DO NOTHING + } + +public: + std::unique_ptr<::tflite::Interpreter> build(void) const override; + +private: + SetupFunc _fn; +}; + +} // namespace interp +} // namespace tflite +} // namespace support +} // namespace nnfw + +#endif // __NNFW_SUPPORT_TFLITE_INTERP_FUNCTION_BUILDER_H__ |