summaryrefslogtreecommitdiff
path: root/compiler/moco/support/src/TFShapeInferenceHelper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/moco/support/src/TFShapeInferenceHelper.cpp')
-rw-r--r--compiler/moco/support/src/TFShapeInferenceHelper.cpp354
1 files changed, 354 insertions, 0 deletions
diff --git a/compiler/moco/support/src/TFShapeInferenceHelper.cpp b/compiler/moco/support/src/TFShapeInferenceHelper.cpp
new file mode 100644
index 000000000..13e514a78
--- /dev/null
+++ b/compiler/moco/support/src/TFShapeInferenceHelper.cpp
@@ -0,0 +1,354 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "moco/Support/TFShapeInferenceHelper.h"
+
+#include <loco/Service/ShapeInference.h>
+
+#include <oops/InternalExn.h>
+
+#include <cassert>
+
+namespace
+{
+
+// TODO Use codes in loco and remove duplicate broadcast_shape() and related
+/**
+ * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
+ *
+ * HOW TO USE:
+ *
+ * auto expanded_tensor_shape = expand(tensor_shape).to(N);
+ */
+class TensorShapeExpander
+{
+public:
+ TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
+ {
+ // DO NOTHING
+ }
+
+public:
+ loco::TensorShape to(uint32_t output_rank)
+ {
+ auto const &input_shape = _shape;
+ uint32_t const input_rank = input_shape.rank();
+
+ assert(input_rank <= output_rank && "Cannot shrink rank");
+ uint32_t const axis_shift = output_rank - input_rank;
+
+ loco::TensorShape output_shape;
+
+ output_shape.rank(output_rank);
+ for (uint32_t axis = 0; axis < output_rank; ++axis)
+ {
+ output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
+ }
+
+ return output_shape;
+ }
+
+private:
+ const loco::TensorShape _shape;
+};
+
+/**
+ * @breif Expand shape x and y to same rank by align right and filling with 1
+ */
+void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
+{
+ auto x_rank = x.rank();
+ auto y_rank = y.rank();
+
+ if (x_rank == y_rank)
+ return;
+
+ TensorShapeExpander x_exp(x);
+ TensorShapeExpander y_exp(y);
+
+ auto xy_rank = std::max(x_rank, y_rank);
+
+ x = x_rank > y_rank ? x : x_exp.to(xy_rank);
+ y = y_rank > x_rank ? y : y_exp.to(xy_rank);
+}
+
+/**
+ * @breif Returns shape of expanded dimension of input x and y having same rank
+ */
+loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
+{
+ assert(x.rank() == y.rank());
+
+ auto rank = x.rank();
+
+ loco::TensorShape output_shape;
+
+ output_shape.rank(rank);
+ for (uint32_t axis = 0; axis < rank; ++axis)
+ {
+ assert(x.dim(axis).known() && y.dim(axis).known());
+
+ auto x_dim = x.dim(axis).value();
+ auto y_dim = y.dim(axis).value();
+
+ // each dimension of x and y should be same or one must be 1 if different
+ if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
+ {
+ // TODO may need to refine message
+ INTERNAL_EXN("ShapeInference: Input shapes don't match");
+ }
+
+ output_shape.dim(axis) = std::max(x_dim, y_dim);
+ }
+
+ return output_shape;
+}
+
+} // namespace
+
+namespace moco
+{
+
+loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
+{
+ auto x_match = x;
+ auto y_match = y;
+
+ expand_rank(x_match, y_match);
+
+ auto output_shape = expand_dimension(x_match, y_match);
+
+ return output_shape;
+}
+
+} // namespace moco
+
+namespace moco
+{
+
+loco::NodeShape node_shape(const loco::Node *node)
+{
+ loco::NodeShape nodeshape; // default domain is Unknown
+
+ if (loco::shape_known(node))
+ {
+ nodeshape = loco::shape_get(node);
+ }
+
+ return nodeshape;
+}
+
+bool node_shape(const loco::Node *node, loco::NodeShape &nodeshape)
+{
+ nodeshape = node_shape(node);
+ return (nodeshape.domain() != loco::Domain::Unknown);
+}
+
+loco::TensorShape as_tensor_shape(const loco::FeatureShape &feature_shape,
+ const TFDataLayout &data_layout)
+{
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(4);
+ if (data_layout == "NHWC")
+ {
+ tensor_shape.dim(0) = feature_shape.count();
+ tensor_shape.dim(1) = feature_shape.height();
+ tensor_shape.dim(2) = feature_shape.width();
+ tensor_shape.dim(3) = feature_shape.depth();
+ }
+ else if (data_layout == "NCHW")
+ {
+ tensor_shape.dim(0) = feature_shape.count();
+ tensor_shape.dim(1) = feature_shape.depth();
+ tensor_shape.dim(2) = feature_shape.height();
+ tensor_shape.dim(3) = feature_shape.width();
+ }
+ else
+ {
+ // TODO support for other data_layout if needed
+ INTERNAL_EXN_V("ShapeInference: Unknown data_format", data_layout);
+ }
+
+ return tensor_shape;
+}
+
+loco::FeatureShape as_feature_shape(const loco::NodeShape &nodeshape,
+ const TFDataLayout &data_layout)
+{
+ if (nodeshape.domain() == loco::Domain::Feature)
+ return nodeshape.as<loco::FeatureShape>();
+
+ loco::FeatureShape feature_shape;
+
+ // only convert from tensor to feature
+ if (nodeshape.domain() != loco::Domain::Tensor)
+ {
+ INTERNAL_EXN("ShapeInference: Invalid shape information");
+ }
+
+ loco::TensorShape tensor_shape = nodeshape.as<loco::TensorShape>();
+
+ if (tensor_shape.rank() != 4)
+ {
+ INTERNAL_EXN("ShapeInference: Rank is not 4");
+ }
+
+ if (data_layout == "NHWC")
+ {
+ feature_shape.count() = tensor_shape.dim(0);
+ feature_shape.height() = tensor_shape.dim(1);
+ feature_shape.width() = tensor_shape.dim(2);
+ feature_shape.depth() = tensor_shape.dim(3);
+ }
+ else if (data_layout == "NCHW")
+ {
+ feature_shape.count() = tensor_shape.dim(0);
+ feature_shape.depth() = tensor_shape.dim(1);
+ feature_shape.height() = tensor_shape.dim(2);
+ feature_shape.width() = tensor_shape.dim(3);
+ }
+ else
+ {
+ // TODO support for other data_layout if needed
+ INTERNAL_EXN_V("ShapeInference: Unknown data_format", data_layout);
+ }
+
+ return feature_shape;
+}
+
+} // namespace moco
+
+namespace moco
+{
+
+PlaneShape make_plane_shape(const loco::FeatureShape &feature_shape)
+{
+ PlaneShape plane_shape;
+
+ plane_shape.height = feature_shape.height();
+ plane_shape.width = feature_shape.width();
+
+ return plane_shape;
+}
+
+FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
+{
+ return FeatureShapeUpdater{&feature_shape};
+}
+
+} // namespace moco
+
+namespace
+{
+
+/**
+ * @brief Class to represent TensorFlow "data_format" attr.
+ */
+enum class DataLayout
+{
+ NHWC,
+ NCHW,
+};
+
+DataLayout as_data_layout(const std::string &tf_layout_str)
+{
+ if (tf_layout_str == "NHWC")
+ return DataLayout::NHWC;
+ else if (tf_layout_str == "NCHW")
+ return DataLayout::NCHW;
+ else
+ /// @note data layout tag in TensorFlow is 'data_format'
+ INTERNAL_EXN_V("ShapeInference: Unknown data_format", tf_layout_str);
+}
+
+} // namespace
+
+namespace moco
+{
+
+loco::Stride<2> stride_of(const TFStrides &strides, const TFDataLayout &datalayout)
+{
+ loco::Stride<2> stride;
+
+ auto data_layout = as_data_layout(datalayout);
+ if (data_layout == DataLayout::NHWC)
+ {
+ stride.vertical(strides[1]);
+ stride.horizontal(strides[2]);
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ stride.vertical(strides[2]);
+ stride.horizontal(strides[3]);
+ }
+ else
+ {
+ // TODO add more datalayout supports if needed
+ INTERNAL_EXN("ShapeInference: Unknown data_format");
+ }
+
+ return stride;
+}
+
+loco::Window<2> window_of(const TFKSize &ksize, const TFDataLayout &datalayout)
+{
+ loco::Window<2> window;
+
+ auto data_layout = as_data_layout(datalayout);
+ if (data_layout == DataLayout::NHWC)
+ {
+ window.vertical(ksize[1]);
+ window.horizontal(ksize[2]);
+ }
+ else if (data_layout == DataLayout::NCHW)
+ {
+ window.vertical(ksize[2]);
+ window.horizontal(ksize[3]);
+ }
+ else
+ {
+ // TODO add more datalayout supports if needed
+ INTERNAL_EXN("ShapeInference: Unknown data_format");
+ }
+
+ return window;
+}
+
+loco::Window<2> window_of(const loco::TensorShape &shape, const TFDataLayout &datalayout)
+{
+ loco::Window<2> window;
+
+ if (datalayout == "HWIO")
+ {
+ window.vertical(shape.dim(0).value());
+ window.horizontal(shape.dim(1).value());
+ }
+ else if (datalayout == "HWCM")
+ {
+ window.vertical(shape.dim(0).value());
+ window.horizontal(shape.dim(1).value());
+ }
+ else
+ {
+ // TODO add more datalayout supports if needed
+ INTERNAL_EXN_V("ShapeInference: Unknown data_format", datalayout);
+ }
+
+ return window;
+}
+
+} // namespace moco