diff options
Diffstat (limited to 'compiler/mir-onnx-importer/ONNXImporterImpl.cpp')
-rw-r--r-- | compiler/mir-onnx-importer/ONNXImporterImpl.cpp | 241 |
1 files changed, 241 insertions, 0 deletions
diff --git a/compiler/mir-onnx-importer/ONNXImporterImpl.cpp b/compiler/mir-onnx-importer/ONNXImporterImpl.cpp new file mode 100644 index 000000000..c33104198 --- /dev/null +++ b/compiler/mir-onnx-importer/ONNXImporterImpl.cpp @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ONNXImporterImpl.h" +#include "ONNXHelpers.h" +#include "ONNXOpRegistration.h" +#include "onnx/onnx.pb.h" + +#include "mir/Shape.h" +#include "mir/TensorUtil.h" + +#include "mir/ops/ConstantOp.h" + +#include <fcntl.h> + +#include <google/protobuf/io/zero_copy_stream_impl.h> +#include <google/protobuf/io/coded_stream.h> +#include <google/protobuf/text_format.h> +#include <functional> +#include <iostream> +#include <stdex/Memory.h> +#include <utility> + +namespace mir_onnx +{ + +namespace +{ + +class ONNXImporterImpl final +{ +public: + ONNXImporterImpl(); + ~ONNXImporterImpl(); + /// @brief Load the model and convert it into a MIR Graph. + std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename); + std::unique_ptr<mir::Graph> importModelFromTextFile(const std::string &filename); + +private: + std::unique_ptr<mir::Graph> createIR(); + void createGraphInputs(); + void collectUnsupportedOps(); + std::unique_ptr<onnx::ModelProto> _model; + std::unique_ptr<ConverterContext> _converterCtx; + std::unique_ptr<ModelContext> _modelCtx; + std::unique_ptr<mir::Graph> _graph; +}; + +ONNXImporterImpl::ONNXImporterImpl() { registerSupportedOps(); } + +ONNXImporterImpl::~ONNXImporterImpl() = default; + +void loadModelFromBinaryFile(const std::string &filename, onnx::ModelProto *model) +{ + GOOGLE_PROTOBUF_VERIFY_VERSION; + + int file_handle = open(filename.c_str(), O_RDONLY); + + if (file_handle == -1) + throw std::runtime_error("Couldn't open file \"" + filename + "\": " + std::strerror(errno) + + "."); + + google::protobuf::io::FileInputStream file_stream(file_handle); + file_stream.SetCloseOnDelete(true); + + google::protobuf::io::CodedInputStream coded_stream(&file_stream); + coded_stream.SetTotalBytesLimit(INT_MAX, INT_MAX); + + if (!model->ParseFromCodedStream(&coded_stream)) + throw std::runtime_error("Couldn't parse file \"" + filename + "\"."); + + // If the file has not been consumed entirely, assume that the file is in the wrong format. + if (!coded_stream.ConsumedEntireMessage()) + throw std::runtime_error("File \"" + filename + "\" has not been consumed entirely."); +} + +void loadModelFromTextFile(const std::string &filename, onnx::ModelProto *model) +{ + GOOGLE_PROTOBUF_VERIFY_VERSION; + + int file_handle = open(filename.c_str(), O_RDONLY); + + if (file_handle == -1) + throw std::runtime_error("Couldn't open file \"" + filename + "\": " + std::strerror(errno) + + "."); + + google::protobuf::io::FileInputStream file_stream(file_handle); + file_stream.SetCloseOnDelete(true); + + if (!google::protobuf::TextFormat::Parse(&file_stream, model)) + throw std::runtime_error("Couldn't parse file \"" + filename + "\"."); +} + +std::unique_ptr<mir::Graph> ONNXImporterImpl::importModelFromBinaryFile(const std::string &filename) +{ + _model = stdex::make_unique<onnx::ModelProto>(); + loadModelFromBinaryFile(filename, _model.get()); + _modelCtx = stdex::make_unique<ModelContext>(_model.get()); + collectUnsupportedOps(); + return createIR(); +} + +std::unique_ptr<mir::Graph> ONNXImporterImpl::importModelFromTextFile(const std::string &filename) +{ + _model = stdex::make_unique<onnx::ModelProto>(); + loadModelFromTextFile(filename, _model.get()); + _modelCtx = stdex::make_unique<ModelContext>(_model.get()); + collectUnsupportedOps(); + return createIR(); +} + +void ONNXImporterImpl::collectUnsupportedOps() +{ + std::set<std::pair<std::string, int64_t>> problems_op_set; + + for (int i = 0; i < _model->graph().node_size(); i++) + { + const auto &onnx_node = _model->graph().node(i); + assert(onnx_node.has_op_type()); + const auto &op_type = onnx_node.op_type(); + auto opset = _modelCtx->getDomainOpsetVersion(onnx_node.domain()); + + NodeConverterRegistry::ConverterFunc converter = + NodeConverterRegistry::getInstance().lookup(op_type, opset); + + if (converter == nullptr) + problems_op_set.emplace(op_type, opset); + } + if (!problems_op_set.empty()) + { + std::cerr << "The following operators are not supported:\n"; + for (const auto &op : problems_op_set) + std::cerr << op.first << " opset " << op.second << std::endl; + throw std::runtime_error("Unsupported operators found"); + } +} + +void ONNXImporterImpl::createGraphInputs() +{ + const auto &graph = _model->graph(); + const auto &initializer = graph.initializer(); + const auto &value_info = graph.value_info(); + + // Create all initializer Tensors + for (const auto &tensor : initializer) + { + const auto mir_tensor = createTensor(&tensor); + auto *op = _graph->create<mir::ops::ConstantOp>(mir_tensor); + _converterCtx->setOutput(tensor.name(), op->getOutput(0)); + } + + for (const auto &input : graph.input()) + { + assert(input.has_name()); + + if (_converterCtx->getOutput(input.name()) == nullptr) + { + const auto &onnx_input_shape = input.type().tensor_type().shape(); + mir::Shape shape(onnx_input_shape.dim_size()); + for (int i = 0; i < onnx_input_shape.dim_size(); i++) + { + assert(onnx_input_shape.dim(i).has_dim_value()); + shape.dim(i) = static_cast<int32_t>(onnx_input_shape.dim(i).dim_value()); + } + + auto elem_type = onnxDataTypeToMirDataType( + (onnx::TensorProto_DataType)input.type().tensor_type().elem_type()); + mir::TensorType type{elem_type, shape}; + auto *op = _graph->create<mir::ops::InputOp>(type); + _converterCtx->setOutput(input.name(), op->getOutput(0)); + } + } +} + +std::unique_ptr<mir::Graph> ONNXImporterImpl::createIR() +{ + _graph = stdex::make_unique<mir::Graph>(); + _converterCtx = stdex::make_unique<ConverterContext>(_graph.get()); + + createGraphInputs(); + + // Forming partially ordered computation graph + for (const auto &onnx_node : _model->graph().node()) + { + assert(onnx_node.has_op_type()); + auto &op_type = onnx_node.op_type(); + auto opset = _modelCtx->getDomainOpsetVersion(onnx_node.domain()); + // Get converter + NodeConverterRegistry::ConverterFunc converter = + NodeConverterRegistry::getInstance().lookup(op_type, opset); + assert(converter != nullptr); + converter(onnx_node, _converterCtx.get()); + } + // Set graph outputs + const auto &outputs = _model->graph().output(); + for (const auto &output : outputs) + { + assert(output.has_name()); + auto mir_output = _converterCtx->getOutput(output.name()); + if (mir_output == nullptr) + throw std::runtime_error("Bad output name!"); + + _graph->create<mir::ops::OutputOp>(mir_output); + } + + return std::move(_graph); +} + +} // namespace + +std::unique_ptr<mir::Graph> importModelFromBinaryFile(const std::string &filename) +{ + ONNXImporterImpl importer; + return importer.importModelFromBinaryFile(filename); +} + +std::unique_ptr<mir::Graph> importModelFromTextFile(const std::string &filename) +{ + ONNXImporterImpl importer; + return importer.importModelFromTextFile(filename); +} + +std::unique_ptr<mir::Graph> loadModel(const std::string &filename) +{ + return importModelFromBinaryFile(filename); +} + +} // namespace mir_onnx |