summaryrefslogtreecommitdiff
path: root/compiler/luci/import/src/Nodes/CircleConst.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/import/src/Nodes/CircleConst.cpp')
-rw-r--r--compiler/luci/import/src/Nodes/CircleConst.cpp110
1 files changed, 110 insertions, 0 deletions
diff --git a/compiler/luci/import/src/Nodes/CircleConst.cpp b/compiler/luci/import/src/Nodes/CircleConst.cpp
new file mode 100644
index 000000000..1d798983b
--- /dev/null
+++ b/compiler/luci/import/src/Nodes/CircleConst.cpp
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Import/Nodes/CircleConst.h"
+
+#include <luci/IR/Nodes/CircleConst.h>
+#include <luci/Log.h>
+
+#include <loco.h>
+#include <oops/UserExn.h>
+
+#include <cassert>
+
+namespace luci
+{
+
+template <loco::DataType DT>
+static void copy_data(const std::vector<uint8_t> &raw_data, uint32_t num_elements,
+ CircleConst *const_node)
+{
+ using T = typename loco::DataTypeImpl<DT>::Type;
+
+ assert(raw_data.size() == num_elements * sizeof(T));
+ const auto *data = reinterpret_cast<const T *>(raw_data.data());
+
+ const_node->size<DT>(num_elements);
+ for (uint32_t i = 0; i < num_elements; ++i)
+ {
+ const_node->at<DT>(i) = data[i];
+ }
+}
+
+//
+// circleconst_from_tensor() ?
+//
+CircleConst *create_circleconst(GraphBuilderContext *context, int32_t tensor_index)
+{
+ LOGGER(l);
+
+ auto graph = context->graph();
+ auto reader = context->reader();
+ const auto &tensors = reader->tensors();
+
+ // (1) create CircleConst
+ auto const_node = graph->nodes()->create<CircleConst>();
+ const circle::TensorT &const_tensor = *tensors[tensor_index];
+ const_node->name(tensor_name(const_tensor));
+ auto quantization = luci::tensor_quantization(const_tensor);
+ if (quantization)
+ {
+ auto quantparam = luci::luci_quantparam(quantization);
+ if (quantparam.get())
+ const_node->quantparam(std::move(quantparam));
+ }
+
+ INFO(l) << "[luci] NodeFinder const_node(" << tensor_index << ") -> " << const_node << std::endl;
+
+ // (2) set data_type to CircleConst
+ const_node->dtype(luci_datatype(const_tensor.type));
+
+ // (3) set shape to CicleConst
+ std::vector<int32_t> const_dims = const_tensor.shape; // in NHWC
+ const_node->rank(const_dims.size());
+ uint32_t num_elements = 1;
+ for (uint32_t r = 0; r < const_dims.size(); ++r)
+ {
+ const_node->dim(r) = loco::Dimension(const_dims[r]);
+ num_elements = num_elements * const_dims[r];
+ }
+
+ // (4) constant values from circle buffer
+ const std::vector<uint8_t> &buffer = reader->buffers()[const_tensor.buffer]->data;
+ if (buffer.empty())
+ throw oops::UserExn("Empty buffer");
+
+ switch (luci_datatype(const_tensor.type))
+ {
+ case loco::DataType::FLOAT32:
+ copy_data<loco::DataType::FLOAT32>(buffer, num_elements, const_node);
+ break;
+
+ case loco::DataType::U8:
+ copy_data<loco::DataType::U8>(buffer, num_elements, const_node);
+ break;
+
+ case loco::DataType::S32:
+ copy_data<loco::DataType::S32>(buffer, num_elements, const_node);
+ break;
+
+ default:
+ throw oops::UserExn("Unsupported tensor type", circle::EnumNameTensorType(const_tensor.type));
+ }
+
+ return const_node;
+}
+
+} // namespace luci