summaryrefslogtreecommitdiff
path: root/modules/dnn/src/tensorflow/tf_importer.cpp
diff options
context:
space:
mode:
authorDavid <34099314+dmonterom@users.noreply.github.com>2018-06-07 15:29:04 +0200
committerVadim Pisarevsky <vadim.pisarevsky@gmail.com>2018-06-07 16:29:04 +0300
commit7175f257b513fd1e45d108dda2d1f56575d839e8 (patch)
tree5c983a816f6eecb3a0c0f83a32d75fa3575b5a2a /modules/dnn/src/tensorflow/tf_importer.cpp
parent60fa6bea70530b31ff88d473c84baf38c6463e5c (diff)
downloadopencv-7175f257b513fd1e45d108dda2d1f56575d839e8.tar.gz
opencv-7175f257b513fd1e45d108dda2d1f56575d839e8.tar.bz2
opencv-7175f257b513fd1e45d108dda2d1f56575d839e8.zip
Added ResizeBilinear op for tf (#11050)
* Added ResizeBilinear op for tf Combined ResizeNearestNeighbor and ResizeBilinear layers into Resize (with an interpolation param). Minor changes to tf_importer and resize layer to save some code lines Minor changes in init.cpp Minor changes in tf_importer.cpp * Replaced implementation of a custom ResizeBilinear layer to all layers * Use Mat::ptr. Replace interpolation flags
Diffstat (limited to 'modules/dnn/src/tensorflow/tf_importer.cpp')
-rw-r--r--modules/dnn/src/tensorflow/tf_importer.cpp30
1 files changed, 24 insertions, 6 deletions
diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp
index f19daf9cc6..4bff84175d 100644
--- a/modules/dnn/src/tensorflow/tf_importer.cpp
+++ b/modules/dnn/src/tensorflow/tf_importer.cpp
@@ -1450,18 +1450,36 @@ void TFImporter::populateNet(Net dstNet)
connect(layer_id, dstNet, parsePin(layer.input(1)), id, 0);
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
}
- else if (type == "ResizeNearestNeighbor")
+ else if (type == "ResizeNearestNeighbor" || type == "ResizeBilinear")
{
- Mat outSize = getTensorContent(getConstBlob(layer, value_id, 1));
- CV_Assert(outSize.type() == CV_32SC1, outSize.total() == 2);
+ if (layer.input_size() == 2)
+ {
+ Mat outSize = getTensorContent(getConstBlob(layer, value_id, 1));
+ CV_Assert(outSize.type() == CV_32SC1, outSize.total() == 2);
+ layerParams.set("height", outSize.at<int>(0, 0));
+ layerParams.set("width", outSize.at<int>(0, 1));
+ }
+ else if (layer.input_size() == 3)
+ {
+ Mat factorHeight = getTensorContent(getConstBlob(layer, value_id, 1));
+ Mat factorWidth = getTensorContent(getConstBlob(layer, value_id, 2));
+ CV_Assert(factorHeight.type() == CV_32SC1, factorHeight.total() == 1,
+ factorWidth.type() == CV_32SC1, factorWidth.total() == 1);
+ layerParams.set("zoom_factor_x", factorWidth.at<int>(0));
+ layerParams.set("zoom_factor_y", factorHeight.at<int>(0));
+ }
+ else
+ CV_Assert(layer.input_size() == 2 || layer.input_size() == 3);
- layerParams.set("height", outSize.at<int>(0, 0));
- layerParams.set("width", outSize.at<int>(0, 1));
+ if (type == "ResizeNearestNeighbor")
+ layerParams.set("interpolation", "nearest");
+ else
+ layerParams.set("interpolation", "bilinear");
if (hasLayerAttr(layer, "align_corners"))
layerParams.set("align_corners", getLayerAttr(layer, "align_corners").b());
- int id = dstNet.addLayer(name, "ResizeNearestNeighbor", layerParams);
+ int id = dstNet.addLayer(name, "Resize", layerParams);
layer_id[name] = id;
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);