summaryrefslogtreecommitdiff
path: root/tools/tflitefile_tool/tensor_wrapping.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/tflitefile_tool/tensor_wrapping.py')
-rwxr-xr-xtools/tflitefile_tool/tensor_wrapping.py82
1 files changed, 52 insertions, 30 deletions
diff --git a/tools/tflitefile_tool/tensor_wrapping.py b/tools/tflitefile_tool/tensor_wrapping.py
index b1fba57d2..a32a573ce 100755
--- a/tools/tflitefile_tool/tensor_wrapping.py
+++ b/tools/tflitefile_tool/tensor_wrapping.py
@@ -1,5 +1,19 @@
#!/usr/bin/python
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import tflite.Tensor
import tflite.TensorType
@@ -16,39 +30,47 @@ def SetTensorTypeStr():
TensorTypeList[fieldValue] = fieldName
+TYPES = {
+ 'BOOL': 1,
+ 'COMPLEX64': 8,
+ 'FLOAT16': 2,
+ 'FLOAT32': 4,
+ 'INT16': 2,
+ 'INT32': 4,
+ 'INT64': 8,
+ 'UINT8': 1
+}
+
+
+def GetTypeSize(type_name):
+ try:
+ return TYPES[type_name]
+
+ except KeyError as error:
+ return 0
+
+
class Tensor(object):
def __init__(self, tensor_idx, tf_tensor, tf_buffer):
self.tensor_idx = tensor_idx
self.tf_tensor = tf_tensor
self.tf_buffer = tf_buffer
+ self.type_name = TensorTypeList[self.tf_tensor.Type()]
+ self.memory_size = self.GetMemorySize()
+
+ def GetMemorySize(self):
+ type_size = GetTypeSize(self.type_name)
+ if type_size == 0:
+ return 0
+
+ # memory size in bytes
+ size = int(type_size)
+ shape_length = self.tf_tensor.ShapeLength()
+ if shape_length == 0:
+ return size
+
+ for shape_idx in range(shape_length):
+ shape_size = int(self.tf_tensor.Shape(shape_idx))
+ size *= shape_size
- def PrintInfo(self, depth_str=""):
- print_str = ""
- if self.tensor_idx < 0:
- print_str = "Tensor {0:4}".format(self.tensor_idx)
- else:
- buffer_idx = self.tf_tensor.Buffer()
- isEmpty = "Filled"
- if (self.tf_buffer.DataLength() == 0):
- isEmpty = " Empty"
- shape_str = self.GetShapeString()
- type_name = TensorTypeList[self.tf_tensor.Type()]
-
- shape_name = ""
- if self.tf_tensor.Name() != 0:
- shape_name = self.tf_tensor.Name()
-
- print_str = "Tensor {0:4} : buffer {1:4} | {2} | {3:7} | Shape {4} ({5})".format(
- self.tensor_idx, buffer_idx, isEmpty, type_name, shape_str, shape_name)
- print(depth_str + print_str)
-
- def GetShapeString(self):
- if self.tf_tensor.ShapeLength() == 0:
- return "Scalar"
- return_string = "["
- for shape_idx in range(self.tf_tensor.ShapeLength()):
- if (shape_idx != 0):
- return_string += ", "
- return_string += str(self.tf_tensor.Shape(shape_idx))
- return_string += "]"
- return return_string
+ return size