diff options
author | 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com> | 2019-09-16 15:28:44 +0900 |
---|---|---|
committer | GitHub Enterprise <noreply-CODE@samsung.com> | 2019-09-16 15:28:44 +0900 |
commit | cb6523d3a81d391f77ca1e0a7a92c6f65a5720e4 (patch) | |
tree | cc07e2147553863629a6f89de3b36b768aebc097 | |
parent | 44096cbda3b475f3f6b5b12a5c98ca41436f263e (diff) | |
download | nnfw-cb6523d3a81d391f77ca1e0a7a92c6f65a5720e4.tar.gz nnfw-cb6523d3a81d391f77ca1e0a7a92c6f65a5720e4.tar.bz2 nnfw-cb6523d3a81d391f77ca1e0a7a92c6f65a5720e4.zip |
[locop] Show TensorShape if possible (#7437)
With this change, Linear V1 Graph Formtter now shows some details of
TensorShape if possible.
Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
-rw-r--r-- | compiler/locop/include/locop/FormattedTensorShape.h | 18 | ||||
-rw-r--r-- | compiler/locop/src/FormattedGraph.cpp | 33 | ||||
-rw-r--r-- | compiler/locop/src/FormattedTensorShape.cpp | 18 |
3 files changed, 67 insertions, 2 deletions
diff --git a/compiler/locop/include/locop/FormattedTensorShape.h b/compiler/locop/include/locop/FormattedTensorShape.h index cebd984d3..25621d6c3 100644 --- a/compiler/locop/include/locop/FormattedTensorShape.h +++ b/compiler/locop/include/locop/FormattedTensorShape.h @@ -26,6 +26,8 @@ namespace locop enum class TensorShapeFormat { + // D_0 x D_1 x ... D_N + Plain, // [ D_0 x D_1 x D_2 x ... ] Bracket, }; @@ -33,6 +35,22 @@ enum class TensorShapeFormat template <TensorShapeFormat Format> class FormattedTensorShape; template <> +class FormattedTensorShape<TensorShapeFormat::Plain> final : public Spec<Interface::Formatted> +{ +public: + FormattedTensorShape(const loco::TensorShape *ptr) : _ptr{ptr} + { + // DO NOTHING + } + +public: + void dump(std::ostream &os) const final; + +private: + const loco::TensorShape *_ptr = nullptr; +}; + +template <> class FormattedTensorShape<TensorShapeFormat::Bracket> final : public Spec<Interface::Formatted> { public: diff --git a/compiler/locop/src/FormattedGraph.cpp b/compiler/locop/src/FormattedGraph.cpp index ef483ba26..04ca7c46f 100644 --- a/compiler/locop/src/FormattedGraph.cpp +++ b/compiler/locop/src/FormattedGraph.cpp @@ -57,6 +57,31 @@ std::string str(const loco::Domain &domain) throw std::invalid_argument{"domain"}; } +std::string str(const loco::NodeShape &node_shape) +{ + using namespace locop; + + switch (node_shape.domain()) + { + case loco::Domain::Tensor: + { + auto tensor_shape = node_shape.as<loco::TensorShape>(); + return pp::fmt(locop::fmt<TensorShapeFormat::Plain>(&tensor_shape)); + } + // TODO Show details + case loco::Domain::Feature: + case loco::Domain::Filter: + case loco::Domain::DepthwiseFilter: + case loco::Domain::Bias: + return "..."; + + default: + break; + } + + throw std::invalid_argument{"domain"}; +} + // TODO Use locop::fmt<TensorShapeFormat ...> locop::FormattedTensorShape<locop::TensorShapeFormat::Bracket> formatted_tensor_shape(const loco::TensorShape *ptr) @@ -343,9 +368,13 @@ void FormattedGraphImpl<Formatter::LinearV1>::dump(std::ostream &os) const if (loco::shape_known(node)) { auto node_shape = loco::shape_get(node); - os << " : " << str(node_shape.domain()) << "(...)"; + os << " : " << str(node_shape.domain()); + os << "<"; + os << str(node_shape); + os << ", "; + os << "?"; // TODO Show DataType - // TODO Show Shape details + os << ">"; } os << " = " << node_summary << std::endl; diff --git a/compiler/locop/src/FormattedTensorShape.cpp b/compiler/locop/src/FormattedTensorShape.cpp index 2741dd7a9..b2b6ea074 100644 --- a/compiler/locop/src/FormattedTensorShape.cpp +++ b/compiler/locop/src/FormattedTensorShape.cpp @@ -30,6 +30,24 @@ std::ostream &operator<<(std::ostream &os, const loco::Dimension &d) namespace locop { +void FormattedTensorShape<TensorShapeFormat::Plain>::dump(std::ostream &os) const +{ + if (_ptr->rank() > 0) + { + os << _ptr->dim(0); + + for (uint32_t axis = 1; axis < _ptr->rank(); ++axis) + { + os << " x " << _ptr->dim(axis); + } + } +} + +} // namespace locop + +namespace locop +{ + void FormattedTensorShape<TensorShapeFormat::Bracket>::dump(std::ostream &os) const { os << "["; |