summaryrefslogtreecommitdiff
path: root/externals/nnapi_test_generator/slicing.py
blob: f08e9d1a1f985d096dbcf27cad71410def7e298a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#!/usr/bin/python3

# Copyright 2017, The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Slicing the input Model file

Invoked by ml/nn/runtime/test/specs/slicing.sh; this Python code is
not intended to be invoked directly by the users. See that script for
details on how to use the slicing tool is used.

This script does the following work:

Perform a topological sort similar to the test generator, except that:
* It would stop at the N-th operation it encounters, and
* Rename the output of the N-th operation to a model output, and
* Name that as the output of the model.
* Also only inputs and weights used by the submodel would be emitted.

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from functools import reduce
import math
import os
import struct
import sys
import contextlib
import test_generator
import pprint
# Stuff from test generator
from test_generator import Example
from test_generator import Float32Scalar
from test_generator import Input
from test_generator import Int32Scalar
from test_generator import Internal
from test_generator import Model
from test_generator import Output
from test_generator import Parameter
from test_generator import smart_open


# Take a model from command line
def import_source():
  parser = argparse.ArgumentParser()
  parser.add_argument("spec", help="the spec file")
  parser.add_argument(
      "-n", "--number",
      help="number of operations in the sliced model. Default = 1",
      default=1)
  parser.add_argument(
      "-m", "--model", help="the output model file", default="-")
  parser.add_argument(
      "-e", "--example", help="the output example file", default="-")
  args = parser.parse_args()

  if os.path.exists(args.spec):
    test_generator.FileNames.SpecFile = os.path.basename(args.spec)
    exec (open(args.spec).read())

  return (args.model, args.example, args.number)


# Slice till the Nth op the topological sort finds
# the output of that op becomes the output of the model
class slicing:

  def __init__(self, threshold):
    self.__nr_op_seen = 0
    self.__threshold = threshold
    self.__last_outs = []
    self.__all_formatted_ops = []
    self.__referenced_operands = set()

  def format_as_py_op(self, op):
    try:
      fmt = op.PyDefinition()
    except AttributeError:  # not an op, but things like weights
      return True
    if fmt is not None:
      self.__nr_op_seen += 1
      if self.__nr_op_seen > self.__threshold:
        return False
      self.__last_outs = op.outs
      for o in op.ins:
        self.__referenced_operands.add(o)
      for o in op.outs:
        self.__referenced_operands.add(o)
      self.__all_formatted_ops.append("model = model.%s" % fmt)
      return True

  def dump(self, model_file):
    for x in self.__all_formatted_ops:
      print(x, file=model_file)

  def dump_example(self, example_file):
    override = {}
    # Make alias for the output variable
    for lo in self.__last_outs:
      override[lo.get_name()] = lo.type.get_nr_elements()
      alias_def = """\
# Alias for the output variable {operand_name}
aliased_output{number} = {operand_name}
"""
      op = {
          'operand_name': lo.get_name(),
          'number': 0 # only support one output as of now
      }
      print (alias_def.format(**op), file=example_file)
    Example.py_dump(example_file, override, self.__referenced_operands)

  def format_operands(self):
    # Dump operand definitions
    op_definitions = []
    for o in test_generator.Operand.operands.objects():
      if o not in self.__referenced_operands:
        continue
      ty = o.type
      raw_shape = ty.get_raw_shape()
      op_def = """{op_name} = {operand}("{op_name}", "{element_type}", "{shape}" """
      if isinstance(o, test_generator.Parameter):
        op_def += """, {initializer})"""
        init = o.initializer
        py_operand_name = "Parameter"
      else:
        op_def += ")"
        init = []
        py_operand_name = "IgnoredOutput" if o in set(
            self.__last_outs) else o.__class__.__name__

      op = {
          "element_type": ty.get_element_type(),
          "shape": ty.get_raw_shape(),
          "op_name": o.get_name(),
          "operand": py_operand_name,
          "initializer": init
      }
      op_definitions.append(op_def.format(**op))
    return "\n".join(op_definitions)


if __name__ == "__main__":
  (model, example, number) = import_source()
  s = slicing(int(number))

  with smart_open(model) as model_file:
    spec_file = " (from: %s)" % (test_generator.FileNames.SpecFile)
    print("# Generated file%s. Do not edit" % (spec_file), file=model_file)
    print("model = Model()", file=model_file)
    test_generator.TopologicalSort(lambda x: s.format_as_py_op(x))
    print(s.format_operands(), file=model_file)
    s.dump(model_file)
  with smart_open(example) as example_file:
    s.dump_example(example_file)