diff options
Diffstat (limited to 'tools/tensorflow_model_freezer/sample')
-rwxr-xr-x | tools/tensorflow_model_freezer/sample/ARGMAX_gen.py | 100 | ||||
-rwxr-xr-x | tools/tensorflow_model_freezer/sample/ARGMIN_gen.py | 100 | ||||
-rwxr-xr-x | tools/tensorflow_model_freezer/sample/DIV_gen.py | 7 | ||||
-rwxr-xr-x | tools/tensorflow_model_freezer/sample/LOGICAL_AND_gen.py | 105 | ||||
-rwxr-xr-x | tools/tensorflow_model_freezer/sample/LOGICAL_NOT_gen.py | 100 | ||||
-rwxr-xr-x | tools/tensorflow_model_freezer/sample/LOGICAL_OR_gen.py | 104 | ||||
-rwxr-xr-x | tools/tensorflow_model_freezer/sample/MUL_gen.py | 11 | ||||
-rwxr-xr-x | tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py | 7 | ||||
-rwxr-xr-x | tools/tensorflow_model_freezer/sample/STACK_gen.py | 101 | ||||
-rwxr-xr-x | tools/tensorflow_model_freezer/sample/TOPK_gen.py | 14 | ||||
-rw-r--r-- | tools/tensorflow_model_freezer/sample/UNSTACK_gen.py | 100 |
11 files changed, 717 insertions, 32 deletions
diff --git a/tools/tensorflow_model_freezer/sample/ARGMAX_gen.py b/tools/tensorflow_model_freezer/sample/ARGMAX_gen.py new file mode 100755 index 000000000..68e2262c0 --- /dev/null +++ b/tools/tensorflow_model_freezer/sample/ARGMAX_gen.py @@ -0,0 +1,100 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import platform +import tensorflow as tf +import argparse + +import base_freezer as base +import model_freezer_util as util + + +class Gen(base.BaseFreezer): + ''' + class to generate tflite files for MUL + ''' + + def __init__(self, path): + super(self.__class__, self).__init__(path) + + def getOutputDirectory(self): + return os.path.join(self.root_output_path, + 'argmax') # the root path of generated files + + def getTestCases(self): + ''' + this returns a a hash containg test cases. + key of return hash is test case name and + value of return hash is test is a list of input tensor metadata. + test name (key of hash) is used as + - prefix of file name to be generated (don't use white space or special characters) + - output node name pf graph + ''' + return {"argmax_4d": [base.Tensor([1, 2, 4, 3])]} + + def buildModel(self, sess, test_case_tensor, tc_name): + ''' + This method is called per test case (defined by getTestCases()). + + keyword argument: + test_case_tensor -- test case tensor metadata + For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] } + test_case_tensor is [base.Tensor([5]), base.Tensor([5])] + ''' + + input_list = [] + + # ------ modify below for your model FROM here -------# + x_tensor = self.createTFInput(test_case_tensor[0], input_list) + + output_node = tf.arg_max(x_tensor, 0, output_type=tf.int32, name=tc_name) + # ------ modify UNTIL here for your model -------# + + # Note if don't have any CONST value, creating checkpoint file fails. + # The next lines insert such (CONST) to prevent such error. + # So, Graph.pb/pbtxt contains this garbage info, + # but this garbage info will be removed in Graph_frozen.pb/pbtxt + garbage = tf.get_variable( + "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer()) + init_op = tf.global_variables_initializer() + garbage_value = [0] + sess.run(tf.assign(garbage, garbage_value)) + + sess.run(init_op) + + # ------ modify appropriate return value -------# + + # returning (input_node_list, output_node_list) + return (input_list, [output_node]) + + +# -------- +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description='Converted Tensorflow model in python to frozen model.') + parser.add_argument( + "out_dir", + help= + "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored." + ) + + args = parser.parse_args() + root_output_path = args.out_dir + + Gen(root_output_path).createSaveFreezeModel() diff --git a/tools/tensorflow_model_freezer/sample/ARGMIN_gen.py b/tools/tensorflow_model_freezer/sample/ARGMIN_gen.py new file mode 100755 index 000000000..68b399234 --- /dev/null +++ b/tools/tensorflow_model_freezer/sample/ARGMIN_gen.py @@ -0,0 +1,100 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import platform +import tensorflow as tf +import argparse + +import base_freezer as base +import model_freezer_util as util + + +class Gen(base.BaseFreezer): + ''' + class to generate tflite files for MUL + ''' + + def __init__(self, path): + super(self.__class__, self).__init__(path) + + def getOutputDirectory(self): + return os.path.join(self.root_output_path, + 'argmin') # the root path of generated files + + def getTestCases(self): + ''' + this returns a a hash containg test cases. + key of return hash is test case name and + value of return hash is test is a list of input tensor metadata. + test name (key of hash) is used as + - prefix of file name to be generated (don't use white space or special characters) + - output node name pf graph + ''' + return {"argmin_4d": [base.Tensor([1, 2, 4, 3])]} + + def buildModel(self, sess, test_case_tensor, tc_name): + ''' + This method is called per test case (defined by getTestCases()). + + keyword argument: + test_case_tensor -- test case tensor metadata + For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] } + test_case_tensor is [base.Tensor([5]), base.Tensor([5])] + ''' + + input_list = [] + + # ------ modify below for your model FROM here -------# + x_tensor = self.createTFInput(test_case_tensor[0], input_list) + + output_node = tf.arg_min(x_tensor, 0, name=tc_name) + # ------ modify UNTIL here for your model -------# + + # Note if don't have any CONST value, creating checkpoint file fails. + # The next lines insert such (CONST) to prevent such error. + # So, Graph.pb/pbtxt contains this garbage info, + # but this garbage info will be removed in Graph_frozen.pb/pbtxt + garbage = tf.get_variable( + "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer()) + init_op = tf.global_variables_initializer() + garbage_value = [0] + sess.run(tf.assign(garbage, garbage_value)) + + sess.run(init_op) + + # ------ modify appropriate return value -------# + + # returning (input_node_list, output_node_list) + return (input_list, [output_node]) + + +# -------- +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description='Converted Tensorflow model in python to frozen model.') + parser.add_argument( + "out_dir", + help= + "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored." + ) + + args = parser.parse_args() + root_output_path = args.out_dir + + Gen(root_output_path).createSaveFreezeModel() diff --git a/tools/tensorflow_model_freezer/sample/DIV_gen.py b/tools/tensorflow_model_freezer/sample/DIV_gen.py index c4e9cde07..d1b794cd7 100755 --- a/tools/tensorflow_model_freezer/sample/DIV_gen.py +++ b/tools/tensorflow_model_freezer/sample/DIV_gen.py @@ -124,13 +124,6 @@ class Gen(base.BaseFreezer): return (input_list, [output_node]) -''' -How to run -$ chmod +x tools/tensorflow_model_freezer/sample/name_of_this_file.py -$ PYTHONPATH=$PYTHONPATH:./tools/tensorflow_model_freezer/ \ - tools/tensorflow_model_freezer/sample/name_of_this_file.py \ - ~/temp # directory where model files are saved -''' # -------- if __name__ == "__main__": diff --git a/tools/tensorflow_model_freezer/sample/LOGICAL_AND_gen.py b/tools/tensorflow_model_freezer/sample/LOGICAL_AND_gen.py new file mode 100755 index 000000000..912af65b2 --- /dev/null +++ b/tools/tensorflow_model_freezer/sample/LOGICAL_AND_gen.py @@ -0,0 +1,105 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import platform +import tensorflow as tf +import argparse + +import base_freezer as base +import model_freezer_util as util + + +class Gen(base.BaseFreezer): + ''' + class to generate tflite files for MUL + ''' + + def __init__(self, path): + super(self.__class__, self).__init__(path) + + def getOutputDirectory(self): + return os.path.join(self.root_output_path, + 'logical_and') # the root path of generated files + + def getTestCases(self): + ''' + this returns a a hash containg test cases. + key of return hash is test case name and + value of return hash is test is a list of input tensor metadata. + test name (key of hash) is used as + - prefix of file name to be generated (don't use white space or special characters) + - output node name pf graph + ''' + return {"logical_and_4d": [base.Tensor([1, 2, 4, 3]), base.Tensor([1, 2, 4, 3])]} + + def buildModel(self, sess, test_case_tensor, tc_name): + ''' + This method is called per test case (defined by getTestCases()). + + keyword argument: + test_case_tensor -- test case tensor metadata + For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] } + test_case_tensor is [base.Tensor([5]), base.Tensor([5])] + ''' + + input_list = [] + + # ------ modify below for your model FROM here -------# + x_tensor = self.createTFInput(test_case_tensor[0], input_list) + y_tensor = self.createTFInput(test_case_tensor[1], input_list) + + output_node = tf.logical_and( + tf.greater(x_tensor, tf.constant(0.0)), + tf.less(y_tensor, tf.constant(1.0)), + name=tc_name) + + # ------ modify UNTIL here for your model -------# + + # Note if don't have any CONST value, creating checkpoint file fails. + # The next lines insert such (CONST) to prevent such error. + # So, Graph.pb/pbtxt contains this garbage info, + # but this garbage info will be removed in Graph_frozen.pb/pbtxt + garbage = tf.get_variable( + "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer()) + init_op = tf.global_variables_initializer() + garbage_value = [0] + sess.run(tf.assign(garbage, garbage_value)) + + sess.run(init_op) + + # ------ modify appropriate return value -------# + + # returning (input_node_list, output_node_list) + return (input_list, [output_node]) + + +# -------- +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description='Converted Tensorflow model in python to frozen model.') + parser.add_argument( + "out_dir", + help= + "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored." + ) + + args = parser.parse_args() + root_output_path = args.out_dir + + Gen(root_output_path).createSaveFreezeModel() diff --git a/tools/tensorflow_model_freezer/sample/LOGICAL_NOT_gen.py b/tools/tensorflow_model_freezer/sample/LOGICAL_NOT_gen.py new file mode 100755 index 000000000..34c0994ea --- /dev/null +++ b/tools/tensorflow_model_freezer/sample/LOGICAL_NOT_gen.py @@ -0,0 +1,100 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import platform +import tensorflow as tf +import argparse + +import base_freezer as base +import model_freezer_util as util + + +class Gen(base.BaseFreezer): + ''' + class to generate tflite files for MUL + ''' + + def __init__(self, path): + super(self.__class__, self).__init__(path) + + def getOutputDirectory(self): + return os.path.join(self.root_output_path, + 'logical_not') # the root path of generated files + + def getTestCases(self): + ''' + this returns a a hash containg test cases. + key of return hash is test case name and + value of return hash is test is a list of input tensor metadata. + test name (key of hash) is used as + - prefix of file name to be generated (don't use white space or special characters) + - output node name pf graph + ''' + return {"logical_not_4d": [base.Tensor([1, 2, 4, 3])]} + + def buildModel(self, sess, test_case_tensor, tc_name): + ''' + This method is called per test case (defined by getTestCases()). + + keyword argument: + test_case_tensor -- test case tensor metadata + For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] } + test_case_tensor is [base.Tensor([5]), base.Tensor([5])] + ''' + + input_list = [] + + # ------ modify below for your model FROM here -------# + x_tensor = self.createTFInput(test_case_tensor[0], input_list) + + output_node = tf.logical_not(tf.greater(x_tensor, tf.constant(0.0)), name=tc_name) + # ------ modify UNTIL here for your model -------# + + # Note if don't have any CONST value, creating checkpoint file fails. + # The next lines insert such (CONST) to prevent such error. + # So, Graph.pb/pbtxt contains this garbage info, + # but this garbage info will be removed in Graph_frozen.pb/pbtxt + garbage = tf.get_variable( + "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer()) + init_op = tf.global_variables_initializer() + garbage_value = [0] + sess.run(tf.assign(garbage, garbage_value)) + + sess.run(init_op) + + # ------ modify appropriate return value -------# + + # returning (input_node_list, output_node_list) + return (input_list, [output_node]) + + +# -------- +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description='Converted Tensorflow model in python to frozen model.') + parser.add_argument( + "out_dir", + help= + "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored." + ) + + args = parser.parse_args() + root_output_path = args.out_dir + + Gen(root_output_path).createSaveFreezeModel() diff --git a/tools/tensorflow_model_freezer/sample/LOGICAL_OR_gen.py b/tools/tensorflow_model_freezer/sample/LOGICAL_OR_gen.py new file mode 100755 index 000000000..714a52e8d --- /dev/null +++ b/tools/tensorflow_model_freezer/sample/LOGICAL_OR_gen.py @@ -0,0 +1,104 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import platform +import tensorflow as tf +import argparse + +import base_freezer as base +import model_freezer_util as util + + +class Gen(base.BaseFreezer): + ''' + class to generate tflite files for MUL + ''' + + def __init__(self, path): + super(self.__class__, self).__init__(path) + + def getOutputDirectory(self): + return os.path.join(self.root_output_path, + 'logical_or') # the root path of generated files + + def getTestCases(self): + ''' + this returns a a hash containg test cases. + key of return hash is test case name and + value of return hash is test is a list of input tensor metadata. + test name (key of hash) is used as + - prefix of file name to be generated (don't use white space or special characters) + - output node name pf graph + ''' + return {"logical_or_4d": [base.Tensor([1, 2, 4, 3]), base.Tensor([1, 2, 4, 3])]} + + def buildModel(self, sess, test_case_tensor, tc_name): + ''' + This method is called per test case (defined by getTestCases()). + + keyword argument: + test_case_tensor -- test case tensor metadata + For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] } + test_case_tensor is [base.Tensor([5]), base.Tensor([5])] + ''' + + input_list = [] + + # ------ modify below for your model FROM here -------# + x_tensor = self.createTFInput(test_case_tensor[0], input_list) + y_tensor = self.createTFInput(test_case_tensor[1], input_list) + + output_node = tf.logical_or( + tf.greater(x_tensor, tf.constant(0.0)), + tf.less(y_tensor, tf.constant(1.0)), + name=tc_name) + # ------ modify UNTIL here for your model -------# + + # Note if don't have any CONST value, creating checkpoint file fails. + # The next lines insert such (CONST) to prevent such error. + # So, Graph.pb/pbtxt contains this garbage info, + # but this garbage info will be removed in Graph_frozen.pb/pbtxt + garbage = tf.get_variable( + "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer()) + init_op = tf.global_variables_initializer() + garbage_value = [0] + sess.run(tf.assign(garbage, garbage_value)) + + sess.run(init_op) + + # ------ modify appropriate return value -------# + + # returning (input_node_list, output_node_list) + return (input_list, [output_node]) + + +# -------- +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description='Converted Tensorflow model in python to frozen model.') + parser.add_argument( + "out_dir", + help= + "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored." + ) + + args = parser.parse_args() + root_output_path = args.out_dir + + Gen(root_output_path).createSaveFreezeModel() diff --git a/tools/tensorflow_model_freezer/sample/MUL_gen.py b/tools/tensorflow_model_freezer/sample/MUL_gen.py index f2a92547b..596898dbb 100755 --- a/tools/tensorflow_model_freezer/sample/MUL_gen.py +++ b/tools/tensorflow_model_freezer/sample/MUL_gen.py @@ -59,8 +59,8 @@ class Gen(base.BaseFreezer): "mul_1d_scalarConst": [base.Tensor([5]), base.Tensor([], const_val=1.1)], # mul by scalar "mul_2d_scalarConst": [base.Tensor([5, 3]), - base.Tensor([], const_val=1.1)], - "mul_1d_scalar": [base.Tensor([5, 3]), base.Tensor([])] + base.Tensor([], const_val=1.1)] + # "mul_2d_scalar": [base.Tensor([5, 3]), base.Tensor([])] # not support scalar input } def buildModel(self, sess, test_case_tensor, tc_name): @@ -104,13 +104,6 @@ class Gen(base.BaseFreezer): return (input_list, [output_node]) -''' -How to run -$ chmod +x tools/tensorflow_model_freezer/sample/name_of_this_file.py -$ PYTHONPATH=$PYTHONPATH:./tools/tensorflow_model_freezer/ \ - tools/tensorflow_model_freezer/sample/name_of_this_file.py \ - ~/temp # directory where model files are saved -''' # -------- if __name__ == "__main__": diff --git a/tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py b/tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py index 88b3dfcb2..12fb5122e 100755 --- a/tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py +++ b/tools/tensorflow_model_freezer/sample/SQUEEZE_gen.py @@ -103,13 +103,6 @@ class Gen(base.BaseFreezer): return (input_list, [output_node]) -''' -How to run -$ chmod +x tools/tensorflow_model_freezer/sample/name_of_this_file.py -$ PYTHONPATH=$PYTHONPATH:./tools/tensorflow_model_freezer/ \ - tools/tensorflow_model_freezer/sample/name_of_this_file.py \ - ~/temp # directory where model files are saved -''' # -------- if __name__ == "__main__": diff --git a/tools/tensorflow_model_freezer/sample/STACK_gen.py b/tools/tensorflow_model_freezer/sample/STACK_gen.py new file mode 100755 index 000000000..2bea40698 --- /dev/null +++ b/tools/tensorflow_model_freezer/sample/STACK_gen.py @@ -0,0 +1,101 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import platform +import tensorflow as tf +import argparse + +import base_freezer as base +import model_freezer_util as util + + +class Gen(base.BaseFreezer): + ''' + class to generate tflite files for MUL + ''' + + def __init__(self, path): + super(self.__class__, self).__init__(path) + + def getOutputDirectory(self): + return os.path.join(self.root_output_path, + 'stack') # the root path of generated files + + def getTestCases(self): + ''' + this returns a a hash containg test cases. + key of return hash is test case name and + value of return hash is test is a list of input tensor metadata. + test name (key of hash) is used as + - prefix of file name to be generated (don't use white space or special characters) + - output node name pf graph + ''' + return {"stack_4d": [base.Tensor([1, 4, 3]), base.Tensor([1, 4, 3])]} + + def buildModel(self, sess, test_case_tensor, tc_name): + ''' + This method is called per test case (defined by getTestCases()). + + keyword argument: + test_case_tensor -- test case tensor metadata + For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] } + test_case_tensor is [base.Tensor([5]), base.Tensor([5])] + ''' + + input_list = [] + + # ------ modify below for your model FROM here -------# + x_tensor = self.createTFInput(test_case_tensor[0], input_list) + y_tensor = self.createTFInput(test_case_tensor[1], input_list) + + output_node = tf.stack([x_tensor, y_tensor], 0, name=tc_name) + # ------ modify UNTIL here for your model -------# + + # Note if don't have any CONST value, creating checkpoint file fails. + # The next lines insert such (CONST) to prevent such error. + # So, Graph.pb/pbtxt contains this garbage info, + # but this garbage info will be removed in Graph_frozen.pb/pbtxt + garbage = tf.get_variable( + "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer()) + init_op = tf.global_variables_initializer() + garbage_value = [0] + sess.run(tf.assign(garbage, garbage_value)) + + sess.run(init_op) + + # ------ modify appropriate return value -------# + + # returning (input_node_list, output_node_list) + return (input_list, [output_node]) + + +# -------- +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description='Converted Tensorflow model in python to frozen model.') + parser.add_argument( + "out_dir", + help= + "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored." + ) + + args = parser.parse_args() + root_output_path = args.out_dir + + Gen(root_output_path).createSaveFreezeModel() diff --git a/tools/tensorflow_model_freezer/sample/TOPK_gen.py b/tools/tensorflow_model_freezer/sample/TOPK_gen.py index 0c16d5b75..8f1882bd1 100755 --- a/tools/tensorflow_model_freezer/sample/TOPK_gen.py +++ b/tools/tensorflow_model_freezer/sample/TOPK_gen.py @@ -63,6 +63,7 @@ class Gen(base.BaseFreezer): ''' input_list = [] + output_list = [] # ------ modify below for your model FROM here -------# @@ -70,11 +71,13 @@ class Gen(base.BaseFreezer): y_tensor = self.createTFInput(test_case_tensor[1], input_list) # defining output node and input list - output_node = tf.nn.top_k( + values_op, indices_op = tf.nn.top_k( x_tensor, y_tensor, # add your input here name=tc_name) # do not modify name + output_list.append(values_op) + output_list.append(indices_op) # ------ modify UNTIL here for your model -------# # Note if don't have any CONST value, creating checkpoint file fails. @@ -92,16 +95,9 @@ class Gen(base.BaseFreezer): # ------ modify appropriate return value -------# # returning (input_node_list, output_node_list) - return (input_list, [output_node]) + return (input_list, output_list) -''' -How to run -$ chmod +x tools/tensorflow_model_freezer/sample/name_of_this_file.py -$ PYTHONPATH=$PYTHONPATH:./tools/tensorflow_model_freezer/ \ - tools/tensorflow_model_freezer/sample/name_of_this_file.py \ - ~/temp # directory where model files are saved -''' # -------- if __name__ == "__main__": diff --git a/tools/tensorflow_model_freezer/sample/UNSTACK_gen.py b/tools/tensorflow_model_freezer/sample/UNSTACK_gen.py new file mode 100644 index 000000000..c5bce0d18 --- /dev/null +++ b/tools/tensorflow_model_freezer/sample/UNSTACK_gen.py @@ -0,0 +1,100 @@ +#!/usr/bin/python + +# Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import platform +import tensorflow as tf +import argparse + +import base_freezer as base +import model_freezer_util as util + + +class Gen(base.BaseFreezer): + ''' + class to generate tflite files for MUL + ''' + + def __init__(self, path): + super(self.__class__, self).__init__(path) + + def getOutputDirectory(self): + return os.path.join(self.root_output_path, + 'unstack') # the root path of generated files + + def getTestCases(self): + ''' + this returns a a hash containg test cases. + key of return hash is test case name and + value of return hash is test is a list of input tensor metadata. + test name (key of hash) is used as + - prefix of file name to be generated (don't use white space or special characters) + - output node name pf graph + ''' + return {"unstack_4d": [base.Tensor([4, 4, 3])]} + + def buildModel(self, sess, test_case_tensor, tc_name): + ''' + This method is called per test case (defined by getTestCases()). + + keyword argument: + test_case_tensor -- test case tensor metadata + For example, if a test case is { "mul_1d_1d": [base.Tensor([5]), base.Tensor([5])] } + test_case_tensor is [base.Tensor([5]), base.Tensor([5])] + ''' + + input_list = [] + + # ------ modify below for your model FROM here -------# + x_tensor = self.createTFInput(test_case_tensor[0], input_list) + + output_node = tf.unstack([x_tensor], 4, 1, name=tc_name) + # ------ modify UNTIL here for your model -------# + + # Note if don't have any CONST value, creating checkpoint file fails. + # The next lines insert such (CONST) to prevent such error. + # So, Graph.pb/pbtxt contains this garbage info, + # but this garbage info will be removed in Graph_frozen.pb/pbtxt + garbage = tf.get_variable( + "garbage", [1], dtype=tf.float32, initializer=tf.zeros_initializer()) + init_op = tf.global_variables_initializer() + garbage_value = [0] + sess.run(tf.assign(garbage, garbage_value)) + + sess.run(init_op) + + # ------ modify appropriate return value -------# + + # returning (input_node_list, output_node_list) + return (input_list, [output_node]) + + +# -------- +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description='Converted Tensorflow model in python to frozen model.') + parser.add_argument( + "out_dir", + help= + "directory where generated pb, pbtxt, checkpoint and Tensorboard log files are stored." + ) + + args = parser.parse_args() + root_output_path = args.out_dir + + Gen(root_output_path).createSaveFreezeModel() |