summaryrefslogtreecommitdiff
path: root/runtime/onert/frontend/trix/src/trix_loader.cc
blob: cdf2396482dccc0b1605bffc8046cc1f8286b94a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
/*
 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "trix_loader.h"

#include "ir/Graph.h"
#include "ir/operation/Bulk.h"

#include <libnpuhost.h>
#include <npubinfmt.h>
#include <typedef.h>

namespace onert
{
namespace trix_loader
{

/**
 * @brief A tvn metadata reader
 */
class TrixMetaReader
{
public:
  TrixMetaReader() = default;
  ~TrixMetaReader() { free(_meta); }

  void init(const char *path);
  data_layout input_seg_layout(uint32_t n) const { return _meta->input_seg_layout[n]; }
  data_layout output_seg_layout(uint32_t n) const { return _meta->output_seg_layout[n]; }
  data_type input_seg_quant_type(uint32_t n) const { return _meta->input_seg_quant_type[n]; }
  data_type output_seg_quant_type(uint32_t n) const { return _meta->output_seg_quant_type[n]; }
  float input_seg_quant_scale(uint32_t n) const { return _meta->input_seg_quant_s[n]; }
  float output_seg_quant_scale(uint32_t n) const { return _meta->output_seg_quant_s[n]; }
  int32_t input_seg_quant_zp(uint32_t n) { return _meta->input_seg_quant_z[n]; }
  int32_t output_seg_quant_zp(uint32_t n) { return _meta->output_seg_quant_z[n]; }
  uint32_t input_seg_num() const { return _meta->input_seg_num; }
  uint32_t output_seg_num() const { return _meta->output_seg_num; }
  uint32_t input_seg_dims(uint32_t n, uint32_t axis) const
  {
    return _meta->input_seg_dims[n][axis];
  }
  uint32_t output_seg_dims(uint32_t n, uint32_t axis) const
  {
    return _meta->output_seg_dims[n][axis];
  }

private:
  npubin_meta *_meta = nullptr;
};

void TrixMetaReader::init(const char *path)
{
  assert(path);
  _meta = getNPUmodel_metadata(path, false);
  if (_meta == nullptr)
  {
    throw std::runtime_error("Failed to get TRIX model metadata");
  }
  if (NPUBIN_VERSION(_meta->magiccode) != 3)
  {
    throw std::runtime_error("TRIX model metadata version mismatched.");
  }
}

class TrixLoader
{
public:
  /**
   * @brief Construct a new Loader object
   *
   * @param model reference on model
   */
  explicit TrixLoader(std::unique_ptr<ir::Model> &model) : _model(model) {}

  /**
   * @brief Load a model from file
   * @param file_path
   */
  void loadFromFile(const std::string &file_path);

private:
  /*
   * @brief Load actually
   * @throw runtime_error when tvn path is wrong or tvn is invalid
   */
  void loadModel();
  std::unique_ptr<ir::Graph> loadSubgraph();
  void loadOperands(ir::Graph &subg);
  ir::OperandIndex loadOperandFromInput(uint32_t i, ir::Graph &subg);
  ir::OperandIndex loadOperandFromOutput(uint32_t i, ir::Graph &subg);
  void loadBulk(ir::Graph &subg);
  void loadOperationIO(ir::OperandIndexSequence &inputs, ir::OperandIndexSequence &outputs);
  ir::OperandIndex inputIdxToOperandIdx(uint32_t i) const;
  ir::OperandIndex outputIdxToOperandIdx(uint32_t i) const;
  ir::DataType toDataType(const data_type type) const;

private:
protected:
  /** path to model (e.g. tvn) */
  std::string _model_path;
  /** original IO shapes */
  std::vector<ir::Shape> _origin_input_shapes;
  std::vector<ir::Shape> _origin_output_shapes;
  /** Reference on loadable subgraphs */
  std::unique_ptr<ir::Model> &_model;
  TrixMetaReader _meta;
};

ir::DataType TrixLoader::toDataType(const data_type type) const
{
  switch (type)
  {
    case DATA_TYPE_QASYMM8:
      return ir::DataType::QUANT_UINT8_ASYMM;
    case DATA_TYPE_QSYMM16:
      return ir::DataType::QUANT_INT16_SYMM;
    default:
      throw std::runtime_error("Unsupported data type from trix model");
  }
}

ir::OperandIndex TrixLoader::inputIdxToOperandIdx(uint32_t i) const { return ir::OperandIndex(i); }
ir::OperandIndex TrixLoader::outputIdxToOperandIdx(uint32_t i) const
{
  return ir::OperandIndex(_meta.input_seg_num() + i);
}

void TrixLoader::loadOperationIO(ir::OperandIndexSequence &inputs,
                                 ir::OperandIndexSequence &outputs)
{
  for (uint32_t i = 0; i < _meta.input_seg_num(); ++i)
  {
    inputs.append(inputIdxToOperandIdx(i));
  }

  for (uint32_t i = 0; i < _meta.output_seg_num(); ++i)
  {
    outputs.append(outputIdxToOperandIdx(i));
  }
}

void TrixLoader::loadBulk(ir::Graph &subg)
{
  ir::operation::Bulk::Param param;
  param.binary_path = _model_path;
  param.origin_input_shapes = _origin_input_shapes;
  param.origin_output_shapes = _origin_output_shapes;

  ir::OperandIndexSequence inputs;
  ir::OperandIndexSequence outputs;

  loadOperationIO(inputs, outputs);

  std::unique_ptr<ir::operation::Bulk> bulk(new ir::operation::Bulk(inputs, outputs, param));
  subg.addOperation(std::move(bulk));
}

ir::OperandIndex TrixLoader::loadOperandFromInput(uint32_t idx, ir::Graph &subg)
{
  // Shape
  ir::Shape shape;
  for (uint32_t d = 0; d < MAX_RANK; ++d)
    shape.append(_meta.input_seg_dims(idx, d));

  // TypeInfo
  ir::TypeInfo type_info(toDataType(_meta.input_seg_quant_type(idx)),
                         _meta.input_seg_quant_scale(idx), _meta.input_seg_quant_zp(idx));

  _origin_input_shapes.push_back(shape);
  // Create operand
  const auto operand_index = subg.addOperand(shape, type_info);
  return operand_index;
}

ir::OperandIndex TrixLoader::loadOperandFromOutput(uint32_t idx, ir::Graph &subg)
{
  // Shape
  ir::Shape shape;
  for (uint32_t d = 0; d < MAX_RANK; ++d)
    shape.append(_meta.output_seg_dims(idx, d));

  // TypeInfo
  ir::TypeInfo type_info(toDataType(_meta.output_seg_quant_type(idx)),
                         _meta.output_seg_quant_scale(idx), _meta.output_seg_quant_zp(idx));

  _origin_output_shapes.push_back(shape);
  // Create operand
  const auto operand_index = subg.addOperand(shape, type_info);
  return operand_index;
}

void TrixLoader::loadOperands(ir::Graph &subg)
{
  auto in_num = _meta.input_seg_num();
  for (uint32_t i = 0; i < in_num; ++i)
  {
    loadOperandFromInput(i, subg);
  }
  auto out_num = _meta.output_seg_num();
  for (uint32_t i = 0; i < out_num; ++i)
  {
    loadOperandFromOutput(i, subg);
  }
}

std::unique_ptr<ir::Graph> TrixLoader::loadSubgraph()
{
  auto subg = std::make_unique<ir::Graph>();
  _meta.init(_model_path.c_str());

  // Load tensors
  loadOperands(*subg);

  // Set inputs
  for (uint32_t i = 0; i < _meta.input_seg_num(); ++i)
  {
    subg->addInput(inputIdxToOperandIdx(i), "tvn_input" + std::to_string(i));
  }
  // Set outputs
  for (uint32_t i = 0; i < _meta.output_seg_num(); ++i)
  {
    subg->addOutput(outputIdxToOperandIdx(i), "tvn_out" + std::to_string(i));
  }
  // Create operations
  loadBulk(*subg);

  // TODO: NHWC only supported at this moment.
  subg->setLayout(ir::Layout::NHWC);
  subg->verify();
  return subg;
}

void TrixLoader::loadModel()
{
  // one subgraph only
  auto subg = loadSubgraph();
  _model->push(ir::SubgraphIndex(0), std::move(subg));
}

void TrixLoader::loadFromFile(const std::string &file_path)
{
  // model path will be used to set Bulk param
  _model_path = file_path;
  // metadata is initialized from model path since it is loadFromFile
  _meta.init(_model_path.c_str());
  loadModel();
}

std::unique_ptr<ir::Model> loadModel(const std::string &filename)
{
  auto model = std::make_unique<ir::Model>();
  TrixLoader loader(model);
  loader.loadFromFile(filename);
  return model;
}
} // namespace trix_loader
} // namespace onert