blob: 7d766063e0562daa61b849abe78abf0b9240f026 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
|
/*
* Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnkit/support/tflite/AbstractBackend.h"
#include <tensorflow/contrib/lite/kernels/register.h>
#include <tensorflow/contrib/lite/model.h>
#include <stdexcept>
namespace
{
class GenericBackend final : public nnkit::support::tflite::AbstractBackend
{
public:
GenericBackend(const std::string &path)
{
::tflite::StderrReporter error_reporter;
_model = ::tflite::FlatBufferModel::BuildFromFile(path.c_str(), &error_reporter);
::tflite::ops::builtin::BuiltinOpResolver resolver;
::tflite::InterpreterBuilder builder(*_model, resolver);
if (kTfLiteOk != builder(&_interp))
{
throw std::runtime_error{"Failed to build a tflite interpreter"};
}
_interp->SetNumThreads(1);
}
public:
::tflite::Interpreter &interpreter(void) override { return *_interp; }
private:
std::unique_ptr<::tflite::FlatBufferModel> _model;
std::unique_ptr<::tflite::Interpreter> _interp;
};
}
#include <nnkit/CmdlineArguments.h>
#include <stdex/Memory.h>
extern "C" std::unique_ptr<nnkit::Backend> make_backend(const nnkit::CmdlineArguments &args)
{
return stdex::make_unique<GenericBackend>(args.at(0));
}
|