summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--caffe2/image/image_input_op.cc9
-rw-r--r--caffe2/image/image_input_op.h90
-rw-r--r--caffe2/python/helpers/tools.py16
-rw-r--r--caffe2/python/operator_test/image_input_op_test.py241
4 files changed, 266 insertions, 90 deletions
diff --git a/caffe2/image/image_input_op.cc b/caffe2/image/image_input_op.cc
index 49ff80455d..478e2b640f 100644
--- a/caffe2/image/image_input_op.cc
+++ b/caffe2/image/image_input_op.cc
@@ -6,7 +6,7 @@ REGISTER_CPU_OPERATOR(ImageInput, ImageInputOp<CPUContext>);
OPERATOR_SCHEMA(ImageInput)
.NumInputs(0, 1)
- .NumOutputs(2)
+ .NumOutputs(2, INT_MAX)
.TensorInferenceFunction(
[](const OperatorDef& def, const vector<TensorShape>& /* unused */ ) {
vector<TensorShape> out(2);
@@ -75,9 +75,14 @@ The dimension of the output image will always be cropxcrop
.Arg("db", "Name of the database (if not passed as input)")
.Arg("db_type", "Type of database (if not passed as input)."
" Defaults to leveldb")
+ .Arg("output_sizes", "The sizes of any outputs besides the data and label "
+ "(should have a number of elements equal to the number of additional "
+ "outputs)")
.Input(0, "reader", "The input reader (a db::DBReader)")
.Output(0, "data", "Tensor containing the images")
- .Output(1, "label", "Tensor containing the labels");
+ .Output(1, "label", "Tensor containing the labels")
+ .Output(2, "additional outputs", "Any outputs after the first 2 will be "
+ "Tensors read from the input TensorProtos");
NO_GRADIENT(ImageInput);
diff --git a/caffe2/image/image_input_op.h b/caffe2/image/image_input_op.h
index a08dbf9754..8d8a32fa8f 100644
--- a/caffe2/image/image_input_op.h
+++ b/caffe2/image/image_input_op.h
@@ -63,8 +63,10 @@ class ImageInputOp final
CPUContext cpu_context_;
TensorCPU prefetched_image_;
TensorCPU prefetched_label_;
+ vector<TensorCPU> prefetched_additional_outputs_;
Tensor<Context> prefetched_image_on_device_;
Tensor<Context> prefetched_label_on_device_;
+ vector<Tensor<Context>> prefetched_additional_outputs_on_device_;
// Default parameters for images
PerImageArg default_arg_;
int batch_size_;
@@ -105,6 +107,8 @@ ImageInputOp<Context>::ImageInputOp(
Workspace* ws)
: PrefetchOperator<Context>(operator_def, ws),
reader_(nullptr),
+ prefetched_additional_outputs_(OutputSize() - 2),
+ prefetched_additional_outputs_on_device_(OutputSize() - 2),
batch_size_(
OperatorBase::template GetSingleArgument<int>("batch_size", 0)),
multiple_label_(
@@ -137,6 +141,10 @@ ImageInputOp<Context>::ImageInputOp(
"std_per_channel",
{OperatorBase::template GetSingleArgument<float>("std", 1.)});
+ vector<int> additional_output_sizes =
+ OperatorBase::template GetRepeatedArgument<int>(
+ "output_sizes", vector<int>(OutputSize() - 2, 1));
+
default_arg_.bounding_params = {
false,
OperatorBase::template GetSingleArgument<int>("bounding_ymin", -1),
@@ -180,6 +188,13 @@ ImageInputOp<Context>::ImageInputOp(
"The mean and std. dev vectors must be of the same size.");
CAFFE_ENFORCE(mean_.size() == 1 || mean_.size() == 3,
"The mean and std. dev vectors must be of size 1 or 3");
+ CAFFE_ENFORCE(
+ !use_caffe_datum_ || OutputSize() == 2,
+ "There can only be 2 outputs if the Caffe datum format is used");
+ CAFFE_ENFORCE(
+ additional_output_sizes.size() == OutputSize() - 2,
+ "If the output sizes are specified, they must be specified for all "
+ "additional outputs");
if (default_arg_.bounding_params.ymin < 0
|| default_arg_.bounding_params.xmin < 0
@@ -255,6 +270,11 @@ ImageInputOp<Context>::ImageInputOp(
} else {
prefetched_label_.Resize(vector<TIndex>(1, batch_size_));
}
+
+ for (int i = 0; i < additional_output_sizes.size(); ++i) {
+ prefetched_additional_outputs_[i].Resize(
+ TIndex(batch_size_), TIndex(additional_output_sizes[i]));
+ }
}
template <class Context>
@@ -319,9 +339,15 @@ bool ImageInputOp<Context>::GetImageAndLabelAndInfoFromDBValue(
CAFFE_ENFORCE(protos.ParseFromString(value));
const TensorProto& image_proto = protos.protos(0);
const TensorProto& label_proto = protos.protos(1);
- if (protos.protos_size() == 3) {
+ vector<TensorProto> additional_output_protos;
+
+ for (int i = 2; i < OutputSize(); ++i) {
+ additional_output_protos.push_back(protos.protos(i));
+ }
+
+ if (protos.protos_size() == OutputSize() + 1) {
// We have bounding box information
- const TensorProto& bounding_proto = protos.protos(2);
+ const TensorProto& bounding_proto = protos.protos(OutputSize());
DCHECK_EQ(bounding_proto.data_type(), TensorProto::INT32);
DCHECK_EQ(bounding_proto.int32_data_size(), 4);
info.bounding_params.valid = true;
@@ -392,6 +418,30 @@ bool ImageInputOp<Context>::GetImageAndLabelAndInfoFromDBValue(
} else {
LOG(FATAL) << "Unsupported label type.";
}
+
+ for (int i = 0; i < additional_output_protos.size(); ++i) {
+ auto additional_output_proto = additional_output_protos[i];
+
+ if (additional_output_proto.data_type() == TensorProto::FLOAT) {
+ float* additional_output =
+ prefetched_additional_outputs_[i].template mutable_data<float>() +
+ item_id * additional_output_proto.float_data_size();
+
+ for (int j = 0; j < additional_output_proto.float_data_size(); ++j) {
+ additional_output[j] = additional_output_proto.float_data(j);
+ }
+ } else if (additional_output_proto.data_type() == TensorProto::INT32) {
+ int* additional_output =
+ prefetched_additional_outputs_[i].template mutable_data<int>() +
+ item_id * additional_output_proto.int32_data_size();
+
+ for (int j = 0; j < additional_output_proto.int32_data_size(); ++j) {
+ additional_output[j] = additional_output_proto.int32_data(j);
+ }
+ } else {
+ LOG(FATAL) << "Unsupported output type.";
+ }
+ }
}
//
@@ -664,6 +714,20 @@ bool ImageInputOp<Context>::Prefetch() {
} else {
LOG(FATAL) << "Unsupported label type.";
}
+
+ for (int i = 2; i < OutputSize(); ++i) {
+ TensorProto additional_output_proto = protos.protos(i);
+
+ if (additional_output_proto.data_type() == TensorProto::FLOAT) {
+ prefetched_additional_outputs_[i - 2]
+ .template mutable_data<float>();
+ } else if (
+ additional_output_proto.data_type() == TensorProto::INT32) {
+ prefetched_additional_outputs_[i - 2].template mutable_data<int>();
+ } else {
+ LOG(FATAL) << "Unsupported output type.";
+ }
+ }
}
}
@@ -700,6 +764,11 @@ bool ImageInputOp<Context>::Prefetch() {
if (!std::is_same<Context, CPUContext>::value) {
prefetched_image_on_device_.CopyFrom(prefetched_image_, &context_);
prefetched_label_on_device_.CopyFrom(prefetched_label_, &context_);
+
+ for (int i = 0; i < prefetched_additional_outputs_on_device_.size(); ++i) {
+ prefetched_additional_outputs_on_device_[i].CopyFrom(
+ prefetched_additional_outputs_[i], &context_);
+ }
}
return true;
}
@@ -708,11 +777,23 @@ template <class Context>
bool ImageInputOp<Context>::CopyPrefetched() {
auto* image_output = OperatorBase::Output<Tensor<Context> >(0);
auto* label_output = OperatorBase::Output<Tensor<Context> >(1);
+ vector<Tensor<Context>*> additional_outputs_output;
+
+ for (int i = 2; i < OutputSize(); ++i) {
+ additional_outputs_output.push_back(
+ OperatorBase::Output<Tensor<Context>>(i));
+ }
+
// Note(jiayq): The if statement below should be optimized away by the
// compiler since std::is_same is a constexpr.
if (std::is_same<Context, CPUContext>::value) {
image_output->CopyFrom(prefetched_image_, &context_);
label_output->CopyFrom(prefetched_label_, &context_);
+
+ for (int i = 0; i < additional_outputs_output.size(); ++i) {
+ additional_outputs_output[i]->CopyFrom(
+ prefetched_additional_outputs_[i], &context_);
+ }
} else {
if (gpu_transform_) {
if (!mean_std_copied_) {
@@ -741,6 +822,11 @@ bool ImageInputOp<Context>::CopyPrefetched() {
image_output->CopyFrom(prefetched_image_on_device_, &context_);
}
label_output->CopyFrom(prefetched_label_on_device_, &context_);
+
+ for (int i = 0; i < additional_outputs_output.size(); ++i) {
+ additional_outputs_output[i]->CopyFrom(
+ prefetched_additional_outputs_on_device_[i], &context_);
+ }
}
return true;
}
diff --git a/caffe2/python/helpers/tools.py b/caffe2/python/helpers/tools.py
index 308934799d..df0525fa7d 100644
--- a/caffe2/python/helpers/tools.py
+++ b/caffe2/python/helpers/tools.py
@@ -13,18 +13,18 @@ def image_input(
if (use_gpu_transform):
kwargs['use_gpu_transform'] = 1 if use_gpu_transform else 0
# GPU transform will handle NHWC -> NCHW
- data, label = model.net.ImageInput(
- blob_in, [blob_out[0], blob_out[1]], **kwargs
- )
+ outputs = model.net.ImageInput(blob_in, blob_out, **kwargs)
pass
else:
- data, label = model.net.ImageInput(
- blob_in, [blob_out[0] + '_nhwc', blob_out[1]], **kwargs
+ outputs = model.net.ImageInput(
+ blob_in, [blob_out[0] + '_nhwc'] + blob_out[1:], **kwargs
)
- data = model.net.NHWC2NCHW(data, blob_out[0])
+ outputs_list = list(outputs)
+ outputs_list[0] = model.net.NHWC2NCHW(outputs_list[0], blob_out[0])
+ outputs = tuple(outputs_list)
else:
- data, label = model.net.ImageInput(blob_in, blob_out, **kwargs)
- return data, label
+ outputs = model.net.ImageInput(blob_in, blob_out, **kwargs)
+ return outputs
def video_input(model, blob_in, blob_out, **kwargs):
diff --git a/caffe2/python/operator_test/image_input_op_test.py b/caffe2/python/operator_test/image_input_op_test.py
index 86b82b40f7..cf7e4dbd93 100644
--- a/caffe2/python/operator_test/image_input_op_test.py
+++ b/caffe2/python/operator_test/image_input_op_test.py
@@ -119,8 +119,9 @@ def caffe2_img(img):
# Bounding box is ymin, xmin, height, width
-def create_test(output_dir, width, height, default_bound,
- minsize, crop, means, stds, count, multiple_label, num_labels):
+def create_test(output_dir, width, height, default_bound, minsize, crop, means,
+ stds, count, multiple_label, num_labels, output1=None,
+ output2_size=None):
print("Creating a temporary lmdb database of %d pictures..." % (count))
if default_bound is None:
@@ -189,7 +190,22 @@ def create_test(output_dir, width, height, default_bound,
label_tensor.int32_data.append(idx)
expected_label = binary_labels
- expected_results.append([caffe2_img(img_expected), expected_label])
+ if output1:
+ output1_tensor = tensor_protos.protos.add()
+ output1_tensor.data_type = 1 # float data
+ output1_tensor.float_data.append(output1)
+
+ output2 = []
+ if output2_size:
+ output2_tensor = tensor_protos.protos.add()
+ output2_tensor.data_type = 2 # int32 data
+ values = np.random.randint(1024, size=output2_size)
+ for val in values.tolist():
+ output2.append(val)
+ output2_tensor.int32_data.append(val)
+
+ expected_results.append(
+ [caffe2_img(img_expected), expected_label, output1, output2])
if not do_default_bound:
bounding_tensor = tensor_protos.protos.add()
@@ -206,9 +222,107 @@ def create_test(output_dir, width, height, default_bound,
return expected_results
+def run_test(
+ size_tuple, means, stds, multiple_label, num_labels, dc, validator,
+ output1=None, output2_size=None):
+ # TODO: Does not test on GPU and does not test use_gpu_transform
+ # WARNING: Using ModelHelper automatically does NHWC to NCHW
+ # transformation if needed.
+ width, height, minsize, crop = size_tuple
+ means = [float(m) for m in means]
+ stds = [float(s) for s in stds]
+ out_dir = tempfile.mkdtemp()
+ count_images = 2 # One with bounding box and one without
+ expected_images = create_test(
+ out_dir,
+ width=width,
+ height=height,
+ default_bound=(3, 5, height - 3, width - 5),
+ minsize=minsize,
+ crop=crop,
+ means=means,
+ stds=stds,
+ count=count_images,
+ multiple_label=multiple_label,
+ num_labels=num_labels,
+ output1=output1,
+ output2_size=output2_size
+ )
+ for device_option in dc:
+ with hu.temp_workspace():
+ reader_net = core.Net('reader')
+ reader_net.CreateDB(
+ [],
+ 'DB',
+ db=out_dir,
+ db_type="lmdb"
+ )
+ workspace.RunNetOnce(reader_net)
+ outputs = ['data', 'label']
+ output_sizes = []
+ if output1:
+ outputs.append('output1')
+ output_sizes.append(1)
+ if output2_size:
+ outputs.append('output2')
+ output_sizes.append(output2_size)
+ imageop = core.CreateOperator(
+ 'ImageInput',
+ ['DB'],
+ outputs,
+ batch_size=count_images,
+ color=3,
+ minsize=minsize,
+ crop=crop,
+ is_test=True,
+ bounding_ymin=3,
+ bounding_xmin=5,
+ bounding_height=height - 3,
+ bounding_width=width - 5,
+ mean_per_channel=means,
+ std_per_channel=stds,
+ use_gpu_transform=(device_option.device_type == 1),
+ multiple_label=multiple_label,
+ num_labels=num_labels,
+ output_sizes=output_sizes
+ )
+
+ imageop.device_option.CopyFrom(device_option)
+ main_net = core.Net('main')
+ main_net.Proto().op.extend([imageop])
+ workspace.RunNetOnce(main_net)
+ validator(expected_images, device_option, count_images)
+ # End for
+ # End with
+ # End for
+ shutil.rmtree(out_dir)
+# end run_test
+
+
@unittest.skipIf('cv2' not in sys.modules, 'python-opencv is not installed')
@unittest.skipIf('lmdb' not in sys.modules, 'python-lmdb is not installed')
class TestImport(hu.HypothesisTestCase):
+ def validate_image_and_label(
+ self, expected_images, device_option, count_images, multiple_label):
+ l = workspace.FetchBlob('label')
+ result = workspace.FetchBlob('data').astype(np.int32)
+ # If we don't use_gpu_transform, the output is in NHWC
+ # Our reference output is CHW so we swap
+ if device_option.device_type != 1:
+ expected = [img.swapaxes(0, 1).swapaxes(1, 2) for
+ (img, _, _, _) in expected_images]
+ else:
+ expected = [img for (img, _, _, _) in expected_images]
+ for i in range(count_images):
+ if multiple_label == 0:
+ self.assertEqual(l[i], expected_images[i][1])
+ else:
+ self.assertEqual(
+ (l[i] - expected_images[i][1] > 0).sum(), 0)
+ self.assertEqual((expected[i] - result[i] > 1).sum(), 0)
+ # End for
+ # end validate_image_and_label
+
@given(size_tuple=st.tuples(
st.integers(min_value=8, max_value=4096),
st.integers(min_value=8, max_value=4096)).flatmap(lambda t: st.tuples(
@@ -228,81 +342,52 @@ class TestImport(hu.HypothesisTestCase):
def test_imageinput(
self, size_tuple, means, stds, multiple_label,
num_labels, gc, dc):
- # TODO: Does not test on GPU and does not test use_gpu_transform
- # WARNING: Using ModelHelper automatically does NHWC to NCHW
- # transformation if needed.
- width, height, minsize, crop = size_tuple
- means = [float(m) for m in means]
- stds = [float(s) for s in stds]
- out_dir = tempfile.mkdtemp()
- count_images = 2 # One with bounding box and one without
- expected_images = create_test(
- out_dir,
- width=width,
- height=height,
- default_bound=(3, 5, height - 3, width - 5),
- minsize=minsize,
- crop=crop,
- means=means,
- stds=stds,
- count=count_images,
- multiple_label=multiple_label,
- num_labels=num_labels,
- )
- for device_option in dc:
- with hu.temp_workspace():
- reader_net = core.Net('reader')
- reader_net.CreateDB(
- [],
- 'DB',
- db=out_dir,
- db_type="lmdb"
- )
- workspace.RunNetOnce(reader_net)
- imageop = core.CreateOperator(
- 'ImageInput',
- ['DB'],
- ["data", "label"],
- batch_size=count_images,
- color=3,
- minsize=minsize,
- crop=crop,
- is_test=True,
- bounding_ymin=3,
- bounding_xmin=5,
- bounding_height=height - 3,
- bounding_width=width - 5,
- mean_per_channel=means,
- std_per_channel=stds,
- use_gpu_transform=(device_option.device_type == 1),
- multiple_label=multiple_label,
- num_labels=num_labels,
- )
-
- imageop.device_option.CopyFrom(device_option)
- main_net = core.Net('main')
- main_net.Proto().op.extend([imageop])
- workspace.RunNetOnce(main_net)
- l = workspace.FetchBlob('label')
- result = workspace.FetchBlob('data').astype(np.int32)
- # If we don't use_gpu_transform, the output is in NHWC
- # Our reference output is CHW so we swap
- if device_option.device_type != 1:
- expected = [img.swapaxes(0, 1).swapaxes(1, 2) for
- (img, _) in expected_images]
- else:
- expected = [img for (img, _) in expected_images]
- for i in range(count_images):
- if multiple_label == 0:
- self.assertEqual(l[i], expected_images[i][1])
- else:
- self.assertEqual(
- (l[i] - expected_images[i][1] > 0).sum(), 0)
- self.assertEqual((expected[i] - result[i] > 1).sum(), 0)
- # End for
- # End with
- # End for
- shutil.rmtree(out_dir)
+ def validator(expected_images, device_option, count_images):
+ self.validate_image_and_label(
+ expected_images, device_option, count_images, multiple_label)
+ # End validator
+ run_test(
+ size_tuple, means, stds, multiple_label, num_labels, dc,
+ validator)
+ # End test_imageinput
+
+ @given(size_tuple=st.tuples(
+ st.integers(min_value=8, max_value=4096),
+ st.integers(min_value=8, max_value=4096)).flatmap(lambda t: st.tuples(
+ st.just(t[0]), st.just(t[1]),
+ st.just(min(t[0] - 6, t[1] - 4)),
+ st.integers(min_value=1, max_value=min(t[0] - 6, t[1] - 4)))),
+ means=st.tuples(st.integers(min_value=0, max_value=255),
+ st.integers(min_value=0, max_value=255),
+ st.integers(min_value=0, max_value=255)),
+ stds=st.tuples(st.floats(min_value=1, max_value=10),
+ st.floats(min_value=1, max_value=10),
+ st.floats(min_value=1, max_value=10)),
+ multiple_label=st.integers(0, 1),
+ num_labels=st.integers(min_value=8, max_value=4096),
+ output1=st.floats(min_value=1, max_value=10),
+ output2_size=st.integers(min_value=2, max_value=10),
+ **hu.gcs)
+ @settings(verbosity=Verbosity.verbose)
+ def test_imageinput_with_additional_outputs(
+ self, size_tuple, means, stds, multiple_label,
+ num_labels, output1, output2_size, gc, dc):
+ def validator(expected_images, device_option, count_images):
+ self.validate_image_and_label(
+ expected_images, device_option, count_images, multiple_label)
+
+ output1_result = workspace.FetchBlob('output1')
+ output2_result = workspace.FetchBlob('output2')
+
+ for i in range(count_images):
+ self.assertEqual(output1_result[i], expected_images[i][2])
+ self.assertEqual(
+ (output2_result[i] - expected_images[i][3] > 0).sum(), 0)
+ # End for
+ # End validator
+ run_test(
+ size_tuple, means, stds, multiple_label, num_labels, dc,
+ validator, output1, output2_size)
# End test_imageinput