summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorJing Huang <jinghuang@fb.com>2019-03-28 16:58:54 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-28 17:02:50 -0700
commit11ac0cf276d804ec0b1eb07bcb10eaa988282bcf (patch)
tree578dbc424a3d576f7082502ce8a7a96868fecb07 /caffe2
parent1ae2c1950c4ac63cff4c8c9251f3a83a9968376b (diff)
downloadpytorch-11ac0cf276d804ec0b1eb07bcb10eaa988282bcf.tar.gz
pytorch-11ac0cf276d804ec0b1eb07bcb10eaa988282bcf.tar.bz2
pytorch-11ac0cf276d804ec0b1eb07bcb10eaa988282bcf.zip
Implement rotated generate_proposals_op without opencv dependency (CPU version)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18533 Reviewed By: ezyang Differential Revision: D14648083 fbshipit-source-id: e53e8f537100862f8015c4efa4efe4d387cef551
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/operators/generate_proposals_op_test.cc4
-rw-r--r--caffe2/operators/generate_proposals_op_util_nms.h400
-rw-r--r--caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc12
-rw-r--r--caffe2/operators/generate_proposals_op_util_nms_test.cc38
4 files changed, 242 insertions, 212 deletions
diff --git a/caffe2/operators/generate_proposals_op_test.cc b/caffe2/operators/generate_proposals_op_test.cc
index eff256d5d1..f79cf68912 100644
--- a/caffe2/operators/generate_proposals_op_test.cc
+++ b/caffe2/operators/generate_proposals_op_test.cc
@@ -413,7 +413,6 @@ TEST(GenerateProposalsTest, TestRealDownSampled) {
1e-4);
}
-#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) {
// Similar to TestRealDownSampled but for rotated boxes with angle info.
const float angle = 0;
@@ -522,7 +521,7 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotatedAngle0) {
ERMatXf rois_gt(rois_gt_xyxy.rows(), 6);
// Batch ID
rois_gt.block(0, 0, rois_gt.rows(), 1) =
- rois_gt_xyxy.block(0, 0, rois_gt.rows(), 0);
+ rois_gt_xyxy.block(0, 0, rois_gt.rows(), 1);
// rois_gt in [x_ctr, y_ctr, w, h] format
rois_gt.block(0, 1, rois_gt.rows(), 4) = utils::bbox_xyxy_to_ctrwh(
rois_gt_xyxy.block(0, 1, rois_gt.rows(), 4).array());
@@ -721,6 +720,5 @@ TEST(GenerateProposalsTest, TestRealDownSampledRotated) {
EXPECT_LE(std::abs(rois_data(i, 5) - expected_angle), 1e-4);
}
}
-#endif // CV_MAJOR_VERSION >= 3
} // namespace caffe2
diff --git a/caffe2/operators/generate_proposals_op_util_nms.h b/caffe2/operators/generate_proposals_op_util_nms.h
index b90fea8bb0..8c5234e347 100644
--- a/caffe2/operators/generate_proposals_op_util_nms.h
+++ b/caffe2/operators/generate_proposals_op_util_nms.h
@@ -169,274 +169,296 @@ std::vector<int> soft_nms_cpu_upright(
return keep;
}
-#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
namespace {
+const int INTERSECT_NONE = 0;
+const int INTERSECT_PARTIAL = 1;
+const int INTERSECT_FULL = 2;
+
+class RotatedRect {
+ public:
+ RotatedRect() {}
+ RotatedRect(
+ const Eigen::Vector2f& p_center,
+ const Eigen::Vector2f& p_size,
+ float p_angle)
+ : center(p_center), size(p_size), angle(p_angle) {}
+ void get_vertices(Eigen::Vector2f* pt) const {
+ // M_PI / 180. == 0.01745329251
+ double _angle = angle * 0.01745329251;
+ float b = (float)cos(_angle) * 0.5f;
+ float a = (float)sin(_angle) * 0.5f;
+
+ pt[0].x() = center.x() - a * size.y() - b * size.x();
+ pt[0].y() = center.y() + b * size.y() - a * size.x();
+ pt[1].x() = center.x() + a * size.y() - b * size.x();
+ pt[1].y() = center.y() - b * size.y() - a * size.x();
+ pt[2] = 2 * center - pt[0];
+ pt[3] = 2 * center - pt[1];
+ }
+ Eigen::Vector2f center;
+ Eigen::Vector2f size;
+ float angle;
+};
template <class Derived>
-cv::RotatedRect bbox_to_rotated_rect(const Eigen::ArrayBase<Derived>& box) {
+RotatedRect bbox_to_rotated_rect(const Eigen::ArrayBase<Derived>& box) {
CAFFE_ENFORCE_EQ(box.size(), 5);
// cv::RotatedRect takes angle to mean clockwise rotation, but RRPN bbox
// representation means counter-clockwise rotation.
- return cv::RotatedRect(
- cv::Point2f(box[0], box[1]), cv::Size2f(box[2], box[3]), -box[4]);
+ return RotatedRect(
+ Eigen::Vector2f(box[0], box[1]),
+ Eigen::Vector2f(box[2], box[3]),
+ -box[4]);
+}
+
+// Eigen doesn't seem to support 2d cross product, so we make one here
+float cross_2d(const Eigen::Vector2f& A, const Eigen::Vector2f& B) {
+ return A.x() * B.y() - B.x() * A.y();
}
-// TODO: cvfix_rotatedRectangleIntersection is a replacement function for
+// rotated_rect_intersection_pts is a replacement function for
// cv::rotatedRectangleIntersection, which has a bug due to float underflow
-// When OpenCV version is upgraded to be >= 4.0,
-// we can remove this replacement function.
// For anyone interested, here're the PRs on OpenCV:
// https://github.com/opencv/opencv/issues/12221
// https://github.com/opencv/opencv/pull/12222
-int cvfix_rotatedRectangleIntersection(
- const cv::RotatedRect& rect1,
- const cv::RotatedRect& rect2,
- cv::OutputArray intersectingRegion) {
+// Note that we do not check if the number of intersections is <= 8 in this case
+int rotated_rect_intersection_pts(
+ const RotatedRect& rect1,
+ const RotatedRect& rect2,
+ Eigen::Vector2f* intersections,
+ int& num) {
// Used to test if two points are the same
const float samePointEps = 0.00001f;
const float EPS = 1e-14;
+ num = 0; // number of intersections
- cv::Point2f vec1[4], vec2[4];
- cv::Point2f pts1[4], pts2[4];
-
- std::vector<cv::Point2f> intersection;
+ Eigen::Vector2f vec1[4], vec2[4], pts1[4], pts2[4];
- rect1.points(pts1);
- rect2.points(pts2);
-
- int ret = cv::INTERSECT_FULL;
+ rect1.get_vertices(pts1);
+ rect2.get_vertices(pts2);
// Specical case of rect1 == rect2
- {
- bool same = true;
+ bool same = true;
- for (int i = 0; i < 4; i++) {
- if (fabs(pts1[i].x - pts2[i].x) > samePointEps ||
- (fabs(pts1[i].y - pts2[i].y) > samePointEps)) {
- same = false;
- break;
- }
+ for (int i = 0; i < 4; i++) {
+ if (fabs(pts1[i].x() - pts2[i].x()) > samePointEps ||
+ (fabs(pts1[i].y() - pts2[i].y()) > samePointEps)) {
+ same = false;
+ break;
}
+ }
- if (same) {
- intersection.resize(4);
-
- for (int i = 0; i < 4; i++) {
- intersection[i] = pts1[i];
- }
-
- cv::Mat(intersection).copyTo(intersectingRegion);
-
- return cv::INTERSECT_FULL;
+ if (same) {
+ for (int i = 0; i < 4; i++) {
+ intersections[i] = pts1[i];
}
+ num = 4;
+ return INTERSECT_FULL;
}
// Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
for (int i = 0; i < 4; i++) {
- vec1[i].x = pts1[(i + 1) % 4].x - pts1[i].x;
- vec1[i].y = pts1[(i + 1) % 4].y - pts1[i].y;
-
- vec2[i].x = pts2[(i + 1) % 4].x - pts2[i].x;
- vec2[i].y = pts2[(i + 1) % 4].y - pts2[i].y;
+ vec1[i] = pts1[(i + 1) % 4] - pts1[i];
+ vec2[i] = pts2[(i + 1) % 4] - pts2[i];
}
// Line test - test all line combos for intersection
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
// Solve for 2x2 Ax=b
- float x21 = pts2[j].x - pts1[i].x;
- float y21 = pts2[j].y - pts1[i].y;
-
- const auto& l1 = vec1[i];
- const auto& l2 = vec2[j];
// This takes care of parallel lines
- float det = l2.x * l1.y - l1.x * l2.y;
+ float det = cross_2d(vec2[j], vec1[i]);
if (std::fabs(det) <= EPS) {
continue;
}
- float t1 = (l2.x * y21 - l2.y * x21) / det;
- float t2 = (l1.x * y21 - l1.y * x21) / det;
+ auto vec12 = pts2[j] - pts1[i];
+
+ float t1 = cross_2d(vec2[j], vec12) / det;
+ float t2 = cross_2d(vec1[i], vec12) / det;
if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
- float xi = pts1[i].x + vec1[i].x * t1;
- float yi = pts1[i].y + vec1[i].y * t1;
- intersection.push_back(cv::Point2f(xi, yi));
+ intersections[num++] = pts1[i] + t1 * vec1[i];
}
}
}
- if (!intersection.empty()) {
- ret = cv::INTERSECT_PARTIAL;
- }
-
// Check for vertices from rect1 inside rect2
- for (int i = 0; i < 4; i++) {
- // We do a sign test to see which side the point lies.
- // If the point all lie on the same sign for all 4 sides of the rect,
- // then there's an intersection
- int posSign = 0;
- int negSign = 0;
+ {
+ const auto& AB = vec2[0];
+ const auto& DA = vec2[3];
+ auto ABdotAB = AB.squaredNorm();
+ auto ADdotAD = DA.squaredNorm();
+ for (int i = 0; i < 4; i++) {
+ // assume ABCD is the rectangle, and P is the point to be judged
+ // P is inside ABCD iff. P's projection on AB lies within AB
+ // and P's projection on AD lies within AD
- float x = pts1[i].x;
- float y = pts1[i].y;
+ auto AP = pts1[i] - pts2[0];
- for (int j = 0; j < 4; j++) {
- // line equation: Ax + By + C = 0
- // see which side of the line this point is at
-
- // float causes underflow!
- // Original version:
- // float A = -vec2[j].y;
- // float B = vec2[j].x;
- // float C = -(A * pts2[j].x + B * pts2[j].y);
- // float s = A * x + B * y + C;
-
- double A = -vec2[j].y;
- double B = vec2[j].x;
- double C = -(A * pts2[j].x + B * pts2[j].y);
- double s = A * x + B * y + C;
-
- if (s >= 0) {
- posSign++;
- } else {
- negSign++;
- }
- }
+ auto APdotAB = AP.dot(AB);
+ auto APdotAD = -AP.dot(DA);
- if (posSign == 4 || negSign == 4) {
- intersection.push_back(pts1[i]);
+ if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
+ (APdotAD <= ADdotAD)) {
+ intersections[num++] = pts1[i];
+ }
}
}
// Reverse the check - check for vertices from rect2 inside rect1
- for (int i = 0; i < 4; i++) {
- // We do a sign test to see which side the point lies.
- // If the point all lie on the same sign for all 4 sides of the rect,
- // then there's an intersection
- int posSign = 0;
- int negSign = 0;
+ {
+ const auto& AB = vec1[0];
+ const auto& DA = vec1[3];
+ auto ABdotAB = AB.squaredNorm();
+ auto ADdotAD = DA.squaredNorm();
+ for (int i = 0; i < 4; i++) {
+ auto AP = pts2[i] - pts1[0];
- float x = pts2[i].x;
- float y = pts2[i].y;
+ auto APdotAB = AP.dot(AB);
+ auto APdotAD = -AP.dot(DA);
- for (int j = 0; j < 4; j++) {
- // line equation: Ax + By + C = 0
- // see which side of the line this point is at
-
- // float causes underflow!
- // Original version:
- // float A = -vec1[j].y;
- // float B = vec1[j].x;
- // float C = -(A * pts1[j].x + B * pts1[j].y);
- // float s = A*x + B*y + C;
-
- double A = -vec1[j].y;
- double B = vec1[j].x;
- double C = -(A * pts1[j].x + B * pts1[j].y);
- double s = A * x + B * y + C;
-
- if (s >= 0) {
- posSign++;
- } else {
- negSign++;
+ if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
+ (APdotAD <= ADdotAD)) {
+ intersections[num++] = pts2[i];
}
}
-
- if (posSign == 4 || negSign == 4) {
- intersection.push_back(pts2[i]);
- }
}
- // Get rid of dupes
- for (int i = 0; i < (int)intersection.size() - 1; i++) {
- for (size_t j = i + 1; j < intersection.size(); j++) {
- float dx = intersection[i].x - intersection[j].x;
- float dy = intersection[i].y - intersection[j].y;
- // can be a really small number, need double here
- double d2 = dx * dx + dy * dy;
-
- if (d2 < samePointEps * samePointEps) {
- // Found a dupe, remove it
- std::swap(intersection[j], intersection.back());
- intersection.pop_back();
- j--; // restart check
- }
+ return num ? INTERSECT_PARTIAL : INTERSECT_NONE;
+}
+
+// Compute convex hull using Graham scan algorithm
+int convex_hull_graham(
+ const Eigen::Vector2f* p,
+ const int& num_in,
+ Eigen::Vector2f* q,
+ bool shift_to_zero = false) {
+ CAFFE_ENFORCE(num_in >= 2);
+ std::vector<int> order;
+
+ // Step 1:
+ // Find point with minimum y
+ // if more than 1 points have the same minimum y,
+ // pick the one with the mimimum x.
+ int t = 0;
+ for (int i = 1; i < num_in; i++) {
+ if (p[i].y() < p[t].y() || (p[i].y() == p[t].y() && p[i].x() < p[t].x())) {
+ t = i;
}
}
+ auto& s = p[t]; // starting point
- if (intersection.empty()) {
- return cv::INTERSECT_NONE;
+ // Step 2:
+ // Subtract starting point from every points (for sorting in the next step)
+ for (int i = 0; i < num_in; i++) {
+ q[i] = p[i] - s;
}
- // If this check fails then it means we're getting dupes
- // CV_Assert(intersection.size() <= 8);
-
- // At this point, there might still be some edge cases failing the check above
- // However, it doesn't affect the result of polygon area,
- // even if the number of intersections is greater than 8.
- // Therefore, we just print out these cases for now instead of assertion.
- // TODO: These cases should provide good reference for improving the accuracy
- // for intersection computation above (for example, we should use
- // cross-product/dot-product of vectors instead of line equation to
- // judge the relationships between the points and line segments)
-
- if (intersection.size() > 8) {
- LOG(ERROR) << "Intersection size = " << intersection.size();
- LOG(ERROR) << "Rect 1:";
- for (int i = 0; i < 4; i++) {
- LOG(ERROR) << " (" << pts1[i].x << " ," << pts1[i].y << "),";
- }
- LOG(ERROR) << "Rect 2:";
- for (int i = 0; i < 4; i++) {
- LOG(ERROR) << " (" << pts2[i].x << " ," << pts2[i].y << "),";
- }
- LOG(ERROR) << "Intersections:";
- for (auto& p : intersection) {
- LOG(ERROR) << " (" << p.x << " ," << p.y << "),";
+ // Swap the starting point to position 0
+ std::swap(q[0], q[t]);
+
+ // Step 3:
+ // Sort point 1 ~ num_in according to their relative cross-product values
+ // (essentially sorting according to angles)
+ std::sort(
+ q + 1,
+ q + num_in,
+ [](const Eigen::Vector2f& A, const Eigen::Vector2f& B) -> bool {
+ float temp = cross_2d(A, B);
+ if (fabs(temp) < 1e-6) {
+ return A.squaredNorm() < B.squaredNorm();
+ } else {
+ return temp > 0;
+ }
+ });
+
+ // Step 4:
+ // Make sure there are at least 2 points (that don't overlap with each other)
+ // in the stack
+ int k; // index of the non-overlapped second point
+ for (k = 1; k < num_in; k++) {
+ if (q[k].squaredNorm() > 1e-8)
+ break;
+ }
+ if (k == num_in) {
+ // We reach the end, which means the convex hull is just one point
+ q[0] = p[t];
+ return 1;
+ }
+ q[1] = q[k];
+ int m = 2; // 2 elements in the stack
+ // Step 5:
+ // Finally we can start the scanning process.
+ // If we find a non-convex relationship between the 3 points,
+ // we pop the previous point from the stack until the stack only has two
+ // points, or the 3-point relationship is convex again
+ for (int i = k + 1; i < num_in; i++) {
+ while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
+ m--;
}
+ q[m++] = q[i];
+ }
+
+ // Step 6 (Optional):
+ // In general sense we need the original coordinates, so we
+ // need to shift the points back (reverting Step 2)
+ // But if we're only interested in getting the area/perimeter of the shape
+ // We can simply return.
+ if (!shift_to_zero) {
+ for (int i = 0; i < m; i++)
+ q[i] += s;
}
- cv::Mat(intersection).copyTo(intersectingRegion);
+ return m;
+}
- return ret;
+double polygon_area(const Eigen::Vector2f* q, const int& m) {
+ if (m <= 2)
+ return 0;
+ double area = 0;
+ for (int i = 1; i < m - 1; i++)
+ area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0]));
+ return area / 2.0;
}
/**
* Returns the intersection area of two rotated rectangles.
*/
double rotated_rect_intersection(
- const cv::RotatedRect& rect1,
- const cv::RotatedRect& rect2) {
- std::vector<cv::Point2f> intersectPts, orderedPts;
+ const RotatedRect& rect1,
+ const RotatedRect& rect2) {
+ // There are up to 16 intersections returned from
+ // rotated_rect_intersection_pts
+ Eigen::Vector2f intersectPts[16], orderedPts[16];
+ int num = 0; // number of intersections
// Find points of intersection
- // TODO: cvfix_rotatedRectangleIntersection is a replacement function for
+ // TODO: rotated_rect_intersection_pts is a replacement function for
// cv::rotatedRectangleIntersection, which has a bug due to float underflow
- // When OpenCV version is upgraded to be >= 4.0,
- // we can remove this replacement function and use the following instead:
- // auto ret = cv::rotatedRectangleIntersection(rect1, rect2, intersectPts);
// For anyone interested, here're the PRs on OpenCV:
// https://github.com/opencv/opencv/issues/12221
// https://github.com/opencv/opencv/pull/12222
- auto ret = cvfix_rotatedRectangleIntersection(rect1, rect2, intersectPts);
- if (intersectPts.size() <= 2) {
+ // Note: it doesn't matter if #intersections is greater than 8 here
+ auto ret = rotated_rect_intersection_pts(rect1, rect2, intersectPts, num);
+ CAFFE_ENFORCE(num <= 16);
+ if (num <= 2)
return 0.0;
- }
// If one rectangle is fully enclosed within another, return the area
// of the smaller one early.
- if (ret == cv::INTERSECT_FULL) {
- return std::min(rect1.size.area(), rect2.size.area());
+ if (ret == INTERSECT_FULL) {
+ return std::min(
+ rect1.size.x() * rect1.size.y(), rect2.size.x() * rect2.size.y());
}
// Convex Hull to order the intersection points in clockwise or
// counter-clockwise order and find the countour area.
- cv::convexHull(intersectPts, orderedPts);
- return cv::contourArea(orderedPts);
+ int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true);
+ return polygon_area(orderedPts, num_convex);
}
} // namespace
@@ -507,7 +529,7 @@ std::vector<int> nms_cpu_rotated(
auto heights = proposals.col(3);
EArrX areas = widths * heights;
- std::vector<cv::RotatedRect> rotated_rects(proposals.rows());
+ std::vector<RotatedRect> rotated_rects(proposals.rows());
for (int i = 0; i < proposals.rows(); ++i) {
rotated_rects[i] = bbox_to_rotated_rect(proposals.row(i));
}
@@ -568,7 +590,7 @@ std::vector<int> soft_nms_cpu_rotated(
auto heights = proposals.col(3);
EArrX areas = widths * heights;
- std::vector<cv::RotatedRect> rotated_rects(proposals.rows());
+ std::vector<RotatedRect> rotated_rects(proposals.rows());
for (int i = 0; i < proposals.rows(); ++i) {
rotated_rects[i] = bbox_to_rotated_rect(proposals.row(i));
}
@@ -627,7 +649,6 @@ std::vector<int> soft_nms_cpu_rotated(
return keep;
}
-#endif // CV_MAJOR_VERSION >= 3
template <class Derived1, class Derived2>
std::vector<int> nms_cpu(
@@ -636,7 +657,6 @@ std::vector<int> nms_cpu(
const std::vector<int>& sorted_indices,
float thresh,
int topN = -1) {
-#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
CAFFE_ENFORCE(proposals.cols() == 4 || proposals.cols() == 5);
if (proposals.cols() == 4) {
// Upright boxes
@@ -645,9 +665,6 @@ std::vector<int> nms_cpu(
// Rotated boxes with angle info
return nms_cpu_rotated(proposals, scores, sorted_indices, thresh, topN);
}
-#else
- return nms_cpu_upright(proposals, scores, sorted_indices, thresh, topN);
-#endif // CV_MAJOR_VERSION >= 3
}
// Greedy non-maximum suppression for proposed bounding boxes
@@ -686,7 +703,6 @@ std::vector<int> soft_nms_cpu(
float score_thresh = 0.001,
unsigned int method = 1,
int topN = -1) {
-#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
CAFFE_ENFORCE(proposals.cols() == 4 || proposals.cols() == 5);
if (proposals.cols() == 4) {
// Upright boxes
@@ -713,18 +729,6 @@ std::vector<int> soft_nms_cpu(
method,
topN);
}
-#else
- return soft_nms_cpu_upright(
- out_scores,
- proposals,
- scores,
- indices,
- sigma,
- overlap_thresh,
- score_thresh,
- method,
- topN);
-#endif // CV_MAJOR_VERSION >= 3
}
template <class Derived1, class Derived2, class Derived3>
diff --git a/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc b/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc
index 8999fda4ac..372accae0a 100644
--- a/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc
+++ b/caffe2/operators/generate_proposals_op_util_nms_gpu_test.cc
@@ -379,13 +379,9 @@ TEST(UtilsNMSTest, TestNMSGPURotatedAngle0) {
return;
const int box_dim = 5;
// Same boxes in TestNMS with (x_ctr, y_ctr, w, h, angle) format
- std::vector<float> boxes = {
- 30, 35, 41, 51, 0,
- 29.5, 36, 38, 49, 0,
- 24, 29.5, 33, 42, 0,
- 125, 120, 51, 41, 0,
- 127, 124.5, 57, 30, 0
- };
+ std::vector<float> boxes = {30, 35, 41, 51, 0, 29.5, 36, 38, 49,
+ 0, 24, 29.5, 33, 42, 0, 125, 120, 51,
+ 41, 0, 127, 124.5, 57, 30, 0};
std::vector<float> scores = {0.5f, 0.7f, 0.6f, 0.9f, 0.8f};
@@ -466,7 +462,6 @@ TEST(UtilsNMSTest, TestNMSGPURotatedAngle0) {
cuda_context.FinishDeviceComputation();
}
-#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
TEST(UtilsNMSTest, TestPerfRotatedNMS) {
if (!HasCudaGPU())
return;
@@ -678,6 +673,5 @@ TEST(UtilsNMSTest, GPUEqualsCPURotatedCorrectnessTest) {
}
}
}
-#endif // CV_MAJOR_VERSION >= 3
} // namespace caffe2
diff --git a/caffe2/operators/generate_proposals_op_util_nms_test.cc b/caffe2/operators/generate_proposals_op_util_nms_test.cc
index b7da35b697..8e8b5f17af 100644
--- a/caffe2/operators/generate_proposals_op_util_nms_test.cc
+++ b/caffe2/operators/generate_proposals_op_util_nms_test.cc
@@ -212,7 +212,6 @@ TEST(UtilsNMSTest, TestSoftNMS) {
}
}
-#if defined(CV_MAJOR_VERSION) && (CV_MAJOR_VERSION >= 3)
TEST(UtilsNMSTest, TestNMSRotatedAngle0) {
// Same inputs as TestNMS, but in RRPN format with angle 0 for testing
// nms_cpu_rotated
@@ -389,6 +388,42 @@ TEST(UtilsNMSTest, TestSoftNMSRotatedAngle0) {
TEST(UtilsNMSTest, RotatedBBoxOverlaps) {
{
+ // One box is fully within another box, the angle is irrelavant
+ int M = 2, N = 3;
+ Eigen::ArrayXXf boxes(M, 5);
+ for (int i = 0; i < M; i++) {
+ boxes.row(i) << 0, 0, 5, 6, (360.0 / M - 180.0);
+ }
+
+ Eigen::ArrayXXf query_boxes(N, 5);
+ for (int i = 0; i < N; i++) {
+ query_boxes.row(i) << 0, 0, 3, 3, (360.0 / M - 180.0);
+ }
+
+ Eigen::ArrayXXf expected(M, N);
+ // 0.3 == (3 * 3) / (5 * 6)
+ expected.fill(0.3);
+
+ auto actual = utils::bbox_overlaps_rotated(boxes, query_boxes);
+ EXPECT_TRUE(((expected - actual).abs() < 1e-6).all());
+ }
+
+ {
+ // Angle 0
+ Eigen::ArrayXXf boxes(1, 5);
+ boxes << 39.500000, 50.451096, 80.000000, 18.097809, -0.000000;
+
+ Eigen::ArrayXXf query_boxes(1, 5);
+ query_boxes << 39.120628, 41.014862, 79.241257, 36.427757, -0.000000;
+
+ Eigen::ArrayXXf expected(1, 1);
+ expected << 0.48346716237;
+
+ auto actual = utils::bbox_overlaps_rotated(boxes, query_boxes);
+ EXPECT_TRUE(((expected - actual).abs() < 1e-6).all());
+ }
+
+ {
// Simple case with angle 0 (upright boxes)
Eigen::ArrayXXf boxes(2, 5);
boxes << 10.5, 15.5, 21, 31, 0, 14.0, 17, 4, 10, 0;
@@ -436,6 +471,5 @@ TEST(UtilsNMSTest, RotatedBBoxOverlaps) {
EXPECT_TRUE(((expected - actual).abs() < 1e-6).all());
}
}
-#endif // CV_MAJOR_VERSION >= 3
} // namespace caffe2