summaryrefslogtreecommitdiff
path: root/tools/tflitefile_tool/saver/config_saver.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/tflitefile_tool/saver/config_saver.py')
-rwxr-xr-xtools/tflitefile_tool/saver/config_saver.py122
1 files changed, 122 insertions, 0 deletions
diff --git a/tools/tflitefile_tool/saver/config_saver.py b/tools/tflitefile_tool/saver/config_saver.py
new file mode 100755
index 000000000..fa359693f
--- /dev/null
+++ b/tools/tflitefile_tool/saver/config_saver.py
@@ -0,0 +1,122 @@
+#!/usr/bin/python
+
+# Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from printer.string_builder import GetStringShape
+
+
+# TODO: Revise it as minimized `write` methods by using `StringBuilder`
+class ConfigSaver(object):
+ def __init__(self, file_name, operator):
+ self.file_name = file_name
+ self.operator = operator
+ # Set self.verbose to 1 level to print more information
+ self.verbose = 1
+ self.op_idx = operator.index
+ self.op_name = operator.op_name
+ self.options = operator.tf_options
+
+ self.f = open(file_name, 'at')
+
+ def __del__(self):
+ self.f.close()
+
+ def SaveInfo(self):
+ self.f.write("[{}]\n".format(self.op_idx))
+ if (self.op_name == 'CONV_2D'):
+ self.SaveConv2DInputs()
+ else:
+ self.SaveInputs()
+
+ self.SaveOutputs()
+
+ self.SaveAttributes()
+
+ self.f.write('\n')
+
+ def SaveConv2DInputs(self):
+ if (len(self.operator.inputs) != 3):
+ raise AssertionError('Conv2D input count should be 3')
+
+ input = self.operator.inputs[0]
+ weight = self.operator.inputs[1]
+ bias = self.operator.inputs[2]
+
+ self.f.write("input: {}\n".format(GetStringShape(input)))
+ self.f.write("input_type: {}\n".format(input.type_name))
+ self.f.write("weights: {}\n".format(GetStringShape(weight)))
+ self.f.write("weights_type: {}\n".format(weight.type_name))
+ self.f.write("bias: {}\n".format(GetStringShape(bias)))
+ self.f.write("bias_type: {}\n".format(bias.type_name))
+
+ def SaveInputs(self):
+ total = len(self.operator.inputs)
+ self.f.write("input_counts: {}\n".format(total))
+ for idx in range(total):
+ tensor = self.operator.inputs[idx]
+ input_shape_str = GetStringShape(tensor)
+ self.f.write("input{}: {}\n".format(idx, input_shape_str))
+ self.f.write("input{}_type: {}\n".format(idx, tensor.type_name))
+
+ def SaveOutputs(self):
+ total = len(self.operator.outputs)
+ self.f.write("output_counts: {}\n".format(total))
+ for idx in range(total):
+ tensor = self.operator.outputs[idx]
+ output_shape_str = GetStringShape(tensor)
+ self.f.write("output{}: {}\n".format(idx, output_shape_str))
+ self.f.write("output{}_type: {}\n".format(idx, tensor.type_name))
+
+ def SaveFilter(self):
+ self.f.write("filter_w: {}\n".format(self.options.FilterWidth()))
+ self.f.write("filter_h: {}\n".format(self.options.FilterHeight()))
+
+ def SaveStride(self):
+ self.f.write("stride_w: {}\n".format(self.options.StrideW()))
+ self.f.write("stride_h: {}\n".format(self.options.StrideH()))
+
+ def SaveDilation(self):
+ self.f.write("dilation_w: {}\n".format(self.options.DilationWFactor()))
+ self.f.write("dilation_h: {}\n".format(self.options.DilationHFactor()))
+
+ def SavePadding(self):
+ if self.options.Padding() == 0:
+ self.f.write("padding: SAME\n")
+ elif self.options.Padding() == 1:
+ self.f.write("padding: VALID\n")
+
+ def SaveFusedAct(self):
+ if self.operator.activation is not "NONE":
+ self.f.write("fused_act: {}\n".format(self.operator.activation))
+
+ def SaveAttributes(self):
+ if self.op_name == 'AVERAGE_POOL_2D' or self.op_name == 'MAX_POOL_2D':
+ self.SaveFilter()
+ self.SaveStride()
+ self.SavePadding()
+ elif self.op_name == 'CONV_2D':
+ self.SaveStride()
+ self.SaveDilation()
+ self.SavePadding()
+ elif self.op_name == 'TRANSPOSE_CONV':
+ self.SaveStride()
+ self.SavePadding()
+ elif self.op_name == 'DEPTHWISE_CONV_2D':
+ self.SaveStride()
+ self.SaveDilation()
+ self.SavePadding()
+ self.f.write("depthmultiplier: {}\n".format(self.options.DepthMultiplier()))
+
+ self.SaveFusedAct()