#include namespace caffe2 { REGISTER_CPU_OPERATOR(VideoInput, VideoInputOp); OPERATOR_SCHEMA(VideoInput) .NumInputs(0, 1) .NumOutputs(2, 4) .TensorInferenceFunction( [](const OperatorDef& def, const vector& /* unused */ /*in*/) { ArgumentHelper helper(def); int batch_size = helper.GetSingleArgument("batch_size", 0); int clip_per_video = helper.GetSingleArgument("clip_per_video", 1); int crop_height = helper.GetSingleArgument( "crop_height", helper.GetSingleArgument("crop_size", 0)); int crop_width = helper.GetSingleArgument( "crop_width", helper.GetSingleArgument("crop_size", 0)); int length_rgb = helper.GetSingleArgument("length_rgb", 0); int channels_rgb = helper.GetSingleArgument("channels_rgb", 3); int length_of = helper.GetSingleArgument("length_of", 0); int channels_of = helper.GetSingleArgument("channels_of", 2); // get the flags bool get_rgb = helper.GetSingleArgument("get_rgb", true); bool get_optical_flow = helper.GetSingleArgument("get_optical_flow", false); bool do_multi_label = helper.GetSingleArgument("do_multi_label", false); bool get_video_id = helper.GetSingleArgument("get_video_id", false); int output_size = 1; if (get_rgb) { output_size++; } if (get_optical_flow) { output_size++; } if (get_video_id) { output_size++; } int index = 0; vector out(output_size); CHECK_GT(crop_height, 0); CHECK_GT(crop_width, 0); batch_size *= clip_per_video; if (get_rgb) { out[index++] = CreateTensorShape( vector{batch_size, channels_rgb, length_rgb, crop_height, crop_width}, TensorProto::FLOAT); } if (get_optical_flow) { out[index++] = CreateTensorShape( vector{batch_size, channels_of, length_of, crop_height, crop_width}, TensorProto::FLOAT); } if (!do_multi_label) { out[index++] = CreateTensorShape( vector{1, batch_size}, TensorProto::INT32); } else { int num_of_class = helper.GetSingleArgument("num_of_class", 0); out[index++] = CreateTensorShape( vector{batch_size, num_of_class}, TensorProto::INT32); } if (get_video_id) { out[index] = CreateTensorShape( vector{1, batch_size}, TensorProto::INT32); } return out; }); NO_GRADIENT(VideoInput); } // namespace caffe2