diff options
author | Saulo Aldighieri Moraes/AI R&D /SRBR/Engineer/삼성전자 <s.moraes@samsung.com> | 2019-04-23 04:06:08 -0300 |
---|---|---|
committer | 오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com> | 2019-04-23 16:06:08 +0900 |
commit | 0b3dc5b948e2148ba1df372a6bc03814c1c94102 (patch) | |
tree | 7f26e41ff97d8e21e311f8d94018303de6d0adcd | |
parent | bf8af66a7d961c58c08131e48a31d12eabce3f37 (diff) | |
download | nnfw-0b3dc5b948e2148ba1df372a6bc03814c1c94102.tar.gz nnfw-0b3dc5b948e2148ba1df372a6bc03814c1c94102.tar.bz2 nnfw-0b3dc5b948e2148ba1df372a6bc03814c1c94102.zip |
[batch_run] A batch execution tool used to run experiments (#4507)
* Related: #4405
The batch execution tool can be used to run experiments. It reads a neural network model from a file and a series of input images from a directory, runs each image through the network, and collect
statistics, such as execution time and accuracy.
Signed-off-by: Saulo A. Moraes <s.moraes@samsung.com>
* Remove boost_ext reference
Build script cleanup, remove unnecessary reference to boost_ext.
Signed-off-by: Saulo A. Moraes <s.moraes@samsung.com>
* Related: #4405
Code review fixes and tool rename from batch_run to tflite_accuracy.
Signed-off-by: Saulo A. Moraes <s.moraes@samsung.com>
-rw-r--r-- | tools/tflite_accuracy/CMakeLists.txt | 9 | ||||
-rw-r--r-- | tools/tflite_accuracy/README.md | 37 | ||||
-rw-r--r-- | tools/tflite_accuracy/src/labels.h | 1023 | ||||
-rw-r--r-- | tools/tflite_accuracy/src/tflite_accuracy.cc | 494 |
4 files changed, 1563 insertions, 0 deletions
diff --git a/tools/tflite_accuracy/CMakeLists.txt b/tools/tflite_accuracy/CMakeLists.txt new file mode 100644 index 000000000..bbb25da58 --- /dev/null +++ b/tools/tflite_accuracy/CMakeLists.txt @@ -0,0 +1,9 @@ +list(APPEND TFLITE_ACCURACY_SRCS "src/tflite_accuracy.cc") + +add_executable(tflite_accuracy ${TFLITE_ACCURACY_SRCS}) +target_include_directories(tflite_accuracy PRIVATE src) +target_link_libraries(tflite_accuracy tensorflow-lite ${LIB_PTHREAD} dl nnfw_lib_tflite) +target_link_libraries(tflite_accuracy boost_program_options boost_system boost_filesystem) + +install(TARGETS tflite_accuracy DESTINATION bin) + diff --git a/tools/tflite_accuracy/README.md b/tools/tflite_accuracy/README.md new file mode 100644 index 000000000..22804e140 --- /dev/null +++ b/tools/tflite_accuracy/README.md @@ -0,0 +1,37 @@ +Using the batch execution tool +============================== + +The batch execution tool (`tflite_accuracy`) can be used to run experiments +where execution time and accuracy are to be measured on a test set. +`tflite_accuracy` reads a neural network model from a file and a series of +input images from a directory, runs each image through the network, +and collect statistics, such as execution time and accuracy. + +In order to run this tool, you'll need: + +* a model in `.tflite` format; +* a set of preprocessed input images in binary format, properly named +(see below). + +`tflite_accuracy` expects all the input images to be located in the same directory +in the file system. Each image file is the binary dump of the network's +input tensor. So, if the network's input tensor is a `float32` tensor of +format (1, 224, 224, 3) containing 1 image of height 224, width 224, and +3 channels, each image file is expected to be a series of 224 * 224 * 3 +`float32` values. + +`tflite_accuracy` does **not** perform any preprocessing on the input tensor +(e.g., subtraction of mean or division by standard deviation). Each image +file is treated as the final value of the input tensor, so all the +necessary preprocessing should be done prior to invoking the tool. + +In order to calculate accuracy on the image set, `tflite_accuracy` needs to know +the correct label corresponding to each image. This information is +extracted from the file's name: the first four characters in the name are +assumed to be the numerical code of the image's class. So, a file named +`0123_0123456789.bin` is assumed to represent an image belonging to class +`123`. The remainder of the name (`0123456789` in the example) is assumed +to be an identifier of the image itself. + +The width and height each image can be informed via the command line +argument `--imgsize`, whose default value is 224.
\ No newline at end of file diff --git a/tools/tflite_accuracy/src/labels.h b/tools/tflite_accuracy/src/labels.h new file mode 100644 index 000000000..1e5170e06 --- /dev/null +++ b/tools/tflite_accuracy/src/labels.h @@ -0,0 +1,1023 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LABELS_H_ +#define LABELS_H_ + +// Labels used for image classification (imagenet dataset) +static const char *labels[] = {"background", + "tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead", + "electric ray", + "stingray", + "cock", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "water ouzel", + "kite", + "bald eagle", + "vulture", + "great grey owl", + "European fire salamander", + "common newt", + "eft", + "spotted salamander", + "axolotl", + "bullfrog", + "tree frog", + "tailed frog", + "loggerhead", + "leatherback turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "common iguana", + "American chameleon", + "whiptail", + "agama", + "frilled lizard", + "alligator lizard", + "Gila monster", + "green lizard", + "African chameleon", + "Komodo dragon", + "African crocodile", + "American alligator", + "triceratops", + "thunder snake", + "ringneck snake", + "hognose snake", + "green snake", + "king snake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "rock python", + "Indian cobra", + "green mamba", + "sea snake", + "horned viper", + "diamondback", + "sidewinder", + "trilobite", + "harvestman", + "scorpion", + "black and gold garden spider", + "barn spider", + "garden spider", + "black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie chicken", + "peacock", + "quail", + "partridge", + "African grey", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "drake", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "American egret", + "bittern", + "crane", + "limpkin", + "European gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "red-backed sandpiper", + "redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese spaniel", + "Maltese dog", + "Pekinese", + "Shih-Tzu", + "Blenheim spaniel", + "papillon", + "toy terrier", + "Rhodesian ridgeback", + "Afghan hound", + "basset", + "beagle", + "bloodhound", + "bluetick", + "black-and-tan coonhound", + "Walker hound", + "English foxhound", + "redbone", + "borzoi", + "Irish wolfhound", + "Italian greyhound", + "whippet", + "Ibizan hound", + "Norwegian elkhound", + "otterhound", + "Saluki", + "Scottish deerhound", + "Weimaraner", + "Staffordshire bullterrier", + "American Staffordshire terrier", + "Bedlington terrier", + "Border terrier", + "Kerry blue terrier", + "Irish terrier", + "Norfolk terrier", + "Norwich terrier", + "Yorkshire terrier", + "wire-haired fox terrier", + "Lakeland terrier", + "Sealyham terrier", + "Airedale", + "cairn", + "Australian terrier", + "Dandie Dinmont", + "Boston bull", + "miniature schnauzer", + "giant schnauzer", + "standard schnauzer", + "Scotch terrier", + "Tibetan terrier", + "silky terrier", + "soft-coated wheaten terrier", + "West Highland white terrier", + "Lhasa", + "flat-coated retriever", + "curly-coated retriever", + "golden retriever", + "Labrador retriever", + "Chesapeake Bay retriever", + "German short-haired pointer", + "vizsla", + "English setter", + "Irish setter", + "Gordon setter", + "Brittany spaniel", + "clumber", + "English springer", + "Welsh springer spaniel", + "cocker spaniel", + "Sussex spaniel", + "Irish water spaniel", + "kuvasz", + "schipperke", + "groenendael", + "malinois", + "briard", + "kelpie", + "komondor", + "Old English sheepdog", + "Shetland sheepdog", + "collie", + "Border collie", + "Bouvier des Flandres", + "Rottweiler", + "German shepherd", + "Doberman", + "miniature pinscher", + "Greater Swiss Mountain dog", + "Bernese mountain dog", + "Appenzeller", + "EntleBucher", + "boxer", + "bull mastiff", + "Tibetan mastiff", + "French bulldog", + "Great Dane", + "Saint Bernard", + "Eskimo dog", + "malamute", + "Siberian husky", + "dalmatian", + "affenpinscher", + "basenji", + "pug", + "Leonberg", + "Newfoundland", + "Great Pyrenees", + "Samoyed", + "Pomeranian", + "chow", + "keeshond", + "Brabancon griffon", + "Pembroke", + "Cardigan", + "toy poodle", + "miniature poodle", + "standard poodle", + "Mexican hairless", + "timber wolf", + "white wolf", + "red wolf", + "coyote", + "dingo", + "dhole", + "African hunting dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian cat", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "ice bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "long-horned beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket", + "walking stick", + "cockroach", + "mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "admiral", + "ringlet", + "monarch", + "cabbage butterfly", + "sulphur butterfly", + "lycaenid", + "starfish", + "sea urchin", + "sea cucumber", + "wood rabbit", + "hare", + "Angora", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "sorrel", + "zebra", + "hog", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram", + "bighorn", + "ibex", + "hartebeest", + "impala", + "gazelle", + "Arabian camel", + "llama", + "weasel", + "mink", + "polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas", + "baboon", + "macaque", + "langur", + "colobus", + "proboscis monkey", + "marmoset", + "capuchin", + "howler monkey", + "titi", + "spider monkey", + "squirrel monkey", + "Madagascar cat", + "indri", + "Indian elephant", + "African elephant", + "lesser panda", + "giant panda", + "barracouta", + "eel", + "coho", + "rock beauty", + "anemone fish", + "sturgeon", + "gar", + "lionfish", + "puffer", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibian", + "analog clock", + "apiary", + "apron", + "ashcan", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint", + "Band Aid", + "banjo", + "bannister", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "barrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "bathing cap", + "bath towel", + "bathtub", + "beach wagon", + "beacon", + "beaker", + "bearskin", + "beer bottle", + "beer glass", + "bell cote", + "bib", + "bicycle-built-for-two", + "bikini", + "binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsled", + "bolo tie", + "bonnet", + "bookcase", + "bookshop", + "bottlecap", + "bow", + "bow tie", + "brass", + "brassiere", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "bullet train", + "butcher shop", + "cab", + "caldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "carpenter's kit", + "carton", + "car wheel", + "cash machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "cellular telephone", + "chain", + "chainlink fence", + "chain mail", + "chain saw", + "chest", + "chiffonier", + "chime", + "china cabinet", + "Christmas stocking", + "church", + "cinema", + "cleaver", + "cliff dwelling", + "cloak", + "clog", + "cocktail shaker", + "coffee mug", + "coffeepot", + "coil", + "combination lock", + "computer keyboard", + "confectionery", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "crane", + "crash helmet", + "crate", + "crib", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishrag", + "dishwasher", + "disk brake", + "dock", + "dogsled", + "dome", + "doormat", + "drilling platform", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso maker", + "face powder", + "feather boa", + "file", + "fireboat", + "fire engine", + "fire screen", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gasmask", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golfcart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "grille", + "grocery store", + "guillotine", + "hair slide", + "hair spray", + "half track", + "hammer", + "hamper", + "hand blower", + "hand-held computer", + "handkerchief", + "hard disc", + "harmonica", + "harp", + "harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoopskirt", + "horizontal bar", + "horse cart", + "hourglass", + "iPod", + "iron", + "jack-o'-lantern", + "jean", + "jeep", + "jersey", + "jigsaw puzzle", + "jinrikisha", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop", + "lawn mower", + "lens cap", + "letter opener", + "library", + "lifeboat", + "lighter", + "limousine", + "liner", + "lipstick", + "Loafer", + "lotion", + "loudspeaker", + "loupe", + "lumbermill", + "magnetic compass", + "mailbag", + "mailbox", + "maillot", + "maillot", + "manhole cover", + "maraca", + "marimba", + "mask", + "matchstick", + "maypole", + "maze", + "measuring cup", + "medicine chest", + "megalith", + "microphone", + "microwave", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "Model T", + "modem", + "monastery", + "monitor", + "moped", + "mortar", + "mortarboard", + "mosque", + "mosquito net", + "motor scooter", + "mountain bike", + "mountain tent", + "mouse", + "mousetrap", + "moving van", + "muzzle", + "nail", + "neck brace", + "necklace", + "nipple", + "notebook", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "organ", + "oscilloscope", + "overskirt", + "oxcart", + "oxygen mask", + "packet", + "paddle", + "paddlewheel", + "padlock", + "paintbrush", + "pajama", + "palace", + "panpipe", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "passenger car", + "patio", + "pay-phone", + "pedestal", + "pencil box", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "pick", + "pickelhaube", + "picket fence", + "pickup", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate", + "pitcher", + "plane", + "planetarium", + "plastic bag", + "plate rack", + "plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "pool table", + "pop bottle", + "pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "projectile", + "projector", + "puck", + "punching bag", + "purse", + "quill", + "quilt", + "racer", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "rubber eraser", + "rugby ball", + "rule", + "running shoe", + "safe", + "safety pin", + "saltshaker", + "sandal", + "sarong", + "sax", + "scabbard", + "scale", + "school bus", + "schooner", + "scoreboard", + "screen", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe shop", + "shoji", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar dish", + "sombrero", + "soup bowl", + "space bar", + "space heater", + "space shuttle", + "spatula", + "speedboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "steel arch bridge", + "steel drum", + "stethoscope", + "stole", + "stone wall", + "stopwatch", + "stove", + "strainer", + "streetcar", + "stretcher", + "studio couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglass", + "sunglasses", + "sunscreen", + "suspension bridge", + "swab", + "sweatshirt", + "swimming trunks", + "swing", + "switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy", + "television", + "tennis ball", + "thatch", + "theater curtain", + "thimble", + "thresher", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toyshop", + "tractor", + "trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright", + "vacuum", + "vase", + "vault", + "velvet", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "warplane", + "washbasin", + "washer", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "wing", + "wok", + "wooden spoon", + "wool", + "worm fence", + "wreck", + "yawl", + "yurt", + "web site", + "comic book", + "crossword puzzle", + "street sign", + "traffic light", + "book jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "ice lolly", + "French loaf", + "bagel", + "pretzel", + "cheeseburger", + "hotdog", + "mashed potato", + "head cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "custard apple", + "pomegranate", + "hay", + "carbonara", + "chocolate sauce", + "dough", + "meat loaf", + "pizza", + "potpie", + "burrito", + "red wine", + "espresso", + "cup", + "eggnog", + "alp", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeside", + "promontory", + "sandbar", + "seashore", + "valley", + "volcano", + "ballplayer", + "groom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "hip", + "buckeye", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn", + "earthstar", + "hen-of-the-woods", + "bolete", + "ear", + "toilet tissue"}; + +#endif /* LABELS_H_ */ diff --git a/tools/tflite_accuracy/src/tflite_accuracy.cc b/tools/tflite_accuracy/src/tflite_accuracy.cc new file mode 100644 index 000000000..83b7ba9a4 --- /dev/null +++ b/tools/tflite_accuracy/src/tflite_accuracy.cc @@ -0,0 +1,494 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <algorithm> +#include <atomic> +#include <chrono> +#include <forward_list> +#include <fstream> +#include <iostream> +#include <memory> +#include <numeric> +#include <stdexcept> +#include <string> +#include <thread> + +#include <boost/filesystem.hpp> +#include <boost/format.hpp> +#include <boost/program_options.hpp> + +#include <cmath> +#include <cstdint> +#include <signal.h> + +#include <tensorflow/contrib/lite/context.h> +#include <tensorflow/contrib/lite/interpreter.h> +#include <tensorflow/contrib/lite/model.h> + +#include "labels.h" +#include "tflite/ext/nnapi_delegate.h" +#include "tflite/ext/kernels/register.h" + +const std::string kDefaultImagesDir = "res/input/"; +const std::string kDefaultModelFile = "res/model.tflite"; + +template <typename... Args> void Print(const char *fmt, Args... args) +{ +#if __cplusplus >= 201703L + std::cerr << boost::str(boost::format(fmt) % ... % std::forward<Args>(args)) << std::endl; +#else + boost::format f(fmt); + using unroll = int[]; + unroll{0, (f % std::forward<Args>(args), 0)...}; + std::cerr << boost::str(f) << std::endl; +#endif +} + +template <typename DataType> struct BaseLabelData +{ + explicit BaseLabelData(int label = -1, DataType confidence = 0) + : label(label), confidence(confidence) + { + } + + static std::vector<BaseLabelData<DataType>> FindLabels(const DataType *output_tensor, + unsigned int top_n = 5) + { + top_n = top_n > 1000 ? 1000 : top_n; + size_t n = 0; + std::vector<size_t> indices(1000); + std::generate(indices.begin(), indices.end(), [&n]() { return n++; }); + std::sort(indices.begin(), indices.end(), [output_tensor](const size_t &i1, const size_t &i2) { + return output_tensor[i1] > output_tensor[i2]; + }); + std::vector<BaseLabelData<DataType>> results(top_n); + for (unsigned int i = 0; i < top_n; ++i) + { + results[i].label = indices[i]; + results[i].confidence = output_tensor[indices[i]]; + } + return results; + } + + int label; + DataType confidence; +}; + +class BaseRunner +{ +public: + virtual ~BaseRunner() = default; + + /** + * @brief Run a model for each file in a directory, and collect and print + * statistics. + */ + virtual void IterateInDirectory(const std::string &dir_path, const int labels_offset) = 0; + + /** + * @brief Request that the iteration be stopped after the current file. + */ + virtual void ScheduleInterruption() = 0; +}; + +template <typename DataType_> class Runner : public BaseRunner +{ +public: + using DataType = DataType_; + using LabelData = BaseLabelData<DataType>; + + const int kInputSize; + const int KOutputSize = 1001 * sizeof(DataType); + + Runner(std::unique_ptr<tflite::Interpreter> interpreter, + std::unique_ptr<tflite::FlatBufferModel> model, + std::unique_ptr<::nnfw::tflite::NNAPIDelegate> delegate, unsigned img_size) + : interpreter(std::move(interpreter)), model(std::move(model)), delegate(std::move(delegate)), + interrupted(false), kInputSize(1 * img_size * img_size * 3 * sizeof(DataType)) + { + inference_times.reserve(500); + top1.reserve(500); + top5.reserve(500); + } + + virtual ~Runner() = default; + + /** + * @brief Get the model's input tensor. + */ + virtual DataType *GetInputTensor() = 0; + + /** + * @brief Get the model's output tensor. + */ + virtual DataType *GetOutputTensor() = 0; + + /** + * @brief Load Image file into tensor. + * @return Class number if present in filename, -1 otherwise. + */ + virtual int LoadFile(const boost::filesystem::path &input_file) + { + DataType *input_tensor = GetInputTensor(); + if (input_file.extension() == ".bin") + { + // Load data as raw tensor + std::ifstream input_stream(input_file.string(), std::ifstream::binary); + input_stream.read(reinterpret_cast<char *>(input_tensor), kInputSize); + input_stream.close(); + int class_num = boost::lexical_cast<int>(input_file.filename().string().substr(0, 4)); + return class_num; + } + else + { + // Load data as image file + throw std::runtime_error("Runner can only load *.bin files"); + } + } + + void Invoke() + { + TfLiteStatus status; + if (delegate) + { + status = delegate->Invoke(interpreter.get()); + } + else + { + status = interpreter->Invoke(); + } + if (status != kTfLiteOk) + { + throw std::runtime_error("Failed to invoke interpreter."); + } + } + + int Process() + { + auto t0 = std::chrono::high_resolution_clock::now(); + Invoke(); + auto t1 = std::chrono::high_resolution_clock::now(); + std::chrono::duration<double> fs = t1 - t0; + auto d = std::chrono::duration_cast<std::chrono::milliseconds>(fs); + inference_times.push_back(d.count()); + if (d > std::chrono::milliseconds(10)) + { + Print(" -- inference duration: %lld ms", d.count()); + } + else + { + auto du = std::chrono::duration_cast<std::chrono::microseconds>(fs); + Print(" -- inference duration: %lld us", du.count()); + } + return 0; + } + + void DumpOutputTensor(const std::string &output_file) + { + DataType *output_tensor = GetOutputTensor(); + std::ofstream output_stream(output_file, std::ofstream::binary); + output_stream.write(reinterpret_cast<char *>(output_tensor), KOutputSize); + } + + void PrintExecutionSummary() const + { + Print("Execution summary:"); + Print(" -- # of processed images: %d", num_images); + if (num_images == 0) + { + return; + } + // Inference time - mean + double mean = std::accumulate(inference_times.begin(), inference_times.end(), 0.0) / num_images; + Print(" -- mean inference time: %.1f ms", mean); + // Inference time - std + std::vector<double> diff(num_images); + std::transform(inference_times.begin(), inference_times.end(), diff.begin(), + [mean](size_t n) { return n - mean; }); + double sq_sum = std::inner_product(diff.begin(), diff.end(), diff.begin(), 0.0); + double std_inference_time = std::sqrt(sq_sum / num_images); + Print(" -- std inference time: %.1f ms", std_inference_time); + // Top-1 and Top-5 accuracies + float num_top1 = std::accumulate(top1.begin(), top1.end(), 0); + float num_top5 = std::accumulate(top5.begin(), top5.end(), 0); + Print(" -- top1: %.3f, top5: %.3f", num_top1 / num_images, num_top5 / num_images); + } + + virtual void ScheduleInterruption() override { interrupted = true; } + + virtual void IterateInDirectory(const std::string &dir_path, const int labels_offset) override + { + interrupted = false; + namespace fs = boost::filesystem; + if (!fs::is_directory(dir_path)) + { + throw std::runtime_error("Could not open input directory."); + } + + inference_times.clear(); + top1.clear(); + top5.clear(); + int class_num; + num_images = 0; + std::vector<LabelData> lds; + fs::directory_iterator end; + for (auto it = fs::directory_iterator(dir_path); it != end; ++it) + { + if (interrupted) + { + break; + } + if (!fs::is_regular_file(*it)) + { + continue; + } + Print("File : %s", it->path().string()); + try + { + class_num = LoadFile(*it) + labels_offset; + Print("Class: %d", class_num); + } + catch (std::exception &e) + { + Print("%s", e.what()); + continue; + } + int status = Process(); + if (status == 0) + { + DataType *output_tensor = GetOutputTensor(); + lds = LabelData::FindLabels(output_tensor, 5); + bool is_top1 = lds[0].label == class_num; + bool is_top5 = false; + for (const auto &ld : lds) + { + is_top5 = is_top5 || (ld.label == class_num); + Print(" -- label: %s (%d), prob: %.3f", ld.label >= 0 ? labels[ld.label] : "", ld.label, + static_cast<float>(ld.confidence)); + } + Print(" -- top1: %d, top5: %d", is_top1, is_top5); + top1.push_back(is_top1); + top5.push_back(is_top5); + } + ++num_images; + } + PrintExecutionSummary(); + } + +protected: + std::unique_ptr<tflite::Interpreter> interpreter; + std::unique_ptr<tflite::FlatBufferModel> model; + std::unique_ptr<::nnfw::tflite::NNAPIDelegate> delegate; + + std::vector<size_t> inference_times; + std::vector<bool> top1; + std::vector<bool> top5; + uint num_images; + std::atomic_bool interrupted; +}; + +class FloatRunner : public Runner<float> +{ +public: + using Runner<float>::DataType; + + FloatRunner(std::unique_ptr<tflite::Interpreter> interpreter, + std::unique_ptr<tflite::FlatBufferModel> model, + std::unique_ptr<::nnfw::tflite::NNAPIDelegate> delegate, unsigned img_size) + : Runner<float>(std::move(interpreter), std::move(model), std::move(delegate), img_size) + { + } + + virtual ~FloatRunner() = default; + + virtual DataType *GetInputTensor() override + { + return interpreter->tensor(interpreter->inputs()[0])->data.f; + } + + virtual DataType *GetOutputTensor() override + { + return interpreter->tensor(interpreter->outputs()[0])->data.f; + } +}; + +class QuantizedRunner : public Runner<uint8_t> +{ +public: + using Runner<uint8_t>::DataType; + + QuantizedRunner(std::unique_ptr<tflite::Interpreter> interpreter, + std::unique_ptr<tflite::FlatBufferModel> model, + std::unique_ptr<::nnfw::tflite::NNAPIDelegate> delegate, unsigned img_size) + : Runner<uint8_t>(std::move(interpreter), std::move(model), std::move(delegate), img_size) + { + } + + virtual ~QuantizedRunner() = default; + + virtual DataType *GetInputTensor() override + { + return interpreter->tensor(interpreter->inputs()[0])->data.uint8; + } + + virtual DataType *GetOutputTensor() override + { + return interpreter->tensor(interpreter->outputs()[0])->data.uint8; + } +}; + +enum class Target +{ + TfLiteCpu, /**< Use Tensorflow Lite's CPU kernels. */ + TfLiteDelegate, /**< Use Tensorflow Lite's NN API delegate. */ + NnfwDelegate /**< Use NNFW's NN API delegate. */ +}; + +std::unique_ptr<BaseRunner> MakeRunner(const std::string &model_path, unsigned img_size, + Target target = Target::NnfwDelegate) +{ + auto model = tflite::FlatBufferModel::BuildFromFile(model_path.c_str()); + if (not model) + { + throw std::runtime_error(model_path + ": file not found or corrupted."); + } + Print("Model loaded."); + + std::unique_ptr<tflite::Interpreter> interpreter; + nnfw::tflite::BuiltinOpResolver resolver; + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (not interpreter) + { + throw std::runtime_error("interpreter construction failed."); + } + if (target == Target::TfLiteCpu) + { + interpreter->SetNumThreads(std::max(std::thread::hardware_concurrency(), 1U)); + } + else + { + interpreter->SetNumThreads(1); + } + if (target == Target::TfLiteDelegate) + { + interpreter->UseNNAPI(true); + } + + int input_index = interpreter->inputs()[0]; + interpreter->ResizeInputTensor(input_index, + {1, static_cast<int>(img_size), static_cast<int>(img_size), 3}); + if (interpreter->AllocateTensors() != kTfLiteOk) + { + throw std::runtime_error("tensor allocation failed."); + } + + if (target == Target::TfLiteDelegate) + { + // Do a fake run to load NN API functions. + interpreter->Invoke(); + } + + std::unique_ptr<::nnfw::tflite::NNAPIDelegate> delegate; + if (target == Target::NnfwDelegate) + { + delegate.reset(new ::nnfw::tflite::NNAPIDelegate); + delegate->BuildGraph(interpreter.get()); + } + + if (interpreter->tensor(input_index)->type == kTfLiteFloat32) + { + return std::unique_ptr<FloatRunner>( + new FloatRunner(std::move(interpreter), std::move(model), std::move(delegate), img_size)); + } + else if (interpreter->tensor(input_index)->type == kTfLiteUInt8) + { + return std::unique_ptr<QuantizedRunner>(new QuantizedRunner( + std::move(interpreter), std::move(model), std::move(delegate), img_size)); + } + throw std::invalid_argument("data type of model's input tensor is not supported."); +} + +Target GetTarget(const std::string &str) +{ + static const std::map<std::string, Target> target_names{ + {"tflite-cpu", Target::TfLiteCpu}, + {"tflite-delegate", Target::TfLiteDelegate}, + {"nnfw-delegate", Target::NnfwDelegate}}; + if (target_names.find(str) == target_names.end()) + { + throw std::invalid_argument( + str + ": invalid target. Run with --help for a list of available targets."); + } + return target_names.at(str); +} + +// We need a global pointer to the runner for the SIGINT handler +BaseRunner *runner_ptr = nullptr; +void HandleSigInt(int) +{ + if (runner_ptr != nullptr) + { + Print("Interrupted. Execution will stop after current image."); + runner_ptr->ScheduleInterruption(); + runner_ptr = nullptr; + } + else + { + exit(1); + } +} + +int main(int argc, char *argv[]) try +{ + namespace po = boost::program_options; + po::options_description desc("Run a model on multiple binary images and print" + " statistics"); + desc.add_options()("help", "print this message and quit")( + "model", po::value<std::string>()->default_value(kDefaultModelFile), "tflite file")( + "input", po::value<std::string>()->default_value(kDefaultImagesDir), + "directory with input images")("offset", po::value<int>()->default_value(1), "labels offset")( + "target", po::value<std::string>()->default_value("nnfw-delegate"), + "how the model will be run (available targets: tflite-cpu, " + "tflite-delegate, nnfw-delegate)")("imgsize", po::value<unsigned>()->default_value(224), + "the width and height of the image"); + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cerr << desc << std::endl; + return 0; + } + + auto runner = MakeRunner(vm["model"].as<std::string>(), vm["imgsize"].as<unsigned>(), + GetTarget(vm["target"].as<std::string>())); + runner_ptr = runner.get(); + + struct sigaction sigint_handler; + sigint_handler.sa_handler = HandleSigInt; + sigemptyset(&sigint_handler.sa_mask); + sigint_handler.sa_flags = 0; + sigaction(SIGINT, &sigint_handler, nullptr); + + Print("Running TensorFlow Lite..."); + runner->IterateInDirectory(vm["input"].as<std::string>(), vm["offset"].as<int>()); + Print("Done."); + return 0; +} +catch (std::exception &e) +{ + Print("%s: %s", argv[0], e.what()); + return 1; +} |