summaryrefslogtreecommitdiff
path: root/caffe2/image
diff options
context:
space:
mode:
authorKevin Wilfong <kevinwilfong@fb.com>2017-08-01 14:19:12 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2017-08-01 14:36:05 -0700
commit60cb55461e269a31421b73fe6e51d90b88acb131 (patch)
tree867efd1e5a5227a7af785f4890ca7188e9d68f32 /caffe2/image
parent3a99698734b3230b09fa641c98d3e17a973b50b1 (diff)
downloadpytorch-60cb55461e269a31421b73fe6e51d90b88acb131.tar.gz
pytorch-60cb55461e269a31421b73fe6e51d90b88acb131.tar.bz2
pytorch-60cb55461e269a31421b73fe6e51d90b88acb131.zip
Caffe2: Support additional outputs in ImageInputOp
Summary: This allows users to add an arbitrary of additional outputs to ImageInputOp. These are populated by reading additional TensorProto values from the TensorProtos from the DBReader, and converting them into Tensors. Similar to labels, only ints and floats are supported, and multiple values are supported. Reviewed By: panshen1 Differential Revision: D5502019 fbshipit-source-id: 5a8b61b3a8549272a112e8e02cd613d8f9a271ba
Diffstat (limited to 'caffe2/image')
-rw-r--r--caffe2/image/image_input_op.cc9
-rw-r--r--caffe2/image/image_input_op.h90
2 files changed, 95 insertions, 4 deletions
diff --git a/caffe2/image/image_input_op.cc b/caffe2/image/image_input_op.cc
index 49ff80455d..478e2b640f 100644
--- a/caffe2/image/image_input_op.cc
+++ b/caffe2/image/image_input_op.cc
@@ -6,7 +6,7 @@ REGISTER_CPU_OPERATOR(ImageInput, ImageInputOp<CPUContext>);
OPERATOR_SCHEMA(ImageInput)
.NumInputs(0, 1)
- .NumOutputs(2)
+ .NumOutputs(2, INT_MAX)
.TensorInferenceFunction(
[](const OperatorDef& def, const vector<TensorShape>& /* unused */ ) {
vector<TensorShape> out(2);
@@ -75,9 +75,14 @@ The dimension of the output image will always be cropxcrop
.Arg("db", "Name of the database (if not passed as input)")
.Arg("db_type", "Type of database (if not passed as input)."
" Defaults to leveldb")
+ .Arg("output_sizes", "The sizes of any outputs besides the data and label "
+ "(should have a number of elements equal to the number of additional "
+ "outputs)")
.Input(0, "reader", "The input reader (a db::DBReader)")
.Output(0, "data", "Tensor containing the images")
- .Output(1, "label", "Tensor containing the labels");
+ .Output(1, "label", "Tensor containing the labels")
+ .Output(2, "additional outputs", "Any outputs after the first 2 will be "
+ "Tensors read from the input TensorProtos");
NO_GRADIENT(ImageInput);
diff --git a/caffe2/image/image_input_op.h b/caffe2/image/image_input_op.h
index a08dbf9754..8d8a32fa8f 100644
--- a/caffe2/image/image_input_op.h
+++ b/caffe2/image/image_input_op.h
@@ -63,8 +63,10 @@ class ImageInputOp final
CPUContext cpu_context_;
TensorCPU prefetched_image_;
TensorCPU prefetched_label_;
+ vector<TensorCPU> prefetched_additional_outputs_;
Tensor<Context> prefetched_image_on_device_;
Tensor<Context> prefetched_label_on_device_;
+ vector<Tensor<Context>> prefetched_additional_outputs_on_device_;
// Default parameters for images
PerImageArg default_arg_;
int batch_size_;
@@ -105,6 +107,8 @@ ImageInputOp<Context>::ImageInputOp(
Workspace* ws)
: PrefetchOperator<Context>(operator_def, ws),
reader_(nullptr),
+ prefetched_additional_outputs_(OutputSize() - 2),
+ prefetched_additional_outputs_on_device_(OutputSize() - 2),
batch_size_(
OperatorBase::template GetSingleArgument<int>("batch_size", 0)),
multiple_label_(
@@ -137,6 +141,10 @@ ImageInputOp<Context>::ImageInputOp(
"std_per_channel",
{OperatorBase::template GetSingleArgument<float>("std", 1.)});
+ vector<int> additional_output_sizes =
+ OperatorBase::template GetRepeatedArgument<int>(
+ "output_sizes", vector<int>(OutputSize() - 2, 1));
+
default_arg_.bounding_params = {
false,
OperatorBase::template GetSingleArgument<int>("bounding_ymin", -1),
@@ -180,6 +188,13 @@ ImageInputOp<Context>::ImageInputOp(
"The mean and std. dev vectors must be of the same size.");
CAFFE_ENFORCE(mean_.size() == 1 || mean_.size() == 3,
"The mean and std. dev vectors must be of size 1 or 3");
+ CAFFE_ENFORCE(
+ !use_caffe_datum_ || OutputSize() == 2,
+ "There can only be 2 outputs if the Caffe datum format is used");
+ CAFFE_ENFORCE(
+ additional_output_sizes.size() == OutputSize() - 2,
+ "If the output sizes are specified, they must be specified for all "
+ "additional outputs");
if (default_arg_.bounding_params.ymin < 0
|| default_arg_.bounding_params.xmin < 0
@@ -255,6 +270,11 @@ ImageInputOp<Context>::ImageInputOp(
} else {
prefetched_label_.Resize(vector<TIndex>(1, batch_size_));
}
+
+ for (int i = 0; i < additional_output_sizes.size(); ++i) {
+ prefetched_additional_outputs_[i].Resize(
+ TIndex(batch_size_), TIndex(additional_output_sizes[i]));
+ }
}
template <class Context>
@@ -319,9 +339,15 @@ bool ImageInputOp<Context>::GetImageAndLabelAndInfoFromDBValue(
CAFFE_ENFORCE(protos.ParseFromString(value));
const TensorProto& image_proto = protos.protos(0);
const TensorProto& label_proto = protos.protos(1);
- if (protos.protos_size() == 3) {
+ vector<TensorProto> additional_output_protos;
+
+ for (int i = 2; i < OutputSize(); ++i) {
+ additional_output_protos.push_back(protos.protos(i));
+ }
+
+ if (protos.protos_size() == OutputSize() + 1) {
// We have bounding box information
- const TensorProto& bounding_proto = protos.protos(2);
+ const TensorProto& bounding_proto = protos.protos(OutputSize());
DCHECK_EQ(bounding_proto.data_type(), TensorProto::INT32);
DCHECK_EQ(bounding_proto.int32_data_size(), 4);
info.bounding_params.valid = true;
@@ -392,6 +418,30 @@ bool ImageInputOp<Context>::GetImageAndLabelAndInfoFromDBValue(
} else {
LOG(FATAL) << "Unsupported label type.";
}
+
+ for (int i = 0; i < additional_output_protos.size(); ++i) {
+ auto additional_output_proto = additional_output_protos[i];
+
+ if (additional_output_proto.data_type() == TensorProto::FLOAT) {
+ float* additional_output =
+ prefetched_additional_outputs_[i].template mutable_data<float>() +
+ item_id * additional_output_proto.float_data_size();
+
+ for (int j = 0; j < additional_output_proto.float_data_size(); ++j) {
+ additional_output[j] = additional_output_proto.float_data(j);
+ }
+ } else if (additional_output_proto.data_type() == TensorProto::INT32) {
+ int* additional_output =
+ prefetched_additional_outputs_[i].template mutable_data<int>() +
+ item_id * additional_output_proto.int32_data_size();
+
+ for (int j = 0; j < additional_output_proto.int32_data_size(); ++j) {
+ additional_output[j] = additional_output_proto.int32_data(j);
+ }
+ } else {
+ LOG(FATAL) << "Unsupported output type.";
+ }
+ }
}
//
@@ -664,6 +714,20 @@ bool ImageInputOp<Context>::Prefetch() {
} else {
LOG(FATAL) << "Unsupported label type.";
}
+
+ for (int i = 2; i < OutputSize(); ++i) {
+ TensorProto additional_output_proto = protos.protos(i);
+
+ if (additional_output_proto.data_type() == TensorProto::FLOAT) {
+ prefetched_additional_outputs_[i - 2]
+ .template mutable_data<float>();
+ } else if (
+ additional_output_proto.data_type() == TensorProto::INT32) {
+ prefetched_additional_outputs_[i - 2].template mutable_data<int>();
+ } else {
+ LOG(FATAL) << "Unsupported output type.";
+ }
+ }
}
}
@@ -700,6 +764,11 @@ bool ImageInputOp<Context>::Prefetch() {
if (!std::is_same<Context, CPUContext>::value) {
prefetched_image_on_device_.CopyFrom(prefetched_image_, &context_);
prefetched_label_on_device_.CopyFrom(prefetched_label_, &context_);
+
+ for (int i = 0; i < prefetched_additional_outputs_on_device_.size(); ++i) {
+ prefetched_additional_outputs_on_device_[i].CopyFrom(
+ prefetched_additional_outputs_[i], &context_);
+ }
}
return true;
}
@@ -708,11 +777,23 @@ template <class Context>
bool ImageInputOp<Context>::CopyPrefetched() {
auto* image_output = OperatorBase::Output<Tensor<Context> >(0);
auto* label_output = OperatorBase::Output<Tensor<Context> >(1);
+ vector<Tensor<Context>*> additional_outputs_output;
+
+ for (int i = 2; i < OutputSize(); ++i) {
+ additional_outputs_output.push_back(
+ OperatorBase::Output<Tensor<Context>>(i));
+ }
+
// Note(jiayq): The if statement below should be optimized away by the
// compiler since std::is_same is a constexpr.
if (std::is_same<Context, CPUContext>::value) {
image_output->CopyFrom(prefetched_image_, &context_);
label_output->CopyFrom(prefetched_label_, &context_);
+
+ for (int i = 0; i < additional_outputs_output.size(); ++i) {
+ additional_outputs_output[i]->CopyFrom(
+ prefetched_additional_outputs_[i], &context_);
+ }
} else {
if (gpu_transform_) {
if (!mean_std_copied_) {
@@ -741,6 +822,11 @@ bool ImageInputOp<Context>::CopyPrefetched() {
image_output->CopyFrom(prefetched_image_on_device_, &context_);
}
label_output->CopyFrom(prefetched_label_on_device_, &context_);
+
+ for (int i = 0; i < additional_outputs_output.size(); ++i) {
+ additional_outputs_output[i]->CopyFrom(
+ prefetched_additional_outputs_on_device_[i], &context_);
+ }
}
return true;
}