summaryrefslogtreecommitdiff
path: root/compiler/locop/src/FormattedGraph.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/locop/src/FormattedGraph.cpp')
-rw-r--r--compiler/locop/src/FormattedGraph.cpp390
1 files changed, 390 insertions, 0 deletions
diff --git a/compiler/locop/src/FormattedGraph.cpp b/compiler/locop/src/FormattedGraph.cpp
new file mode 100644
index 000000000..84de1e888
--- /dev/null
+++ b/compiler/locop/src/FormattedGraph.cpp
@@ -0,0 +1,390 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "locop/FormattedGraph.h"
+#include "locop/FormattedTensorShape.h"
+#include "locop/GenericNodeSummaryBuilder.h"
+
+#include <loco/Service/TypeInference.h>
+#include <loco/Service/ShapeInference.h>
+
+#include <pp/Format.h>
+
+#include <stdex/Memory.h>
+
+#include <map>
+#include <set>
+
+#include <cassert>
+
+using locop::SymbolTable;
+
+namespace
+{
+
+std::string str(const loco::DataType &dtype)
+{
+ switch (dtype)
+ {
+ case loco::DataType::Unknown:
+ return "Unknown";
+
+ case loco::DataType::U8:
+ return "U8";
+ case loco::DataType::U16:
+ return "U16";
+ case loco::DataType::U32:
+ return "U32";
+ case loco::DataType::U64:
+ return "U64";
+
+ case loco::DataType::S8:
+ return "S8";
+ case loco::DataType::S16:
+ return "S16";
+ case loco::DataType::S32:
+ return "S32";
+ case loco::DataType::S64:
+ return "S64";
+
+ case loco::DataType::FLOAT16:
+ return "FLOAT16";
+ case loco::DataType::FLOAT32:
+ return "FLOAT32";
+ case loco::DataType::FLOAT64:
+ return "FLOAT64";
+
+ default:
+ break;
+ };
+
+ throw std::invalid_argument{"dtype"};
+}
+
+std::string str(const loco::Domain &domain)
+{
+ // TODO Generate!
+ switch (domain)
+ {
+ case loco::Domain::Unknown:
+ return "Unknown";
+ case loco::Domain::Tensor:
+ return "Tensor";
+ case loco::Domain::Feature:
+ return "Feature";
+ case loco::Domain::Filter:
+ return "Filter";
+ case loco::Domain::DepthwiseFilter:
+ return "DWFilter";
+ case loco::Domain::Bias:
+ return "Bias";
+ default:
+ break;
+ }
+
+ throw std::invalid_argument{"domain"};
+}
+
+std::string str(const loco::NodeShape &node_shape)
+{
+ using namespace locop;
+
+ switch (node_shape.domain())
+ {
+ case loco::Domain::Tensor:
+ {
+ auto tensor_shape = node_shape.as<loco::TensorShape>();
+ return pp::fmt(locop::fmt<TensorShapeFormat::Plain>(&tensor_shape));
+ }
+ // TODO Show details
+ case loco::Domain::Feature:
+ case loco::Domain::Filter:
+ case loco::Domain::DepthwiseFilter:
+ case loco::Domain::Bias:
+ return "...";
+
+ default:
+ break;
+ }
+
+ throw std::invalid_argument{"domain"};
+}
+
+// TODO Use locop::fmt<TensorShapeFormat ...>
+locop::FormattedTensorShape<locop::TensorShapeFormat::Bracket>
+formatted_tensor_shape(const loco::TensorShape *ptr)
+{
+ return locop::FormattedTensorShape<locop::TensorShapeFormat::Bracket>{ptr};
+}
+
+} // namespace
+
+namespace
+{
+
+struct NodeDesc : public locop::NodeDesc
+{
+public:
+ NodeDesc() = default;
+ NodeDesc(const locop::OpName &opname) : locop::NodeDesc{opname}
+ {
+ // DO NOTHING
+ }
+
+public:
+ // DEPRECATED
+ const locop::OpName &name(void) const { return opname(); }
+
+ // DEPRECATED
+ uint32_t arg_size(void) const { return args().count(); }
+ // DEPRECATED
+ const locop::ArgElem &arg(uint32_t n) const { return args().at(n); }
+ // DEPRECATED
+ void arg(const locop::ArgName &name, const locop::ArgValue &value) { args().append(name, value); }
+};
+
+} // namespace
+
+// TODO Remove this workaround
+namespace locop
+{
+
+std::ostream &operator<<(std::ostream &os, const NodeDesc &d)
+{
+ assert(d.state() != NodeDesc::State::Invalid);
+
+ std::vector<std::string> values;
+
+ for (uint32_t n = 0; n < d.args().count(); ++n)
+ {
+ values.emplace_back(d.args().at(n).first + ": " + d.args().at(n).second);
+ }
+
+ if (d.state() == NodeDesc::State::PartiallyKnown)
+ {
+ values.emplace_back("...");
+ }
+
+ os << d.opname();
+ os << "(";
+ if (values.size() > 0)
+ {
+ os << values.at(0);
+ for (uint32_t n = 1; n < values.size(); ++n)
+ {
+ os << ", " << values.at(n);
+ }
+ }
+ os << ")";
+
+ return os;
+}
+
+} // namespace locop
+
+namespace locop
+{
+
+std::ostream &operator<<(std::ostream &os, const FormattedGraph &fmt)
+{
+ fmt.dump(os);
+ return os;
+}
+
+} // namespace locop
+
+namespace locop
+{
+
+void FormattedGraphImpl<Formatter::LinearV1>::dump(std::ostream &os) const
+{
+ struct SymbolTableImpl final : public SymbolTable
+ {
+ std::string lookup(const loco::Node *node) const final
+ {
+ if (node == nullptr)
+ {
+ return "(null)";
+ }
+
+ return _content.at(node);
+ }
+
+ std::map<const loco::Node *, std::string> _content;
+ };
+
+ SymbolTableImpl symbols;
+
+ auto symbol = [&symbols](const loco::Node *node) { return symbols.lookup(node); };
+
+ for (uint32_t n = 0; n < _graph->nodes()->size(); ++n)
+ {
+ symbols._content[_graph->nodes()->at(n)] = pp::fmt("%", n);
+ }
+
+ // Find the disjoint node clusters
+ //
+ // TODO Move this implementation into loco Algorithms.h
+ std::map<loco::Node *, loco::Node *> parents;
+
+ for (auto node : loco::all_nodes(_graph))
+ {
+ parents[node] = nullptr;
+ }
+
+ for (auto node : loco::all_nodes(_graph))
+ {
+ for (uint32_t n = 0; n < node->arity(); ++n)
+ {
+ if (auto arg = node->arg(n))
+ {
+ parents[arg] = node;
+ }
+ }
+ }
+
+ auto find = [&parents](loco::Node *node) {
+ loco::Node *cur = node;
+
+ while (parents.at(cur) != nullptr)
+ {
+ cur = parents.at(cur);
+ }
+
+ return cur;
+ };
+
+ std::set<loco::Node *> roots;
+
+ for (auto node : loco::all_nodes(_graph))
+ {
+ roots.insert(find(node));
+ }
+
+ std::map<loco::Node *, std::set<loco::Node *>> clusters;
+
+ // Create clusters
+ for (auto root : roots)
+ {
+ clusters[root] = std::set<loco::Node *>{};
+ }
+
+ for (auto node : loco::all_nodes(_graph))
+ {
+ clusters.at(find(node)).insert(node);
+ }
+
+ std::unique_ptr<locop::NodeSummaryBuilder> node_summary_builder;
+
+ if (_factory)
+ {
+ // Use User-defined NodeSummaryBuilder if NodeSummaryBuilderFactory is present
+ node_summary_builder = _factory->create(&symbols);
+ }
+ else
+ {
+ // Use Built-in NodeSummaryBuilder otherwise
+ node_summary_builder = stdex::make_unique<GenericNodeSummaryBuilder>(&symbols);
+ }
+
+ // Print Graph Input(s)
+ for (uint32_t n = 0; n < _graph->inputs()->size(); ++n)
+ {
+ auto input = _graph->inputs()->at(n);
+
+ std::string name = input->name();
+
+ std::string shape = "?";
+ if (input->shape() != nullptr)
+ {
+ shape = pp::fmt(formatted_tensor_shape(input->shape()));
+ }
+
+ // TODO Print dtype
+ os << pp::fmt("In #", n, " { name: ", name, ", shape: ", shape, " }") << std::endl;
+ }
+
+ // Print Graph Output(s)
+ for (uint32_t n = 0; n < _graph->outputs()->size(); ++n)
+ {
+ auto output = _graph->outputs()->at(n);
+
+ std::string name = output->name();
+
+ std::string shape = "?";
+ if (output->shape() != nullptr)
+ {
+ shape = pp::fmt(formatted_tensor_shape(output->shape()));
+ }
+
+ // TODO Print dtype
+ os << pp::fmt("Out #", n, " { name: ", name, ", shape: ", shape, " }") << std::endl;
+ }
+
+ if (_graph->inputs()->size() + _graph->outputs()->size() != 0)
+ {
+ os << std::endl;
+ }
+
+ for (auto it = clusters.begin(); it != clusters.end(); ++it)
+ {
+ std::vector<loco::Node *> cluster_outputs;
+
+ for (auto node : it->second)
+ {
+ // NOTE This is inefficient but anyway working :)
+ if (loco::succs(node).empty())
+ {
+ cluster_outputs.emplace_back(node);
+ }
+ }
+
+ for (auto node : loco::postorder_traversal(cluster_outputs))
+ {
+ locop::NodeSummary node_summary;
+
+ // Build a node summary
+ if (!node_summary_builder->build(node, node_summary))
+ {
+ throw std::runtime_error{"Fail to build a node summary"};
+ }
+
+ for (uint32_t n = 0; n < node_summary.comments().count(); ++n)
+ {
+ os << "; " << node_summary.comments().at(n) << std::endl;
+ }
+
+ os << symbol(node);
+
+ if (loco::shape_known(node))
+ {
+ auto node_shape = loco::shape_get(node);
+ os << " : " << str(node_shape.domain());
+ os << "<";
+ os << str(node_shape);
+ os << ", ";
+ // Show DataType
+ os << (loco::dtype_known(node) ? str(loco::dtype_get(node)) : std::string{"?"});
+ os << ">";
+ }
+
+ os << " = " << node_summary << std::endl;
+ }
+ os << std::endl;
+ }
+}
+
+} // namespace locop