summaryrefslogtreecommitdiff
path: root/tools/tensorflow_model_freezer/sample/MUL_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/tensorflow_model_freezer/sample/MUL_gen.py')
-rwxr-xr-xtools/tensorflow_model_freezer/sample/MUL_gen.py128
1 files changed, 128 insertions, 0 deletions
diff --git a/tools/tensorflow_model_freezer/sample/MUL_gen.py b/tools/tensorflow_model_freezer/sample/MUL_gen.py
new file mode 100755
index 000000000..f2a92547b
--- /dev/null
+++ b/tools/tensorflow_model_freezer/sample/MUL_gen.py
@@ -0,0 +1,128 @@
+#!/usr/bin/python
+
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import platform
+import tensorflow as tf
+import argparse
+
+import base_freezer as base
+import model_freezer_util as util
+
+
+class Gen(base.BaseFreezer):
+ '''
+ class to generate tflite files for MUL
+ '''
+
+ def __init__(self, path):
+ super(self.__class__, self).__init__(path)
+
+ def getOutputDirectory(self):
+ return os.path.join(self.root_output_path,
+ 'mul') # the root path of generated files
+
+ def getTestCases(self):
+ '''
+ this returns a a hash containg test cases.
+ key of return hash is test case name and
+ value of return hash is test is a list of input tensor metadata.
+ test name (key of hash) is used as
+ - prefix of file name to be generated (don't use white space or special characters)
+ - output node name pf graph
+ '''
+ return {
+ "mul_scalarConst_scalarConst":
+ [base.Tensor([], const_val=1.2),
+ base.Tensor([], const_val=-2.3)],
+ "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])],
+ "mul_2d_2d": [base.Tensor([5, 3]), base.Tensor([5, 3])],
+ "mul_3d_3d": [base.Tensor([5, 4, 3]),
+ base.Tensor([5, 4, 3])],
+ "mul_2d_1d": [base.Tensor([5, 3]), base.Tensor([3])], # broadcasting
+ "mul_3d_1d": [base.Tensor([5, 4, 3]),
+ base.Tensor([3])],
+ "mul_1d_scalarConst": [base.Tensor([5]),
+ base.Tensor([], const_val=1.1)], # mul by scalar
+ "mul_2d_scalarConst": [base.Tensor([5, 3]),
+ base.Tensor([], const_val=1.1)],
+ "mul_1d_scalar": [base.Tensor([5, 3]), base.Tensor([])]
+ }
+
+ def buildModel(self, sess, test_case_tensor, tc_name):
+ '''
+ This method is called per test case (defined by getTestCases()).
+
+ keyword argument:
+ test_case_tensor -- test case tensor metadata
+ For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] }
+ test_case_tensor is [base.Tensor([5]), base.Tensor([5])]
+ '''
+
+ input_list = []
+
+ # ------ modify below for your model FROM here -------#
+
+ x_tensor = self.createTFInput(test_case_tensor[0], input_list)
+ y_tensor = self.createTFInput(test_case_tensor[1], input_list)
+
+ # defining output node = x_input * y_input
+ # and input list
+ output_node = tf.multiply(x_tensor, y_tensor, name=tc_name) # do not modify name
+
+ # ------ modify UNTIL here for your model -------#
+
+ # Note if don't have any CONST value, creating checkpoint file fails.
+ # The next lines insert such (CONST) to prevent such error.
+ # So, Graph.pb/pbtxt contains this garbage info,
+ # but this garbage info will be removed in Graph_frozen.pb/pbtxt
+ garbage = tf.get_variable(
+ "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer())
+ init_op = tf.global_variables_initializer()
+ garbage_value = [0]
+ sess.run(tf.assign(garbage, garbage_value))
+
+ sess.run(init_op)
+
+ # ------ modify appropriate return value -------#
+
+ # returning (input_node_list, output_node_list)
+ return (input_list, [output_node])
+
+
+'''
+How to run
+$ chmod +x tools/tensorflow_model_freezer/sample/name_of_this_file.py
+$ PYTHONPATH=$PYTHONPATH:./tools/tensorflow_model_freezer/ \
+ tools/tensorflow_model_freezer/sample/name_of_this_file.py \
+ ~/temp # directory where model files are saved
+'''
+# --------
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(
+ description='Converted Tensorflow model in python to frozen model.')
+ parser.add_argument(
+ "out_dir",
+ help=
+ "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored."
+ )
+
+ args = parser.parse_args()
+ root_output_path = args.out_dir
+
+ Gen(root_output_path).createSaveFreezeModel()