diff options
Diffstat (limited to 'tools/tflitefile_tool/tensor_printer.py')
-rw-r--r-- | tools/tflitefile_tool/tensor_printer.py | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/tools/tflitefile_tool/tensor_printer.py b/tools/tflitefile_tool/tensor_printer.py new file mode 100644 index 000000000..f566a6e10 --- /dev/null +++ b/tools/tflitefile_tool/tensor_printer.py @@ -0,0 +1,80 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tensor_wrapping import Tensor + +SYMBOLS = ['B', 'K', 'M', 'G', 'T'] + + +def ConvertBytesToHuman(n): + n = int(n) + if n < 0: + return 0 + + format_str = "%(val)3.1f%(symb)s" + prefix = {} + for i, s in enumerate(SYMBOLS[1:]): + prefix[s] = 1 << (i + 1) * 10 + + for symbol in reversed(SYMBOLS[1:]): + if n >= prefix[symbol]: + v = float(n) / prefix[symbol] + return format_str % dict(symb=symbol, val=v) + + return format_str % dict(symb=SYMBOLS[0], val=n) + + +class TensorPrinter(object): + def __init__(self, verbose, tensor): + self.verbose = verbose + self.tensor = tensor + + def PrintInfo(self, depth_str=""): + if (self.verbose < 1): + pass + + print_str = "" + if self.tensor.tensor_idx < 0: + print_str = "Tensor {0:4}".format(self.tensor.tensor_idx) + else: + buffer_idx = self.tensor.tf_tensor.Buffer() + isEmpty = "Filled" + if (self.tensor.tf_buffer.DataLength() == 0): + isEmpty = " Empty" + shape_str = self.GetShapeString() + type_name = self.tensor.type_name + + shape_name = "" + if self.tensor.tf_tensor.Name() != 0: + shape_name = self.tensor.tf_tensor.Name() + + memory_size = ConvertBytesToHuman(self.tensor.memory_size) + + print_str = "Tensor {0:4} : buffer {1:4} | {2} | {3:7} | Memory {4:6} | Shape {5} ({6})".format( + self.tensor.tensor_idx, buffer_idx, isEmpty, type_name, memory_size, + shape_str, shape_name) + print(depth_str + print_str) + + def GetShapeString(self): + if self.tensor.tf_tensor.ShapeLength() == 0: + return "Scalar" + return_string = "[" + for shape_idx in range(self.tensor.tf_tensor.ShapeLength()): + if (shape_idx != 0): + return_string += ", " + return_string += str(self.tensor.tf_tensor.Shape(shape_idx)) + return_string += "]" + return return_string |