#include "caffe2/operators/arg_ops.h" #include #include "caffe2/utils/math.h" namespace caffe2 { namespace { template void ComputeArgImpl( const int prev_size, const int next_size, const int n, const Compare& comp, const T* X, int64_t* Y, Context* context) { math::Set(prev_size * next_size, int64_t(0), Y, context); for (int i = 0; i < prev_size; ++i) { const T* cur_X = X + i * n * next_size + next_size; for (int k = 1; k < n; ++k) { for (int j = 0; j < next_size; ++j) { int64_t* cur_Y = Y + i * next_size + j; if (comp(*cur_X, X[i * n * next_size + *cur_Y * next_size + j])) { *cur_Y = k; } ++cur_X; } } } } } // namespace template <> template bool ArgMaxReducer::operator()( const int prev_size, const int next_size, const int n, const T* X, int64_t* Y, CPUContext* context) const { ComputeArgImpl(prev_size, next_size, n, std::greater(), X, Y, context); return true; } template <> template bool ArgMinReducer::operator()( const int prev_size, const int next_size, const int n, const T* X, int64_t* Y, CPUContext* context) const { ComputeArgImpl(prev_size, next_size, n, std::less(), X, Y, context); return true; } REGISTER_CPU_OPERATOR(ArgMax, ArgOp>); REGISTER_CPU_OPERATOR(ArgMin, ArgOp>); namespace { std::vector InferTensor( const OperatorDef& def, const std::vector& in) { std::vector out(1); ArgumentHelper helper(def); int axis = helper.GetSingleArgument("axis", -1); const bool keep_dims = helper.GetSingleArgument("keepdims", true); const auto& in_dims = in[0].dims(); auto* out_dims = out[0].mutable_dims(); if (axis == -1) { axis = in_dims.size() - 1; } for (int i = 0; i < axis; ++i) { out_dims->Add(in_dims.Get(i)); } if (keep_dims) { out_dims->Add(1); } for (int i = axis + 1; i < in_dims.size(); ++i) { out_dims->Add(in_dims.Get(i)); } out[0].set_data_type(TensorProto::INT64); return out; } } // namespace OPERATOR_SCHEMA(ArgMax) .NumInputs(1) .NumOutputs(1) .TensorInferenceFunction(InferTensor) .SetDoc(R"DOC( Retrieve the argmax of an axis dimension specified by the `axis` argument. Given an input tensor and two arguments (`axis` and `keepdims`), returns a tensor containing the indices of the largest element along the given axis. If the `keepdims` arg is *True* (default), the shape of the output tensor matches the input tensor except the `axis` dimension equals 1. Else, the `axis` dimension of the output tensor is removed. Github Links: - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/arg_ops.cc
Example **Code** ``` workspace.ResetWorkspace() op = core.CreateOperator( "ArgMax", ["X"], ["Indices"], axis=2, keepdims=False ) workspace.FeedBlob("X", (np.random.randint(10, size=(3,3,3))).astype(np.float32)) print("X:", workspace.FetchBlob("X")) workspace.RunOperatorOnce(op) print("Indices:", workspace.FetchBlob("Indices")) ``` **Result** ``` X: [[[4. 9. 6.] [6. 6. 1.] [9. 5. 4.]] [[6. 7. 4.] [7. 9. 1.] [3. 2. 8.]] [[3. 4. 6.] [5. 2. 7.] [1. 5. 7.]]] Indices: [[1 0 0] [1 1 2] [2 2 2]] ```
)DOC") .Input(0, "X", "*(type: Tensor``)* Input tensor.") .Output( 0, "Indices", "*(type: Tensor``)* Tensor of indices for the largest values.") .Arg("axis", "*(type: int; default: -1)* The axis to get argmax.") .Arg( "keepdims", "*(type: bool; default: True)* If True (default), the output tensor " "shape will match the input tensor shape except the `axis` dimension " "equals 1. Else, the `axis` dimension of the output tensor is removed."); OPERATOR_SCHEMA(ArgMin) .NumInputs(1) .NumOutputs(1) .TensorInferenceFunction(InferTensor) .SetDoc(R"DOC( Retrieve the argmin of an axis dimension specified by the `axis` argument. Given an input tensor and two arguments (`axis` and `keepdims`), returns a tensor containing the indices of the smallest element along the given axis. If the `keepdims` arg is *True* (default), the shape of the output tensor matches the input tensor except the `axis` dimension equals 1. Else, the `axis` dimension of the output tensor is removed. Github Links: - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/arg_ops.cc
Example **Code** ``` workspace.ResetWorkspace() op = core.CreateOperator( "ArgMin", ["X"], ["Indices"], axis=1 ) workspace.FeedBlob("X", (np.random.randint(10, size=(5,5))).astype(np.float32)) print("X:", workspace.FetchBlob("X")) workspace.RunOperatorOnce(op) print("Indices:", workspace.FetchBlob("Indices")) ``` **Result** ``` X: [[9. 4. 6. 4. 1.] [5. 9. 8. 3. 4.] [6. 1. 0. 2. 9.] [7. 8. 2. 4. 9.] [3. 9. 4. 9. 4.]] Indices: [[4] [3] [2] [2] [0]] ```
)DOC") .Input(0, "X", "*(type: Tensor``)* Input tensor.") .Output( 0, "Indices", "*(type: Tensor``)* Tensor of indices for the smallest values.") .Arg("axis", "*(type: int; default: -1)* The axis to get argmin.") .Arg( "keepdims", "*(type: bool; default: True)* If True (default), the output tensor " "shape will match the input tensor shape except the `axis` dimension " "equals 1. Else, the `axis` dimension of the output tensor is removed."); SHOULD_NOT_DO_GRADIENT(ArgMax); SHOULD_NOT_DO_GRADIENT(ArgMin); } // namespace caffe2