summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tests/tools/tflite_benchmark/src/tflite_benchmark.cc21
1 files changed, 20 insertions, 1 deletions
diff --git a/tests/tools/tflite_benchmark/src/tflite_benchmark.cc b/tests/tools/tflite_benchmark/src/tflite_benchmark.cc
index 21eee4a18..1fde0c449 100644
--- a/tests/tools/tflite_benchmark/src/tflite_benchmark.cc
+++ b/tests/tools/tflite_benchmark/src/tflite_benchmark.cc
@@ -70,6 +70,23 @@ bool checkParams(const int argc, char **argv)
return true;
}
+// Verifies whether the model is a flatbuffer file.
+class BMFlatBufferVerifier : public tflite::TfLiteVerifier
+{
+public:
+ bool Verify(const char *data, int length, tflite::ErrorReporter *reporter) override
+ {
+
+ flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t *>(data), length);
+ if (!tflite::VerifyModelBuffer(verifier))
+ {
+ reporter->Report("The model is not a valid Flatbuffer file");
+ return false;
+ }
+ return true;
+ }
+};
+
int main(const int argc, char **argv)
{
@@ -98,7 +115,9 @@ int main(const int argc, char **argv)
StderrReporter error_reporter;
- auto model = FlatBufferModel::BuildFromFile(filename, &error_reporter);
+ std::unique_ptr<tflite::TfLiteVerifier> verifier{new BMFlatBufferVerifier};
+
+ auto model = FlatBufferModel::VerifyAndBuildFromFile(filename, verifier.get(), &error_reporter);
if (model == nullptr)
{
std::cerr << "Cannot create model" << std::endl;