summaryrefslogtreecommitdiff
path: root/compiler/luci/tester/src/WriteTester.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/tester/src/WriteTester.cpp')
-rw-r--r--compiler/luci/tester/src/WriteTester.cpp142
1 files changed, 142 insertions, 0 deletions
diff --git a/compiler/luci/tester/src/WriteTester.cpp b/compiler/luci/tester/src/WriteTester.cpp
new file mode 100644
index 000000000..80019d1b1
--- /dev/null
+++ b/compiler/luci/tester/src/WriteTester.cpp
@@ -0,0 +1,142 @@
+#include "Model.h"
+
+#include <luci/Importer.h>
+#include <luci/Pass/ShapeInferencePass.h>
+#include <luci/Pass/TypeInferencePass.h>
+#include <luci/Service/Validate.h>
+#include <luci/CircleExporter.h>
+#include <oops/InternalExn.h>
+
+#include <fstream>
+#include <iostream>
+#include <map>
+#include <string>
+
+namespace
+{
+
+void show_help_message(const char *progname, std::ostream &os)
+{
+ os << "USAGE: " << progname << " circlefile_in circlefile_out" << std::endl << std::endl;
+}
+
+void show_error_message(const char *progname, std::ostream &os, const std::string &msg)
+{
+ os << "ERROR: " << msg << std::endl;
+ os << std::endl;
+
+ show_help_message(progname, os);
+}
+
+struct CircleExpContract : public luci::CircleExporter::Contract
+{
+public:
+ CircleExpContract(loco::Graph *graph, const std::string &filename)
+ : _graph(graph), _filepath(filename)
+ {
+ // NOTHING TO DO
+ }
+ CircleExpContract(luci::Module *module, const std::string &filename)
+ : _module(module), _filepath(filename)
+ {
+ // NOTHING TO DO
+ }
+ virtual ~CircleExpContract() = default;
+
+public:
+ loco::Graph *graph(void) const final { return _graph; }
+
+ luci::Module *module(void) const final { return _module; }
+
+public:
+ bool store(const char *ptr, const size_t size) const final;
+
+private:
+ loco::Graph *_graph;
+ luci::Module *_module;
+ const std::string _filepath;
+};
+
+bool CircleExpContract::store(const char *ptr, const size_t size) const
+{
+ if (!ptr)
+ INTERNAL_EXN("Graph was not serialized by FlatBuffer for some reason");
+
+ std::ofstream fs(_filepath.c_str(), std::ofstream::binary);
+ fs.write(ptr, size);
+
+ return fs.good();
+}
+
+} // namespace
+
+/*
+ * @brief WriteTester main
+ *
+ * Give two Circle file as an argument
+ *
+ * This will use luci_import to read the first file and get loco graph
+ * With the graph, this will use luci_export to write to the second file
+ * Like ReadTester, LUCI_LOG=1 environment variable is available to dump the graph
+ */
+int main(int argc, char **argv)
+{
+ if (argc != 3)
+ {
+ show_error_message(argv[0], std::cerr, "In/Out Circle file path is not specified");
+ return 255;
+ }
+
+ std::string input_path = argv[1];
+ std::string output_path = argv[2];
+
+ std::cout << "[INFO] Circle from '" << input_path << "' to '" << output_path << "'" << std::endl;
+
+ // Load model from the file
+ std::unique_ptr<luci::Model> model = luci::load_model(input_path);
+ if (model == nullptr)
+ {
+ std::cerr << "ERROR: Failed to load '" << input_path << "'" << std::endl;
+ return 255;
+ }
+
+ const circle::Model *input_model = model->model();
+ if (input_model == nullptr)
+ {
+ std::cerr << "ERROR: Failed to read '" << input_path << "'" << std::endl;
+ return 255;
+ }
+
+ // Import from input Circle file
+ luci::Importer importer;
+ auto module = importer.importModule(input_model);
+ assert(module->size() > 0);
+
+ for (size_t g = 0; g < module->size(); ++g)
+ {
+ auto graph = module->graph(g);
+ if (graph == nullptr)
+ return 255;
+
+ {
+ luci::ShapeInferencePass pass;
+ while (pass.run(graph) == true)
+ ;
+ }
+ {
+ luci::TypeInferencePass pass;
+ while (pass.run(graph) == true)
+ ;
+ }
+
+ if (!luci::validate(graph))
+ return 255;
+ }
+
+ // Export to output Circle file
+ luci::CircleExporter exporter;
+
+ CircleExpContract contract(module.get(), output_path);
+
+ return exporter.invoke(&contract) ? 0 : 255;
+}