summaryrefslogtreecommitdiff
path: root/tools/tflitefile_tool/operation.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/tflitefile_tool/operation.py')
-rwxr-xr-xtools/tflitefile_tool/operation.py32
1 files changed, 14 insertions, 18 deletions
diff --git a/tools/tflitefile_tool/operation.py b/tools/tflitefile_tool/operation.py
index 127d6c566..6aa752772 100755
--- a/tools/tflitefile_tool/operation.py
+++ b/tools/tflitefile_tool/operation.py
@@ -20,13 +20,11 @@ import tflite.BuiltinOptions
import tflite.Tensor
from tensor_wrapping import Tensor
import math
-'''
-NOTICE
-- an internal class. do not import outside this file.
-- REF: https://stackoverflow.com/questions/551038/private-implementation-class-in-python
-'''
+# NOTICE
+# - an internal class. do not import outside this file.
+# - REF: https://stackoverflow.com/questions/551038/private-implementation-class-in-python
class _OperationComputeMethod(object):
'''
NOTE: How to count operations of convolution(and also pooling)?
@@ -55,7 +53,7 @@ class _OperationComputeMethod(object):
Anyway, we can calculate total operations on this way. This can apply to the way of pooling.
'''
- def ComputeOperationForConv2D(tf_operator, inputs, outputs):
+ def ComputeOperationForConv2D(self, tf_operator, inputs, outputs):
assert (
tf_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions()
.Conv2DOptions)
@@ -83,16 +81,14 @@ class _OperationComputeMethod(object):
nonlinear_instr_num = 0
return (add_instr_num, mul_instr_num, nonlinear_instr_num)
- '''
- NOTE: Reference the comment 'NOTE' of ComputeOperationForConv2D
- '''
+ # NOTE: Reference the comment 'NOTE' of ComputeOperationForConv2D
- def ComputeOperationForPooling(tf_operator, inputs, outputs):
+ def ComputeOperationForPooling(self, tf_operator, inputs, outputs):
assert (
tf_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions()
.Pool2DOptions)
- input_tensor = inputs[0].tf_tensor
+ dummy_input_tensor = inputs[0].tf_tensor
output_tensor = outputs[0].tf_tensor
pool2d_options = tflite.Pool2DOptions.Pool2DOptions()
@@ -113,14 +109,14 @@ class _OperationComputeMethod(object):
nonlinear_instr_num = 0
return (add_instr_num, mul_instr_num, nonlinear_instr_num)
- def ComputeOperationForSoftmax(tf_operator, inputs, outputs):
+ def ComputeOperationForSoftmax(self, tf_operator, inputs, outputs):
assert (
tf_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions()
.SoftmaxOptions)
input_tensor = inputs[0].tf_tensor
- batch_size = input_tensor.Shape(0)
+ dummy_batch_size = input_tensor.Shape(0)
input_dim = input_tensor.Shape(1)
# Softmax(x_i) = exp(x_i) / sum of exp(x)
@@ -129,7 +125,7 @@ class _OperationComputeMethod(object):
nonlinear_instr_num = input_dim + input_dim # sum of exp(x) and exp(x_i)
return (add_instr_num, mul_instr_num, nonlinear_instr_num)
- def ComputeOperationForFullyConnected(tf_operator, inputs, outputs):
+ def ComputeOperationForFullyConnected(self, tf_operator, inputs, outputs):
assert (
tf_operator.BuiltinOptionsType() == tflite.BuiltinOptions.BuiltinOptions()
.FullyConnectedOptions)
@@ -150,13 +146,13 @@ class _OperationComputeMethod(object):
nonlinear_instr_num = 0
return (add_instr_num, mul_instr_num, nonlinear_instr_num)
- def ComputeOperationForNothing(tf_operator, inputs, outputs):
+ def ComputeOperationForNothing(self, tf_operator, inputs, outputs):
add_instr_num = 0
mul_instr_num = 0
nonlinear_instr_num = 0
return (add_instr_num, mul_instr_num, nonlinear_instr_num)
- def NYI_ComputeOperation(tf_operator, inputs, outputs):
+ def NYI_ComputeOperation(self, tf_operator, inputs, outputs):
pass
operation_to_method_map = {
@@ -167,7 +163,7 @@ class _OperationComputeMethod(object):
"SOFTMAX": ComputeOperationForSoftmax,
"FULLY_CONNECTED": ComputeOperationForFullyConnected,
"CONCATENATION": ComputeOperationForNothing,
- # ADAS
+ # Extension
"TOPK_V2": NYI_ComputeOperation,
"SUB": NYI_ComputeOperation,
"STRIDED_SLICE": NYI_ComputeOperation,
@@ -207,7 +203,7 @@ class Operation(object):
return
self.add_instr_num, self.mul_instr_num, self.nonlinear_instr_num = method(
- self.tf_operator, self.inputs, self.outputs)
+ _OperationComputeMethod(), self.tf_operator, self.inputs, self.outputs)
def TotalInstrNum(self):
return (self.add_instr_num + self.mul_instr_num + self.nonlinear_instr_num)