diff options
Diffstat (limited to 'tools/tflitefile_tool/tensor_wrapping.py')
-rwxr-xr-x | tools/tflitefile_tool/tensor_wrapping.py | 82 |
1 files changed, 52 insertions, 30 deletions
diff --git a/tools/tflitefile_tool/tensor_wrapping.py b/tools/tflitefile_tool/tensor_wrapping.py index b1fba57d2..a32a573ce 100755 --- a/tools/tflitefile_tool/tensor_wrapping.py +++ b/tools/tflitefile_tool/tensor_wrapping.py @@ -1,5 +1,19 @@ #!/usr/bin/python +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import tflite.Tensor import tflite.TensorType @@ -16,39 +30,47 @@ def SetTensorTypeStr(): TensorTypeList[fieldValue] = fieldName +TYPES = { + 'BOOL': 1, + 'COMPLEX64': 8, + 'FLOAT16': 2, + 'FLOAT32': 4, + 'INT16': 2, + 'INT32': 4, + 'INT64': 8, + 'UINT8': 1 +} + + +def GetTypeSize(type_name): + try: + return TYPES[type_name] + + except KeyError as error: + return 0 + + class Tensor(object): def __init__(self, tensor_idx, tf_tensor, tf_buffer): self.tensor_idx = tensor_idx self.tf_tensor = tf_tensor self.tf_buffer = tf_buffer + self.type_name = TensorTypeList[self.tf_tensor.Type()] + self.memory_size = self.GetMemorySize() + + def GetMemorySize(self): + type_size = GetTypeSize(self.type_name) + if type_size == 0: + return 0 + + # memory size in bytes + size = int(type_size) + shape_length = self.tf_tensor.ShapeLength() + if shape_length == 0: + return size + + for shape_idx in range(shape_length): + shape_size = int(self.tf_tensor.Shape(shape_idx)) + size *= shape_size - def PrintInfo(self, depth_str=""): - print_str = "" - if self.tensor_idx < 0: - print_str = "Tensor {0:4}".format(self.tensor_idx) - else: - buffer_idx = self.tf_tensor.Buffer() - isEmpty = "Filled" - if (self.tf_buffer.DataLength() == 0): - isEmpty = " Empty" - shape_str = self.GetShapeString() - type_name = TensorTypeList[self.tf_tensor.Type()] - - shape_name = "" - if self.tf_tensor.Name() != 0: - shape_name = self.tf_tensor.Name() - - print_str = "Tensor {0:4} : buffer {1:4} | {2} | {3:7} | Shape {4} ({5})".format( - self.tensor_idx, buffer_idx, isEmpty, type_name, shape_str, shape_name) - print(depth_str + print_str) - - def GetShapeString(self): - if self.tf_tensor.ShapeLength() == 0: - return "Scalar" - return_string = "[" - for shape_idx in range(self.tf_tensor.ShapeLength()): - if (shape_idx != 0): - return_string += ", " - return_string += str(self.tf_tensor.Shape(shape_idx)) - return_string += "]" - return return_string + return size |