summaryrefslogtreecommitdiff
path: root/compiler/nnc/driver/Driver.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/nnc/driver/Driver.cpp')
-rw-r--r--compiler/nnc/driver/Driver.cpp216
1 files changed, 216 insertions, 0 deletions
diff --git a/compiler/nnc/driver/Driver.cpp b/compiler/nnc/driver/Driver.cpp
new file mode 100644
index 000000000..5b369623e
--- /dev/null
+++ b/compiler/nnc/driver/Driver.cpp
@@ -0,0 +1,216 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "pass/PassData.h"
+
+#include "passes/transformations/DataFormatSwitcher.h"
+#include "passes/transformations/LowerConv2D.h"
+
+#include "backends/interpreter/InterpreterBackend.h"
+#include "backends/soft_backend/CPPGenerator.h"
+#include "passes/dot_dumper/DumperPass.h"
+#include "backends/acl_soft_backend/AclCppGenerator.h"
+
+#include "passes/optimizations/CombineTransposes.h"
+#include "passes/optimizations/ConstantFoldTranspose.h"
+#include "passes/optimizations/DeadCodeElimination.h"
+#include "passes/optimizations/FuseArithmeticOps.h"
+#include "passes/optimizations/SinkRelu.h"
+#include "passes/optimizations/SinkTranspose.h"
+
+#include "support/CommandLine.h"
+#include "Definitions.h"
+#include "Options.h"
+#include "Driver.h"
+
+#ifdef NNC_FRONTEND_CAFFE2_ENABLED
+#include <caffe2_importer.h>
+#endif // NNC_FRONTEND_CAFFE2_ENABLED
+#ifdef NNC_FRONTEND_CAFFE_ENABLED
+#include <caffe_importer.h>
+#endif // NNC_FRONTEND_CAFFE_ENABLED
+#ifdef NNC_FRONTEND_TFLITE_ENABLED
+#include <tflite_importer.h>
+#endif // NNC_FRONTEND_TFLITE_ENABLED
+#ifdef NNC_FRONTEND_ONNX_ENABLED
+#include <ONNXImporterImpl.h>
+#endif // NNC_FRONTEND_ONNX_ENABLED
+
+#include <stdex/Memory.h>
+
+namespace nnc
+{
+
+static std::string getFrontendOptionsString()
+{
+ std::string res;
+
+ if (!cli::caffeFrontend.isDisabled())
+ res += "'" + cli::caffeFrontend.getNames()[0] + "' ";
+
+ if (!cli::caffe2Frontend.isDisabled())
+ res += "'" + cli::caffe2Frontend.getNames()[0] + "' ";
+
+ if (!cli::onnxFrontend.isDisabled())
+ res += "'" + cli::onnxFrontend.getNames()[0] + "' ";
+
+ if (!cli::tflFrontend.isDisabled())
+ res += "'" + cli::tflFrontend.getNames()[0] + "'";
+
+ return res;
+}
+
+static std::unique_ptr<mir::Graph> importModel()
+{
+ // For bool, the value false is converted to zero and the value true is converted to one
+ if (cli::caffeFrontend + cli::caffe2Frontend + cli::tflFrontend + cli::onnxFrontend != 1)
+ throw DriverException("One and only one of the following options are allowed and have to be set"
+ "in the same time: " +
+ getFrontendOptionsString());
+
+ if (cli::caffeFrontend)
+ {
+#ifdef NNC_FRONTEND_CAFFE_ENABLED
+ return mir_caffe::loadModel(cli::inputFile.getRawValue());
+#endif // NNC_FRONTEND_CAFFE_ENABLED
+ }
+ else if (cli::caffe2Frontend)
+ {
+#ifdef NNC_FRONTEND_CAFFE2_ENABLED
+ // FIXME: caffe2 input shapes are not provided by model and must be set from cli
+ // current 'inputShapes' could provide only one shape, while model could has several inputs
+ return mir_caffe2::loadModel(cli::inputFile.getRawValue(), cli::initNet.getRawValue(),
+ {cli::inputShapes.getRawValue()});
+#endif // NNC_FRONTEND_CAFFE2_ENABLED
+ }
+ else if (cli::onnxFrontend)
+ {
+#ifdef NNC_FRONTEND_ONNX_ENABLED
+ return mir_onnx::loadModel(cli::inputFile.getRawValue());
+#endif // NNC_FRONTEND_ONNX_ENABLED
+ }
+ else if (cli::tflFrontend)
+ {
+#ifdef NNC_FRONTEND_TFLITE_ENABLED
+ return mir_tflite::loadModel(cli::inputFile.getRawValue());
+#endif // NNC_FRONTEND_TFLITE_ENABLED
+ }
+
+ assert(false);
+ return nullptr;
+}
+
+static void backend(mir::Graph *graph)
+{
+ if (cli::target == NNC_TARGET_ARM_CPP || cli::target == NNC_TARGET_X86_CPP)
+ {
+ CPPCodeGenerator(cli::artifactDir, cli::artifactName).run(graph);
+ }
+ else if (cli::target == NNC_TARGET_ARM_GPU_CPP)
+ {
+ AclCppCodeGenerator(cli::artifactDir, cli::artifactName).run(graph);
+ }
+ else if (cli::target == NNC_TARGET_INTERPRETER)
+ {
+ InterpreterBackend(cli::interInputDataDir, cli::artifactDir).run(graph);
+ }
+ else
+ {
+ assert(false && "invalid option value");
+ }
+}
+
+/**
+ * @brief run all registered passes
+ * @throw PassException, if errors occured
+ */
+void Driver::runPasses()
+{
+ auto graph = importModel();
+ PassData pass_data(graph.get());
+
+ for (const auto &pass : _passManager.getPasses())
+ {
+ pass_data = pass->run(pass_data);
+ if (cli::dumpGraph && static_cast<mir::Graph *>(pass_data))
+ {
+ DumperPass d(pass->getName());
+ d.run(pass_data);
+ }
+ }
+
+ backend(pass_data);
+
+ // NOTE. Now we destroy data of all passes when PassManager is destroyed.
+ // In future to reduce memory consumption we can destory it when passes are being performed
+
+} // runPasses
+
+/**
+ * @brief Register backend specific passes
+ * @throw DriverException if errors occurred
+ */
+void Driver::registerBackendSpecificPasses()
+{
+ std::unique_ptr<Pass> data_format_pass;
+
+ if (cli::target == NNC_TARGET_ARM_CPP || cli::target == NNC_TARGET_X86_CPP)
+ {
+ _passManager.registerPass(stdex::make_unique<LowerConv2D>());
+ _passManager.registerPass(stdex::make_unique<DataFormatSwitcher>(mir::DataFormat::NHWC));
+ }
+ else if (cli::target == NNC_TARGET_ARM_GPU_CPP)
+ {
+ _passManager.registerPass(stdex::make_unique<LowerConv2D>());
+ _passManager.registerPass(stdex::make_unique<ConstantFoldTranspose>());
+ // TODO Change to DataFormat::NCHW when fix it in ACL
+ _passManager.registerPass(stdex::make_unique<DataFormatSwitcher>(mir::DataFormat::NHWC));
+ }
+ else if (cli::target == NNC_TARGET_INTERPRETER)
+ {
+ _passManager.registerPass(stdex::make_unique<DataFormatSwitcher>(mir::DataFormat::NHWC));
+ }
+ else
+ {
+ assert(false && "invalid option value");
+ }
+}
+
+void Driver::registerOptimizationPass()
+{
+ if (cli::doOptimizationPass)
+ {
+ // TODO: maybe we should start managing the optimizations more intelligently?
+ _passManager.registerPass(std::unique_ptr<Pass>(new CombineTransposes()));
+ _passManager.registerPass(std::unique_ptr<Pass>(new SinkTranspose()));
+ _passManager.registerPass(std::unique_ptr<Pass>(new SinkRelu()));
+#if 0
+ // TODO Support broadcasting.
+ _passManager.registerPass(std::unique_ptr<Pass>(new FuseArithmeticOps()));
+#endif
+ _passManager.registerPass(std::unique_ptr<Pass>(new DeadCodeElimination()));
+ }
+} // registerOptimizationPass
+
+void Driver::runDriver()
+{
+ registerOptimizationPass();
+ registerBackendSpecificPasses();
+
+ runPasses();
+}
+
+} // namespace nnc