#pragma once #include "caffe2/core/context.h" #include "caffe2/core/operator.h" namespace caffe2 { // RecordShapeOp records the shape of the input tensor to a vector of int. You // mostly don't need this operator explicitly, and it is mostly used in the // autodiff process. template class ShapeOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; template explicit ShapeOp(Args&&... args) : Operator(std::forward(args)...), axes_(OperatorBase ::GetRepeatedArgument("axes")) {} bool RunOnDevice() override { auto& data = Input(DATA); int numDims = data.dim(); int numAxes = axes_.size(); if (numAxes == 0) { auto* output = Output(0, {numDims}, at::dtype()); int64_t* output_data = output->template mutable_data(); context_.CopyBytesSameDevice( numDims * sizeof(int64_t), data.sizes().data(), output_data); return true; } auto* output = Output(0, {numAxes}, at::dtype()); auto src = reinterpret_cast(data.sizes().data()); auto out = reinterpret_cast(output->template mutable_data()); for (int i = 0; i < numAxes; i++) { auto axis = axes_[i]; CAFFE_ENFORCE_LT(axis, numDims, "Axis out of range"); CAFFE_ENFORCE_GE(axis, 0, "Each axis should be non-negative"); context_.CopyBytesSameDevice( sizeof(int64_t), src + axis * sizeof(int64_t), out); out += sizeof(int64_t); } return true; } INPUT_TAGS(DATA); private: vector axes_; }; } // namespace caffe2