diff options
Diffstat (limited to 'compiler/moco/pass/src/Passes/ConstantFoldMul.cpp')
-rw-r--r-- | compiler/moco/pass/src/Passes/ConstantFoldMul.cpp | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/compiler/moco/pass/src/Passes/ConstantFoldMul.cpp b/compiler/moco/pass/src/Passes/ConstantFoldMul.cpp new file mode 100644 index 000000000..c1870ffee --- /dev/null +++ b/compiler/moco/pass/src/Passes/ConstantFoldMul.cpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moco/Pass/Passes/ConstantFoldMul.h" + +#include "ConstantFoldHelper.h" + +#include <moco/IR/Nodes/TFMul.h> +#include <moco/IR/Nodes/TFConst.h> + +#include <moco/Support/NodeAs.h> + +namespace +{ + +struct Func final : public moco::BinaryFunc +{ + float apply(float lhs, float rhs) const { return lhs * rhs; } + int32_t apply(int32_t lhs, int32_t rhs) const { return lhs * rhs; } +}; + +bool constantfold_mul(moco::TFMul *node) +{ + auto x_const = moco::as<moco::TFConst>(node->x()); + auto y_const = moco::as<moco::TFConst>(node->y()); + if (x_const == nullptr || y_const == nullptr) + return false; + + if (x_const->dtype() != y_const->dtype()) + return false; + // TODO support other types + if (x_const->dtype() != loco::DataType::S32 && x_const->dtype() != loco::DataType::FLOAT32) + return false; + + // NOTE we support limited shape of elementwise mul or multiply with a scalar. + // valid_shape_for_constfold_binary_op() explains limited shape. + auto x_shape = moco::tensor_shape(x_const); + auto y_shape = moco::tensor_shape(y_const); + if (!moco::valid_shape_for_constfold_binary_op(x_shape, y_shape)) + return false; + + loco::TensorShape output_shape; + if (y_shape.rank() == 0 || y_shape.rank() == 1) + output_shape = x_shape; + else + output_shape = y_shape; + + auto graph = node->graph(); + auto output_const = moco::new_const(graph, output_shape, x_const->dtype()); + Func f; + + if (x_const->dtype() == loco::DataType::S32) + { + moco::apply_binary<int32_t>(x_const, y_const, output_const, f); + } + else if (x_const->dtype() == loco::DataType::FLOAT32) + { + moco::apply_binary<float>(x_const, y_const, output_const, f); + } + + // replace + loco::replace(node).with(output_const); + + return true; +} + +} // namespace + +namespace moco +{ + +/** + * @note This will Replace TFMul with TFConst when input are TFConst + * + * Before + * A --- TFMul --- C + * B --/ + * After + * A --- TFMul + * B --/ + * TFConst ---------- C + * Where + * A,B : inputs of TFMul + * C : a node that uses TFMul as an input + * TFMul is disconnected from C + * Nodes are drawn multiple times to simplify the diagram + */ +bool ConstantFoldMul::run(loco::Graph *graph) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(graph))) + { + if (auto mul_node = as<moco::TFMul>(node)) + { + if (constantfold_mul(mul_node)) + changed = true; + } + } + + return changed; +} + +} // namespace moco |