diff options
Diffstat (limited to 'libs/util/src/tensor/Shape.cpp')
-rw-r--r-- | libs/util/src/tensor/Shape.cpp | 55 |
1 files changed, 54 insertions, 1 deletions
diff --git a/libs/util/src/tensor/Shape.cpp b/libs/util/src/tensor/Shape.cpp index d177d1382..f1de26fdc 100644 --- a/libs/util/src/tensor/Shape.cpp +++ b/libs/util/src/tensor/Shape.cpp @@ -16,6 +16,8 @@ #include "util/tensor/Shape.h" +#include <cassert> + namespace nnfw { namespace util @@ -32,7 +34,7 @@ bool operator==(const Shape &lhs, const Shape &rhs) for (size_t axis = 0; axis < lhs.rank(); ++axis) { - if(lhs.dim(axis) != rhs.dim(axis)) + if (lhs.dim(axis) != rhs.dim(axis)) { return false; } @@ -41,6 +43,57 @@ bool operator==(const Shape &lhs, const Shape &rhs) return true; } +Shape Shape::from(const std::string &str) +{ + Shape shape(0); + + bool pending = false; + int value = 0; + + for (const char *cur = str.c_str(); true; ++cur) + { + if (*cur == ',' || *cur == '\0') + { + if (pending) + { + shape.append(value); + } + + if (*cur == '\0') + { + break; + } + + pending = false; + value = 0; + continue; + } + + assert(*cur >= '0' && *cur <= '9'); + + pending = true; + value *= 10; + value += *cur - '0'; + } + + return shape; +} + +std::ostream &operator<<(std::ostream &os, const Shape &shape) +{ + if (shape.rank() > 0) + { + os << shape.dim(0); + + for (uint32_t axis = 1; axis < shape.rank(); ++axis) + { + os << "," << shape.dim(axis); + } + } + + return os; +} + } // namespace tensor } // namespace util } // namespace nnfw |