summaryrefslogtreecommitdiff
path: root/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc')
-rw-r--r--runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc151
1 files changed, 151 insertions, 0 deletions
diff --git a/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc
new file mode 100644
index 000000000..8b00db468
--- /dev/null
+++ b/runtime/onert/core/src/backend/builtin/kernel/WhileLayer.cc
@@ -0,0 +1,151 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "WhileLayer.h"
+
+#include "PermuteLayer.h"
+#include "../../../exec/ExecutorBase.h"
+
+#include <misc/polymorphic_downcast.h>
+
+#include <algorithm>
+
+namespace onert
+{
+namespace backend
+{
+namespace builtin
+{
+namespace kernel
+{
+
+WhileLayer::WhileLayer(const std::vector<backend::IPortableTensor *> input_tensors,
+ const std::vector<backend::IPortableTensor *> output_tensors,
+ const ir::SubgraphIndex &cond_subg_index,
+ const ir::SubgraphIndex &body_subg_index, exec::IExecutors *executors,
+ const ir::ModelIndex &model_index,
+ basic::DynamicMemoryManager *dyn_memory_manager,
+ const std::shared_ptr<ExternalContext> &external_context)
+ : _cond_subg_index{cond_subg_index}, _body_subg_index{body_subg_index},
+ _input_tensors{input_tensors}, _output_tensors{output_tensors}, _executors{executors},
+ _model_index{model_index}, _dyn_memory_manager{dyn_memory_manager}, _external_context{
+ external_context}
+{
+ // At this point, executors may not have executors of cond subg and body subg
+}
+
+void WhileLayer::run()
+{
+ // Copy "_input_tensors" -> "cond subg inputs"
+ // Run cond subg
+ // Start loop while output of cond subg is ture
+ // // Copy "_input_tensors" -> "body subg inputs" in the first iteration, then copy "body subg
+ // outputs" -> "body subg inputs" in the second or more iterations
+ // // Run body subg
+ // // Copy "body subg outputs" -> "cond subg inputs"
+ // // Run cond subg
+ // If there is no loop copy "_input_tensors" -> "_dst_tensors", else copy "cond subg inputs" ->
+ // "_dst_tensors"
+ auto cond_exec = _executors->at(_model_index, _cond_subg_index);
+ auto body_exec = _executors->at(_model_index, _body_subg_index);
+
+ // Need a temp tensor to hold the cond subgraph output
+ assert(cond_exec->getOutputTensors().size() == 1);
+ auto cond_output_tensor = [&]() {
+ auto cond_output = cond_exec->getOutputTensors().at(0);
+ auto tensor = std::make_unique<Tensor>(cond_output->orig_info(), cond_output->orig_layout(),
+ _dyn_memory_manager);
+ tensor->set_dynamic();
+ tensor->setBuffer(_dyn_memory_manager->allocate(tensor.get(), tensor->total_size()));
+ return tensor;
+ }();
+
+ VERBOSE(While) << "Call to $" << _cond_subg_index << " (cond)" << std::endl;
+ cond_exec->execute(_input_tensors, {cond_output_tensor.get()});
+ VERBOSE(While) << "Return from $" << _cond_subg_index << std::endl;
+
+ auto getResultCond = [](backend::ITensor *tensor) -> bool {
+ bool ret = false;
+ tensor->access([&](ITensor &tensor) { ret = *reinterpret_cast<bool *>(tensor.buffer()); });
+ return ret;
+ };
+
+ std::vector<ITensor *> op_inputs(_input_tensors.begin(), _input_tensors.end());
+ std::vector<ITensor *> op_outputs(_output_tensors.begin(), _output_tensors.end());
+ // Copying body inputs to outputs when the loop body is never executed
+ if (!getResultCond(cond_output_tensor.get()))
+ {
+ PermuteLayer copy_body_inputs_to_op_outputs{op_inputs, op_outputs, _external_context};
+ copy_body_inputs_to_op_outputs.run();
+ return;
+ }
+
+ // Need some temp tensors to hold the body subgraph output
+ std::vector<std::unique_ptr<Tensor>> temp_outputs_o;
+ std::vector<IPortableTensor *> temp_outputs;
+ for (auto &&io_tensor : body_exec->getOutputTensors())
+ {
+ auto tensor = std::make_unique<Tensor>(io_tensor->orig_info(), io_tensor->orig_layout(),
+ _dyn_memory_manager);
+ tensor->set_dynamic();
+ tensor->setBuffer(_dyn_memory_manager->allocate(tensor.get(), tensor->total_size()));
+ temp_outputs.push_back(tensor.get());
+ temp_outputs_o.push_back(std::move(tensor));
+ }
+
+ std::vector<ITensor *> body_outputs(temp_outputs.begin(), temp_outputs.end());
+ PermuteLayer copy_body_outputs_to_op_outputs{body_outputs, op_outputs, _external_context};
+
+ const auto body_execute_with_op_inputs = [&]() {
+ VERBOSE(While) << "Call to $" << _body_subg_index << " (body)" << std::endl;
+ body_exec->execute(_input_tensors, temp_outputs);
+ VERBOSE(While) << "Return from $" << _body_subg_index << std::endl;
+ };
+
+ const auto body_execute_with_body_outputs = [&]() {
+ VERBOSE(While) << "Call to $" << _body_subg_index << " (body)" << std::endl;
+ body_exec->execute(_output_tensors, temp_outputs);
+ VERBOSE(While) << "Return from $" << _body_subg_index << std::endl;
+ };
+
+ std::function<void()> body_execute = body_execute_with_op_inputs;
+ const auto cond_execute = [&]() {
+ VERBOSE(While) << "Call to $" << _cond_subg_index << " (cond)" << std::endl;
+ cond_exec->execute(_output_tensors, {cond_output_tensor.get()});
+ VERBOSE(While) << "Return from $" << _cond_subg_index << std::endl;
+ };
+
+ // Loop while Cond subgraph's output is true
+ while (getResultCond(cond_output_tensor.get()))
+ {
+ body_execute();
+ copy_body_outputs_to_op_outputs.run();
+ cond_execute();
+ body_execute = body_execute_with_body_outputs;
+ }
+
+ // Clean-up the temp tensors
+ _dyn_memory_manager->deallocate(cond_output_tensor.get());
+ for (auto &&tensor : temp_outputs)
+ {
+ _dyn_memory_manager->deallocate(tensor);
+ }
+}
+
+} // namespace kernel
+} // namespace builtin
+} // namespace backend
+} // namespace onert