summaryrefslogtreecommitdiff
path: root/tests/nnapi/nnapi_test_generator/android-p/vts_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/nnapi/nnapi_test_generator/android-p/vts_generator.py')
-rw-r--r--tests/nnapi/nnapi_test_generator/android-p/vts_generator.py247
1 files changed, 247 insertions, 0 deletions
diff --git a/tests/nnapi/nnapi_test_generator/android-p/vts_generator.py b/tests/nnapi/nnapi_test_generator/android-p/vts_generator.py
new file mode 100644
index 000000000..ab34e2bda
--- /dev/null
+++ b/tests/nnapi/nnapi_test_generator/android-p/vts_generator.py
@@ -0,0 +1,247 @@
+#!/usr/bin/python3
+
+# Copyright 2017, The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""VTS testcase generator
+
+Implements VTS test backend. Shares most logic with the CTS test
+generator. Invoked by ml/nn/runtime/test/specs/generate_vts_tests.sh;
+See that script for details on how this script is used.
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import argparse
+from functools import reduce
+import math
+import os
+import struct
+import sys
+import contextlib
+import test_generator
+import pprint
+# Stuff from test generator
+from test_generator import Configuration
+from test_generator import Example
+from test_generator import Float32Scalar
+from test_generator import IgnoredOutput
+from test_generator import Input
+from test_generator import Int32Scalar
+from test_generator import Internal
+from test_generator import Model
+from test_generator import Operand
+from test_generator import Output
+from test_generator import Parameter
+from test_generator import smart_open
+
+# Take a model from command line
+def import_source():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("spec", help="the spec file")
+ parser.add_argument(
+ "-m", "--model", help="the output model file", default="-")
+ parser.add_argument(
+ "-e", "--example", help="the output example file", default="-")
+ args = parser.parse_args()
+
+ if os.path.exists(args.spec):
+ test_generator.FileNames.SpecFile = os.path.basename(args.spec)
+ exec (open(args.spec).read())
+
+ return (args.model, args.example)
+
+# Generate operands in VTS format
+def generate_vts_operands():
+ # Dump operand definitions
+ op_def = """\
+ {{
+ .type = OperandType::{operand_type},
+ .dimensions = {shape},
+ .numberOfConsumers = {no_consumers},
+ .scale = {scale},
+ .zeroPoint = {zero_point},
+ .lifetime = OperandLifeTime::{lifetime},
+ .location = {{.poolIndex = 0, .offset = {offset}, .length = {length}}},
+ }}"""
+ offset = 0
+ op_definitions = []
+ for o in Operand.operands.objects():
+ ty = o.type
+ no_consumers = len(o.outs) if o.traversable() else 0
+ lifetime = o.lifetime()
+ length = ty.get_size() if o.is_weight() else 0
+ real_shape, scale, zero_point = ty.get_parsed_shape()
+ scale = float(scale)
+ zero_point = int(zero_point)
+ op = {
+ "operand_type": ty.get_element_type(),
+ "shape": "{%s}" % real_shape,
+ "no_consumers": no_consumers,
+ "scale": test_generator.pretty_print_as_float(scale),
+ "zero_point": str(int(zero_point)),
+ "lifetime": lifetime,
+ "offset": offset if o.is_weight() else 0,
+ "length": length
+ }
+ offset += length
+ op_definitions.append(op_def.format(**op))
+
+ op_vec = """\
+ const std::vector<Operand> operands = {{
+{0}
+ }};""".format(",\n".join(op_definitions))
+ return op_vec
+
+# Generate VTS operand values
+def generate_vts_operand_values():
+ weights = [o for o in Operand.operands.objects() if o.is_weight()]
+ binit = []
+ for w in weights:
+ ty = w.type.get_element_type()
+ if ty == "TENSOR_QUANT8_ASYMM":
+ binit += w.initializer
+ elif ty in {"TENSOR_FLOAT32", "FLOAT32", "TENSOR_INT32", "INT32"}:
+ fmt = "f" if (ty == "TENSOR_FLOAT32" or ty == "FLOAT32") else "i"
+ for f in w.initializer:
+ binit += [int(x) for x in struct.pack(fmt, f)]
+ else:
+ assert 0 and "Unsupported VTS operand type"
+
+ init_defs = ", ".join([str(x) for x in binit])
+ if (init_defs != ""):
+ init_defs = "\n %s\n " % init_defs
+ byte_vec_fmt = """{%s}""" % init_defs
+ return byte_vec_fmt
+
+# Generate VTS operations
+class VTSOps(object):
+ vts_ops = []
+ def generate_vts_operation(op):
+ try:
+ opcode =op.optype
+ except AttributeError: # not an op, but things like weights
+ return
+ op_fmt = """\
+ {{
+ .type = OperationType::{op_code},
+ .inputs = {{{ins}}},
+ .outputs = {{{outs}}},
+ }}"""
+ op_content = {
+ 'op_code': op.optype,
+ 'op_type': op.type.get_element_type(),
+ 'ins': ", ".join([str(x.ID()) for x in op.ins]),
+ 'outs': ", ".join([str(x.ID()) for x in op.outs]),
+ }
+ VTSOps.vts_ops.append(op_fmt.format(**op_content))
+ return True
+
+def generate_vts_operations(model_file):
+ test_generator.TopologicalSort(lambda x: VTSOps.generate_vts_operation(x))
+ return ",\n".join(VTSOps.vts_ops)
+
+
+def generate_vts_model(model_file):
+ operand_values_fmt = ""
+ if Configuration.useSHM():
+ # Boilerplate code for passing weights in shared memory
+ operand_values_fmt = """\
+ std::vector<uint8_t> operandValues = {{}};
+ const uint8_t data[] = {operand_values};
+
+ // Allocate segment of android shared memory, wrapped in hidl_memory.
+ // This object will be automatically freed when sharedMemory is destroyed.
+ hidl_memory sharedMemory = allocateSharedMemory(sizeof(data));
+
+ // Mmap ashmem into usable address and hold it within the mappedMemory object.
+ // MappedMemory will automatically munmap the memory when it is destroyed.
+ sp<IMemory> mappedMemory = mapMemory(sharedMemory);
+
+ if (mappedMemory != nullptr) {{
+ // Retrieve the mmapped pointer.
+ uint8_t* mappedPointer =
+ static_cast<uint8_t*>(static_cast<void*>(mappedMemory->getPointer()));
+
+ if (mappedPointer != nullptr) {{
+ // Acquire the write lock for the shared memory segment, upload the data,
+ // and release the lock.
+ mappedMemory->update();
+ std::copy(data, data + sizeof(data), mappedPointer);
+ mappedMemory->commit();
+ }}
+ }}
+
+ const std::vector<hidl_memory> pools = {{sharedMemory}};
+"""
+ else:
+ # Passing weights via operandValues
+ operand_values_fmt = """\
+ std::vector<uint8_t> operandValues = {operand_values};
+ const std::vector<hidl_memory> pools = {{}};
+"""
+
+ operand_values_val = {
+ 'operand_values': generate_vts_operand_values()
+ }
+ operand_values = operand_values_fmt.format(**operand_values_val)
+ # operand_values = operand_values_fmt
+ model_fmt = """\
+// Generated code. Do not edit
+// Create the model
+Model createTestModel() {{
+{operand_decls}
+
+ const std::vector<Operation> operations = {{
+{operations}
+ }};
+
+ const std::vector<uint32_t> inputIndexes = {{{input_indices}}};
+ const std::vector<uint32_t> outputIndexes = {{{output_indices}}};
+{operand_values}
+ return {{
+ .operands = operands,
+ .operations = operations,
+ .inputIndexes = inputIndexes,
+ .outputIndexes = outputIndexes,
+ .operandValues = operandValues,
+ .pools = pools,{relaxed_field}
+ }};
+}}
+"""
+ model = {
+ "operations": generate_vts_operations(sys.stdout),
+ "operand_decls": generate_vts_operands(),
+ "operand_values": operand_values,
+ "output_indices": ", ".join([str(i.ID()) for i in Output.get_outputs()]),
+ "input_indices": ", ".join([str(i.ID()) for i in Input.get_inputs(True)]),
+ "relaxed_field":
+ "\n .relaxComputationFloat32toFloat16 = true," if (Model.isRelaxed()) else ""
+ }
+ print(model_fmt.format(**model), file = model_file)
+
+def generate_vts(model_file):
+ generate_vts_model(model_file)
+ print (IgnoredOutput.gen_ignored(), file=model_file)
+
+if __name__ == "__main__":
+ (model, example) = import_source()
+ print("Output VTS model: %s" % model, file=sys.stderr)
+ print("Output example:" + example, file=sys.stderr)
+
+ with smart_open(model) as model_file:
+ generate_vts(model_file)
+ with smart_open(example) as example_file:
+ Example.dump(example_file)