summaryrefslogtreecommitdiff
path: root/modules/dnn/src/tensorflow/tf_importer.cpp
diff options
context:
space:
mode:
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>2017-09-28 16:51:47 +0300
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>2017-10-08 22:25:29 +0300
commite4aa39f9e5a528bfe64b97126489b06c954e60ab (patch)
treee800eb480ce88b537b3debf2ca440a6c89ed7754 /modules/dnn/src/tensorflow/tf_importer.cpp
parent8ac2c5d620b467d3f22802e96c88ddde6da707af (diff)
downloadopencv-e4aa39f9e5a528bfe64b97126489b06c954e60ab.tar.gz
opencv-e4aa39f9e5a528bfe64b97126489b06c954e60ab.tar.bz2
opencv-e4aa39f9e5a528bfe64b97126489b06c954e60ab.zip
Text TensorFlow graphs parsing. MobileNet-SSD for 90 classes.
Diffstat (limited to 'modules/dnn/src/tensorflow/tf_importer.cpp')
-rw-r--r--modules/dnn/src/tensorflow/tf_importer.cpp189
1 files changed, 158 insertions, 31 deletions
diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp
index aca1cd5055..f2e83c087e 100644
--- a/modules/dnn/src/tensorflow/tf_importer.cpp
+++ b/modules/dnn/src/tensorflow/tf_importer.cpp
@@ -321,10 +321,10 @@ DictValue parseDims(const tensorflow::TensorProto &tensor) {
CV_Assert(tensor.dtype() == tensorflow::DT_INT32);
CV_Assert(dims == 1);
- int size = tensor.tensor_content().size() / sizeof(int);
- const int *data = reinterpret_cast<const int*>(tensor.tensor_content().c_str());
+ Mat values = getTensorContent(tensor);
+ CV_Assert(values.type() == CV_32SC1);
// TODO: add reordering shape if dims == 4
- return DictValue::arrayInt(data, size);
+ return DictValue::arrayInt((int*)values.data, values.total());
}
void setKSize(LayerParams &layerParams, const tensorflow::NodeDef &layer)
@@ -448,7 +448,7 @@ void ExcludeLayer(tensorflow::GraphDef& net, const int layer_index, const int in
class TFImporter : public Importer {
public:
- TFImporter(const char *model);
+ TFImporter(const char *model, const char *config = NULL);
void populateNet(Net dstNet);
~TFImporter() {}
@@ -463,13 +463,20 @@ private:
int input_blob_index = -1, int* actual_inp_blob_idx = 0);
- tensorflow::GraphDef net;
+ // Binary serialized TensorFlow graph includes weights.
+ tensorflow::GraphDef netBin;
+ // Optional text definition of TensorFlow graph. More flexible than binary format
+ // and may be used to build the network using binary format only as a weights storage.
+ // This approach is similar to Caffe's `.prorotxt` and `.caffemodel`.
+ tensorflow::GraphDef netTxt;
};
-TFImporter::TFImporter(const char *model)
+TFImporter::TFImporter(const char *model, const char *config)
{
if (model && model[0])
- ReadTFNetParamsFromBinaryFileOrDie(model, &net);
+ ReadTFNetParamsFromBinaryFileOrDie(model, &netBin);
+ if (config && config[0])
+ ReadTFNetParamsFromTextFileOrDie(config, &netTxt);
}
void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
@@ -557,21 +564,23 @@ const tensorflow::TensorProto& TFImporter::getConstBlob(const tensorflow::NodeDe
*actual_inp_blob_idx = input_blob_index;
}
- return net.node(const_layers.at(kernel_inp.name)).attr().at("value").tensor();
+ int nodeIdx = const_layers.at(kernel_inp.name);
+ if (nodeIdx < netBin.node_size() && netBin.node(nodeIdx).name() == kernel_inp.name)
+ {
+ return netBin.node(nodeIdx).attr().at("value").tensor();
+ }
+ else
+ {
+ CV_Assert(nodeIdx < netTxt.node_size(),
+ netTxt.node(nodeIdx).name() == kernel_inp.name);
+ return netTxt.node(nodeIdx).attr().at("value").tensor();
+ }
}
-
-void TFImporter::populateNet(Net dstNet)
+static void addConstNodes(const tensorflow::GraphDef& net, std::map<String, int>& const_layers,
+ std::set<String>& layers_to_ignore)
{
- RemoveIdentityOps(net);
-
- std::map<int, String> layers_to_ignore;
-
- int layersSize = net.node_size();
-
- // find all Const layers for params
- std::map<String, int> value_id;
- for (int li = 0; li < layersSize; li++)
+ for (int li = 0; li < net.node_size(); li++)
{
const tensorflow::NodeDef &layer = net.node(li);
String name = layer.name();
@@ -582,11 +591,27 @@ void TFImporter::populateNet(Net dstNet)
if (layer.attr().find("value") != layer.attr().end())
{
- value_id.insert(std::make_pair(name, li));
+ CV_Assert(const_layers.insert(std::make_pair(name, li)).second);
}
-
- layers_to_ignore[li] = name;
+ layers_to_ignore.insert(name);
}
+}
+
+void TFImporter::populateNet(Net dstNet)
+{
+ RemoveIdentityOps(netBin);
+ RemoveIdentityOps(netTxt);
+
+ std::set<String> layers_to_ignore;
+
+ tensorflow::GraphDef& net = netTxt.ByteSize() != 0 ? netTxt : netBin;
+
+ int layersSize = net.node_size();
+
+ // find all Const layers for params
+ std::map<String, int> value_id;
+ addConstNodes(netBin, value_id, layers_to_ignore);
+ addConstNodes(netTxt, value_id, layers_to_ignore);
std::map<String, int> layer_id;
@@ -597,7 +622,7 @@ void TFImporter::populateNet(Net dstNet)
String type = layer.op();
LayerParams layerParams;
- if(layers_to_ignore.find(li) != layers_to_ignore.end())
+ if(layers_to_ignore.find(name) != layers_to_ignore.end())
continue;
if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative")
@@ -627,7 +652,7 @@ void TFImporter::populateNet(Net dstNet)
StrIntVector next_layers = getNextLayers(net, name, "Conv2D");
CV_Assert(next_layers.size() == 1);
layer = net.node(next_layers[0].second);
- layers_to_ignore[next_layers[0].second] = next_layers[0].first;
+ layers_to_ignore.insert(next_layers[0].first);
name = layer.name();
type = layer.op();
}
@@ -644,7 +669,7 @@ void TFImporter::populateNet(Net dstNet)
blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]);
ExcludeLayer(net, weights_layer_index, 0, false);
- layers_to_ignore[weights_layer_index] = next_layers[0].first;
+ layers_to_ignore.insert(next_layers[0].first);
}
kernelFromTensor(getConstBlob(layer, value_id), layerParams.blobs[0]);
@@ -684,7 +709,7 @@ void TFImporter::populateNet(Net dstNet)
layerParams.set("pad_mode", ""); // We use padding values.
CV_Assert(next_layers.size() == 1);
ExcludeLayer(net, next_layers[0].second, 0, false);
- layers_to_ignore[next_layers[0].second] = next_layers[0].first;
+ layers_to_ignore.insert(next_layers[0].first);
}
int id = dstNet.addLayer(name, "Convolution", layerParams);
@@ -748,7 +773,7 @@ void TFImporter::populateNet(Net dstNet)
int weights_layer_index = next_layers[0].second;
blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]);
ExcludeLayer(net, weights_layer_index, 0, false);
- layers_to_ignore[weights_layer_index] = next_layers[0].first;
+ layers_to_ignore.insert(next_layers[0].first);
}
int kernel_blob_index = -1;
@@ -778,6 +803,30 @@ void TFImporter::populateNet(Net dstNet)
// one input only
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
}
+ else if (type == "Flatten")
+ {
+ int id = dstNet.addLayer(name, "Flatten", layerParams);
+ layer_id[name] = id;
+ connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+ }
+ else if (type == "Transpose")
+ {
+ Mat perm = getTensorContent(getConstBlob(layer, value_id, 1));
+ CV_Assert(perm.type() == CV_32SC1);
+ int* permData = (int*)perm.data;
+ if (perm.total() == 4)
+ {
+ for (int i = 0; i < 4; ++i)
+ permData[i] = toNCHW[permData[i]];
+ }
+ layerParams.set("order", DictValue::arrayInt<int*>(permData, perm.total()));
+
+ int id = dstNet.addLayer(name, "Permute", layerParams);
+ layer_id[name] = id;
+
+ // one input only
+ connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+ }
else if (type == "Const")
{
}
@@ -807,7 +856,7 @@ void TFImporter::populateNet(Net dstNet)
{
int axisId = (type == "Concat" ? 0 : layer.input_size() - 1);
int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0);
- layerParams.set("axis", toNCHW[axis]);
+ layerParams.set("axis", 0 <= axis && axis < 4 ? toNCHW[axis] : axis);
int id = dstNet.addLayer(name, "Concat", layerParams);
layer_id[name] = id;
@@ -929,6 +978,19 @@ void TFImporter::populateNet(Net dstNet)
else // is a vector
{
layerParams.blobs.resize(1, scaleMat);
+
+ StrIntVector next_layers = getNextLayers(net, name, "Add");
+ if (!next_layers.empty())
+ {
+ layerParams.set("bias_term", true);
+ layerParams.blobs.resize(2);
+
+ int weights_layer_index = next_layers[0].second;
+ blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs.back());
+ ExcludeLayer(net, weights_layer_index, 0, false);
+ layers_to_ignore.insert(next_layers[0].first);
+ }
+
id = dstNet.addLayer(name, "Scale", layerParams);
}
layer_id[name] = id;
@@ -1037,7 +1099,7 @@ void TFImporter::populateNet(Net dstNet)
blobFromTensor(getConstBlob(net.node(weights_layer_index), value_id), layerParams.blobs[1]);
ExcludeLayer(net, weights_layer_index, 0, false);
- layers_to_ignore[weights_layer_index] = next_layers[0].first;
+ layers_to_ignore.insert(next_layers[0].first);
}
kernelFromTensor(getConstBlob(layer, value_id, 1), layerParams.blobs[0]);
@@ -1148,6 +1210,71 @@ void TFImporter::populateNet(Net dstNet)
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
}
+ else if (type == "PriorBox")
+ {
+ if (hasLayerAttr(layer, "min_size"))
+ layerParams.set("min_size", getLayerAttr(layer, "min_size").i());
+ if (hasLayerAttr(layer, "max_size"))
+ layerParams.set("max_size", getLayerAttr(layer, "max_size").i());
+ if (hasLayerAttr(layer, "flip"))
+ layerParams.set("flip", getLayerAttr(layer, "flip").b());
+ if (hasLayerAttr(layer, "clip"))
+ layerParams.set("clip", getLayerAttr(layer, "clip").b());
+ if (hasLayerAttr(layer, "offset"))
+ layerParams.set("offset", getLayerAttr(layer, "offset").f());
+ if (hasLayerAttr(layer, "variance"))
+ {
+ Mat variance = getTensorContent(getLayerAttr(layer, "variance").tensor());
+ layerParams.set("variance",
+ DictValue::arrayReal<float*>((float*)variance.data, variance.total()));
+ }
+ if (hasLayerAttr(layer, "aspect_ratio"))
+ {
+ Mat aspectRatios = getTensorContent(getLayerAttr(layer, "aspect_ratio").tensor());
+ layerParams.set("aspect_ratio",
+ DictValue::arrayReal<float*>((float*)aspectRatios.data, aspectRatios.total()));
+ }
+ if (hasLayerAttr(layer, "scales"))
+ {
+ Mat scales = getTensorContent(getLayerAttr(layer, "scales").tensor());
+ layerParams.set("scales",
+ DictValue::arrayReal<float*>((float*)scales.data, scales.total()));
+ }
+ int id = dstNet.addLayer(name, "PriorBox", layerParams);
+ layer_id[name] = id;
+ connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+ connect(layer_id, dstNet, parsePin(layer.input(1)), id, 1);
+ }
+ else if (type == "DetectionOutput")
+ {
+ // op: "DetectionOutput"
+ // input_0: "locations"
+ // input_1: "classifications"
+ // input_2: "prior_boxes"
+ if (hasLayerAttr(layer, "num_classes"))
+ layerParams.set("num_classes", getLayerAttr(layer, "num_classes").i());
+ if (hasLayerAttr(layer, "share_location"))
+ layerParams.set("share_location", getLayerAttr(layer, "share_location").b());
+ if (hasLayerAttr(layer, "background_label_id"))
+ layerParams.set("background_label_id", getLayerAttr(layer, "background_label_id").i());
+ if (hasLayerAttr(layer, "nms_threshold"))
+ layerParams.set("nms_threshold", getLayerAttr(layer, "nms_threshold").f());
+ if (hasLayerAttr(layer, "top_k"))
+ layerParams.set("top_k", getLayerAttr(layer, "top_k").i());
+ if (hasLayerAttr(layer, "code_type"))
+ layerParams.set("code_type", getLayerAttr(layer, "code_type").s());
+ if (hasLayerAttr(layer, "keep_top_k"))
+ layerParams.set("keep_top_k", getLayerAttr(layer, "keep_top_k").i());
+ if (hasLayerAttr(layer, "confidence_threshold"))
+ layerParams.set("confidence_threshold", getLayerAttr(layer, "confidence_threshold").f());
+ if (hasLayerAttr(layer, "loc_pred_transposed"))
+ layerParams.set("loc_pred_transposed", getLayerAttr(layer, "loc_pred_transposed").b());
+
+ int id = dstNet.addLayer(name, "DetectionOutput", layerParams);
+ layer_id[name] = id;
+ for (int i = 0; i < 3; ++i)
+ connect(layer_id, dstNet, parsePin(layer.input(i)), id, i);
+ }
else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
type == "Relu" || type == "Elu" || type == "Softmax" ||
type == "Identity" || type == "Relu6")
@@ -1188,9 +1315,9 @@ Ptr<Importer> createTensorflowImporter(const String&)
#endif //HAVE_PROTOBUF
-Net readNetFromTensorflow(const String &model)
+Net readNetFromTensorflow(const String &model, const String &config)
{
- TFImporter importer(model.c_str());
+ TFImporter importer(model.c_str(), config.c_str());
Net net;
importer.populateNet(net);
return net;