summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBram Wasti <bwasti@fb.com>2019-02-05 12:30:31 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-05 12:56:58 -0800
commita9713d07b0ab2845f15c6ca64cb13931556d0a5a (patch)
treed9dd6b7c6bd267cca0f066b2b4ae4ceda04c7ec0
parent3df7b321cc5b23a3c9946f0fe34727b88c3bbb75 (diff)
downloadpytorch-a9713d07b0ab2845f15c6ca64cb13931556d0a5a.tar.gz
pytorch-a9713d07b0ab2845f15c6ca64cb13931556d0a5a.tar.bz2
pytorch-a9713d07b0ab2845f15c6ca64cb13931556d0a5a.zip
Expose HeatmapMaxKeypoints to torch (#16528)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16528 .. Reviewed By: smessmer Differential Revision: D13866214 fbshipit-source-id: 2ca79037fc070bade5542345af5ce09f88beda44
-rw-r--r--caffe2/operators/heatmap_max_keypoint_op.cc13
-rw-r--r--caffe2/operators/heatmap_max_keypoint_op.h7
-rw-r--r--caffe2/python/operator_test/torch_integration_test.py59
-rw-r--r--torch/csrc/jit/register_caffe2_ops.cpp2
4 files changed, 79 insertions, 2 deletions
diff --git a/caffe2/operators/heatmap_max_keypoint_op.cc b/caffe2/operators/heatmap_max_keypoint_op.cc
index c9e12dd325..6c1b3f5141 100644
--- a/caffe2/operators/heatmap_max_keypoint_op.cc
+++ b/caffe2/operators/heatmap_max_keypoint_op.cc
@@ -2,6 +2,19 @@
#include "caffe2/utils/eigen_utils.h"
namespace caffe2 {
+using HeatmapMaxKeypointFloatCPUImpl = HeatmapMaxKeypointOp<float, CPUContext>;
+DEFINE_FUNCTION_SCHEMA_OPERATOR(
+ HeatmapMaxKeypoint,
+ (std::vector<c10::Argument>{
+ c10::Argument("in0"),
+ c10::Argument("in1"),
+ c10::Argument("should_output_softmax", BoolType::get()),
+ }),
+ (std::vector<c10::Argument>{
+ c10::Argument("out0"),
+ }),
+ HeatmapMaxKeypointFloatCPUImpl);
+
namespace {
REGISTER_CPU_OPERATOR(
diff --git a/caffe2/operators/heatmap_max_keypoint_op.h b/caffe2/operators/heatmap_max_keypoint_op.h
index 3a612ee4a9..adfdf48b7f 100644
--- a/caffe2/operators/heatmap_max_keypoint_op.h
+++ b/caffe2/operators/heatmap_max_keypoint_op.h
@@ -10,11 +10,14 @@
namespace caffe2 {
+DECLARE_FUNCTION_SCHEMA_OPERATOR(HeatmapMaxKeypoint);
+
template <typename T, class Context>
class HeatmapMaxKeypointOp final : public Operator<Context> {
public:
- HeatmapMaxKeypointOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws),
+ template <class... Args>
+ HeatmapMaxKeypointOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...),
should_output_softmax_(this->template GetSingleArgument<bool>(
"should_output_softmax",
false)) {}
diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py
index f46fb52b97..671d8aaa47 100644
--- a/caffe2/python/operator_test/torch_integration_test.py
+++ b/caffe2/python/operator_test/torch_integration_test.py
@@ -9,6 +9,7 @@ from hypothesis import given
import caffe2.python.hypothesis_test_util as hu
import hypothesis.strategies as st
import numpy as np
+from scipy import interpolate
def generate_rois(roi_counts, im_dims):
assert len(roi_counts) == len(im_dims)
@@ -250,3 +251,61 @@ class TorchIntegration(hu.HypothesisTestCase):
torch.testing.assert_allclose(torch.tensor(scores_ref), a)
torch.testing.assert_allclose(torch.tensor(boxes_ref), b)
torch.testing.assert_allclose(torch.tensor(classes_ref), c)
+
+
+ def test_heatmap_max_keypoint(self):
+ NUM_TEST_ROI = 14
+ NUM_KEYPOINTS = 19
+ HEATMAP_SIZE = 56
+ np.random.seed(0)
+
+ # initial coordinates and interpolate HEATMAP_SIZE from it
+ HEATMAP_SMALL_SIZE = 4
+ bboxes_in = 500 * np.random.rand(NUM_TEST_ROI, 4).astype(np.float32)
+ # only bbox with smaller first coordiantes
+ for i in range(NUM_TEST_ROI):
+ if bboxes_in[i][0] > bboxes_in[i][2]:
+ tmp = bboxes_in[i][2]
+ bboxes_in[i][2] = bboxes_in[i][0]
+ bboxes_in[i][0] = tmp
+ if bboxes_in[i][1] > bboxes_in[i][3]:
+ tmp = bboxes_in[i][3]
+ bboxes_in[i][3] = bboxes_in[i][1]
+ bboxes_in[i][1] = tmp
+
+ # initial randomized coordiantes for heatmaps and expand it with interpolation
+ init = np.random.rand(
+ NUM_TEST_ROI,
+ NUM_KEYPOINTS,
+ HEATMAP_SMALL_SIZE,
+ HEATMAP_SMALL_SIZE).astype(np.float32)
+ heatmaps_in = np.zeros((NUM_TEST_ROI, NUM_KEYPOINTS,
+ HEATMAP_SIZE, HEATMAP_SIZE)).astype(np.float32)
+ for roi in range(NUM_TEST_ROI):
+ for keyp in range(NUM_KEYPOINTS):
+ f = interpolate.interp2d(
+ np.arange(0, 1, 1.0 / HEATMAP_SMALL_SIZE),
+ np.arange(0, 1, 1.0 / HEATMAP_SMALL_SIZE),
+ init[roi][keyp],
+ kind='cubic')
+ heatmaps_in[roi][keyp] = f(
+ np.arange(0, 1, 1.0 / HEATMAP_SIZE),
+ np.arange(0, 1, 1.0 / HEATMAP_SIZE))
+ def heatmap_max_keypoint_ref():
+ ref_op = core.CreateOperator(
+ 'HeatmapMaxKeypoint',
+ ['heatmaps_in', 'bboxes_in'],
+ ['keypoints_out'],
+ should_output_softmax = True,
+ )
+ workspace.FeedBlob("heatmaps_in", heatmaps_in)
+ workspace.FeedBlob("bboxes_in", bboxes_in)
+ workspace.RunOperatorOnce(ref_op)
+ return workspace.FetchBlob("keypoints_out")
+
+ keypoints_ref = heatmap_max_keypoint_ref()
+
+ keypoints_torch = torch.ops._caffe2.HeatmapMaxKeypoint(
+ torch.tensor(heatmaps_in), torch.tensor(bboxes_in),
+ True)
+ torch.testing.assert_allclose(torch.tensor(keypoints_ref), keypoints_torch)
diff --git a/torch/csrc/jit/register_caffe2_ops.cpp b/torch/csrc/jit/register_caffe2_ops.cpp
index 09262389b5..9d93930da3 100644
--- a/torch/csrc/jit/register_caffe2_ops.cpp
+++ b/torch/csrc/jit/register_caffe2_ops.cpp
@@ -4,6 +4,7 @@
#include "caffe2/operators/generate_proposals_op.h"
#include "caffe2/operators/bbox_transform_op.h"
#include "caffe2/operators/box_with_nms_limit_op.h"
+#include "caffe2/operators/heatmap_max_keypoint_op.h"
#define REGISTER_CAFFE2_OP(name) \
static caffe2::CAFFE2_STRUCT_OP_REGISTRATION_##name CAFFE2_STRUCT_OP_REGISTRATION_DEFN_TORCH_##name; \
@@ -14,3 +15,4 @@ REGISTER_CAFFE2_OP(RoIAlign);
REGISTER_CAFFE2_OP(GenerateProposals);
REGISTER_CAFFE2_OP(BBoxTransform);
REGISTER_CAFFE2_OP(BoxWithNMSLimit);
+REGISTER_CAFFE2_OP(HeatmapMaxKeypoint);