summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--caffe2/operators/roi_align_gradient_op.cc255
-rw-r--r--caffe2/operators/roi_align_gradient_op.cu231
-rw-r--r--caffe2/operators/roi_align_gradient_op.h43
-rw-r--r--caffe2/operators/roi_align_op.cc376
-rw-r--r--caffe2/operators/roi_align_op.cu181
-rw-r--r--caffe2/operators/roi_align_op.h47
-rw-r--r--caffe2/operators/roi_align_op_gpu_test.cc266
-rw-r--r--modules/detectron/roi_align_op.cc98
-rw-r--r--modules/detectron/roi_align_op.cu363
-rw-r--r--modules/detectron/roi_align_op.h89
10 files changed, 1399 insertions, 550 deletions
diff --git a/caffe2/operators/roi_align_gradient_op.cc b/caffe2/operators/roi_align_gradient_op.cc
new file mode 100644
index 0000000000..1cc4103a53
--- /dev/null
+++ b/caffe2/operators/roi_align_gradient_op.cc
@@ -0,0 +1,255 @@
+#include "roi_align_gradient_op.h"
+
+#include "caffe2/utils/eigen_utils.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+namespace {
+
+template <typename T>
+void bilinear_interpolate_gradient(
+ const int height,
+ const int width,
+ T y,
+ T x,
+ T& w1,
+ T& w2,
+ T& w3,
+ T& w4,
+ int& x_low,
+ int& x_high,
+ int& y_low,
+ int& y_high,
+ const int /*index*/ /* index for debug only*/) {
+ // deal with cases that inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ w1 = w2 = w3 = w4 = 0.;
+ x_low = x_high = y_low = y_high = -1;
+ return;
+ }
+
+ if (y <= 0) {
+ y = 0;
+ }
+ if (x <= 0) {
+ x = 0;
+ }
+
+ y_low = (int)y;
+ x_low = (int)x;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+
+ // reference in forward
+ // T v1 = bottom_data[y_low * width + x_low];
+ // T v2 = bottom_data[y_low * width + x_high];
+ // T v3 = bottom_data[y_high * width + x_low];
+ // T v4 = bottom_data[y_high * width + x_high];
+ // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+ w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ return;
+}
+
+template <class T>
+inline void add(const T& val, T* address) {
+ *address += val;
+}
+
+template <typename T>
+void ROIAlignBackwardFeature(
+ const int nthreads,
+ const T* top_diff,
+ const int /*num_rois*/,
+ const T& spatial_scale,
+ const int channels,
+ const int height,
+ const int width,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio,
+ T* bottom_diff,
+ const T* bottom_rois,
+ int rois_cols) {
+ DCHECK(rois_cols == 4 || rois_cols == 5);
+
+ for (int index = 0; index < nthreads; index++) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const T* offset_bottom_rois = bottom_rois + n * rois_cols;
+ int roi_batch_ind = 0;
+ if (rois_cols == 5) {
+ roi_batch_ind = offset_bottom_rois[0];
+ offset_bottom_rois++;
+ }
+
+ // Do not using rounding; this implementation detail is critical
+ T roi_start_w = offset_bottom_rois[0] * spatial_scale;
+ T roi_start_h = offset_bottom_rois[1] * spatial_scale;
+ T roi_end_w = offset_bottom_rois[2] * spatial_scale;
+ T roi_end_h = offset_bottom_rois[3] * spatial_scale;
+ // T roi_start_w = round(offset_bottom_rois[0] * spatial_scale);
+ // T roi_start_h = round(offset_bottom_rois[1] * spatial_scale);
+ // T roi_end_w = round(offset_bottom_rois[2] * spatial_scale);
+ // T roi_end_h = round(offset_bottom_rois[3] * spatial_scale);
+
+ // Force malformed ROIs to be 1x1
+ T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
+ T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
+ T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+ T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+ T* offset_bottom_diff =
+ bottom_diff + (roi_batch_ind * channels + c) * height * width;
+
+ int top_offset = (n * channels + c) * pooled_height * pooled_width;
+ const T* offset_top_diff = top_diff + top_offset;
+ const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h = (sampling_ratio > 0)
+ ? sampling_ratio
+ : ceil(roi_height / pooled_height); // e.g., = 2
+ int roi_bin_grid_w =
+ (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+ // We do average (integral) pooling inside a bin
+ const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+ const T y = roi_start_h + ph * bin_size_h +
+ static_cast<T>(iy + .5f) * bin_size_h /
+ static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const T x = roi_start_w + pw * bin_size_w +
+ static_cast<T>(ix + .5f) * bin_size_w /
+ static_cast<T>(roi_bin_grid_w);
+
+ T w1, w2, w3, w4;
+ int x_low, x_high, y_low, y_high;
+
+ bilinear_interpolate_gradient(
+ height,
+ width,
+ y,
+ x,
+ w1,
+ w2,
+ w3,
+ w4,
+ x_low,
+ x_high,
+ y_low,
+ y_high,
+ index);
+
+ T g1 = top_diff_this_bin * w1 / count;
+ T g2 = top_diff_this_bin * w2 / count;
+ T g3 = top_diff_this_bin * w3 / count;
+ T g4 = top_diff_this_bin * w4 / count;
+
+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+ // atomic add is not needed for now since it is single threaded
+ add(static_cast<T>(g1), offset_bottom_diff + y_low * width + x_low);
+ add(static_cast<T>(g2), offset_bottom_diff + y_low * width + x_high);
+ add(static_cast<T>(g3), offset_bottom_diff + y_high * width + x_low);
+ add(static_cast<T>(g4), offset_bottom_diff + y_high * width + x_high);
+ } // if
+ } // ix
+ } // iy
+ } // for
+} // ROIAlignBackward
+
+} // namespace
+
+template <>
+bool RoIAlignGradientOp<float, CPUContext>::RunOnDevice() {
+ auto& X = Input(0); // Input data to pool
+ auto& R = Input(1); // RoIs
+ auto& dY = Input(2); // Gradient of net w.r.t. output of "forward" op
+ // (aka "gradOutput")
+ auto* dX = Output(0); // Gradient of net w.r.t. input to "forward" op
+ // (aka "gradInput")
+
+ CAFFE_ENFORCE_EQ(R.ndim(), 2);
+ // if R has 5 columns, the first column is the index, otherwise 0
+ CAFFE_ENFORCE(R.dim32(1) == 4 || R.dim32(1) == 5);
+
+ dX->ResizeLike(X);
+
+ // Must zero-out dX before accumulating gradients
+ // (TODO): Kaiming - is this safe?
+ math::Set<float, CPUContext>(
+ dX->size(), 0.f, dX->mutable_data<float>(), &context_);
+
+ if (dY.size() > 0) { // Handle possibly empty gradient if there were no rois
+ ROIAlignBackwardFeature<float>(
+ dY.size(),
+ dY.data<float>(),
+ R.dim32(0),
+ spatial_scale_,
+ X.dim32(1),
+ X.dim32(2),
+ X.dim32(3),
+ pooled_height_,
+ pooled_width_,
+ sampling_ratio_,
+ dX->mutable_data<float>(),
+ R.data<float>(),
+ R.dim32(1));
+ }
+ return true;
+}
+
+REGISTER_CPU_OPERATOR(RoIAlignGradient, RoIAlignGradientOp<float, CPUContext>);
+
+// Input: X, rois, dY (aka "gradOutput");
+// Output: dX (aka "gradInput")
+OPERATOR_SCHEMA(RoIAlignGradient)
+ .NumInputs(3)
+ .NumOutputs(1)
+ .Input(0, "X", "See RoIPoolF.")
+ .Input(1, "RoIs", "See RoIPoolF.")
+ .Input(2, "dY", "Gradient of forward output 0 (Y)")
+ .Output(0, "dX", "Gradient of forward input 0 (X)");
+
+namespace {
+
+class GetRoIAlignGradient : public GradientMakerBase {
+ using GradientMakerBase::GradientMakerBase;
+ vector<OperatorDef> GetGradientDefs() override {
+ return SingleGradientDef(
+ "RoIAlignGradient",
+ "",
+ vector<string>{I(0), I(1), GO(0)},
+ vector<string>{GI(0)});
+ }
+};
+
+} // namespace
+
+REGISTER_GRADIENT(RoIAlign, GetRoIAlignGradient);
+
+} // namespace caffe2
diff --git a/caffe2/operators/roi_align_gradient_op.cu b/caffe2/operators/roi_align_gradient_op.cu
new file mode 100644
index 0000000000..702b8c7102
--- /dev/null
+++ b/caffe2/operators/roi_align_gradient_op.cu
@@ -0,0 +1,231 @@
+#include "roi_align_gradient_op.h"
+
+#include <stdio.h>
+#include <cfloat>
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+namespace {
+
+template <typename T>
+inline __device__ T gpu_atomic_add(const T val, T* address);
+
+template <>
+inline __device__ float gpu_atomic_add(const float val, float* address) {
+ return atomicAdd(address, val);
+}
+
+template <typename T>
+__device__ void bilinear_interpolate_gradient(
+ const int height,
+ const int width,
+ T y,
+ T x,
+ T& w1,
+ T& w2,
+ T& w3,
+ T& w4,
+ int& x_low,
+ int& x_high,
+ int& y_low,
+ int& y_high,
+ const int index /* index for debug only*/) {
+ // deal with cases that inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ w1 = w2 = w3 = w4 = 0.;
+ x_low = x_high = y_low = y_high = -1;
+ return;
+ }
+
+ if (y <= 0) {
+ y = 0;
+ }
+ if (x <= 0) {
+ x = 0;
+ }
+
+ y_low = (int)y;
+ x_low = (int)x;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+
+ // reference in forward
+ // T v1 = bottom_data[y_low * width + x_low];
+ // T v2 = bottom_data[y_low * width + x_high];
+ // T v3 = bottom_data[y_high * width + x_low];
+ // T v4 = bottom_data[y_high * width + x_high];
+ // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+ w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ return;
+}
+
+template <typename T>
+__global__ void RoIAlignBackwardFeature(
+ const int nthreads,
+ const T* top_diff,
+ const int num_rois,
+ const T spatial_scale,
+ const int channels,
+ const int height,
+ const int width,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio,
+ T* bottom_diff,
+ const T* bottom_rois) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const T* offset_bottom_rois = bottom_rois + n * 5;
+ int roi_batch_ind = offset_bottom_rois[0];
+
+ // Do not using rounding; this implementation detail is critical
+ T roi_start_w = offset_bottom_rois[1] * spatial_scale;
+ T roi_start_h = offset_bottom_rois[2] * spatial_scale;
+ T roi_end_w = offset_bottom_rois[3] * spatial_scale;
+ T roi_end_h = offset_bottom_rois[4] * spatial_scale;
+ // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
+ // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
+ // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
+ // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
+
+ // Force malformed ROIs to be 1x1
+ T roi_width = max(roi_end_w - roi_start_w, (T)1.);
+ T roi_height = max(roi_end_h - roi_start_h, (T)1.);
+ T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+ T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+ T* offset_bottom_diff =
+ bottom_diff + (roi_batch_ind * channels + c) * height * width;
+
+ int top_offset = (n * channels + c) * pooled_height * pooled_width;
+ const T* offset_top_diff = top_diff + top_offset;
+ const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h = (sampling_ratio > 0)
+ ? sampling_ratio
+ : ceil(roi_height / pooled_height); // e.g., = 2
+ int roi_bin_grid_w =
+ (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+ // We do average (integral) pooling inside a bin
+ const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
+ {
+ const T y = roi_start_h + ph * bin_size_h +
+ static_cast<T>(iy + .5f) * bin_size_h /
+ static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const T x = roi_start_w + pw * bin_size_w +
+ static_cast<T>(ix + .5f) * bin_size_w /
+ static_cast<T>(roi_bin_grid_w);
+
+ T w1, w2, w3, w4;
+ int x_low, x_high, y_low, y_high;
+
+ bilinear_interpolate_gradient(
+ height,
+ width,
+ y,
+ x,
+ w1,
+ w2,
+ w3,
+ w4,
+ x_low,
+ x_high,
+ y_low,
+ y_high,
+ index);
+
+ T g1 = top_diff_this_bin * w1 / count;
+ T g2 = top_diff_this_bin * w2 / count;
+ T g3 = top_diff_this_bin * w3 / count;
+ T g4 = top_diff_this_bin * w4 / count;
+
+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+ gpu_atomic_add(
+ static_cast<T>(g1), offset_bottom_diff + y_low * width + x_low);
+ gpu_atomic_add(
+ static_cast<T>(g2), offset_bottom_diff + y_low * width + x_high);
+ gpu_atomic_add(
+ static_cast<T>(g3), offset_bottom_diff + y_high * width + x_low);
+ gpu_atomic_add(
+ static_cast<T>(g4), offset_bottom_diff + y_high * width + x_high);
+ } // if
+ } // ix
+ } // iy
+ } // CUDA_1D_KERNEL_LOOP
+} // RoIAlignBackward
+
+} // namespace
+
+template <>
+bool RoIAlignGradientOp<float, CUDAContext>::RunOnDevice() {
+ auto& X = Input(0); // Input data to pool
+ auto& R = Input(1); // RoIs
+ auto& dY = Input(2); // Gradient of net w.r.t. output of "forward" op
+ // (aka "gradOutput")
+ auto* dX = Output(0); // Gradient of net w.r.t. input to "forward" op
+ // (aka "gradInput")
+
+ dX->ResizeLike(X);
+
+ // Must zero-out dX before accumulating gradients
+ // (TODO): Kaiming - is this safe?
+ math::Set<float, CUDAContext>(
+ dX->size(), 0.f, dX->mutable_data<float>(), &context_);
+
+ if (dY.size() > 0) { // Handle possibly empty gradient if there were no rois
+ RoIAlignBackwardFeature<float>
+ <<<CAFFE_GET_BLOCKS(dY.size()),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context_.cuda_stream()>>>(
+ dY.size(),
+ dY.data<float>(),
+ R.dim32(0),
+ spatial_scale_,
+ X.dim32(1),
+ X.dim32(2),
+ X.dim32(3),
+ pooled_height_,
+ pooled_width_,
+ sampling_ratio_,
+ dX->mutable_data<float>(),
+ R.data<float>());
+ }
+ return true;
+}
+
+REGISTER_CUDA_OPERATOR(
+ RoIAlignGradient,
+ RoIAlignGradientOp<float, CUDAContext>);
+} // namespace caffe2
diff --git a/caffe2/operators/roi_align_gradient_op.h b/caffe2/operators/roi_align_gradient_op.h
new file mode 100644
index 0000000000..509825fbbf
--- /dev/null
+++ b/caffe2/operators/roi_align_gradient_op.h
@@ -0,0 +1,43 @@
+// Copyright 2004-present Facebook. All Rights Reserved.
+
+#ifndef ROI_ALIGN_OP_H_
+#define ROI_ALIGN_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+template <typename T, class Context>
+class RoIAlignGradientOp final : public Operator<Context> {
+ public:
+ RoIAlignGradientOp(const OperatorDef& def, Workspace* ws)
+ : Operator<Context>(def, ws),
+ spatial_scale_(
+ OperatorBase::GetSingleArgument<float>("spatial_scale", 1.)),
+ pooled_height_(OperatorBase::GetSingleArgument<int>("pooled_h", 1)),
+ pooled_width_(OperatorBase::GetSingleArgument<int>("pooled_w", 1)),
+ sampling_ratio_(
+ OperatorBase::GetSingleArgument<int>("sampling_ratio", -1)) {
+ DCHECK_GT(spatial_scale_, 0);
+ DCHECK_GT(pooled_height_, 0);
+ DCHECK_GT(pooled_width_, 0);
+ DCHECK_GE(sampling_ratio_, 0);
+ }
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+
+ bool RunOnDevice() override {
+ CAFFE_NOT_IMPLEMENTED;
+ }
+
+ protected:
+ float spatial_scale_;
+ int pooled_height_;
+ int pooled_width_;
+ int sampling_ratio_;
+};
+
+} // namespace caffe2
+
+#endif // ROI_ALIGN_OP_H_
diff --git a/caffe2/operators/roi_align_op.cc b/caffe2/operators/roi_align_op.cc
new file mode 100644
index 0000000000..8f5d12ea2b
--- /dev/null
+++ b/caffe2/operators/roi_align_op.cc
@@ -0,0 +1,376 @@
+#include "roi_align_op.h"
+
+#include "caffe2/utils/eigen_utils.h"
+#include "caffe2/utils/math.h"
+
+#ifdef CAFFE2_USE_MKL
+#include "caffe2/mkl/operators/operator_fallback_mkl.h"
+#endif // CAFFE2_USE_MKL
+
+namespace caffe2 {
+namespace {
+
+template <typename T>
+struct PreCalc {
+ int pos1;
+ int pos2;
+ int pos3;
+ int pos4;
+ T w1;
+ T w2;
+ T w3;
+ T w4;
+};
+
+template <typename T>
+void pre_calc_for_bilinear_interpolate(
+ const int height,
+ const int width,
+ const int pooled_height,
+ const int pooled_width,
+ const int iy_upper,
+ const int ix_upper,
+ T roi_start_h,
+ T roi_start_w,
+ T bin_size_h,
+ T bin_size_w,
+ int roi_bin_grid_h,
+ int roi_bin_grid_w,
+ std::vector<PreCalc<T>>& pre_calc) {
+ int pre_calc_index = 0;
+ for (int ph = 0; ph < pooled_height; ph++) {
+ for (int pw = 0; pw < pooled_width; pw++) {
+ for (int iy = 0; iy < iy_upper; iy++) {
+ const T yy = roi_start_h + ph * bin_size_h +
+ static_cast<T>(iy + .5f) * bin_size_h /
+ static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+ for (int ix = 0; ix < ix_upper; ix++) {
+ const T xx = roi_start_w + pw * bin_size_w +
+ static_cast<T>(ix + .5f) * bin_size_w /
+ static_cast<T>(roi_bin_grid_w);
+
+ T x = xx;
+ T y = yy;
+ // deal with: inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ PreCalc<T> pc;
+ pc.pos1 = 0;
+ pc.pos2 = 0;
+ pc.pos3 = 0;
+ pc.pos4 = 0;
+ pc.w1 = 0;
+ pc.w2 = 0;
+ pc.w3 = 0;
+ pc.w4 = 0;
+ pre_calc[pre_calc_index] = pc;
+ pre_calc_index += 1;
+ continue;
+ }
+
+ if (y <= 0) {
+ y = 0;
+ }
+ if (x <= 0) {
+ x = 0;
+ }
+
+ int y_low = (int)y;
+ int x_low = (int)x;
+ int y_high;
+ int x_high;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+ T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ // save weights and indeces
+ PreCalc<T> pc;
+ pc.pos1 = y_low * width + x_low;
+ pc.pos2 = y_low * width + x_high;
+ pc.pos3 = y_high * width + x_low;
+ pc.pos4 = y_high * width + x_high;
+ pc.w1 = w1;
+ pc.w2 = w2;
+ pc.w3 = w3;
+ pc.w4 = w4;
+ pre_calc[pre_calc_index] = pc;
+
+ pre_calc_index += 1;
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void ROIAlignForward(
+ const int nthreads,
+ const T* bottom_data,
+ const T& spatial_scale,
+ const int channels,
+ const int height,
+ const int width,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio,
+ const T* bottom_rois,
+ int roi_cols,
+ T* top_data,
+ StorageOrder order) {
+ DCHECK(roi_cols == 4 || roi_cols == 5);
+
+ int n_rois = nthreads / channels / pooled_width / pooled_height;
+ // (n, c, ph, pw) is an element in the pooled output
+ // can be parallelized using omp
+ // #pragma omp parallel for num_threads(32)
+ for (int n = 0; n < n_rois; n++) {
+ int index_n = n * channels * pooled_width * pooled_height;
+
+ // roi could have 4 or 5 columns
+ const T* offset_bottom_rois = bottom_rois + n * roi_cols;
+ int roi_batch_ind = 0;
+ if (roi_cols == 5) {
+ roi_batch_ind = offset_bottom_rois[0];
+ offset_bottom_rois++;
+ }
+
+ // Do not using rounding; this implementation detail is critical
+ T roi_start_w = offset_bottom_rois[0] * spatial_scale;
+ T roi_start_h = offset_bottom_rois[1] * spatial_scale;
+ T roi_end_w = offset_bottom_rois[2] * spatial_scale;
+ T roi_end_h = offset_bottom_rois[3] * spatial_scale;
+ // T roi_start_w = round(offset_bottom_rois[0] * spatial_scale);
+ // T roi_start_h = round(offset_bottom_rois[1] * spatial_scale);
+ // T roi_end_w = round(offset_bottom_rois[2] * spatial_scale);
+ // T roi_end_h = round(offset_bottom_rois[3] * spatial_scale);
+
+ // Force malformed ROIs to be 1x1
+ T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
+ T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
+ T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+ T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h = (sampling_ratio > 0)
+ ? sampling_ratio
+ : ceil(roi_height / pooled_height); // e.g., = 2
+ int roi_bin_grid_w =
+ (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+ // We do average (integral) pooling inside a bin
+ const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+ // we want to precalculate indeces and weights shared by all chanels,
+ // this is the key point of optimiation
+ std::vector<PreCalc<T>> pre_calc(
+ roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
+ pre_calc_for_bilinear_interpolate(
+ height,
+ width,
+ pooled_height,
+ pooled_width,
+ roi_bin_grid_h,
+ roi_bin_grid_w,
+ roi_start_h,
+ roi_start_w,
+ bin_size_h,
+ bin_size_w,
+ roi_bin_grid_h,
+ roi_bin_grid_w,
+ pre_calc);
+
+ if (order == StorageOrder::NCHW) {
+ for (int c = 0; c < channels; c++) {
+ int index_n_c = index_n + c * pooled_width * pooled_height;
+ const T* offset_bottom_data =
+ bottom_data + (roi_batch_ind * channels + c) * height * width;
+ int pre_calc_index = 0;
+
+ for (int ph = 0; ph < pooled_height; ph++) {
+ for (int pw = 0; pw < pooled_width; pw++) {
+ int index = index_n_c + ph * pooled_width + pw;
+
+ T output_val = 0.;
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ PreCalc<T> pc = pre_calc[pre_calc_index];
+ output_val += pc.w1 * offset_bottom_data[pc.pos1] +
+ pc.w2 * offset_bottom_data[pc.pos2] +
+ pc.w3 * offset_bottom_data[pc.pos3] +
+ pc.w4 * offset_bottom_data[pc.pos4];
+
+ pre_calc_index += 1;
+ }
+ }
+ output_val /= count;
+
+ top_data[index] = output_val;
+ } // for pw
+ } // for ph
+ } // for c
+ } // if nchw
+
+ if (order == StorageOrder::NHWC) {
+ const T* offset_bottom_data =
+ bottom_data + roi_batch_ind * channels * height * width;
+ int pre_calc_index = 0;
+
+ for (int ph = 0; ph < pooled_height; ph++) {
+ for (int pw = 0; pw < pooled_width; pw++) {
+ EVecXf output_vals = EVecXf::Zero(channels);
+
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ PreCalc<T> pc = pre_calc[pre_calc_index];
+
+ ConstEigenVectorMap<T> data_1(
+ offset_bottom_data + channels * pc.pos1, channels);
+ ConstEigenVectorMap<T> data_2(
+ offset_bottom_data + channels * pc.pos2, channels);
+ ConstEigenVectorMap<T> data_3(
+ offset_bottom_data + channels * pc.pos3, channels);
+ ConstEigenVectorMap<T> data_4(
+ offset_bottom_data + channels * pc.pos4, channels);
+
+ output_vals += pc.w1 * data_1 + pc.w2 * data_2 + pc.w3 * data_3 +
+ pc.w4 * data_4;
+
+ pre_calc_index += 1;
+ }
+ }
+ output_vals /= count;
+
+ int index_nhw = index_n + (ph * pooled_width + pw) * channels;
+ std::memcpy(
+ top_data + index_nhw, output_vals.data(), channels * sizeof(T));
+ } // for pw
+ } // for ph
+ } // if nhwc
+
+ } // for n
+}
+
+} // namespace
+
+template <>
+bool RoIAlignOp<float, CPUContext>::RunOnDevice() {
+ auto& X = Input(0); // Input data to pool, NCHW
+ auto& R = Input(1); // RoIs
+ auto* Y = Output(0); // RoI pooled data
+
+ if (R.size() == 0) {
+ // Handle empty rois
+ if (order_ == StorageOrder::NCHW) {
+ Y->Resize(0, X.dim32(1), pooled_height_, pooled_width_);
+ } else if (order_ == StorageOrder::NHWC) {
+ Y->Resize(0, pooled_height_, pooled_width_, X.dim32(3));
+ }
+ // The following mutable_data calls are needed to allocate the tensors
+ Y->mutable_data<float>();
+ return true;
+ }
+
+ CAFFE_ENFORCE_EQ(R.ndim(), 2);
+ // if R has 5 columns, the first column is the index, otherwise 0
+ CAFFE_ENFORCE(R.dim32(1) == 4 || R.dim32(1) == 5);
+
+ assert(sampling_ratio_ >= 0);
+
+ if (order_ == StorageOrder::NCHW) {
+ Y->Resize(R.dim32(0), X.dim32(1), pooled_height_, pooled_width_);
+ int output_size = Y->size();
+ ROIAlignForward<float>(
+ output_size,
+ X.data<float>(),
+ spatial_scale_,
+ X.dim32(1),
+ X.dim32(2),
+ X.dim32(3),
+ pooled_height_,
+ pooled_width_,
+ sampling_ratio_,
+ R.data<float>(),
+ R.dim32(1),
+ Y->mutable_data<float>(),
+ order_);
+ } else if (order_ == StorageOrder::NHWC) {
+ Y->Resize(R.dim32(0), pooled_height_, pooled_width_, X.dim32(3));
+ int output_size = Y->size();
+ ROIAlignForward<float>(
+ output_size,
+ X.data<float>(),
+ spatial_scale_,
+ X.dim32(3),
+ X.dim32(1),
+ X.dim32(2),
+ pooled_height_,
+ pooled_width_,
+ sampling_ratio_,
+ R.data<float>(),
+ R.dim32(1),
+ Y->mutable_data<float>(),
+ order_);
+ }
+
+ return true;
+}
+
+REGISTER_CPU_OPERATOR(RoIAlign, RoIAlignOp<float, CPUContext>);
+
+#ifdef CAFFE2_HAS_MKL_DNN
+REGISTER_MKL_OPERATOR(
+ RoIAlign,
+ mkl::MKLFallbackOp<RoIAlignOp<float, CPUContext>>);
+#endif // CAFFE2_HAS_MKL_DNN
+
+// Input: X, rois; Output: Y
+OPERATOR_SCHEMA(RoIAlign)
+ .NumInputs(2)
+ .NumOutputs(1)
+ .SetDoc(R"DOC(
+Region of Interest (RoI) align operation as used in Mask R-CNN.
+)DOC")
+ .Arg(
+ "spatial_scale",
+ "(float) default 1.0; Spatial scale of the input feature map X "
+ "relative to the input image. E.g., 0.0625 if X has a stride of 16 "
+ "w.r.t. the input image.")
+ .Arg("pooled_h", "(int) default 1; Pooled output Y's height.")
+ .Arg("pooled_w", "(int) default 1; Pooled output Y's width.")
+ .Arg(
+ "sampling_ratio",
+ "(int) default -1; number of sampling points in the interpolation grid "
+ "used to compute the output value of each pooled output bin. If > 0, "
+ "then exactly sampling_ratio x sampling_ratio grid points are used. If "
+ "<= 0, then an adaptive number of grid points are used (computed as "
+ "ceil(roi_width / pooled_w), and likewise for height).")
+ .Input(0, "X", "4D feature map input of shape (N, C, H, W).")
+ .Input(
+ 1,
+ "RoIs",
+ "2D input of shape (R, 5) specifying R RoIs with five columns "
+ "representing: batch index in [0, N - 1], x1, y1, x2, y2. The RoI "
+ "coordinates are in the coordinate system of the input image.")
+ .Output(
+ 0,
+ "Y",
+ "4D output of shape (R, C, pooled_h, pooled_w). The r-th batch element "
+ "is a pooled feature map cooresponding to the r-th RoI.");
+
+} // namespace caffe2
diff --git a/caffe2/operators/roi_align_op.cu b/caffe2/operators/roi_align_op.cu
new file mode 100644
index 0000000000..29676f31d6
--- /dev/null
+++ b/caffe2/operators/roi_align_op.cu
@@ -0,0 +1,181 @@
+#include "roi_align_op.h"
+
+#include <stdio.h>
+#include <cfloat>
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+namespace {
+
+template <typename T>
+__device__ T bilinear_interpolate(
+ const T* bottom_data,
+ const int height,
+ const int width,
+ T y,
+ T x,
+ const int index /* index for debug only*/) {
+ // deal with cases that inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ return 0;
+ }
+
+ if (y <= 0) {
+ y = 0;
+ }
+ if (x <= 0) {
+ x = 0;
+ }
+
+ int y_low = (int)y;
+ int x_low = (int)x;
+ int y_high;
+ int x_high;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+ // do bilinear interpolation
+ T v1 = bottom_data[y_low * width + x_low];
+ T v2 = bottom_data[y_low * width + x_high];
+ T v3 = bottom_data[y_high * width + x_low];
+ T v4 = bottom_data[y_high * width + x_high];
+ T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+ return val;
+}
+
+template <typename T>
+__global__ void RoIAlignForward(
+ const int nthreads,
+ const T* bottom_data,
+ const T spatial_scale,
+ const int channels,
+ const int height,
+ const int width,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio,
+ const T* bottom_rois,
+ T* top_data) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const T* offset_bottom_rois = bottom_rois + n * 5;
+ int roi_batch_ind = offset_bottom_rois[0];
+
+ // Do not using rounding; this implementation detail is critical
+ T roi_start_w = offset_bottom_rois[1] * spatial_scale;
+ T roi_start_h = offset_bottom_rois[2] * spatial_scale;
+ T roi_end_w = offset_bottom_rois[3] * spatial_scale;
+ T roi_end_h = offset_bottom_rois[4] * spatial_scale;
+ // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
+ // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
+ // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
+ // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
+
+ // Force malformed ROIs to be 1x1
+ T roi_width = max(roi_end_w - roi_start_w, (T)1.);
+ T roi_height = max(roi_end_h - roi_start_h, (T)1.);
+ T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+ T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+ const T* offset_bottom_data =
+ bottom_data + (roi_batch_ind * channels + c) * height * width;
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h = (sampling_ratio > 0)
+ ? sampling_ratio
+ : ceil(roi_height / pooled_height); // e.g., = 2
+ int roi_bin_grid_w =
+ (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+ // We do average (integral) pooling inside a bin
+ const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+ T output_val = 0.;
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
+ {
+ const T y = roi_start_h + ph * bin_size_h +
+ static_cast<T>(iy + .5f) * bin_size_h /
+ static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const T x = roi_start_w + pw * bin_size_w +
+ static_cast<T>(ix + .5f) * bin_size_w /
+ static_cast<T>(roi_bin_grid_w);
+
+ T val = bilinear_interpolate(
+ offset_bottom_data, height, width, y, x, index);
+ output_val += val;
+ }
+ }
+ output_val /= count;
+
+ top_data[index] = output_val;
+ }
+}
+
+} // namespace
+
+template <>
+bool RoIAlignOp<float, CUDAContext>::RunOnDevice() {
+ auto& X = Input(0); // Input data to pool
+ auto& R = Input(1); // RoIs
+ auto* Y = Output(0); // RoI pooled data
+
+ if (R.size() == 0) {
+ // Handle empty rois
+ Y->Resize(0, X.dim32(1), pooled_height_, pooled_width_);
+ // The following mutable_data calls are needed to allocate the tensors
+ Y->mutable_data<float>();
+ return true;
+ }
+
+ assert(sampling_ratio_ >= 0);
+
+ Y->Resize(R.dim32(0), X.dim32(1), pooled_height_, pooled_width_);
+ int output_size = Y->size();
+ RoIAlignForward<float>
+ <<<CAFFE_GET_BLOCKS(output_size),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context_.cuda_stream()>>>(
+ output_size,
+ X.data<float>(),
+ spatial_scale_,
+ X.dim32(1),
+ X.dim32(2),
+ X.dim32(3),
+ pooled_height_,
+ pooled_width_,
+ sampling_ratio_,
+ R.data<float>(),
+ Y->mutable_data<float>());
+ return true;
+}
+
+REGISTER_CUDA_OPERATOR(RoIAlign, RoIAlignOp<float, CUDAContext>);
+} // namespace caffe2
diff --git a/caffe2/operators/roi_align_op.h b/caffe2/operators/roi_align_op.h
new file mode 100644
index 0000000000..fc6f67c392
--- /dev/null
+++ b/caffe2/operators/roi_align_op.h
@@ -0,0 +1,47 @@
+// Copyright 2004-present Facebook. All Rights Reserved.
+
+#ifndef ROI_ALIGN_OP_H_
+#define ROI_ALIGN_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+template <typename T, class Context>
+class RoIAlignOp final : public Operator<Context> {
+ public:
+ RoIAlignOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<Context>(operator_def, ws),
+ order_(StringToStorageOrder(
+ OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
+ spatial_scale_(
+ OperatorBase::GetSingleArgument<float>("spatial_scale", 1.)),
+ pooled_height_(OperatorBase::GetSingleArgument<int>("pooled_h", 1)),
+ pooled_width_(OperatorBase::GetSingleArgument<int>("pooled_w", 1)),
+ sampling_ratio_(
+ OperatorBase::GetSingleArgument<int>("sampling_ratio", -1)) {
+ DCHECK_GT(spatial_scale_, 0);
+ DCHECK_GT(pooled_height_, 0);
+ DCHECK_GT(pooled_width_, 0);
+ DCHECK_GE(sampling_ratio_, 0);
+ DCHECK(order_ == StorageOrder::NCHW || order_ == StorageOrder::NHWC);
+ }
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+
+ bool RunOnDevice() override {
+ CAFFE_NOT_IMPLEMENTED;
+ }
+
+ protected:
+ StorageOrder order_;
+ float spatial_scale_;
+ int pooled_height_;
+ int pooled_width_;
+ int sampling_ratio_;
+};
+
+} // namespace caffe2
+
+#endif // ROI_ALIGN_OP_H_
diff --git a/caffe2/operators/roi_align_op_gpu_test.cc b/caffe2/operators/roi_align_op_gpu_test.cc
new file mode 100644
index 0000000000..b72738b640
--- /dev/null
+++ b/caffe2/operators/roi_align_op_gpu_test.cc
@@ -0,0 +1,266 @@
+#include "caffe2/utils/eigen_utils.h"
+#include "roi_align_op.h"
+
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/core/flags.h"
+#include "caffe2/utils/math.h"
+#include "gtest/gtest.h"
+
+namespace caffe2 {
+namespace {
+
+template <class Context>
+void AddConstInput(
+ const vector<TIndex>& shape,
+ const float value,
+ const string& name,
+ Context* context,
+ Workspace* ws) {
+ Blob* blob = ws->CreateBlob(name);
+ auto* tensor = blob->GetMutable<Tensor<Context>>();
+ tensor->Resize(shape);
+ math::Set<float, Context>(
+ tensor->size(), value, tensor->template mutable_data<float>(), context);
+ return;
+}
+
+template <class Context>
+void AddInput(
+ const vector<TIndex>& shape,
+ const vector<float>& values,
+ const string& name,
+ Workspace* ws);
+
+template <>
+void AddInput<CPUContext>(
+ const vector<TIndex>& shape,
+ const vector<float>& values,
+ const string& name,
+ Workspace* ws) {
+ Blob* blob = ws->CreateBlob(name);
+ auto* tensor = blob->GetMutable<TensorCPU>();
+ tensor->Resize(shape);
+ EigenVectorMap<float> tensor_vec(
+ tensor->mutable_data<float>(), tensor->size());
+ tensor_vec.array() = utils::AsEArrXt(values);
+}
+
+template <>
+void AddInput<CUDAContext>(
+ const vector<TIndex>& shape,
+ const vector<float>& values,
+ const string& name,
+ Workspace* ws) {
+ TensorCPU tmp(shape);
+ EigenVectorMap<float> tmp_vec(tmp.mutable_data<float>(), tmp.size());
+ tmp_vec.array() = utils::AsEArrXt(values);
+
+ Blob* blob = ws->CreateBlob(name);
+ auto* tensor = blob->template GetMutable<Tensor<CUDAContext>>();
+ tensor->CopyFrom(tmp);
+}
+
+template <class Context>
+DeviceType GetDeviceType() {
+ return CPU;
+}
+template <>
+DeviceType GetDeviceType<CUDAContext>() {
+ return CUDA;
+}
+
+int randInt(int a, int b) {
+ static std::random_device rd;
+ static std::mt19937 gen(rd());
+ return std::uniform_int_distribution<int>(a, b)(gen);
+}
+
+struct TestParams {
+ int N;
+ int C;
+ int H;
+ int W;
+ int n_rois;
+ vector<float> rois_array;
+};
+
+template <class Context>
+void CreateAndRun(
+ TensorCPU* outResult,
+ const string& order,
+ const TestParams& test_params,
+ bool random_test) {
+ Workspace ws;
+ Context context;
+
+ if (random_test) {
+ const int N = test_params.N;
+ const int C = test_params.C;
+ const int H = test_params.H;
+ const int W = test_params.W;
+ vector<float> features(N * C * H * W);
+ std::iota(features.begin(), features.end(), 0);
+ // utils::AsEArrXt(features) /= features.size();
+ AddInput<Context>(vector<TIndex>{N, C, H, W}, features, "X", &ws);
+ const int n_rois = test_params.n_rois;
+ const vector<float>& rois = test_params.rois_array;
+ AddInput<Context>(vector<TIndex>{n_rois, 5}, rois, "R", &ws);
+ } else {
+ const int N = 2;
+ const int C = 3;
+ const int H = 100;
+ const int W = 110;
+ vector<float> features(N * C * H * W);
+ std::iota(features.begin(), features.end(), 0);
+ // utils::AsEArrXt(features) /= features.size();
+ AddInput<Context>(vector<TIndex>{N, C, H, W}, features, "X", &ws);
+ vector<float> rois{0, 0, 0, 79, 59,
+ 0, 0, 5.0005703, 52.63237, 43.69501495,
+ 0, 24.13628387, 7.51243401, 79, 46.06628418,
+ 0, 0, 7.50924301, 68.47792816, 46.03357315,
+ 0, 0, 23.09477997, 51.61448669, 59,
+ 0, 0, 39.52141571, 52.44710541, 59,
+ 0, 23.57396317, 29.98791885, 79, 59,
+ 0, 0, 41.90219116, 79, 59,
+ 0, 0, 23.30098343, 79, 59};
+ AddInput<Context>(vector<TIndex>{9, 5}, rois, "R", &ws);
+ }
+
+ std::vector<unique_ptr<OperatorBase>> ops;
+ EXPECT_TRUE(order == "NCHW" || order == "NHWC");
+ if (order == "NCHW") {
+ OperatorDef def;
+ def.set_name("test");
+ def.set_type("RoIAlign");
+ def.add_input("X");
+ def.add_input("R");
+ def.add_output("Y");
+ def.mutable_device_option()->set_device_type(GetDeviceType<Context>());
+ def.add_arg()->CopyFrom(MakeArgument("spatial_scale", 1.0f / 16.0f));
+ def.add_arg()->CopyFrom(MakeArgument("pooled_h", 6));
+ def.add_arg()->CopyFrom(MakeArgument("pooled_w", 8));
+ def.add_arg()->CopyFrom(MakeArgument("sampling_ratio", 2));
+
+ ops.push_back(CreateOperator(def, &ws));
+ } else if (order == "NHWC") {
+ OperatorDef def_roialign;
+ def_roialign.set_name("test");
+ def_roialign.set_type("RoIAlign");
+ def_roialign.add_input("X_NHWC");
+ def_roialign.add_input("R");
+ def_roialign.add_output("Y_NHWC");
+ def_roialign.mutable_device_option()->set_device_type(
+ GetDeviceType<Context>());
+ def_roialign.add_arg()->CopyFrom(
+ MakeArgument("spatial_scale", 1.0f / 16.0f));
+ def_roialign.add_arg()->CopyFrom(MakeArgument("pooled_h", 6));
+ def_roialign.add_arg()->CopyFrom(MakeArgument("pooled_w", 8));
+ def_roialign.add_arg()->CopyFrom(MakeArgument("sampling_ratio", 2));
+ def_roialign.add_arg()->CopyFrom(MakeArgument<string>("order", "NHWC"));
+
+ OperatorDef def_x;
+ def_x.set_name("test_x");
+ def_x.set_type("NCHW2NHWC");
+ def_x.add_input("X");
+ def_x.add_output("X_NHWC");
+ def_x.mutable_device_option()->set_device_type(GetDeviceType<Context>());
+
+ OperatorDef def_y;
+ def_y.set_name("test_y");
+ def_y.set_type("NHWC2NCHW");
+ def_y.add_input("Y_NHWC");
+ def_y.add_output("Y");
+ def_y.mutable_device_option()->set_device_type(GetDeviceType<Context>());
+
+ ops.push_back(CreateOperator(def_x, &ws));
+ ops.push_back(CreateOperator(def_roialign, &ws));
+ ops.push_back(CreateOperator(def_y, &ws));
+ }
+
+ for (auto const& op : ops) {
+ EXPECT_NE(nullptr, op.get());
+ EXPECT_TRUE(op->Run());
+ }
+
+ Blob* Y_blob = ws.GetBlob("Y");
+ EXPECT_NE(nullptr, Y_blob);
+
+ auto& Y = Y_blob->Get<Tensor<Context>>();
+ outResult->CopyFrom(Y, &context);
+}
+
+} // namespace
+
+TEST(RoiAlignTest, CheckCPUGPUEqual) {
+ if (!caffe2::HasCudaGPU())
+ return;
+
+ TensorCPU y_cpu;
+ TensorCPU y_gpu;
+ TensorCPU y_cpu_nhwc;
+
+ // tests using FAIR example
+ {
+ TestParams test_params;
+ CreateAndRun<CPUContext>(&y_cpu, "NCHW", test_params, false);
+ CreateAndRun<CUDAContext>(&y_gpu, "NCHW", test_params, false);
+ CreateAndRun<CPUContext>(&y_cpu_nhwc, "NHWC", test_params, false);
+
+ EXPECT_EQ(y_cpu.dims(), y_gpu.dims());
+ EXPECT_EQ(y_cpu.dims(), y_cpu_nhwc.dims());
+ ConstEigenVectorMap<float> y_cpu_vec(y_cpu.data<float>(), y_cpu.size());
+ ConstEigenVectorMap<float> y_gpu_vec(y_gpu.data<float>(), y_gpu.size());
+ ConstEigenVectorMap<float> y_cpu_nhwc_vec(
+ y_cpu_nhwc.data<float>(), y_cpu_nhwc.size());
+ int max_diff_idx = -1;
+ (y_cpu_vec - y_gpu_vec).cwiseAbs().maxCoeff(&max_diff_idx);
+ EXPECT_FLOAT_EQ(y_cpu_vec[max_diff_idx], y_gpu_vec[max_diff_idx]);
+
+ max_diff_idx = -1;
+ (y_cpu_vec - y_cpu_nhwc_vec).cwiseAbs().maxCoeff(&max_diff_idx);
+ EXPECT_FLOAT_EQ(y_cpu_vec[max_diff_idx], y_cpu_nhwc_vec[max_diff_idx]);
+ }
+
+ // random tests
+ const int random_test_numbers = 100;
+ for (int i = 0; i < random_test_numbers; i++) {
+ const int N = randInt(1, 5);
+ const int C = randInt(1, 5);
+ const int H = randInt(1, 50);
+ const int W = randInt(1, 50);
+ const int n_rois = randInt(0, 30);
+ vector<float> rois_array;
+ for (int n = 0; n < n_rois; n++) {
+ rois_array.push_back(randInt(0, N - 1));
+ int w1 = randInt(-20, W + 20);
+ int w2 = randInt(-20, W + 20);
+ int h1 = randInt(-20, H + 20);
+ int h2 = randInt(-20, H + 20);
+ rois_array.push_back(std::min(w1, w2));
+ rois_array.push_back(std::max(h1, h2));
+ rois_array.push_back(std::min(w1, w2));
+ rois_array.push_back(std::max(h1, h2));
+ }
+ TestParams test_params{N, C, H, W, n_rois, rois_array};
+
+ CreateAndRun<CPUContext>(&y_cpu, "NCHW", test_params, true);
+ CreateAndRun<CUDAContext>(&y_gpu, "NCHW", test_params, true);
+ CreateAndRun<CPUContext>(&y_cpu_nhwc, "NHWC", test_params, true);
+
+ EXPECT_EQ(y_cpu.dims(), y_gpu.dims());
+ EXPECT_EQ(y_cpu.dims(), y_cpu_nhwc.dims());
+ ConstEigenVectorMap<float> y_cpu_vec(y_cpu.data<float>(), y_cpu.size());
+ ConstEigenVectorMap<float> y_gpu_vec(y_gpu.data<float>(), y_gpu.size());
+ ConstEigenVectorMap<float> y_cpu_nhwc_vec(
+ y_cpu_nhwc.data<float>(), y_cpu_nhwc.size());
+ int max_diff_idx = -1;
+ (y_cpu_vec - y_gpu_vec).cwiseAbs().maxCoeff(&max_diff_idx);
+ EXPECT_FLOAT_EQ(y_cpu_vec[max_diff_idx], y_gpu_vec[max_diff_idx]);
+
+ max_diff_idx = -1;
+ (y_cpu_vec - y_cpu_nhwc_vec).cwiseAbs().maxCoeff(&max_diff_idx);
+ EXPECT_FLOAT_EQ(y_cpu_vec[max_diff_idx], y_cpu_nhwc_vec[max_diff_idx]);
+ }
+}
+
+} // namespace caffe2
diff --git a/modules/detectron/roi_align_op.cc b/modules/detectron/roi_align_op.cc
deleted file mode 100644
index 38094ff210..0000000000
--- a/modules/detectron/roi_align_op.cc
+++ /dev/null
@@ -1,98 +0,0 @@
-/**
- * Copyright (c) 2016-present, Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "roi_align_op.h"
-
-namespace caffe2 {
-
-REGISTER_CPU_OPERATOR(RoIAlign, RoIAlignOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(RoIAlignGradient, RoIAlignGradientOp<float, CPUContext>);
-
-OPERATOR_SCHEMA(RoIAlign)
- .NumInputs(2)
- .NumOutputs(1)
- .SetDoc(R"DOC(
-Region of Interest (RoI) align operation as used in Mask R-CNN.
-)DOC")
- .Arg(
- "spatial_scale",
- "(float) default 1.0; Spatial scale of the input feature map X "
- "relative to the input image. E.g., 0.0625 if X has a stride of 16 "
- "w.r.t. the input image.")
- .Arg(
- "pooled_h",
- "(int) default 1; Pooled output Y's height.")
- .Arg(
- "pooled_w",
- "(int) default 1; Pooled output Y's width.")
- .Arg(
- "sampling_ratio",
- "(int) default -1; number of sampling points in the interpolation grid "
- "used to compute the output value of each pooled output bin. If > 0, "
- "then exactly sampling_ratio x sampling_ratio grid points are used. If "
- "<= 0, then an adaptive number of grid points are used (computed as "
- "ceil(roi_width / pooled_w), and likewise for height)."
- )
- .Input(
- 0,
- "X",
- "4D feature map input of shape (N, C, H, W).")
- .Input(
- 1,
- "RoIs",
- "2D input of shape (R, 5) specifying R RoIs with five columns "
- "representing: batch index in [0, N - 1], x1, y1, x2, y2. The RoI "
- "coordinates are in the coordinate system of the input image.")
- .Output(
- 0,
- "Y",
- "4D output of shape (R, C, pooled_h, pooled_w). The r-th batch element "
- "is a pooled feature map cooresponding to the r-th RoI.");
-
-OPERATOR_SCHEMA(RoIAlignGradient)
- .NumInputs(3)
- .NumOutputs(1)
- .Input(
- 0,
- "X",
- "See RoIPoolF.")
- .Input(
- 1,
- "RoIs",
- "See RoIPoolF.")
- .Input(
- 2,
- "dY",
- "Gradient of forward output 0 (Y)")
- .Output(
- 0,
- "dX",
- "Gradient of forward input 0 (X)");
-
-class GetRoIAlignGradient : public GradientMakerBase {
- using GradientMakerBase::GradientMakerBase;
- vector<OperatorDef> GetGradientDefs() override {
- return SingleGradientDef(
- "RoIAlignGradient",
- "",
- vector<string>{I(0), I(1), GO(0)},
- vector<string>{GI(0)});
- }
-};
-
-REGISTER_GRADIENT(RoIAlign, GetRoIAlignGradient);
-
-} // namespace caffe2
diff --git a/modules/detectron/roi_align_op.cu b/modules/detectron/roi_align_op.cu
deleted file mode 100644
index 01d67f63b0..0000000000
--- a/modules/detectron/roi_align_op.cu
+++ /dev/null
@@ -1,363 +0,0 @@
-/**
- * Copyright (c) 2016-present, Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// RoIAlign for Mask R-CNN
-// This is the per-cell centered algined versio of RoIAlign.
-// This is the official version.
-
-#include <cfloat>
-
-#include "caffe2/core/context_gpu.h"
-#include "roi_align_op.h"
-
-#include <stdio.h>
-
-namespace caffe2 {
-
-namespace {
-
-template <typename T>
-inline __device__ T gpu_atomic_add(const T val, T* address);
-
-template <>
-inline __device__
-float gpu_atomic_add(const float val, float* address) {
- return atomicAdd(address, val);
-}
-
-template <typename T>
-__device__ T bilinear_interpolate(const T* bottom_data,
- const int height, const int width,
- T y, T x,
- const int index /* index for debug only*/) {
-
- // deal with cases that inverse elements are out of feature map boundary
- if (y < -1.0 || y > height || x < -1.0 || x > width) {
- //empty
- return 0;
- }
-
- if (y <= 0) y = 0;
- if (x <= 0) x = 0;
-
- int y_low = (int) y;
- int x_low = (int) x;
- int y_high;
- int x_high;
-
- if (y_low >= height - 1) {
- y_high = y_low = height - 1;
- y = (T) y_low;
- } else {
- y_high = y_low + 1;
- }
-
- if (x_low >= width - 1) {
- x_high = x_low = width - 1;
- x = (T) x_low;
- } else {
- x_high = x_low + 1;
- }
-
- T ly = y - y_low;
- T lx = x - x_low;
- T hy = 1. - ly, hx = 1. - lx;
- // do bilinear interpolation
- T v1 = bottom_data[y_low * width + x_low];
- T v2 = bottom_data[y_low * width + x_high];
- T v3 = bottom_data[y_high * width + x_low];
- T v4 = bottom_data[y_high * width + x_high];
- T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
-
- T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
-
- return val;
-}
-
-template <typename T>
-__global__ void RoIAlignForward(const int nthreads, const T* bottom_data,
- const T spatial_scale, const int channels,
- const int height, const int width,
- const int pooled_height, const int pooled_width,
- const int sampling_ratio,
- const T* bottom_rois, T* top_data) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- // (n, c, ph, pw) is an element in the pooled output
- int pw = index % pooled_width;
- int ph = (index / pooled_width) % pooled_height;
- int c = (index / pooled_width / pooled_height) % channels;
- int n = index / pooled_width / pooled_height / channels;
-
- const T* offset_bottom_rois = bottom_rois + n * 5;
- int roi_batch_ind = offset_bottom_rois[0];
-
- // Do not using rounding; this implementation detail is critical
- T roi_start_w = offset_bottom_rois[1] * spatial_scale;
- T roi_start_h = offset_bottom_rois[2] * spatial_scale;
- T roi_end_w = offset_bottom_rois[3] * spatial_scale;
- T roi_end_h = offset_bottom_rois[4] * spatial_scale;
- // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
- // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
- // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
- // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
-
- // Force malformed ROIs to be 1x1
- T roi_width = max(roi_end_w - roi_start_w, (T)1.);
- T roi_height = max(roi_end_h - roi_start_h, (T)1.);
- T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
- T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
-
- const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
-
- // We use roi_bin_grid to sample the grid and mimic integral
- int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
- int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
-
- // We do average (integral) pooling inside a bin
- const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
-
- T output_val = 0.;
- for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1
- {
- const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
- for (int ix = 0; ix < roi_bin_grid_w; ix ++)
- {
- const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
-
- T val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index);
- output_val += val;
- }
- }
- output_val /= count;
-
- top_data[index] = output_val;
- }
-}
-
-template <typename T>
-__device__ void bilinear_interpolate_gradient(
- const int height, const int width,
- T y, T x,
- T & w1, T & w2, T & w3, T & w4,
- int & x_low, int & x_high, int & y_low, int & y_high,
- const int index /* index for debug only*/) {
-
- // deal with cases that inverse elements are out of feature map boundary
- if (y < -1.0 || y > height || x < -1.0 || x > width) {
- //empty
- w1 = w2 = w3 = w4 = 0.;
- x_low = x_high = y_low = y_high = -1;
- return;
- }
-
- if (y <= 0) y = 0;
- if (x <= 0) x = 0;
-
- y_low = (int) y;
- x_low = (int) x;
-
- if (y_low >= height - 1) {
- y_high = y_low = height - 1;
- y = (T) y_low;
- } else {
- y_high = y_low + 1;
- }
-
- if (x_low >= width - 1) {
- x_high = x_low = width - 1;
- x = (T) x_low;
- } else {
- x_high = x_low + 1;
- }
-
- T ly = y - y_low;
- T lx = x - x_low;
- T hy = 1. - ly, hx = 1. - lx;
-
- // reference in forward
- // T v1 = bottom_data[y_low * width + x_low];
- // T v2 = bottom_data[y_low * width + x_high];
- // T v3 = bottom_data[y_high * width + x_low];
- // T v4 = bottom_data[y_high * width + x_high];
- // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
-
- w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
-
- return;
-}
-
-template <typename T>
-__global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff,
- const int num_rois, const T spatial_scale,
- const int channels, const int height, const int width,
- const int pooled_height, const int pooled_width,
- const int sampling_ratio,
- T* bottom_diff,
- const T* bottom_rois) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- // (n, c, ph, pw) is an element in the pooled output
- int pw = index % pooled_width;
- int ph = (index / pooled_width) % pooled_height;
- int c = (index / pooled_width / pooled_height) % channels;
- int n = index / pooled_width / pooled_height / channels;
-
- const T* offset_bottom_rois = bottom_rois + n * 5;
- int roi_batch_ind = offset_bottom_rois[0];
-
- // Do not using rounding; this implementation detail is critical
- T roi_start_w = offset_bottom_rois[1] * spatial_scale;
- T roi_start_h = offset_bottom_rois[2] * spatial_scale;
- T roi_end_w = offset_bottom_rois[3] * spatial_scale;
- T roi_end_h = offset_bottom_rois[4] * spatial_scale;
- // T roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
- // T roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
- // T roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
- // T roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
-
- // Force malformed ROIs to be 1x1
- T roi_width = max(roi_end_w - roi_start_w, (T)1.);
- T roi_height = max(roi_end_h - roi_start_h, (T)1.);
- T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
- T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
-
- T* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width;
-
- int top_offset = (n * channels + c) * pooled_height * pooled_width;
- const T* offset_top_diff = top_diff + top_offset;
- const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
-
- // We use roi_bin_grid to sample the grid and mimic integral
- int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
- int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
-
- // We do average (integral) pooling inside a bin
- const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
-
- for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1
- {
- const T y = roi_start_h + ph * bin_size_h + static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
- for (int ix = 0; ix < roi_bin_grid_w; ix ++)
- {
- const T x = roi_start_w + pw * bin_size_w + static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
-
- T w1, w2, w3, w4;
- int x_low, x_high, y_low, y_high;
-
- bilinear_interpolate_gradient(height, width, y, x,
- w1, w2, w3, w4,
- x_low, x_high, y_low, y_high,
- index);
-
- T g1 = top_diff_this_bin * w1 / count;
- T g2 = top_diff_this_bin * w2 / count;
- T g3 = top_diff_this_bin * w3 / count;
- T g4 = top_diff_this_bin * w4 / count;
-
- if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0)
- {
- gpu_atomic_add(static_cast<T>(g1), offset_bottom_diff + y_low * width + x_low);
- gpu_atomic_add(static_cast<T>(g2), offset_bottom_diff + y_low * width + x_high);
- gpu_atomic_add(static_cast<T>(g3), offset_bottom_diff + y_high * width + x_low);
- gpu_atomic_add(static_cast<T>(g4), offset_bottom_diff + y_high * width + x_high);
- } // if
- } // ix
- } // iy
- } // CUDA_1D_KERNEL_LOOP
-} // RoIAlignBackward
-
-
-} // namespace
-
-template<>
-bool RoIAlignOp<float, CUDAContext>::RunOnDevice() {
- auto& X = Input(0); // Input data to pool
- auto& R = Input(1); // RoIs
- auto* Y = Output(0); // RoI pooled data
-
- if (R.size() == 0) {
- // Handle empty rois
- Y->Resize(0, X.dim32(1), pooled_height_, pooled_width_);
- // The following mutable_data calls are needed to allocate the tensors
- Y->mutable_data<float>();
- return true;
- }
-
- assert(sampling_ratio_ >= 0);
-
- Y->Resize(R.dim32(0), X.dim32(1), pooled_height_, pooled_width_);
- int output_size = Y->size();
- RoIAlignForward<float>
- <<<CAFFE_GET_BLOCKS(output_size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- output_size,
- X.data<float>(),
- spatial_scale_,
- X.dim32(1),
- X.dim32(2),
- X.dim32(3),
- pooled_height_,
- pooled_width_,
- sampling_ratio_,
- R.data<float>(),
- Y->mutable_data<float>());
- return true;
-}
-
-template<>
-bool RoIAlignGradientOp<float, CUDAContext>::RunOnDevice() {
- auto& X = Input(0); // Input data to pool
- auto& R = Input(1); // RoIs
- auto& dY = Input(2); // Gradient of net w.r.t. output of "forward" op
- // (aka "gradOutput")
- auto* dX = Output(0); // Gradient of net w.r.t. input to "forward" op
- // (aka "gradInput")
-
- dX->ResizeLike(X);
-
- // Must zero-out dX before accumulating gradients
- math::Set<float, CUDAContext>(
- dX->size(), 0.f, dX->mutable_data<float>(), &context_);
-
- if (dY.size() > 0) { // Handle possibly empty gradient if there were no rois
- RoIAlignBackwardFeature<float>
- <<<CAFFE_GET_BLOCKS(dY.size()),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context_.cuda_stream()>>>(
- dY.size(),
- dY.data<float>(),
- R.dim32(0),
- spatial_scale_,
- X.dim32(1),
- X.dim32(2),
- X.dim32(3),
- pooled_height_,
- pooled_width_,
- sampling_ratio_,
- dX->mutable_data<float>(),
- R.data<float>());
- }
- return true;
-}
-
-
-REGISTER_CUDA_OPERATOR(RoIAlign,
- RoIAlignOp<float, CUDAContext>);
-REGISTER_CUDA_OPERATOR(RoIAlignGradient,
- RoIAlignGradientOp<float, CUDAContext>);
-} // namespace caffe2
diff --git a/modules/detectron/roi_align_op.h b/modules/detectron/roi_align_op.h
deleted file mode 100644
index 8283d4b242..0000000000
--- a/modules/detectron/roi_align_op.h
+++ /dev/null
@@ -1,89 +0,0 @@
-/**
- * Copyright (c) 2016-present, Facebook, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef ROI_ALIGN_OP_H_
-#define ROI_ALIGN_OP_H_
-
-#include "caffe2/core/context.h"
-#include "caffe2/core/logging.h"
-#include "caffe2/core/operator.h"
-#include "caffe2/utils/math.h"
-
-namespace caffe2 {
-
-template <typename T, class Context>
-class RoIAlignOp final : public Operator<Context> {
- public:
- RoIAlignOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws),
- spatial_scale_(
- OperatorBase::GetSingleArgument<float>("spatial_scale", 1.)),
- pooled_height_(OperatorBase::GetSingleArgument<int>("pooled_h", 1)),
- pooled_width_(OperatorBase::GetSingleArgument<int>("pooled_w", 1)),
- sampling_ratio_(
- OperatorBase::GetSingleArgument<int>("sampling_ratio", -1)) {
- DCHECK_GT(spatial_scale_, 0);
- DCHECK_GT(pooled_height_, 0);
- DCHECK_GT(pooled_width_, 0);
- DCHECK_GE(sampling_ratio_, 0);
- }
- USE_OPERATOR_CONTEXT_FUNCTIONS;
-
- bool RunOnDevice() override {
- // No CPU implementation for now
- CAFFE_NOT_IMPLEMENTED;
- }
-
- protected:
- float spatial_scale_;
- int pooled_height_;
- int pooled_width_;
- int sampling_ratio_;
-};
-
-template <typename T, class Context>
-class RoIAlignGradientOp final : public Operator<Context> {
- public:
- RoIAlignGradientOp(const OperatorDef& def, Workspace* ws)
- : Operator<Context>(def, ws),
- spatial_scale_(
- OperatorBase::GetSingleArgument<float>("spatial_scale", 1.)),
- pooled_height_(OperatorBase::GetSingleArgument<int>("pooled_h", 1)),
- pooled_width_(OperatorBase::GetSingleArgument<int>("pooled_w", 1)),
- sampling_ratio_(
- OperatorBase::GetSingleArgument<int>("sampling_ratio", -1)) {
- DCHECK_GT(spatial_scale_, 0);
- DCHECK_GT(pooled_height_, 0);
- DCHECK_GT(pooled_width_, 0);
- DCHECK_GE(sampling_ratio_, 0);
- }
- USE_OPERATOR_CONTEXT_FUNCTIONS;
-
- bool RunOnDevice() override {
- // No CPU implementation for now
- CAFFE_NOT_IMPLEMENTED;
- }
-
- protected:
- float spatial_scale_;
- int pooled_height_;
- int pooled_width_;
- int sampling_ratio_;
-};
-
-} // namespace caffe2
-
-#endif // ROI_ALIGN_OP_H_