summaryrefslogtreecommitdiff
path: root/tools/tflitefile_tool/operator_wrapping.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/tflitefile_tool/operator_wrapping.py')
-rwxr-xr-xtools/tflitefile_tool/operator_wrapping.py120
1 files changed, 120 insertions, 0 deletions
diff --git a/tools/tflitefile_tool/operator_wrapping.py b/tools/tflitefile_tool/operator_wrapping.py
new file mode 100755
index 000000000..1b7f55a4c
--- /dev/null
+++ b/tools/tflitefile_tool/operator_wrapping.py
@@ -0,0 +1,120 @@
+#!/usr/bin/python
+
+import tflite.Operator
+import tflite.OperatorCode
+import tflite.BuiltinOperator
+import tflite.ActivationFunctionType
+from tensor_wrapping import Tensor
+from operation import Operation
+from perf_predictor import PerfPredictor
+
+
+# Match enum value integer to name string
+# Assumption 1: enum value is defined by old style (can be used on python 2)
+# Assumption 2: when class define enum value, only constant value is defined and methods are not defined
+# Assumption 3: only integer value is set by constant definition
+def BuildEnumClassStrMap(obj):
+ ret = {}
+ for fieldName in dir(obj):
+ if (not fieldName.startswith('_')):
+ fieldValue = getattr(obj, fieldName)
+ if (isinstance(fieldValue, (int))):
+ ret[fieldValue] = fieldName
+ return ret
+
+
+class EnumStrMaps():
+ BuiltinOpcode = BuildEnumClassStrMap(tflite.BuiltinOperator.BuiltinOperator())
+ ActivationFunctionType = BuildEnumClassStrMap(
+ tflite.ActivationFunctionType.ActivationFunctionType())
+ BuiltinOptions = BuildEnumClassStrMap(tflite.BuiltinOptions.BuiltinOptions())
+
+
+def GetStrTensorIndex(tensors):
+ return_string = "["
+ for idx in range(len(tensors)):
+ if idx != 0:
+ return_string += ", "
+ return_string += str(tensors[idx].tensor_idx)
+ return_string += "]"
+ return return_string
+
+
+def GetAttribute(o, *args):
+ import functools
+ return functools.reduce(getattr, args, o)
+
+
+def BuildBuiltinOptionGen():
+ bo_gen = {}
+ for val_enum in EnumStrMaps.BuiltinOptions:
+ val_str = EnumStrMaps.BuiltinOptions[val_enum]
+ try:
+ # Dynamically import Builtin Option classes
+ # 0 (NONE) is the only exception that does not have no corresponding flatbuffer-generated class
+ module = __import__("tflite." + val_str)
+ bo_gen[val_enum] = GetAttribute(module, val_str, val_str)
+ except ImportError as e:
+ assert val_enum == 0 and val_str == "NONE"
+ return bo_gen
+
+
+class OptionLoader:
+ builtinOptionGen = BuildBuiltinOptionGen()
+
+ @staticmethod
+ def GetBuiltinOptions(options_type, options_table):
+ options = OptionLoader.builtinOptionGen[options_type]()
+ options.Init(options_table.Bytes, options_table.Pos)
+ return options
+
+
+class Operator(object):
+ def __init__(self, operator_idx, tf_operator, input_tensors, output_tensors,
+ opcode_str):
+ self.operator_idx = operator_idx
+ self.tf_operator = tf_operator
+ self.inputs = input_tensors
+ self.outputs = output_tensors
+ self.opcode_str = opcode_str
+ self.operation = Operation(self.tf_operator, self.opcode_str, self.inputs,
+ self.outputs)
+
+ def PrintInfo(self, perf_predictor=None):
+ # total instruction num
+ instrs = "{:,}".format(
+ self.operation.TotalInstrNum()) if self.operation.can_compute else "???"
+
+ # total operation cycles
+ cycles = "{:,}".format(
+ (perf_predictor.PredictCycles(self.operation)
+ )) if self.operation.can_compute and perf_predictor != None else "???"
+
+ print("Operator {0}: {1} (instrs: {2}, cycls: {3})".format(
+ self.operator_idx, self.opcode_str, instrs, cycles))
+
+ self.PrintOptionInfo()
+
+ print("\tInput Tensors" + GetStrTensorIndex(self.inputs))
+ for tensor in self.inputs:
+ tensor.PrintInfo("\t\t")
+ print("\tOutput Tensors" + GetStrTensorIndex(self.outputs))
+ for tensor in self.outputs:
+ tensor.PrintInfo("\t\t")
+
+ def PrintOptionInfo(self):
+ # FIXME: workaround for ops such as custom
+ try:
+ options = OptionLoader.GetBuiltinOptions(
+ self.tf_operator.BuiltinOptionsType(), self.tf_operator.BuiltinOptions())
+ except KeyError:
+ return
+
+ # fused activation function
+ try:
+ activation_code = options.FusedActivationFunction()
+ fused_activation = EnumStrMaps.ActivationFunctionType[activation_code]
+ print("\tFused Activation: " + fused_activation)
+ except AttributeError:
+ # This operator does not support FusedActivationFunction
+ pass