diff options
Diffstat (limited to 'tools/tflitefile_tool/saver/config_saver.py')
-rwxr-xr-x | tools/tflitefile_tool/saver/config_saver.py | 122 |
1 files changed, 122 insertions, 0 deletions
diff --git a/tools/tflitefile_tool/saver/config_saver.py b/tools/tflitefile_tool/saver/config_saver.py new file mode 100755 index 000000000..fa359693f --- /dev/null +++ b/tools/tflitefile_tool/saver/config_saver.py @@ -0,0 +1,122 @@ +#!/usr/bin/python + +# Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from printer.string_builder import GetStringShape + + +# TODO: Revise it as minimized `write` methods by using `StringBuilder` +class ConfigSaver(object): + def __init__(self, file_name, operator): + self.file_name = file_name + self.operator = operator + # Set self.verbose to 1 level to print more information + self.verbose = 1 + self.op_idx = operator.index + self.op_name = operator.op_name + self.options = operator.tf_options + + self.f = open(file_name, 'at') + + def __del__(self): + self.f.close() + + def SaveInfo(self): + self.f.write("[{}]\n".format(self.op_idx)) + if (self.op_name == 'CONV_2D'): + self.SaveConv2DInputs() + else: + self.SaveInputs() + + self.SaveOutputs() + + self.SaveAttributes() + + self.f.write('\n') + + def SaveConv2DInputs(self): + if (len(self.operator.inputs) != 3): + raise AssertionError('Conv2D input count should be 3') + + input = self.operator.inputs[0] + weight = self.operator.inputs[1] + bias = self.operator.inputs[2] + + self.f.write("input: {}\n".format(GetStringShape(input))) + self.f.write("input_type: {}\n".format(input.type_name)) + self.f.write("weights: {}\n".format(GetStringShape(weight))) + self.f.write("weights_type: {}\n".format(weight.type_name)) + self.f.write("bias: {}\n".format(GetStringShape(bias))) + self.f.write("bias_type: {}\n".format(bias.type_name)) + + def SaveInputs(self): + total = len(self.operator.inputs) + self.f.write("input_counts: {}\n".format(total)) + for idx in range(total): + tensor = self.operator.inputs[idx] + input_shape_str = GetStringShape(tensor) + self.f.write("input{}: {}\n".format(idx, input_shape_str)) + self.f.write("input{}_type: {}\n".format(idx, tensor.type_name)) + + def SaveOutputs(self): + total = len(self.operator.outputs) + self.f.write("output_counts: {}\n".format(total)) + for idx in range(total): + tensor = self.operator.outputs[idx] + output_shape_str = GetStringShape(tensor) + self.f.write("output{}: {}\n".format(idx, output_shape_str)) + self.f.write("output{}_type: {}\n".format(idx, tensor.type_name)) + + def SaveFilter(self): + self.f.write("filter_w: {}\n".format(self.options.FilterWidth())) + self.f.write("filter_h: {}\n".format(self.options.FilterHeight())) + + def SaveStride(self): + self.f.write("stride_w: {}\n".format(self.options.StrideW())) + self.f.write("stride_h: {}\n".format(self.options.StrideH())) + + def SaveDilation(self): + self.f.write("dilation_w: {}\n".format(self.options.DilationWFactor())) + self.f.write("dilation_h: {}\n".format(self.options.DilationHFactor())) + + def SavePadding(self): + if self.options.Padding() == 0: + self.f.write("padding: SAME\n") + elif self.options.Padding() == 1: + self.f.write("padding: VALID\n") + + def SaveFusedAct(self): + if self.operator.activation is not "NONE": + self.f.write("fused_act: {}\n".format(self.operator.activation)) + + def SaveAttributes(self): + if self.op_name == 'AVERAGE_POOL_2D' or self.op_name == 'MAX_POOL_2D': + self.SaveFilter() + self.SaveStride() + self.SavePadding() + elif self.op_name == 'CONV_2D': + self.SaveStride() + self.SaveDilation() + self.SavePadding() + elif self.op_name == 'TRANSPOSE_CONV': + self.SaveStride() + self.SavePadding() + elif self.op_name == 'DEPTHWISE_CONV_2D': + self.SaveStride() + self.SaveDilation() + self.SavePadding() + self.f.write("depthmultiplier: {}\n".format(self.options.DepthMultiplier())) + + self.SaveFusedAct() |