summaryrefslogtreecommitdiff
path: root/compiler/locomotiv/src/NodeExecution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/locomotiv/src/NodeExecution.cpp')
-rw-r--r--compiler/locomotiv/src/NodeExecution.cpp158
1 files changed, 158 insertions, 0 deletions
diff --git a/compiler/locomotiv/src/NodeExecution.cpp b/compiler/locomotiv/src/NodeExecution.cpp
new file mode 100644
index 000000000..e532b5af6
--- /dev/null
+++ b/compiler/locomotiv/src/NodeExecution.cpp
@@ -0,0 +1,158 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "NodeExecution.h"
+
+#include "NodeDomain.h"
+#include "NodeDataImpl.h"
+#include "Validation.h"
+
+#include <nncc/core/ADT/tensor/Shape.h>
+#include <nncc/core/ADT/tensor/Buffer.h>
+#include <nncc/core/ADT/tensor/Index.h>
+#include <nncc/core/ADT/tensor/IndexEnumerator.h>
+#include <nncc/core/ADT/tensor/LexicalLayout.h>
+
+#include <cassert>
+#include <stdexcept>
+
+using nncc::core::ADT::tensor::Index;
+using nncc::core::ADT::tensor::IndexEnumerator;
+using nncc::core::ADT::tensor::LexicalLayout;
+using nncc::core::ADT::tensor::make_buffer;
+
+namespace locomotiv
+{
+
+float UnaryFunc::apply(float) const { throw std::runtime_error{"F32 is not supported yet"}; }
+int32_t UnaryFunc::apply(int32_t) const { throw std::runtime_error{"S32 is not supported yet"}; }
+
+float BinaryFunc::apply(float, float) const
+{
+ throw std::runtime_error{"F32 is not supported yet"};
+}
+
+int32_t BinaryFunc::apply(int32_t, int32_t) const
+{
+ throw std::runtime_error{"S32 is not supported yet"};
+}
+
+// TODO Use visitor pattern of loco when available
+void NodeExecution::run(loco::Node *node)
+{
+ erase_annot_data(node);
+
+#define NODE(Name) \
+ if (as<loco::Name>(node)) \
+ { \
+ execute(as<loco::Name>(node)); \
+ return; \
+ }
+#include "Node.lst"
+#undef NODE
+
+ throw std::runtime_error("Not supported loco::Node type");
+}
+
+void NodeExecution::eltwise_unary(loco::Node *node, const UnaryFunc &f)
+{
+ auto input_node = node->arg(0);
+ auto input_domain = annot_domain(input_node);
+ auto input_data = annot_data(input_node);
+ auto input_dtype = input_data->dtype();
+
+ validate(input_data, "Input is not ready");
+ validate(input_domain != loco::Domain::Unknown, "Input domain is unknown");
+
+ auto output_node = node;
+ // Element-wise Unary Operation does not affect Domain
+ auto output_domain = input_domain;
+ // Eltwise-wise Unary Operation does not affet Data Type (ASSUMPTION)
+ //
+ // TODO Check this assumption
+ auto output_dtype = input_dtype;
+ std::unique_ptr<NodeData> output_data = nullptr;
+
+ switch (output_dtype)
+ {
+ case loco::DataType::FLOAT32:
+ {
+ auto input_bufptr = input_data->as_f32_bufptr();
+ auto output_buf = make_buffer<float, LexicalLayout>(*input_data->shape());
+ auto *shape = input_data->shape();
+
+ for (IndexEnumerator e{*shape}; e.valid(); e.advance())
+ {
+ const auto &index = e.current();
+ output_buf.at(index) = f.apply(input_bufptr->at(index));
+ }
+
+ output_data = make_data(output_buf);
+ break;
+ }
+ default:
+ throw std::runtime_error("NYI for this DataType");
+ }
+
+ assert(output_data != nullptr);
+ annot_data(output_node, std::move(output_data));
+ annot_domain(output_node, output_domain);
+}
+
+void NodeExecution::eltwise_binary(loco::Node *node, const BinaryFunc &f)
+{
+ auto lhs_node = node->arg(0);
+ auto rhs_node = node->arg(1);
+ auto lhs_data = annot_data(lhs_node);
+ auto rhs_data = annot_data(rhs_node);
+
+ validate(lhs_data && rhs_data, "Input not ready");
+ validate(annot_domain(lhs_node) == annot_domain(rhs_node), "Wrong input domain");
+ validate(lhs_data->dtype() == rhs_data->dtype(), "Wrong input type");
+ validate(*lhs_data->shape() == *rhs_data->shape(), "Wrong input shape");
+
+ auto out_node = node;
+ std::unique_ptr<NodeData> out_data = nullptr;
+
+ switch (lhs_data->dtype())
+ {
+ case loco::DataType::FLOAT32:
+ {
+ auto lhs_bufptr = lhs_data->as_f32_bufptr();
+ auto rhs_bufptr = rhs_data->as_f32_bufptr();
+ auto out_bufptr = make_buffer<float, LexicalLayout>(*lhs_data->shape());
+
+ auto *shape = lhs_data->shape();
+
+ for (IndexEnumerator e{*shape}; e.valid(); e.advance())
+ {
+ const auto &index = e.current();
+ out_bufptr.at(index) = f.apply(lhs_bufptr->at(index), rhs_bufptr->at(index));
+ }
+
+ out_data = make_data(out_bufptr);
+ break;
+ }
+ default:
+ throw std::runtime_error("NYI for this DataType");
+ }
+
+ assert(out_data != nullptr);
+ annot_data(out_node, std::move(out_data));
+ annot_domain(out_node, annot_domain(lhs_node));
+}
+
+} // namespace locomotiv