summaryrefslogtreecommitdiff
path: root/tools/tensorflow_model_freezer/sample
diff options
context:
space:
mode:
Diffstat (limited to 'tools/tensorflow_model_freezer/sample')
-rwxr-xr-xtools/tensorflow_model_freezer/sample/ARGMAX_gen.py100
-rwxr-xr-xtools/tensorflow_model_freezer/sample/ARGMIN_gen.py100
-rwxr-xr-xtools/tensorflow_model_freezer/sample/DIV_gen.py7
-rwxr-xr-xtools/tensorflow_model_freezer/sample/LOGICAL_AND_gen.py105
-rwxr-xr-xtools/tensorflow_model_freezer/sample/LOGICAL_NOT_gen.py100
-rwxr-xr-xtools/tensorflow_model_freezer/sample/LOGICAL_OR_gen.py104
-rwxr-xr-xtools/tensorflow_model_freezer/sample/MUL_gen.py11
-rwxr-xr-xtools/tensorflow_model_freezer/sample/SQUEEZE_gen.py7
-rwxr-xr-xtools/tensorflow_model_freezer/sample/STACK_gen.py101
-rwxr-xr-xtools/tensorflow_model_freezer/sample/TOPK_gen.py14
-rw-r--r--tools/tensorflow_model_freezer/sample/UNSTACK_gen.py100
11 files changed, 717 insertions, 32 deletions
diff --git a/tools/tensorflow_model_freezer/sample/ARGMAX_gen.py b/tools/tensorflow_model_freezer/sample/ARGMAX_gen.py
new file mode 100755
index 000000000..68e2262c0
--- /dev/null
+++ b/tools/tensorflow_model_freezer/sample/ARGMAX_gen.py
@@ -0,0 +1,100 @@
+#!/usr/bin/python
+
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import platform
+import tensorflow as tf
+import argparse
+
+import base_freezer as base
+import model_freezer_util as util
+
+
+class Gen(base.BaseFreezer):
+ '''
+ class to generate tflite files for MUL
+ '''
+
+ def __init__(self, path):
+ super(self.__class__, self).__init__(path)
+
+ def getOutputDirectory(self):
+ return os.path.join(self.root_output_path,
+ 'argmax') # the root path of generated files
+
+ def getTestCases(self):
+ '''
+ this returns a a hash containg test cases.
+ key of return hash is test case name and
+ value of return hash is test is a list of input tensor metadata.
+ test name (key of hash) is used as
+ - prefix of file name to be generated (don't use white space or special characters)
+ - output node name pf graph
+ '''
+ return {"argmax_4d": [base.Tensor([1, 2, 4, 3])]}
+
+ def buildModel(self, sess, test_case_tensor, tc_name):
+ '''
+ This method is called per test case (defined by getTestCases()).
+
+ keyword argument:
+ test_case_tensor -- test case tensor metadata
+ For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] }
+ test_case_tensor is [base.Tensor([5]), base.Tensor([5])]
+ '''
+
+ input_list = []
+
+ # ------ modify below for your model FROM here -------#
+ x_tensor = self.createTFInput(test_case_tensor[0], input_list)
+
+ output_node = tf.arg_max(x_tensor, 0, output_type=tf.int32, name=tc_name)
+ # ------ modify UNTIL here for your model -------#
+
+ # Note if don't have any CONST value, creating checkpoint file fails.
+ # The next lines insert such (CONST) to prevent such error.
+ # So, Graph.pb/pbtxt contains this garbage info,
+ # but this garbage info will be removed in Graph_frozen.pb/pbtxt
+ garbage = tf.get_variable(
+ "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer())
+ init_op = tf.global_variables_initializer()
+ garbage_value = [0]
+ sess.run(tf.assign(garbage, garbage_value))
+
+ sess.run(init_op)
+
+ # ------ modify appropriate return value -------#
+
+ # returning (input_node_list, output_node_list)
+ return (input_list, [output_node])
+
+
+# --------
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(
+ description='Converted Tensorflow model in python to frozen model.')
+ parser.add_argument(
+ "out_dir",
+ help=
+ "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored."
+ )
+
+ args = parser.parse_args()
+ root_output_path = args.out_dir
+
+ Gen(root_output_path).createSaveFreezeModel()
diff --git a/tools/tensorflow_model_freezer/sample/ARGMIN_gen.py b/tools/tensorflow_model_freezer/sample/ARGMIN_gen.py
new file mode 100755
index 000000000..68b399234
--- /dev/null
+++ b/tools/tensorflow_model_freezer/sample/ARGMIN_gen.py
@@ -0,0 +1,100 @@
+#!/usr/bin/python
+
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import platform
+import tensorflow as tf
+import argparse
+
+import base_freezer as base
+import model_freezer_util as util
+
+
+class Gen(base.BaseFreezer):
+ '''
+ class to generate tflite files for MUL
+ '''
+
+ def __init__(self, path):
+ super(self.__class__, self).__init__(path)
+
+ def getOutputDirectory(self):
+ return os.path.join(self.root_output_path,
+ 'argmin') # the root path of generated files
+
+ def getTestCases(self):
+ '''
+ this returns a a hash containg test cases.
+ key of return hash is test case name and
+ value of return hash is test is a list of input tensor metadata.
+ test name (key of hash) is used as
+ - prefix of file name to be generated (don't use white space or special characters)
+ - output node name pf graph
+ '''
+ return {"argmin_4d": [base.Tensor([1, 2, 4, 3])]}
+
+ def buildModel(self, sess, test_case_tensor, tc_name):
+ '''
+ This method is called per test case (defined by getTestCases()).
+
+ keyword argument:
+ test_case_tensor -- test case tensor metadata
+ For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] }
+ test_case_tensor is [base.Tensor([5]), base.Tensor([5])]
+ '''
+
+ input_list = []
+
+ # ------ modify below for your model FROM here -------#
+ x_tensor = self.createTFInput(test_case_tensor[0], input_list)
+
+ output_node = tf.arg_min(x_tensor, 0, name=tc_name)
+ # ------ modify UNTIL here for your model -------#
+
+ # Note if don't have any CONST value, creating checkpoint file fails.
+ # The next lines insert such (CONST) to prevent such error.
+ # So, Graph.pb/pbtxt contains this garbage info,
+ # but this garbage info will be removed in Graph_frozen.pb/pbtxt
+ garbage = tf.get_variable(
+ "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer())
+ init_op = tf.global_variables_initializer()
+ garbage_value = [0]
+ sess.run(tf.assign(garbage, garbage_value))
+
+ sess.run(init_op)
+
+ # ------ modify appropriate return value -------#
+
+ # returning (input_node_list, output_node_list)
+ return (input_list, [output_node])
+
+
+# --------
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(
+ description='Converted Tensorflow model in python to frozen model.')
+ parser.add_argument(
+ "out_dir",
+ help=
+ "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored."
+ )
+
+ args = parser.parse_args()
+ root_output_path = args.out_dir
+
+ Gen(root_output_path).createSaveFreezeModel()
diff --git a/tools/tensorflow_model_freezer/sample/DIV_gen.py b/tools/tensorflow_model_freezer/sample/DIV_gen.py
index c4e9cde07..d1b794cd7 100755
--- a/tools/tensorflow_model_freezer/sample/DIV_gen.py
+++ b/tools/tensorflow_model_freezer/sample/DIV_gen.py
@@ -124,13 +124,6 @@ class Gen(base.BaseFreezer):
return (input_list, [output_node])
-'''
-How to run
-$ chmod +x tools/tensorflow_model_freezer/sample/name_of_this_file.py
-$ PYTHONPATH=$PYTHONPATH:./tools/tensorflow_model_freezer/ \
- tools/tensorflow_model_freezer/sample/name_of_this_file.py \
- ~/temp # directory where model files are saved
-'''
# --------
if __name__ == "__main__":
diff --git a/tools/tensorflow_model_freezer/sample/LOGICAL_AND_gen.py b/tools/tensorflow_model_freezer/sample/LOGICAL_AND_gen.py
new file mode 100755
index 000000000..912af65b2
--- /dev/null
+++ b/tools/tensorflow_model_freezer/sample/LOGICAL_AND_gen.py
@@ -0,0 +1,105 @@
+#!/usr/bin/python
+
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import platform
+import tensorflow as tf
+import argparse
+
+import base_freezer as base
+import model_freezer_util as util
+
+
+class Gen(base.BaseFreezer):
+ '''
+ class to generate tflite files for MUL
+ '''
+
+ def __init__(self, path):
+ super(self.__class__, self).__init__(path)
+
+ def getOutputDirectory(self):
+ return os.path.join(self.root_output_path,
+ 'logical_and') # the root path of generated files
+
+ def getTestCases(self):
+ '''
+ this returns a a hash containg test cases.
+ key of return hash is test case name and
+ value of return hash is test is a list of input tensor metadata.
+ test name (key of hash) is used as
+ - prefix of file name to be generated (don't use white space or special characters)
+ - output node name pf graph
+ '''
+ return {"logical_and_4d": [base.Tensor([1, 2, 4, 3]), base.Tensor([1, 2, 4, 3])]}
+
+ def buildModel(self, sess, test_case_tensor, tc_name):
+ '''
+ This method is called per test case (defined by getTestCases()).
+
+ keyword argument:
+ test_case_tensor -- test case tensor metadata
+ For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] }
+ test_case_tensor is [base.Tensor([5]), base.Tensor([5])]
+ '''
+
+ input_list = []
+
+ # ------ modify below for your model FROM here -------#
+ x_tensor = self.createTFInput(test_case_tensor[0], input_list)
+ y_tensor = self.createTFInput(test_case_tensor[1], input_list)
+
+ output_node = tf.logical_and(
+ tf.greater(x_tensor, tf.constant(0.0)),
+ tf.less(y_tensor, tf.constant(1.0)),
+ name=tc_name)
+
+ # ------ modify UNTIL here for your model -------#
+
+ # Note if don't have any CONST value, creating checkpoint file fails.
+ # The next lines insert such (CONST) to prevent such error.
+ # So, Graph.pb/pbtxt contains this garbage info,
+ # but this garbage info will be removed in Graph_frozen.pb/pbtxt
+ garbage = tf.get_variable(
+ "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer())
+ init_op = tf.global_variables_initializer()
+ garbage_value = [0]
+ sess.run(tf.assign(garbage, garbage_value))
+
+ sess.run(init_op)
+
+ # ------ modify appropriate return value -------#
+
+ # returning (input_node_list, output_node_list)
+ return (input_list, [output_node])
+
+
+# --------
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(
+ description='Converted Tensorflow model in python to frozen model.')
+ parser.add_argument(
+ "out_dir",
+ help=
+ "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored."
+ )
+
+ args = parser.parse_args()
+ root_output_path = args.out_dir
+
+ Gen(root_output_path).createSaveFreezeModel()
diff --git a/tools/tensorflow_model_freezer/sample/LOGICAL_NOT_gen.py b/tools/tensorflow_model_freezer/sample/LOGICAL_NOT_gen.py
new file mode 100755
index 000000000..34c0994ea
--- /dev/null
+++ b/tools/tensorflow_model_freezer/sample/LOGICAL_NOT_gen.py
@@ -0,0 +1,100 @@
+#!/usr/bin/python
+
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import platform
+import tensorflow as tf
+import argparse
+
+import base_freezer as base
+import model_freezer_util as util
+
+
+class Gen(base.BaseFreezer):
+ '''
+ class to generate tflite files for MUL
+ '''
+
+ def __init__(self, path):
+ super(self.__class__, self).__init__(path)
+
+ def getOutputDirectory(self):
+ return os.path.join(self.root_output_path,
+ 'logical_not') # the root path of generated files
+
+ def getTestCases(self):
+ '''
+ this returns a a hash containg test cases.
+ key of return hash is test case name and
+ value of return hash is test is a list of input tensor metadata.
+ test name (key of hash) is used as
+ - prefix of file name to be generated (don't use white space or special characters)
+ - output node name pf graph
+ '''
+ return {"logical_not_4d": [base.Tensor([1, 2, 4, 3])]}
+
+ def buildModel(self, sess, test_case_tensor, tc_name):
+ '''
+ This method is called per test case (defined by getTestCases()).
+
+ keyword argument:
+ test_case_tensor -- test case tensor metadata
+ For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] }
+ test_case_tensor is [base.Tensor([5]), base.Tensor([5])]
+ '''
+
+ input_list = []
+
+ # ------ modify below for your model FROM here -------#
+ x_tensor = self.createTFInput(test_case_tensor[0], input_list)
+
+ output_node = tf.logical_not(tf.greater(x_tensor, tf.constant(0.0)), name=tc_name)
+ # ------ modify UNTIL here for your model -------#
+
+ # Note if don't have any CONST value, creating checkpoint file fails.
+ # The next lines insert such (CONST) to prevent such error.
+ # So, Graph.pb/pbtxt contains this garbage info,
+ # but this garbage info will be removed in Graph_frozen.pb/pbtxt
+ garbage = tf.get_variable(
+ "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer())
+ init_op = tf.global_variables_initializer()
+ garbage_value = [0]
+ sess.run(tf.assign(garbage, garbage_value))
+
+ sess.run(init_op)
+
+ # ------ modify appropriate return value -------#
+
+ # returning (input_node_list, output_node_list)
+ return (input_list, [output_node])
+
+
+# --------
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(
+ description='Converted Tensorflow model in python to frozen model.')
+ parser.add_argument(
+ "out_dir",
+ help=
+ "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored."
+ )
+
+ args = parser.parse_args()
+ root_output_path = args.out_dir
+
+ Gen(root_output_path).createSaveFreezeModel()
diff --git a/tools/tensorflow_model_freezer/sample/LOGICAL_OR_gen.py b/tools/tensorflow_model_freezer/sample/LOGICAL_OR_gen.py
new file mode 100755
index 000000000..714a52e8d
--- /dev/null
+++ b/tools/tensorflow_model_freezer/sample/LOGICAL_OR_gen.py
@@ -0,0 +1,104 @@
+#!/usr/bin/python
+
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import platform
+import tensorflow as tf
+import argparse
+
+import base_freezer as base
+import model_freezer_util as util
+
+
+class Gen(base.BaseFreezer):
+ '''
+ class to generate tflite files for MUL
+ '''
+
+ def __init__(self, path):
+ super(self.__class__, self).__init__(path)
+
+ def getOutputDirectory(self):
+ return os.path.join(self.root_output_path,
+ 'logical_or') # the root path of generated files
+
+ def getTestCases(self):
+ '''
+ this returns a a hash containg test cases.
+ key of return hash is test case name and
+ value of return hash is test is a list of input tensor metadata.
+ test name (key of hash) is used as
+ - prefix of file name to be generated (don't use white space or special characters)
+ - output node name pf graph
+ '''
+ return {"logical_or_4d": [base.Tensor([1, 2, 4, 3]), base.Tensor([1, 2, 4, 3])]}
+
+ def buildModel(self, sess, test_case_tensor, tc_name):
+ '''
+ This method is called per test case (defined by getTestCases()).
+
+ keyword argument:
+ test_case_tensor -- test case tensor metadata
+ For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] }
+ test_case_tensor is [base.Tensor([5]), base.Tensor([5])]
+ '''
+
+ input_list = []
+
+ # ------ modify below for your model FROM here -------#
+ x_tensor = self.createTFInput(test_case_tensor[0], input_list)
+ y_tensor = self.createTFInput(test_case_tensor[1], input_list)
+
+ output_node = tf.logical_or(
+ tf.greater(x_tensor, tf.constant(0.0)),
+ tf.less(y_tensor, tf.constant(1.0)),
+ name=tc_name)
+ # ------ modify UNTIL here for your model -------#
+
+ # Note if don't have any CONST value, creating checkpoint file fails.
+ # The next lines insert such (CONST) to prevent such error.
+ # So, Graph.pb/pbtxt contains this garbage info,
+ # but this garbage info will be removed in Graph_frozen.pb/pbtxt
+ garbage = tf.get_variable(
+ "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer())
+ init_op = tf.global_variables_initializer()
+ garbage_value = [0]
+ sess.run(tf.assign(garbage, garbage_value))
+
+ sess.run(init_op)
+
+ # ------ modify appropriate return value -------#
+
+ # returning (input_node_list, output_node_list)
+ return (input_list, [output_node])
+
+
+# --------
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(
+ description='Converted Tensorflow model in python to frozen model.')
+ parser.add_argument(
+ "out_dir",
+ help=
+ "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored."
+ )
+
+ args = parser.parse_args()
+ root_output_path = args.out_dir
+
+ Gen(root_output_path).createSaveFreezeModel()
diff --git a/tools/tensorflow_model_freezer/sample/MUL_gen.py b/tools/tensorflow_model_freezer/sample/MUL_gen.py
index f2a92547b..596898dbb 100755
--- a/tools/tensorflow_model_freezer/sample/MUL_gen.py
+++ b/tools/tensorflow_model_freezer/sample/MUL_gen.py
@@ -59,8 +59,8 @@ class Gen(base.BaseFreezer):
"mul_1d_scalarConst": [base.Tensor([5]),
base.Tensor([], const_val=1.1)], # mul by scalar
"mul_2d_scalarConst": [base.Tensor([5, 3]),
- base.Tensor([], const_val=1.1)],
- "mul_1d_scalar": [base.Tensor([5, 3]), base.Tensor([])]
+ base.Tensor([], const_val=1.1)]
+ # "mul_2d_scalar": [base.Tensor([5, 3]), base.Tensor([])] # not support scalar input
}
def buildModel(self, sess, test_case_tensor, tc_name):
@@ -104,13 +104,6 @@ class Gen(base.BaseFreezer):
return (input_list, [output_node])
-'''
-How to run
-$ chmod +x tools/tensorflow_model_freezer/sample/name_of_this_file.py
-$ PYTHONPATH=$PYTHONPATH:./tools/tensorflow_model_freezer/ \
- tools/tensorflow_model_freezer/sample/name_of_this_file.py \
- ~/temp # directory where model files are saved
-'''
# --------
if __name__ == "__main__":
diff --git a/tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py b/tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py
index 88b3dfcb2..12fb5122e 100755
--- a/tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py
+++ b/tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py
@@ -103,13 +103,6 @@ class Gen(base.BaseFreezer):
return (input_list, [output_node])
-'''
-How to run
-$ chmod +x tools/tensorflow_model_freezer/sample/name_of_this_file.py
-$ PYTHONPATH=$PYTHONPATH:./tools/tensorflow_model_freezer/ \
- tools/tensorflow_model_freezer/sample/name_of_this_file.py \
- ~/temp # directory where model files are saved
-'''
# --------
if __name__ == "__main__":
diff --git a/tools/tensorflow_model_freezer/sample/STACK_gen.py b/tools/tensorflow_model_freezer/sample/STACK_gen.py
new file mode 100755
index 000000000..2bea40698
--- /dev/null
+++ b/tools/tensorflow_model_freezer/sample/STACK_gen.py
@@ -0,0 +1,101 @@
+#!/usr/bin/python
+
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import platform
+import tensorflow as tf
+import argparse
+
+import base_freezer as base
+import model_freezer_util as util
+
+
+class Gen(base.BaseFreezer):
+ '''
+ class to generate tflite files for MUL
+ '''
+
+ def __init__(self, path):
+ super(self.__class__, self).__init__(path)
+
+ def getOutputDirectory(self):
+ return os.path.join(self.root_output_path,
+ 'stack') # the root path of generated files
+
+ def getTestCases(self):
+ '''
+ this returns a a hash containg test cases.
+ key of return hash is test case name and
+ value of return hash is test is a list of input tensor metadata.
+ test name (key of hash) is used as
+ - prefix of file name to be generated (don't use white space or special characters)
+ - output node name pf graph
+ '''
+ return {"stack_4d": [base.Tensor([1, 4, 3]), base.Tensor([1, 4, 3])]}
+
+ def buildModel(self, sess, test_case_tensor, tc_name):
+ '''
+ This method is called per test case (defined by getTestCases()).
+
+ keyword argument:
+ test_case_tensor -- test case tensor metadata
+ For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] }
+ test_case_tensor is [base.Tensor([5]), base.Tensor([5])]
+ '''
+
+ input_list = []
+
+ # ------ modify below for your model FROM here -------#
+ x_tensor = self.createTFInput(test_case_tensor[0], input_list)
+ y_tensor = self.createTFInput(test_case_tensor[1], input_list)
+
+ output_node = tf.stack([x_tensor, y_tensor], 0, name=tc_name)
+ # ------ modify UNTIL here for your model -------#
+
+ # Note if don't have any CONST value, creating checkpoint file fails.
+ # The next lines insert such (CONST) to prevent such error.
+ # So, Graph.pb/pbtxt contains this garbage info,
+ # but this garbage info will be removed in Graph_frozen.pb/pbtxt
+ garbage = tf.get_variable(
+ "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer())
+ init_op = tf.global_variables_initializer()
+ garbage_value = [0]
+ sess.run(tf.assign(garbage, garbage_value))
+
+ sess.run(init_op)
+
+ # ------ modify appropriate return value -------#
+
+ # returning (input_node_list, output_node_list)
+ return (input_list, [output_node])
+
+
+# --------
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(
+ description='Converted Tensorflow model in python to frozen model.')
+ parser.add_argument(
+ "out_dir",
+ help=
+ "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored."
+ )
+
+ args = parser.parse_args()
+ root_output_path = args.out_dir
+
+ Gen(root_output_path).createSaveFreezeModel()
diff --git a/tools/tensorflow_model_freezer/sample/TOPK_gen.py b/tools/tensorflow_model_freezer/sample/TOPK_gen.py
index 0c16d5b75..8f1882bd1 100755
--- a/tools/tensorflow_model_freezer/sample/TOPK_gen.py
+++ b/tools/tensorflow_model_freezer/sample/TOPK_gen.py
@@ -63,6 +63,7 @@ class Gen(base.BaseFreezer):
'''
input_list = []
+ output_list = []
# ------ modify below for your model FROM here -------#
@@ -70,11 +71,13 @@ class Gen(base.BaseFreezer):
y_tensor = self.createTFInput(test_case_tensor[1], input_list)
# defining output node and input list
- output_node = tf.nn.top_k(
+ values_op, indices_op = tf.nn.top_k(
x_tensor,
y_tensor, # add your input here
name=tc_name) # do not modify name
+ output_list.append(values_op)
+ output_list.append(indices_op)
# ------ modify UNTIL here for your model -------#
# Note if don't have any CONST value, creating checkpoint file fails.
@@ -92,16 +95,9 @@ class Gen(base.BaseFreezer):
# ------ modify appropriate return value -------#
# returning (input_node_list, output_node_list)
- return (input_list, [output_node])
+ return (input_list, output_list)
-'''
-How to run
-$ chmod +x tools/tensorflow_model_freezer/sample/name_of_this_file.py
-$ PYTHONPATH=$PYTHONPATH:./tools/tensorflow_model_freezer/ \
- tools/tensorflow_model_freezer/sample/name_of_this_file.py \
- ~/temp # directory where model files are saved
-'''
# --------
if __name__ == "__main__":
diff --git a/tools/tensorflow_model_freezer/sample/UNSTACK_gen.py b/tools/tensorflow_model_freezer/sample/UNSTACK_gen.py
new file mode 100644
index 000000000..c5bce0d18
--- /dev/null
+++ b/tools/tensorflow_model_freezer/sample/UNSTACK_gen.py
@@ -0,0 +1,100 @@
+#!/usr/bin/python
+
+# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import platform
+import tensorflow as tf
+import argparse
+
+import base_freezer as base
+import model_freezer_util as util
+
+
+class Gen(base.BaseFreezer):
+ '''
+ class to generate tflite files for MUL
+ '''
+
+ def __init__(self, path):
+ super(self.__class__, self).__init__(path)
+
+ def getOutputDirectory(self):
+ return os.path.join(self.root_output_path,
+ 'unstack') # the root path of generated files
+
+ def getTestCases(self):
+ '''
+ this returns a a hash containg test cases.
+ key of return hash is test case name and
+ value of return hash is test is a list of input tensor metadata.
+ test name (key of hash) is used as
+ - prefix of file name to be generated (don't use white space or special characters)
+ - output node name pf graph
+ '''
+ return {"unstack_4d": [base.Tensor([4, 4, 3])]}
+
+ def buildModel(self, sess, test_case_tensor, tc_name):
+ '''
+ This method is called per test case (defined by getTestCases()).
+
+ keyword argument:
+ test_case_tensor -- test case tensor metadata
+ For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] }
+ test_case_tensor is [base.Tensor([5]), base.Tensor([5])]
+ '''
+
+ input_list = []
+
+ # ------ modify below for your model FROM here -------#
+ x_tensor = self.createTFInput(test_case_tensor[0], input_list)
+
+ output_node = tf.unstack([x_tensor], 4, 1, name=tc_name)
+ # ------ modify UNTIL here for your model -------#
+
+ # Note if don't have any CONST value, creating checkpoint file fails.
+ # The next lines insert such (CONST) to prevent such error.
+ # So, Graph.pb/pbtxt contains this garbage info,
+ # but this garbage info will be removed in Graph_frozen.pb/pbtxt
+ garbage = tf.get_variable(
+ "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer())
+ init_op = tf.global_variables_initializer()
+ garbage_value = [0]
+ sess.run(tf.assign(garbage, garbage_value))
+
+ sess.run(init_op)
+
+ # ------ modify appropriate return value -------#
+
+ # returning (input_node_list, output_node_list)
+ return (input_list, [output_node])
+
+
+# --------
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(
+ description='Converted Tensorflow model in python to frozen model.')
+ parser.add_argument(
+ "out_dir",
+ help=
+ "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored."
+ )
+
+ args = parser.parse_args()
+ root_output_path = args.out_dir
+
+ Gen(root_output_path).createSaveFreezeModel()