diff options
-rw-r--r-- | tests/tools/tflite_benchmark/src/tflite_benchmark.cc | 21 |
1 files changed, 20 insertions, 1 deletions
diff --git a/tests/tools/tflite_benchmark/src/tflite_benchmark.cc b/tests/tools/tflite_benchmark/src/tflite_benchmark.cc index 21eee4a18..1fde0c449 100644 --- a/tests/tools/tflite_benchmark/src/tflite_benchmark.cc +++ b/tests/tools/tflite_benchmark/src/tflite_benchmark.cc @@ -70,6 +70,23 @@ bool checkParams(const int argc, char **argv) return true; } +// Verifies whether the model is a flatbuffer file. +class BMFlatBufferVerifier : public tflite::TfLiteVerifier +{ +public: + bool Verify(const char *data, int length, tflite::ErrorReporter *reporter) override + { + + flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t *>(data), length); + if (!tflite::VerifyModelBuffer(verifier)) + { + reporter->Report("The model is not a valid Flatbuffer file"); + return false; + } + return true; + } +}; + int main(const int argc, char **argv) { @@ -98,7 +115,9 @@ int main(const int argc, char **argv) StderrReporter error_reporter; - auto model = FlatBufferModel::BuildFromFile(filename, &error_reporter); + std::unique_ptr<tflite::TfLiteVerifier> verifier{new BMFlatBufferVerifier}; + + auto model = FlatBufferModel::VerifyAndBuildFromFile(filename, verifier.get(), &error_reporter); if (model == nullptr) { std::cerr << "Cannot create model" << std::endl; |