summaryrefslogtreecommitdiff
path: root/compiler/moco/import/src/Nodes/Const.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/moco/import/src/Nodes/Const.cpp')
-rw-r--r--compiler/moco/import/src/Nodes/Const.cpp242
1 files changed, 242 insertions, 0 deletions
diff --git a/compiler/moco/import/src/Nodes/Const.cpp b/compiler/moco/import/src/Nodes/Const.cpp
new file mode 100644
index 000000000..15ea717db
--- /dev/null
+++ b/compiler/moco/import/src/Nodes/Const.cpp
@@ -0,0 +1,242 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "moco/Import/Nodes/Const.h"
+
+#include <moco/Names.h>
+#include <moco/IR/TFNodes.h>
+
+#include <loco.h>
+#include <plier/tf/Convert.h>
+#include <oops/UserExn.h>
+
+#include <cassert>
+#include <stdexcept>
+#include <string>
+
+namespace
+{
+
+using namespace moco;
+
+void read_value_int8(TFConst *const_node, int num_elements,
+ const tensorflow::TensorProto &input_tensor)
+{
+ const_node->size<loco::DataType::S8>(num_elements);
+
+ int32_t input_elements = input_tensor.int_val_size();
+
+ if (input_tensor.tensor_content().size() == num_elements * sizeof(int8_t))
+ {
+ const std::string &str_content = input_tensor.tensor_content();
+ const int8_t *s8_ptr = reinterpret_cast<const int8_t *>(str_content.c_str());
+ for (int32_t i = 0; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::S8>(i) = *(s8_ptr + i);
+ }
+ }
+ else if (0 < input_elements && input_elements <= num_elements)
+ {
+ for (int32_t i = 0; i < input_elements; i++)
+ {
+ const_node->at<loco::DataType::S8>(i) = input_tensor.int_val(i);
+ }
+
+ for (int32_t i = input_elements; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::S8>(i) = input_tensor.int_val(input_elements - 1);
+ }
+ }
+ else
+ {
+ throw oops::UserExn("Invalid Const values", const_node->name());
+ }
+}
+
+void read_value_int32(TFConst *const_node, int num_elements,
+ const tensorflow::TensorProto &input_tensor)
+{
+ const_node->size<loco::DataType::S32>(num_elements);
+
+ int32_t input_elements = input_tensor.int_val_size();
+
+ if (input_tensor.tensor_content().size() == num_elements * sizeof(int32_t))
+ {
+ const std::string &str_content = input_tensor.tensor_content();
+ const int32_t *s32_ptr = reinterpret_cast<const int32_t *>(str_content.c_str());
+ for (int32_t i = 0; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = *(s32_ptr + i);
+ }
+ }
+ else if (0 < input_elements && input_elements <= num_elements)
+ {
+ for (int32_t i = 0; i < input_elements; i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = input_tensor.int_val(i);
+ }
+
+ for (int32_t i = input_elements; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = input_tensor.int_val(input_elements - 1);
+ }
+ }
+ else
+ {
+ throw oops::UserExn("Invalid Const values", const_node->name());
+ }
+}
+
+void read_value_float32(TFConst *const_node, int num_elements,
+ const tensorflow::TensorProto &input_tensor)
+{
+ const_node->size<loco::DataType::FLOAT32>(num_elements);
+
+ int32_t input_elements = input_tensor.float_val_size();
+
+ if (input_tensor.tensor_content().size() == num_elements * sizeof(float))
+ {
+ const std::string &str_content = input_tensor.tensor_content();
+ const float *float_ptr = reinterpret_cast<const float *>(str_content.c_str());
+ for (int32_t i = 0; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = *(float_ptr + i);
+ }
+ }
+ else if (0 < input_elements && input_elements <= num_elements)
+ {
+ for (int32_t i = 0; i < input_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(i);
+ }
+
+ for (int32_t i = input_elements; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(input_elements - 1);
+ }
+ }
+ else
+ {
+ throw oops::UserExn("Invalid Const values", const_node->name());
+ }
+}
+
+} // namespace
+
+namespace moco
+{
+
+bool ConstGraphBuilder::validate(const tensorflow::NodeDef &node) const
+{
+ if (!plier::tf::has_attrs(node, {"dtype", "value"}))
+ return false;
+
+ const auto &input_tensor = plier::tf::get_tensor_attr(node, "value");
+ const auto &input_shape = input_tensor.tensor_shape();
+ const auto &input_dims = input_shape.dim();
+
+ if (!(input_shape.dim_size() <= 6))
+ return false;
+
+ for (auto &d : input_dims)
+ {
+ if (d.size() > std::numeric_limits<int>::max())
+ throw oops::UserExn("Const Shape element overflows", node.name());
+
+ if (d.size() < 0)
+ throw oops::UserExn("Unknown dim size", node.name());
+ }
+
+ auto dtype = plier::tf::as_loco_datatype(plier::tf::get_datatype_attr(node, "dtype"));
+ if (!(dtype == loco::DataType::S32 || dtype == loco::DataType::FLOAT32 ||
+ dtype == loco::DataType::S8))
+ return false;
+ // TODO support other dtype
+
+ return true;
+}
+
+void ConstGraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+
+ // Create a "TFConstant" node for Const
+ auto const_node = graph->nodes()->create<TFConst>();
+ const_node->name(node.name());
+
+ // set dtype
+ auto dtype = plier::tf::as_loco_datatype(plier::tf::get_datatype_attr(node, "dtype"));
+ const_node->dtype(dtype);
+
+ // import shape and value
+ const auto &input_tensor = plier::tf::get_tensor_attr(node, "value");
+ const auto &input_shape = input_tensor.tensor_shape();
+ const auto &input_dims = input_shape.dim();
+ assert(input_shape.dim_size() <= 6);
+ const_node->rank(input_shape.dim_size());
+ int index = 0;
+ bool zero_sized_shape = false;
+ for (auto &d : input_dims)
+ {
+ assert(d.size() <= std::numeric_limits<int>::max());
+ if (d.size() == 0)
+ zero_sized_shape = true;
+
+ assert(d.size() >= 0);
+ const_node->dim(index++) = d.size();
+ }
+
+ int num_elements = 1;
+ if (zero_sized_shape)
+ {
+ const_node->rank(0);
+ num_elements = 0;
+ }
+ else
+ {
+ for (uint32_t d = 0; d < const_node->rank(); d++)
+ {
+ num_elements *= const_node->dim(d).value();
+ }
+ }
+
+ switch (dtype)
+ {
+ case loco::DataType::S8:
+ read_value_int8(const_node, num_elements, input_tensor);
+ break;
+
+ case loco::DataType::S32:
+ read_value_int32(const_node, num_elements, input_tensor);
+ break;
+
+ case loco::DataType::FLOAT32:
+ read_value_float32(const_node, num_elements, input_tensor);
+ break;
+
+ // TODO support other types
+
+ default:
+ assert(false);
+ }
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, const_node);
+}
+
+} // namespace moco