diff options
author | Bram Wasti <bwasti@fb.com> | 2019-02-05 12:30:31 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-05 12:56:58 -0800 |
commit | a9713d07b0ab2845f15c6ca64cb13931556d0a5a (patch) | |
tree | d9dd6b7c6bd267cca0f066b2b4ae4ceda04c7ec0 | |
parent | 3df7b321cc5b23a3c9946f0fe34727b88c3bbb75 (diff) | |
download | pytorch-a9713d07b0ab2845f15c6ca64cb13931556d0a5a.tar.gz pytorch-a9713d07b0ab2845f15c6ca64cb13931556d0a5a.tar.bz2 pytorch-a9713d07b0ab2845f15c6ca64cb13931556d0a5a.zip |
Expose HeatmapMaxKeypoints to torch (#16528)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16528
..
Reviewed By: smessmer
Differential Revision: D13866214
fbshipit-source-id: 2ca79037fc070bade5542345af5ce09f88beda44
-rw-r--r-- | caffe2/operators/heatmap_max_keypoint_op.cc | 13 | ||||
-rw-r--r-- | caffe2/operators/heatmap_max_keypoint_op.h | 7 | ||||
-rw-r--r-- | caffe2/python/operator_test/torch_integration_test.py | 59 | ||||
-rw-r--r-- | torch/csrc/jit/register_caffe2_ops.cpp | 2 |
4 files changed, 79 insertions, 2 deletions
diff --git a/caffe2/operators/heatmap_max_keypoint_op.cc b/caffe2/operators/heatmap_max_keypoint_op.cc index c9e12dd325..6c1b3f5141 100644 --- a/caffe2/operators/heatmap_max_keypoint_op.cc +++ b/caffe2/operators/heatmap_max_keypoint_op.cc @@ -2,6 +2,19 @@ #include "caffe2/utils/eigen_utils.h" namespace caffe2 { +using HeatmapMaxKeypointFloatCPUImpl = HeatmapMaxKeypointOp<float, CPUContext>; +DEFINE_FUNCTION_SCHEMA_OPERATOR( + HeatmapMaxKeypoint, + (std::vector<c10::Argument>{ + c10::Argument("in0"), + c10::Argument("in1"), + c10::Argument("should_output_softmax", BoolType::get()), + }), + (std::vector<c10::Argument>{ + c10::Argument("out0"), + }), + HeatmapMaxKeypointFloatCPUImpl); + namespace { REGISTER_CPU_OPERATOR( diff --git a/caffe2/operators/heatmap_max_keypoint_op.h b/caffe2/operators/heatmap_max_keypoint_op.h index 3a612ee4a9..adfdf48b7f 100644 --- a/caffe2/operators/heatmap_max_keypoint_op.h +++ b/caffe2/operators/heatmap_max_keypoint_op.h @@ -10,11 +10,14 @@ namespace caffe2 { +DECLARE_FUNCTION_SCHEMA_OPERATOR(HeatmapMaxKeypoint); + template <typename T, class Context> class HeatmapMaxKeypointOp final : public Operator<Context> { public: - HeatmapMaxKeypointOp(const OperatorDef& operator_def, Workspace* ws) - : Operator<Context>(operator_def, ws), + template <class... Args> + HeatmapMaxKeypointOp(Args&&... args) + : Operator<Context>(std::forward<Args>(args)...), should_output_softmax_(this->template GetSingleArgument<bool>( "should_output_softmax", false)) {} diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py index f46fb52b97..671d8aaa47 100644 --- a/caffe2/python/operator_test/torch_integration_test.py +++ b/caffe2/python/operator_test/torch_integration_test.py @@ -9,6 +9,7 @@ from hypothesis import given import caffe2.python.hypothesis_test_util as hu import hypothesis.strategies as st import numpy as np +from scipy import interpolate def generate_rois(roi_counts, im_dims): assert len(roi_counts) == len(im_dims) @@ -250,3 +251,61 @@ class TorchIntegration(hu.HypothesisTestCase): torch.testing.assert_allclose(torch.tensor(scores_ref), a) torch.testing.assert_allclose(torch.tensor(boxes_ref), b) torch.testing.assert_allclose(torch.tensor(classes_ref), c) + + + def test_heatmap_max_keypoint(self): + NUM_TEST_ROI = 14 + NUM_KEYPOINTS = 19 + HEATMAP_SIZE = 56 + np.random.seed(0) + + # initial coordinates and interpolate HEATMAP_SIZE from it + HEATMAP_SMALL_SIZE = 4 + bboxes_in = 500 * np.random.rand(NUM_TEST_ROI, 4).astype(np.float32) + # only bbox with smaller first coordiantes + for i in range(NUM_TEST_ROI): + if bboxes_in[i][0] > bboxes_in[i][2]: + tmp = bboxes_in[i][2] + bboxes_in[i][2] = bboxes_in[i][0] + bboxes_in[i][0] = tmp + if bboxes_in[i][1] > bboxes_in[i][3]: + tmp = bboxes_in[i][3] + bboxes_in[i][3] = bboxes_in[i][1] + bboxes_in[i][1] = tmp + + # initial randomized coordiantes for heatmaps and expand it with interpolation + init = np.random.rand( + NUM_TEST_ROI, + NUM_KEYPOINTS, + HEATMAP_SMALL_SIZE, + HEATMAP_SMALL_SIZE).astype(np.float32) + heatmaps_in = np.zeros((NUM_TEST_ROI, NUM_KEYPOINTS, + HEATMAP_SIZE, HEATMAP_SIZE)).astype(np.float32) + for roi in range(NUM_TEST_ROI): + for keyp in range(NUM_KEYPOINTS): + f = interpolate.interp2d( + np.arange(0, 1, 1.0 / HEATMAP_SMALL_SIZE), + np.arange(0, 1, 1.0 / HEATMAP_SMALL_SIZE), + init[roi][keyp], + kind='cubic') + heatmaps_in[roi][keyp] = f( + np.arange(0, 1, 1.0 / HEATMAP_SIZE), + np.arange(0, 1, 1.0 / HEATMAP_SIZE)) + def heatmap_max_keypoint_ref(): + ref_op = core.CreateOperator( + 'HeatmapMaxKeypoint', + ['heatmaps_in', 'bboxes_in'], + ['keypoints_out'], + should_output_softmax = True, + ) + workspace.FeedBlob("heatmaps_in", heatmaps_in) + workspace.FeedBlob("bboxes_in", bboxes_in) + workspace.RunOperatorOnce(ref_op) + return workspace.FetchBlob("keypoints_out") + + keypoints_ref = heatmap_max_keypoint_ref() + + keypoints_torch = torch.ops._caffe2.HeatmapMaxKeypoint( + torch.tensor(heatmaps_in), torch.tensor(bboxes_in), + True) + torch.testing.assert_allclose(torch.tensor(keypoints_ref), keypoints_torch) diff --git a/torch/csrc/jit/register_caffe2_ops.cpp b/torch/csrc/jit/register_caffe2_ops.cpp index 09262389b5..9d93930da3 100644 --- a/torch/csrc/jit/register_caffe2_ops.cpp +++ b/torch/csrc/jit/register_caffe2_ops.cpp @@ -4,6 +4,7 @@ #include "caffe2/operators/generate_proposals_op.h" #include "caffe2/operators/bbox_transform_op.h" #include "caffe2/operators/box_with_nms_limit_op.h" +#include "caffe2/operators/heatmap_max_keypoint_op.h" #define REGISTER_CAFFE2_OP(name) \ static caffe2::CAFFE2_STRUCT_OP_REGISTRATION_##name CAFFE2_STRUCT_OP_REGISTRATION_DEFN_TORCH_##name; \ @@ -14,3 +15,4 @@ REGISTER_CAFFE2_OP(RoIAlign); REGISTER_CAFFE2_OP(GenerateProposals); REGISTER_CAFFE2_OP(BBoxTransform); REGISTER_CAFFE2_OP(BoxWithNMSLimit); +REGISTER_CAFFE2_OP(HeatmapMaxKeypoint); |