summaryrefslogtreecommitdiff
path: root/compiler/ann-ref/src/Validation.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/ann-ref/src/Validation.cpp')
-rw-r--r--compiler/ann-ref/src/Validation.cpp263
1 files changed, 263 insertions, 0 deletions
diff --git a/compiler/ann-ref/src/Validation.cpp b/compiler/ann-ref/src/Validation.cpp
new file mode 100644
index 000000000..679b14a9a
--- /dev/null
+++ b/compiler/ann-ref/src/Validation.cpp
@@ -0,0 +1,263 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Validation.h"
+#include "Macro.h"
+#include "Assert.h"
+
+static inline bool validCode(uint32_t codeCount, uint32_t code)
+{
+ return (code < codeCount);
+}
+
+int validateOperationType(const OperationType &type)
+{
+ return validCode(kNumberOfOperationTypes, static_cast<uint32_t>(type));
+}
+
+// Validates the type. The used dimensions can be underspecified.
+int validateOperandType(const ANeuralNetworksOperandType &type, const char *tag, bool allowPartial)
+{
+ if (!allowPartial)
+ {
+ for (uint32_t i = 0; i < type.dimensionCount; i++)
+ {
+ if (type.dimensions[i] == 0)
+ {
+ LOG(ERROR) << tag << " OperandType invalid dimensions[" << i
+ << "] = " << type.dimensions[i];
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ }
+ }
+ if (!validCode(kNumberOfDataTypes, type.type))
+ {
+ LOG(ERROR) << tag << " OperandType invalid type " << type.type;
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ if (type.type == ANEURALNETWORKS_TENSOR_QUANT8_ASYMM)
+ {
+ if (type.zeroPoint < 0 || type.zeroPoint > 255)
+ {
+ LOG(ERROR) << tag << " OperandType invalid zeroPoint " << type.zeroPoint;
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ if (type.scale < 0.f)
+ {
+ LOG(ERROR) << tag << " OperandType invalid scale " << type.scale;
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ }
+
+ // TODO-NNRT : add 'type.type == ANEURALNETWORKS_OEM_SCALAR' later.
+ // OEM operaters are not supported now.
+ if (type.type == ANEURALNETWORKS_FLOAT32 || type.type == ANEURALNETWORKS_INT32 ||
+ type.type == ANEURALNETWORKS_UINT32)
+ {
+ if (type.dimensionCount != 0 || type.dimensions != nullptr)
+ {
+ LOG(ERROR) << tag << " Invalid dimensions for scalar type";
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ }
+
+ return ANEURALNETWORKS_NO_ERROR;
+}
+
+int validateOperandList(uint32_t count, const uint32_t *list, uint32_t operandCount,
+ const char *tag)
+{
+ for (uint32_t i = 0; i < count; i++)
+ {
+ if (list[i] >= operandCount)
+ {
+ LOG(ERROR) << tag << " invalid operand index at " << i << " = " << list[i]
+ << ", operandCount " << operandCount;
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ }
+ return ANEURALNETWORKS_NO_ERROR;
+}
+
+static bool validOperandIndexes(const std::vector<uint32_t> indexes, size_t operandCount)
+{
+ for (uint32_t i : indexes)
+ {
+ if (i >= operandCount)
+ {
+ LOG(ERROR) << "Index out of range " << i << "/" << operandCount;
+ return false;
+ }
+ }
+ return true;
+}
+
+static bool validOperands(const std::vector<Operand> &operands, const std::vector<uint8_t> &operandValues)
+{
+ for (auto &operand : operands)
+ {
+ if (!validCode(kNumberOfDataTypes, static_cast<uint32_t>(operand.type)))
+ {
+ LOG(ERROR) << "Invalid operand type ";
+ return false;
+ }
+ /* TODO validate dim with type
+ if (!validOperandIndexes(operand.dimensions, mDimensions)) {
+ return false;
+ }
+ */
+ switch (operand.lifetime)
+ {
+ case OperandLifeTime::CONSTANT_COPY:
+ if (operand.location.offset + operand.location.length > operandValues.size())
+ {
+ LOG(ERROR) << "OperandValue location out of range. Starts at " << operand.location.offset
+ << ", length " << operand.location.length << ", max " << operandValues.size();
+ return false;
+ }
+ break;
+ case OperandLifeTime::TEMPORARY_VARIABLE:
+ case OperandLifeTime::MODEL_INPUT:
+ case OperandLifeTime::MODEL_OUTPUT:
+ case OperandLifeTime::NO_VALUE:
+ if (operand.location.offset != 0 || operand.location.length != 0)
+ {
+ LOG(ERROR) << "Unexpected offset " << operand.location.offset << " or length "
+ << operand.location.length << " for runtime location.";
+ return false;
+ }
+ break;
+ case OperandLifeTime::CONSTANT_REFERENCE:
+#if 0
+ if (operand.location.poolIndex >= poolCount)
+ {
+ LOG(ERROR) << "Invalid poolIndex " << operand.location.poolIndex << "/" << poolCount;
+ return false;
+ }
+#endif
+ break;
+ // TODO: Validate that we are within the pool.
+ default:
+ LOG(ERROR) << "Invalid lifetime";
+ return false;
+ }
+ }
+ return true;
+}
+
+static bool validOperations(const std::vector<Operation> &operations, size_t operandCount)
+{
+ for (auto &op : operations)
+ {
+ if (!validCode(kNumberOfOperationTypes, static_cast<uint32_t>(op.type)))
+ {
+ LOG(ERROR) << "Invalid operation type ";
+ return false;
+ }
+ if (!validOperandIndexes(op.inputs, operandCount) ||
+ !validOperandIndexes(op.outputs, operandCount))
+ {
+ return false;
+ }
+ }
+ return true;
+}
+
+// TODO doublecheck
+bool validateModel(const Model &model)
+{
+ const size_t operandCount = model.operands.size();
+ return (validOperands(model.operands, model.operandValues) &&
+ validOperations(model.operations, operandCount) &&
+ validOperandIndexes(model.inputIndexes, operandCount) &&
+ validOperandIndexes(model.outputIndexes, operandCount));
+}
+
+bool validRequestArguments(const std::vector<RequestArgument> &arguments,
+ const std::vector<uint32_t> &operandIndexes,
+ const std::vector<Operand> &operands, size_t poolCount, const char *type)
+{
+ const size_t argumentCount = arguments.size();
+ if (argumentCount != operandIndexes.size())
+ {
+ LOG(ERROR) << "Request specifies " << argumentCount << " " << type << "s but the model has "
+ << operandIndexes.size();
+ return false;
+ }
+ for (size_t argumentIndex = 0; argumentIndex < argumentCount; argumentIndex++)
+ {
+ const RequestArgument &argument = arguments[argumentIndex];
+ const uint32_t operandIndex = operandIndexes[argumentIndex];
+ const Operand &operand = operands[operandIndex];
+ if (argument.hasNoValue)
+ {
+ if (argument.location.poolIndex != 0 || argument.location.offset != 0 ||
+ argument.location.length != 0 || argument.dimensions.size() != 0)
+ {
+ LOG(ERROR) << "Request " << type << " " << argumentIndex
+ << " has no value yet has details.";
+ return false;
+ }
+ }
+ if (argument.location.poolIndex >= poolCount)
+ {
+ LOG(ERROR) << "Request " << type << " " << argumentIndex << " has an invalid poolIndex "
+ << argument.location.poolIndex << "/" << poolCount;
+ return false;
+ }
+ // TODO: Validate that we are within the pool.
+ uint32_t rank = argument.dimensions.size();
+ if (rank > 0)
+ {
+ if (rank != operand.dimensions.size())
+ {
+ LOG(ERROR) << "Request " << type << " " << argumentIndex << " has number of dimensions ("
+ << rank << ") different than the model's (" << operand.dimensions.size() << ")";
+ return false;
+ }
+ for (size_t i = 0; i < rank; i++)
+ {
+ if (argument.dimensions[i] != operand.dimensions[i] && operand.dimensions[i] != 0)
+ {
+ LOG(ERROR) << "Request " << type << " " << argumentIndex << " has dimension " << i
+ << " of " << operand.dimensions[i] << " different than the model's "
+ << operand.dimensions[i];
+ return false;
+ }
+ if (argument.dimensions[i] == 0)
+ {
+ LOG(ERROR) << "Request " << type << " " << argumentIndex << " has dimension " << i
+ << " of zero";
+ return false;
+ }
+ }
+ }
+ }
+ return true;
+}
+
+// TODO doublecheck
+bool validateRequest(const Request &request, const Model &model)
+{
+ //const size_t poolCount = request.pools.size();
+ const size_t poolCount = 0;
+ return (validRequestArguments(request.inputs, model.inputIndexes, model.operands, poolCount,
+ "input") &&
+ validRequestArguments(request.outputs, model.outputIndexes, model.operands, poolCount,
+ "output"));
+}
+