diff options
Diffstat (limited to 'tools/tflitefile_tool/operation.py')
-rwxr-xr-x | tools/tflitefile_tool/operation.py | 32 |
1 files changed, 14 insertions, 18 deletions
diff --git a/tools/tflitefile_tool/operation.py b/tools/tflitefile_tool/operation.py index 127d6c566..6aa752772 100755 --- a/tools/tflitefile_tool/operation.py +++ b/tools/tflitefile_tool/operation.py @@ -20,13 +20,11 @@ import tflite.BuiltinOptions import tflite.Tensor from tensor_wrapping import Tensor import math -''' -NOTICE -- an internal class. do not import outside this file. -- REF: https://stackoverflow.com/questions/551038/private-implementation-class-in-python -''' +# NOTICE +# - an internal class. do not import outside this file. +# - REF: https://stackoverflow.com/questions/551038/private-implementation-class-in-python class _OperationComputeMethod(object): ''' NOTE: How to count operations of convolution(and also pooling)? @@ -55,7 +53,7 @@ class _OperationComputeMethod(object): Anyway, we can calculate total operations on this way. This can apply to the way of pooling. ''' - def ComputeOperationForConv2D(tf_operator, inputs, outputs): + def ComputeOperationForConv2D(self, tf_operator, inputs, outputs): assert ( tf_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions() .Conv2DOptions) @@ -83,16 +81,14 @@ class _OperationComputeMethod(object): nonlinear_instr_num = 0 return (add_instr_num, mul_instr_num, nonlinear_instr_num) - ''' - NOTE: Reference the comment 'NOTE' of ComputeOperationForConv2D - ''' + # NOTE: Reference the comment 'NOTE' of ComputeOperationForConv2D - def ComputeOperationForPooling(tf_operator, inputs, outputs): + def ComputeOperationForPooling(self, tf_operator, inputs, outputs): assert ( tf_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions() .Pool2DOptions) - input_tensor = inputs[0].tf_tensor + dummy_input_tensor = inputs[0].tf_tensor output_tensor = outputs[0].tf_tensor pool2d_options = tflite.Pool2DOptions.Pool2DOptions() @@ -113,14 +109,14 @@ class _OperationComputeMethod(object): nonlinear_instr_num = 0 return (add_instr_num, mul_instr_num, nonlinear_instr_num) - def ComputeOperationForSoftmax(tf_operator, inputs, outputs): + def ComputeOperationForSoftmax(self, tf_operator, inputs, outputs): assert ( tf_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions() .SoftmaxOptions) input_tensor = inputs[0].tf_tensor - batch_size = input_tensor.Shape(0) + dummy_batch_size = input_tensor.Shape(0) input_dim = input_tensor.Shape(1) # Softmax(x_i) = exp(x_i) / sum of exp(x) @@ -129,7 +125,7 @@ class _OperationComputeMethod(object): nonlinear_instr_num = input_dim + input_dim # sum of exp(x) and exp(x_i) return (add_instr_num, mul_instr_num, nonlinear_instr_num) - def ComputeOperationForFullyConnected(tf_operator, inputs, outputs): + def ComputeOperationForFullyConnected(self, tf_operator, inputs, outputs): assert ( tf_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions() .FullyConnectedOptions) @@ -150,13 +146,13 @@ class _OperationComputeMethod(object): nonlinear_instr_num = 0 return (add_instr_num, mul_instr_num, nonlinear_instr_num) - def ComputeOperationForNothing(tf_operator, inputs, outputs): + def ComputeOperationForNothing(self, tf_operator, inputs, outputs): add_instr_num = 0 mul_instr_num = 0 nonlinear_instr_num = 0 return (add_instr_num, mul_instr_num, nonlinear_instr_num) - def NYI_ComputeOperation(tf_operator, inputs, outputs): + def NYI_ComputeOperation(self, tf_operator, inputs, outputs): pass operation_to_method_map = { @@ -167,7 +163,7 @@ class _OperationComputeMethod(object): "SOFTMAX": ComputeOperationForSoftmax, "FULLY_CONNECTED": ComputeOperationForFullyConnected, "CONCATENATION": ComputeOperationForNothing, - # ADAS + # Extension "TOPK_V2": NYI_ComputeOperation, "SUB": NYI_ComputeOperation, "STRIDED_SLICE": NYI_ComputeOperation, @@ -207,7 +203,7 @@ class Operation(object): return self.add_instr_num, self.mul_instr_num, self.nonlinear_instr_num = method( - self.tf_operator, self.inputs, self.outputs) + _OperationComputeMethod(), self.tf_operator, self.inputs, self.outputs) def TotalInstrNum(self): return (self.add_instr_num + self.mul_instr_num + self.nonlinear_instr_num) |