1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
|
/*
* Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnkit/support/tf/Backend.h"
#include "nnkit/support/tftestinfo/ParsedTensor.h"
#include "nnkit/support/tftestinfo/TensorInfoParser.h"
#include "nnkit/support/tf/TensorDataMap.h"
#include "nnkit/support/tf/TensorContext.h"
#include "nnkit/support/tf/Runner.h"
#include <angkor/TensorShape.h>
#include <nnkit/Backend.h>
#include <cstring> // memcpy
namespace nnkit
{
namespace support
{
namespace tf
{
using nnkit::support::tftestinfo::ParsedTensor;
Backend::Backend(const char *pb_path, const char *info_path) : _tf_runner(pb_path)
{
auto parsed_tensors = nnkit::support::tftestinfo::parse(info_path);
for (auto &parsed_tensor : parsed_tensors)
{
if (parsed_tensor->kind() == ParsedTensor::Kind::Input)
{
// user didn't specify input
if (!parsed_tensor->hasShape())
{
angkor::TensorShape shape;
if (!_tf_runner.getTensorShapeFromGraphDef(parsed_tensor, shape))
throw oops::UserExn(
"Info you provided may be wrong or not enough. Please check the info file.");
parsed_tensor->mutable_shape().resize(shape.rank());
for (int r = 0; r < shape.rank(); r++)
{
parsed_tensor->mutable_shape().dim(r) = shape.dim(r);
}
}
_inputs.emplace_back(std::move(parsed_tensor));
}
else
_outputs.emplace_back(std::move(parsed_tensor));
}
}
void Backend::prepare(const std::function<void(nnkit::TensorContext &)> &f)
{
for (const auto &input_tensor : _inputs)
_data_map.allocate(input_tensor.get());
TensorContext ctx(_inputs, _data_map);
f(ctx); // fill values
_tf_runner.prepareInputs(_inputs, _data_map);
_tf_runner.prepareOutputs(_outputs);
}
void Backend::run(void)
{
_tf_runner.run();
// get result
auto actual_outputs = _tf_runner.output();
for (int n = 0; n < _outputs.size(); n++)
{
auto actual = actual_outputs[n];
const size_t byte_size = TF_TensorByteSize(actual);
const uint8_t *tf_data = reinterpret_cast<const uint8_t *>(TF_TensorData(actual));
const uint32_t shape_rank = TF_NumDims(actual);
_outputs[n]->mutable_shape().resize(shape_rank);
for (uint32_t r = 0; r < shape_rank; r++)
{
_outputs[n]->mutable_shape().dim(r) = TF_Dim(actual, r);
}
uint8_t *dest = _data_map.allocate(_outputs[n].get());
std::memcpy(dest, tf_data, byte_size);
}
}
void Backend::teardown(const std::function<void(nnkit::TensorContext &)> &f)
{
TensorContext ctx(_outputs, _data_map);
f(ctx);
}
} // namespace tf
} // namespace support
} // namespace nnkit
|