summaryrefslogtreecommitdiff
path: root/compiler/mir2loco/src/mir2loco.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/mir2loco/src/mir2loco.cpp')
-rw-r--r--compiler/mir2loco/src/mir2loco.cpp725
1 files changed, 725 insertions, 0 deletions
diff --git a/compiler/mir2loco/src/mir2loco.cpp b/compiler/mir2loco/src/mir2loco.cpp
new file mode 100644
index 000000000..fc1f6933b
--- /dev/null
+++ b/compiler/mir2loco/src/mir2loco.cpp
@@ -0,0 +1,725 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "mir2loco.h"
+
+#include "mir/ops/AddOp.h"
+#include "mir/ops/AvgPool2DOp.h"
+#include "mir/ops/ConcatOp.h"
+#include "mir/ops/ConstantOp.h"
+#include "mir/ops/Conv2DOp.h"
+#include "mir/ops/Deconv2DOp.h"
+#include "mir/ops/DepthwiseConv2DOp.h"
+#include "mir/ops/DivOp.h"
+#include "mir/ops/FullyConnectedOp.h"
+#include "mir/ops/MaxPool2DOp.h"
+#include "mir/ops/MulOp.h"
+#include "mir/ops/ReluOp.h"
+#include "mir/ops/ReshapeOp.h"
+#include "mir/ops/SoftmaxOp.h"
+#include "mir/ops/SubOp.h"
+#include "mir/ops/TransposeOp.h"
+
+#include "mir/ShapeRange.h"
+
+#include <cassert>
+#include <cstring>
+#include <stdex/Memory.h>
+
+namespace mir2loco
+{
+namespace
+{
+template <class NodeType> void setupShape(const mir::Shape &shape, NodeType *node)
+{
+ node->rank(shape.rank());
+ for (int32_t i = 0; i < shape.rank(); i++)
+ {
+ node->dim(i) = static_cast<uint32_t>(shape.dim(i));
+ }
+}
+
+std::unique_ptr<loco::TensorShape> make_tensor_shape(const mir::Shape &shape)
+{
+ auto res = stdex::make_unique<loco::TensorShape>();
+ setupShape(shape, res.get());
+ return std::move(res);
+}
+
+void setupPad(const std::vector<std::int32_t> &padding_before,
+ const std::vector<std::int32_t> &padding_after, loco::Padding2D *pad)
+{
+ assert(padding_before.size() == 2 && padding_after.size() == 2);
+ pad->top(padding_before[0]);
+ pad->left(padding_before[1]);
+ pad->bottom(padding_after[0]);
+ pad->right(padding_after[1]);
+}
+
+void setupWindow(const std::vector<std::int32_t> &window_size, loco::Window<2> *window)
+{
+ assert(window_size.size() == 2);
+ window->vertical(window_size[0]);
+ window->horizontal(window_size[1]);
+}
+
+void setupStride(const std::vector<std::int32_t> &strides, loco::Stride<2> *stride)
+{
+ assert(strides.size() == 2);
+ stride->vertical(strides[0]);
+ stride->horizontal(strides[1]);
+}
+
+loco::Permutation<loco::Domain::Feature> createFeaturePermutation(mir::DataFormat format)
+{
+ loco::Permutation<loco::Domain::Feature> perm;
+ if (format == mir::DataFormat::NHWC)
+ {
+ perm.axis(loco::FeatureAxis::Count) = 0;
+ perm.axis(loco::FeatureAxis::Height) = 1;
+ perm.axis(loco::FeatureAxis::Width) = 2;
+ perm.axis(loco::FeatureAxis::Depth) = 3;
+ }
+ else
+ {
+ assert(format == mir::DataFormat::NCHW);
+ perm.axis(loco::FeatureAxis::Count) = 0;
+ perm.axis(loco::FeatureAxis::Depth) = 1;
+ perm.axis(loco::FeatureAxis::Height) = 2;
+ perm.axis(loco::FeatureAxis::Width) = 3;
+ }
+ return perm;
+}
+
+std::unique_ptr<loco::FeatureEncoder> createFeatureEncoder(mir::DataFormat data_format)
+{
+ auto perm = createFeaturePermutation(data_format);
+ return stdex::make_unique<loco::PermutingEncoder<loco::Domain::Feature>>(perm);
+}
+
+std::unique_ptr<loco::FeatureDecoder> createFeatureDecoder(mir::DataFormat data_format)
+{
+ auto perm = createFeaturePermutation(data_format);
+ return stdex::make_unique<loco::PermutingDecoder<loco::Domain::Feature>>(perm);
+}
+
+std::unique_ptr<loco::FilterEncoder> createOHWIFilterEncoder()
+{
+ loco::Permutation<loco::Domain::Filter> perm;
+ perm.axis(loco::FilterAxis::Count) = 0;
+ perm.axis(loco::FilterAxis::Height) = 1;
+ perm.axis(loco::FilterAxis::Width) = 2;
+ perm.axis(loco::FilterAxis::Depth) = 3;
+ return stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>(perm);
+}
+
+std::unique_ptr<loco::FilterEncoder> createHWOIFilterEncoder()
+{
+ loco::Permutation<loco::Domain::Filter> perm;
+ perm.axis(loco::FilterAxis::Height) = 0;
+ perm.axis(loco::FilterAxis::Width) = 1;
+ perm.axis(loco::FilterAxis::Count) = 2;
+ perm.axis(loco::FilterAxis::Depth) = 3;
+ return stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>(perm);
+}
+
+std::unique_ptr<loco::DepthwiseFilterEncoder> createHWIMDepthwiseFilterEncoder()
+{
+ loco::Permutation<loco::Domain::DepthwiseFilter> perm;
+ perm.axis(loco::DepthwiseFilterAxis::Height) = 0;
+ perm.axis(loco::DepthwiseFilterAxis::Width) = 1;
+ perm.axis(loco::DepthwiseFilterAxis::Depth) = 2;
+ perm.axis(loco::DepthwiseFilterAxis::Multiplier) = 3;
+ return stdex::make_unique<loco::PermutingEncoder<loco::Domain::DepthwiseFilter>>(perm);
+}
+
+std::unique_ptr<loco::DepthwiseFilterEncoder> createIHWMDepthwiseFilterEncoder()
+{
+ loco::Permutation<loco::Domain::DepthwiseFilter> perm;
+ perm.axis(loco::DepthwiseFilterAxis::Depth) = 0;
+ perm.axis(loco::DepthwiseFilterAxis::Height) = 1;
+ perm.axis(loco::DepthwiseFilterAxis::Width) = 2;
+ perm.axis(loco::DepthwiseFilterAxis::Multiplier) = 3;
+ return stdex::make_unique<loco::PermutingEncoder<loco::Domain::DepthwiseFilter>>(perm);
+}
+
+std::unique_ptr<loco::MatrixEncoder> createHWMatrixEncoder()
+{
+ loco::Permutation<loco::Domain::Matrix> perm;
+ perm.axis(loco::MatrixAxis::Height) = 0;
+ perm.axis(loco::MatrixAxis::Width) = 1;
+ return stdex::make_unique<loco::PermutingEncoder<loco::Domain::Matrix>>(perm);
+}
+
+std::unique_ptr<loco::MatrixDecoder> createHWMatrixDecoder()
+{
+ loco::Permutation<loco::Domain::Matrix> perm;
+ perm.axis(loco::MatrixAxis::Height) = 0;
+ perm.axis(loco::MatrixAxis::Width) = 1;
+ return stdex::make_unique<loco::PermutingDecoder<loco::Domain::Matrix>>(perm);
+}
+
+loco::DataType convertDataType(mir::DataType data_type)
+{
+ switch (data_type)
+ {
+ case mir::DataType::UNKNOWN:
+ return loco::DataType::Unknown;
+ case mir::DataType::FLOAT32:
+ return loco::DataType::FLOAT32;
+ case mir::DataType::FLOAT64:
+ return loco::DataType::FLOAT64;
+ case mir::DataType::INT32:
+ return loco::DataType::S32;
+ case mir::DataType::INT64:
+ return loco::DataType::S64;
+ default:
+ break;
+ }
+ throw std::runtime_error("Unsupported data type");
+}
+
+loco::Node *createBroadcastIfNeeded(loco::Node *node, const mir::Shape &shape,
+ const mir::Shape &out_shape)
+{
+ auto graph = node->graph();
+
+ if (shape == out_shape)
+ return node; // not needed
+
+ int32_t out_rank = out_shape.rank();
+ int32_t rank_diff = out_rank - shape.rank();
+ // Create Broadcast
+ auto *broadcast = graph->nodes()->create<loco::TensorBroadcast>();
+ // Create Reshape for equal ranks
+ if (shape.rank() != out_rank)
+ {
+ auto *reshape = graph->nodes()->create<loco::FixedReshape>();
+ reshape->input(node);
+ reshape->rank(out_rank);
+ broadcast->input(reshape);
+ // Set reshape dims
+ for (int32_t dim = 0; dim < out_rank; dim++)
+ {
+ if (dim < rank_diff)
+ reshape->dim(dim) = 1;
+ else
+ reshape->dim(dim) = shape.dim(dim - rank_diff);
+ }
+ }
+ else
+ {
+ broadcast->input(node);
+ }
+ // Flag if no one dim isn't equal
+ bool compatible_shapes = true;
+ for (int32_t dim = 0; dim < out_rank; dim++)
+ {
+ // Set broadcast mapping
+ if (dim < rank_diff || (shape.dim(dim - rank_diff) == 1 && out_shape.dim(dim) != 1))
+ broadcast->mapping()->dim(dim) = out_shape.dim(dim);
+ // Check compatibility
+ if (dim >= rank_diff && shape.dim(dim - rank_diff) != 1 &&
+ shape.dim(dim - rank_diff) != out_shape.dim(dim))
+ compatible_shapes = false;
+ }
+ // Check compatibility
+ if (!compatible_shapes)
+ throw std::runtime_error("Not compatible shapes for broadcasting!");
+
+ return broadcast;
+}
+
+template <typename NodeType>
+NodeType *createEltwiseBinary(const mir::ops::BinaryElementwiseOp &op, loco::Node *lhs,
+ loco::Node *rhs)
+{
+ auto graph = lhs->graph();
+
+ const auto &lhs_shape = op.getInput(0)->getShape();
+ const auto &rhs_shape = op.getInput(1)->getShape();
+ const auto &out_shape = op.getOutputShape(0);
+ // Create Broadcast if it's needed
+ auto lhs_node = createBroadcastIfNeeded(lhs, lhs_shape, out_shape);
+ auto rhs_node = createBroadcastIfNeeded(rhs, rhs_shape, out_shape);
+ // Create Node
+ auto result = graph->nodes()->create<NodeType>();
+ result->lhs(lhs_node);
+ result->rhs(rhs_node);
+ return result;
+}
+} // namespace
+
+void Transformer::visit(mir::ops::AddOp &op)
+{
+ // Get Input
+ auto lhs = _mir2loco_map.at(op.getInput(0));
+ auto rhs = _mir2loco_map.at(op.getInput(1));
+ auto result = createEltwiseBinary<loco::EltwiseAdd>(op, lhs, rhs);
+ // Not set Shape
+ // Add to map
+ _mir2loco_map.emplace(op.getOutput(0), result);
+}
+
+void Transformer::visit(mir::ops::AvgPool2DOp &op)
+{
+ loco::Node *input = _mir2loco_map.at(op.getInput(0));
+
+ auto *encoded_input = _loco_graph->nodes()->create<loco::FeatureEncode>();
+ encoded_input->input(input);
+ encoded_input->encoder(createFeatureEncoder(op.getDataFormat()));
+
+ auto *avg_pool_node = _loco_graph->nodes()->create<loco::AvgPool2D>();
+ avg_pool_node->ifm(encoded_input);
+ avg_pool_node->convention(op.getIncludePad() ? loco::AvgPool2D::Convention::Full
+ : loco::AvgPool2D::Convention::Valid);
+ setupWindow(op.getWindowSize(), avg_pool_node->window());
+ setupStride(op.getStrides(), avg_pool_node->stride());
+ setupPad(op.getPaddingBefore(), op.getPaddingAfter(), avg_pool_node->pad());
+
+ auto *output = _loco_graph->nodes()->create<loco::FeatureDecode>();
+ output->input(avg_pool_node);
+ output->decoder(createFeatureDecoder(op.getDataFormat()));
+
+ _mir2loco_map.emplace(op.getOutput(0), output);
+}
+
+void Transformer::visit(mir::ops::ConcatOp &op)
+{
+ if (op.getNumInputs() < 2)
+ throw std::runtime_error("Not enough tensors for concatenation!");
+
+ loco::Node *last_concat = nullptr;
+
+ for (std::size_t i = 1; i < op.getNumInputs(); i++)
+ {
+ loco::Node *lhs = last_concat;
+ if (lhs == nullptr)
+ {
+ mir::Operation::Output *mir_lhs = op.getInput(i - 1);
+ lhs = _mir2loco_map.at(mir_lhs);
+ }
+ mir::Operation::Output *mir_rhs = op.getInput(i);
+ loco::Node *rhs = _mir2loco_map.at(mir_rhs);
+ // Create TensorConcat
+ auto concat_node = _loco_graph->nodes()->create<loco::TensorConcat>();
+ // Set inputs
+ concat_node->lhs(lhs);
+ concat_node->rhs(rhs);
+ // Set axis
+ concat_node->axis(op.getAxis());
+ // Set last concat
+ last_concat = concat_node;
+ }
+ // Not set Shape
+ // Add to map
+ _mir2loco_map.emplace(op.getOutput(0), last_concat);
+}
+
+void Transformer::visit(mir::ops::ConstantOp &op)
+{
+ auto const_node = _loco_graph->nodes()->create<loco::ConstGen>();
+ // Not set Input
+ // Set Shape
+ const auto &out_shape = op.getOutputShape(0);
+ setupShape(out_shape, const_node);
+ // Copy value
+ const auto &value = op.getValue();
+ const_node->dtype(convertDataType(value.getElementType()));
+ // TODO Support other data types
+ switch (const_node->dtype())
+ {
+ case loco::DataType::FLOAT32:
+ {
+ const_node->size<loco::DataType::FLOAT32>(out_shape.numElements());
+ float &const_float = const_node->at<loco::DataType::FLOAT32>(0);
+ char *loco_ptr = reinterpret_cast<char *>(&const_float);
+ char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+ std::memcpy(loco_ptr, mir_ptr, out_shape.numElements() * sizeof(float));
+ break;
+ }
+ case loco::DataType::FLOAT64:
+ {
+ // TODO Change that when loco support other DataTypeImpl
+ const_node->dtype(loco::DataType::FLOAT32);
+ const_node->size<loco::DataType::FLOAT32>(out_shape.numElements());
+ float &const_float = const_node->at<loco::DataType::FLOAT32>(0);
+ char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+ double *mir_double = reinterpret_cast<double *>(mir_ptr);
+ float *loco_float = &const_float;
+ for (const mir::Index &idx : mir::ShapeRange(out_shape))
+ {
+ *loco_float = static_cast<float>(*mir_double);
+ loco_float++;
+ mir_double++;
+ }
+ break;
+ }
+ case loco::DataType::S32:
+ {
+ const_node->size<loco::DataType::S32>(out_shape.numElements());
+ int32_t &const_int32 = const_node->at<loco::DataType::S32>(0);
+ char *loco_ptr = reinterpret_cast<char *>(&const_int32);
+ char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+ std::memcpy(loco_ptr, mir_ptr, out_shape.numElements() * sizeof(int32_t));
+ break;
+ }
+ case loco::DataType::S64:
+ {
+ // TODO Change that when loco support other DataTypeImpl
+ const_node->dtype(loco::DataType::S32);
+ const_node->size<loco::DataType::S32>(out_shape.numElements());
+ int32_t &const_int32 = const_node->at<loco::DataType::S32>(0);
+ char *mir_ptr = value.at(mir::Index(out_shape.rank()));
+ int64_t *mir_int64 = reinterpret_cast<int64_t *>(mir_ptr);
+ int32_t *loco_int32 = &const_int32;
+ for (const mir::Index &idx : mir::ShapeRange(out_shape))
+ {
+ *loco_int32 = static_cast<float>(*mir_int64);
+ loco_int32++;
+ mir_int64++;
+ }
+ break;
+ }
+ default:
+ std::runtime_error("Unsupported data type");
+ }
+ // Add to map
+ _mir2loco_map.emplace(op.getOutput(0), const_node);
+}
+
+void Transformer::visit(mir::ops::Conv2DOp &op)
+{
+ mir::Operation::Output *mir_input = op.getInput(0);
+ mir::Operation::Output *mir_filter = op.getInput(1);
+
+ loco::Node *input = _mir2loco_map.at(mir_input);
+ loco::Node *filter = _mir2loco_map.at(mir_filter);
+
+ // loco does not have grouped Conv2D operation. Try to translate into something else.
+ if (op.getNumGroups() != 1)
+ {
+ const std::int32_t group_size = mir_filter->getShape().dim(3);
+ const std::int32_t num_in_channels = group_size * op.getNumGroups();
+ const std::int32_t num_out_channels = mir_filter->getShape().dim(0);
+
+ // If the size of the group is 1, translate the operation into DepthwiseConv2D. Limit ourselves
+ // with the case of 'multiplier' == 1 for now.
+ if (group_size == 1 && (num_out_channels == num_in_channels))
+ {
+ // [O, H, W, I / group] == [I, H, W, M].
+ auto *encoded_input = _loco_graph->nodes()->create<loco::FeatureEncode>();
+ encoded_input->input(input);
+ encoded_input->encoder(createFeatureEncoder(op.getDataFormat()));
+
+ auto *encoded_filter = _loco_graph->nodes()->create<loco::DepthwiseFilterEncode>();
+ encoded_filter->input(filter);
+ encoded_filter->encoder(createIHWMDepthwiseFilterEncoder());
+
+ auto *dw_conv2d_node = _loco_graph->nodes()->create<loco::DepthwiseConv2D>();
+ dw_conv2d_node->ifm(encoded_input);
+ dw_conv2d_node->ker(encoded_filter);
+ setupStride(op.getStrides(), dw_conv2d_node->stride());
+ setupPad(op.getPaddingBefore(), op.getPaddingAfter(), dw_conv2d_node->pad());
+
+ auto *output = _loco_graph->nodes()->create<loco::FeatureDecode>();
+ output->input(dw_conv2d_node);
+ output->decoder(createFeatureDecoder(op.getDataFormat()));
+
+ _mir2loco_map.emplace(op.getOutput(0), output);
+ }
+ else
+ {
+ // There are few things we can do here:
+ // 1) If group_size == 1, reshape the kernel [O, H, W, I / group] == [I * M, H, W, 1] ->
+ // [I, M, H, W] and use DepthwiseConv2D.
+ // 2) Split the operation into smaller Conv2Ds.
+ // 3) Replicate the filter along 'O' axis 'num_groups' times, zero out some elements, and use
+ // ordinary Conv2D.
+ throw std::runtime_error("Grouped Conv2D operation is not fully supported.");
+ }
+ }
+ else
+ {
+ auto *encoded_input = _loco_graph->nodes()->create<loco::FeatureEncode>();
+ encoded_input->input(input);
+ encoded_input->encoder(createFeatureEncoder(op.getDataFormat()));
+
+ auto *encoded_filter = _loco_graph->nodes()->create<loco::FilterEncode>();
+ encoded_filter->input(filter);
+ encoded_filter->encoder(createOHWIFilterEncoder());
+
+ auto *conv2d_node = _loco_graph->nodes()->create<loco::Conv2D>();
+ conv2d_node->ifm(encoded_input);
+ conv2d_node->ker(encoded_filter);
+ setupStride(op.getStrides(), conv2d_node->stride());
+ setupPad(op.getPaddingBefore(), op.getPaddingAfter(), conv2d_node->pad());
+
+ auto *output = _loco_graph->nodes()->create<loco::FeatureDecode>();
+ output->input(conv2d_node);
+ output->decoder(createFeatureDecoder(op.getDataFormat()));
+
+ _mir2loco_map.emplace(op.getOutput(0), output);
+ }
+}
+
+void Transformer::visit(mir::ops::DeConv2DOp &op)
+{
+ mir::Operation::Output *mir_input = op.getInput(0);
+ mir::Operation::Output *mir_filter = op.getInput(1);
+
+ loco::Node *input = _mir2loco_map.at(mir_input);
+ loco::Node *filter = _mir2loco_map.at(mir_filter);
+
+ auto *encoded_input = _loco_graph->nodes()->create<loco::FeatureEncode>();
+ encoded_input->input(input);
+ encoded_input->encoder(createFeatureEncoder(op.getDataFormat()));
+
+ auto *encoded_filter = _loco_graph->nodes()->create<loco::FilterEncode>();
+ encoded_filter->input(filter);
+ encoded_filter->encoder(createHWOIFilterEncoder());
+
+ auto *tr_conv2d_node = _loco_graph->nodes()->create<loco::TransposedConv2D>();
+ tr_conv2d_node->ifm(encoded_input);
+ tr_conv2d_node->ker(encoded_filter);
+ setupStride(op.getStrides(), tr_conv2d_node->stride());
+ if (op.getPaddingType() == mir::ops::PaddingType::Explicit)
+ setupPad(op.getPaddingBefore(), op.getPaddingAfter(), tr_conv2d_node->pad());
+ else
+ throw std::runtime_error("Not supported non explicit paddings on loco!");
+
+ auto *output = _loco_graph->nodes()->create<loco::FeatureDecode>();
+ output->input(tr_conv2d_node);
+ output->decoder(createFeatureDecoder(op.getDataFormat()));
+
+ _mir2loco_map.emplace(op.getOutput(0), output);
+}
+
+void Transformer::visit(mir::ops::DepthwiseConv2DOp &op)
+{
+ mir::Operation::Output *mir_input = op.getInput(0);
+ mir::Operation::Output *mir_filter = op.getInput(1);
+
+ loco::Node *input = _mir2loco_map.at(mir_input);
+ loco::Node *filter = _mir2loco_map.at(mir_filter);
+
+ auto *encoded_input = _loco_graph->nodes()->create<loco::FeatureEncode>();
+ encoded_input->input(input);
+ encoded_input->encoder(createFeatureEncoder(op.getDataFormat()));
+
+ auto *encoded_filter = _loco_graph->nodes()->create<loco::DepthwiseFilterEncode>();
+ encoded_filter->input(filter);
+ encoded_filter->encoder(createHWIMDepthwiseFilterEncoder());
+
+ auto *dw_conv2d_node = _loco_graph->nodes()->create<loco::DepthwiseConv2D>();
+ dw_conv2d_node->ifm(encoded_input);
+ dw_conv2d_node->ker(encoded_filter);
+ setupStride(op.getStrides(), dw_conv2d_node->stride());
+ setupPad(op.getPaddingBefore(), op.getPaddingAfter(), dw_conv2d_node->pad());
+
+ auto *output = _loco_graph->nodes()->create<loco::FeatureDecode>();
+ output->input(dw_conv2d_node);
+ output->decoder(createFeatureDecoder(op.getDataFormat()));
+
+ _mir2loco_map.emplace(op.getOutput(0), output);
+}
+
+void Transformer::visit(mir::ops::DivOp &op)
+{
+ // Get Input
+ loco::Node *lhs = _mir2loco_map.at(op.getInput(0));
+ loco::Node *rhs = _mir2loco_map.at(op.getInput(1));
+ auto result = createEltwiseBinary<loco::EltwiseDiv>(op, lhs, rhs);
+ // Not set Shape
+ // Add to map
+ _mir2loco_map.emplace(op.getOutput(0), result);
+}
+
+void Transformer::visit(mir::ops::FullyConnectedOp &op)
+{
+ mir::Operation::Output *mir_lhs = op.getInput(0);
+ mir::Operation::Output *mir_rhs = op.getInput(1);
+ // Check 2D shape
+ assert(op.getInput(0)->getShape().rank() == 2);
+ assert(op.getInput(1)->getShape().rank() == 2);
+
+ loco::Node *lhs = _mir2loco_map.at(mir_lhs);
+ loco::Node *rhs = _mir2loco_map.at(mir_rhs);
+
+ auto *encoded_lhs = _loco_graph->nodes()->create<loco::MatrixEncode>();
+ encoded_lhs->input(lhs);
+ encoded_lhs->encoder(createHWMatrixEncoder());
+
+ auto *encoded_rhs = _loco_graph->nodes()->create<loco::MatrixEncode>();
+ encoded_rhs->input(rhs);
+ encoded_rhs->encoder(createHWMatrixEncoder());
+
+ auto *mat_mul = _loco_graph->nodes()->create<loco::MatMul>();
+ mat_mul->lhs(encoded_lhs);
+ mat_mul->rhs(encoded_rhs);
+
+ auto *output = _loco_graph->nodes()->create<loco::MatrixDecode>();
+ output->input(mat_mul);
+ output->decoder(createHWMatrixDecoder());
+
+ _mir2loco_map.emplace(op.getOutput(0), output);
+}
+
+void Transformer::visit(mir::ops::InputOp &op)
+{
+ mir::Operation::Output *mir_output = op.getOutput(0);
+
+ loco::GraphInput *graph_input = _loco_graph->inputs()->create();
+ graph_input->name(mir_output->getName());
+ graph_input->dtype(convertDataType(mir_output->getElementType()));
+
+ auto *pull_node = _loco_graph->nodes()->create<loco::Pull>();
+ setupShape(mir_output->getShape(), pull_node);
+
+ loco::link(graph_input, pull_node);
+
+ _mir2loco_map.emplace(mir_output, pull_node);
+}
+
+void Transformer::visit(mir::ops::MaxPool2DOp &op)
+{
+ loco::Node *input = _mir2loco_map.at(op.getInput(0));
+
+ auto *encoded_input = _loco_graph->nodes()->create<loco::FeatureEncode>();
+ encoded_input->input(input);
+ encoded_input->encoder(createFeatureEncoder(op.getDataFormat()));
+
+ auto max_pool_node = _loco_graph->nodes()->create<loco::MaxPool2D>();
+ max_pool_node->ifm(encoded_input);
+ setupWindow(op.getWindowSize(), max_pool_node->window());
+ setupStride(op.getStrides(), max_pool_node->stride());
+ setupPad(op.getPaddingBefore(), op.getPaddingAfter(), max_pool_node->pad());
+
+ auto *output = _loco_graph->nodes()->create<loco::FeatureDecode>();
+ output->input(max_pool_node);
+ output->decoder(createFeatureDecoder(op.getDataFormat()));
+
+ _mir2loco_map.emplace(op.getOutput(0), output);
+}
+
+void Transformer::visit(mir::ops::MulOp &op)
+{
+ // Get Input
+ loco::Node *lhs = _mir2loco_map.at(op.getInput(0));
+ loco::Node *rhs = _mir2loco_map.at(op.getInput(1));
+ auto result = createEltwiseBinary<loco::EltwiseMul>(op, lhs, rhs);
+ // Not set Shape
+ // Add to map
+ _mir2loco_map.emplace(op.getOutput(0), result);
+}
+
+void Transformer::visit(mir::ops::OutputOp &op)
+{
+ mir::Operation::Output *mir_input = op.getInput(0);
+ loco::Node *input = _mir2loco_map.at(mir_input);
+
+ loco::GraphOutput *graph_output = _loco_graph->outputs()->create();
+ graph_output->name(mir_input->getName());
+ graph_output->dtype(convertDataType(mir_input->getElementType()));
+ graph_output->shape(make_tensor_shape(mir_input->getShape()));
+
+ auto *push_node = _loco_graph->nodes()->create<loco::Push>();
+ push_node->from(input);
+
+ loco::link(graph_output, push_node);
+}
+
+void Transformer::visit(mir::ops::ReluOp &op)
+{
+ loco::Node *input = _mir2loco_map.at(op.getInput(0));
+
+ auto relu_node = _loco_graph->nodes()->create<loco::ReLU>();
+ relu_node->input(input);
+ // Not set shape
+ // Add to map
+ _mir2loco_map.emplace(op.getOutput(0), relu_node);
+}
+
+void Transformer::visit(mir::ops::ReshapeOp &op)
+{
+ loco::Node *input = _mir2loco_map.at(op.getInput(0));
+
+ auto reshape_node = _loco_graph->nodes()->create<loco::Reshape<loco::ReshapeType::Fixed>>();
+ reshape_node->input(input);
+ // Set Shape
+ auto &out_shape = op.getOutputShape(0);
+ setupShape(out_shape, reshape_node);
+ // Add to map
+ _mir2loco_map.emplace(op.getOutput(0), reshape_node);
+}
+
+void Transformer::visit(mir::ops::SoftmaxOp &op)
+{
+ loco::Node *input = _mir2loco_map.at(op.getInput(0));
+
+ auto softmax_node = _loco_graph->nodes()->create<loco::TensorSoftmax>();
+ softmax_node->input(input);
+ // Set Axis
+ softmax_node->axis(op.getAxis());
+ // Add to map
+ _mir2loco_map.emplace(op.getOutput(0), softmax_node);
+}
+
+void Transformer::visit(mir::ops::SubOp &op)
+{
+ // Get Input
+ loco::Node *lhs = _mir2loco_map.at(op.getInput(0));
+ loco::Node *rhs = _mir2loco_map.at(op.getInput(1));
+ auto result = createEltwiseBinary<loco::EltwiseSub>(op, lhs, rhs);
+ // Not set Shape
+ // Add to map
+ _mir2loco_map.emplace(op.getOutput(0), result);
+}
+
+void Transformer::visit(mir::ops::TransposeOp &op)
+{
+ loco::Node *input = _mir2loco_map.at(op.getInput(0));
+ const auto &axis_order = op.getAxisOrder();
+
+ auto transpose_node = _loco_graph->nodes()->create<loco::TensorTranspose>();
+ transpose_node->input(input);
+ // Set axis order
+ transpose_node->perm()->size(axis_order.size());
+ for (size_t i = 0; i < axis_order.size(); i++)
+ transpose_node->perm()->axis(i) = axis_order[i];
+ // Not set shape
+ // Add to map
+ _mir2loco_map.emplace(op.getOutput(0), transpose_node);
+}
+
+void Transformer::visit_fallback(mir::Operation &op) { throw std::runtime_error("NYI operation"); }
+
+std::unique_ptr<loco::Graph> Transformer::transform(mir::Graph *mir_graph)
+{
+ _mir2loco_map.clear();
+ _loco_graph.reset();
+ _loco_graph = loco::make_graph();
+
+ // Transform Nodes
+ mir_graph->accept(this);
+
+ // validate graph
+ assert(loco::valid(_loco_graph.get()));
+
+ return std::move(_loco_graph);
+}
+
+} // namespace mir2loco