summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Antonov <michael.antonov@oculus.com>2018-10-18 12:47:12 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-10-18 12:49:01 -0700
commit63cd051867b830ec62ecdf16da319e4522c506fd (patch)
tree972722fb9b493347239778ec869105b433b18f81
parent2c566a17c763ab000ebeba1d3c01762bae814e42 (diff)
downloadpytorch-63cd051867b830ec62ecdf16da319e4522c506fd.tar.gz
pytorch-63cd051867b830ec62ecdf16da319e4522c506fd.tar.bz2
pytorch-63cd051867b830ec62ecdf16da319e4522c506fd.zip
Guard all Caffe2 protobuf string serializations with CAFFE_ENFORCE (#12799)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12799 Updated all non-test uses of protobuf::MessageLite::SerializeAsString to call SerializeAsString_EnforceCheck so that the return value is checked and can throw an exception if failing. Most of the affected code was called from classes derived from BlobSerializeBase. Didn't touch most tests and ENFORCE calls because they usually do checks anyway. Reviewed By: ezyang Differential Revision: D10416438 fbshipit-source-id: cb842e3e26b0918829d71267a375d4dd40600d58
-rw-r--r--binaries/convert_caffe_image_db.cc3
-rw-r--r--caffe2/core/blob_serialization.cc23
-rw-r--r--caffe2/core/blob_serialization.h18
-rw-r--r--caffe2/core/blob_test.cc2
-rw-r--r--caffe2/core/db.cc4
-rw-r--r--caffe2/core/int8_serialization.cc2
-rw-r--r--caffe2/core/qtensor_serialization.h2
-rw-r--r--caffe2/db/protodb.cc5
-rw-r--r--caffe2/operators/counter_ops.cc2
-rw-r--r--caffe2/operators/dataset_ops.cc4
-rw-r--r--caffe2/operators/index_ops.cc2
-rw-r--r--caffe2/operators/map_ops.h4
-rw-r--r--caffe2/python/pybind_state.cc3
-rw-r--r--caffe2/sgd/iter_op.cc2
14 files changed, 58 insertions, 18 deletions
diff --git a/binaries/convert_caffe_image_db.cc b/binaries/convert_caffe_image_db.cc
index de7efbf65b..dca13d6e97 100644
--- a/binaries/convert_caffe_image_db.cc
+++ b/binaries/convert_caffe_image_db.cc
@@ -79,7 +79,7 @@ int main(int argc, char** argv) {
data->add_dims(datum.channels());
data->set_byte_data(buffer, datum.data().size());
}
- transaction->Put(cursor->key(), protos.SerializeAsString());
+ transaction->Put(cursor->key(), SerializeAsString_EnforceCheck(protos));
if (++count % FLAGS_batch_size == 0) {
transaction->Commit();
LOG(INFO) << "Converted " << count << " items so far.";
@@ -88,4 +88,3 @@ int main(int argc, char** argv) {
LOG(INFO) << "A total of " << count << " items processed.";
return 0;
}
-
diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc
index 281b5bd4f7..82dedf63d6 100644
--- a/caffe2/core/blob_serialization.cc
+++ b/caffe2/core/blob_serialization.cc
@@ -47,7 +47,7 @@ class StringSerializer : public BlobSerializerBase {
blob_proto.set_name(name);
blob_proto.set_type("std::string");
blob_proto.set_content(*static_cast<const std::string*>(pointer));
- acceptor(name, blob_proto.SerializeAsString());
+ acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};
@@ -134,7 +134,7 @@ void TensorSerializer::SerializeWithChunkSize(
tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size);
acceptor(
c10::str(name, kChunkIdSeparator, chunkStart / chunk_size),
- blob_proto.SerializeAsString());
+ SerializeBlobProtoAsString_EnforceCheck(blob_proto));
};
#ifndef __ANDROID__
@@ -543,6 +543,25 @@ void TensorDeserializer::Deserialize(const TensorProto& proto, Tensor* tensor) {
context->FinishDeviceComputation();
}
+////////////////////////////////////////////////////////////////////////////////
+// Serialization Helpers
+////////////////////////////////////////////////////////////////////////////////
+
+std::string SerializeAsString_EnforceCheck(
+ const google::protobuf::MessageLite& msg,
+ const char* error_location) {
+ std::string serialize_output;
+ bool result = msg.SerializeToString(&serialize_output);
+ if (!error_location) {
+ CAFFE_ENFORCE(result, "protobuf::SerializeToString failed");
+ } else {
+ CAFFE_ENFORCE(result,
+ "protobuf::SerializeToString failed for ", error_location);
+ }
+ return serialize_output;
+}
+
+
namespace {
// Serialize Tensor
REGISTER_BLOB_SERIALIZER((TypeMeta::Id<Tensor>()), TensorSerializer);
diff --git a/caffe2/core/blob_serialization.h b/caffe2/core/blob_serialization.h
index 2c2d590711..90a700f439 100644
--- a/caffe2/core/blob_serialization.h
+++ b/caffe2/core/blob_serialization.h
@@ -184,6 +184,24 @@ inline void CopyFromProtoWithCast(
}
} // namespace detail
+
+////////////////////////////////////////////////////////////////////////////////
+// Serialization Helpers
+////////////////////////////////////////////////////////////////////////////////
+
+// Converts MessageLite to string while also checking that SerializeAsString
+// succeeds. Pass description of class/function of the call if you'd
+// like it appended to the error message.
+std::string SerializeAsString_EnforceCheck(
+ const google::protobuf::MessageLite&,
+ const char* error_location = nullptr);
+
+// Convert BlobProto to string with success checks.
+inline std::string SerializeBlobProtoAsString_EnforceCheck(
+ const BlobProto& blob) {
+ return SerializeAsString_EnforceCheck(blob, blob.name().c_str());
+}
+
} // namespace caffe2
#endif // CAFFE2_CORE_BLOB_SERIALIZATION_H_
diff --git a/caffe2/core/blob_test.cc b/caffe2/core/blob_test.cc
index c65b860bcb..662eab7673 100644
--- a/caffe2/core/blob_test.cc
+++ b/caffe2/core/blob_test.cc
@@ -65,7 +65,7 @@ class BlobTestFooSerializer : public BlobSerializerBase {
reinterpret_cast<const char*>(
&static_cast<const BlobTestFoo*>(pointer)->val),
sizeof(int32_t)));
- acceptor(name, blob_proto.SerializeAsString());
+ acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};
diff --git a/caffe2/core/db.cc b/caffe2/core/db.cc
index 67b0f1ffe2..16c6509299 100644
--- a/caffe2/core/db.cc
+++ b/caffe2/core/db.cc
@@ -186,8 +186,8 @@ void DBReaderSerializer::Serialize(
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type("DBReader");
- blob_proto.set_content(proto.SerializeAsString());
- acceptor(name, blob_proto.SerializeAsString());
+ blob_proto.set_content(SerializeAsString_EnforceCheck(proto));
+ acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
void DBReaderDeserializer::Deserialize(const BlobProto& proto, Blob* blob) {
diff --git a/caffe2/core/int8_serialization.cc b/caffe2/core/int8_serialization.cc
index 7a18e16a2b..dc22b12a99 100644
--- a/caffe2/core/int8_serialization.cc
+++ b/caffe2/core/int8_serialization.cc
@@ -51,7 +51,7 @@ class Int8TensorCPUSerializer : public BlobSerializerBase {
CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
}
- acceptor(name, blob_proto.SerializeAsString());
+ acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
private:
diff --git a/caffe2/core/qtensor_serialization.h b/caffe2/core/qtensor_serialization.h
index d9881030f1..007174368a 100644
--- a/caffe2/core/qtensor_serialization.h
+++ b/caffe2/core/qtensor_serialization.h
@@ -55,7 +55,7 @@ void QTensorSerializer<Context>::Serialize(
proto.set_is_signed(qtensor.is_signed());
detail::CopyToProtoWithCast(
qtensor.nbytes(), qtensor.data(), proto.mutable_data(), &this->context_);
- acceptor(name, blob_proto.SerializeAsString());
+ acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
template <class Context>
diff --git a/caffe2/db/protodb.cc b/caffe2/db/protodb.cc
index fdaaaf57f1..68b74724a7 100644
--- a/caffe2/db/protodb.cc
+++ b/caffe2/db/protodb.cc
@@ -20,7 +20,10 @@ class ProtoDBCursor : public Cursor {
void SeekToFirst() override { iter_ = 0; }
void Next() override { ++iter_; }
string key() override { return proto_->protos(iter_).name(); }
- string value() override { return proto_->protos(iter_).SerializeAsString(); }
+ string value() override {
+ return
+ SerializeAsString_EnforceCheck(proto_->protos(iter_), "ProtoDBCursor");
+ }
bool Valid() override { return iter_ < proto_->protos_size(); }
private:
diff --git a/caffe2/operators/counter_ops.cc b/caffe2/operators/counter_ops.cc
index 79a6b51057..2a2278c313 100644
--- a/caffe2/operators/counter_ops.cc
+++ b/caffe2/operators/counter_ops.cc
@@ -155,7 +155,7 @@ class CounterSerializer : public BlobSerializerBase {
proto.add_int64_data(
(*static_cast<const std::unique_ptr<Counter<int64_t>>*>(pointer))
->retrieve());
- acceptor(name, blob_proto.SerializeAsString());
+ acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};
diff --git a/caffe2/operators/dataset_ops.cc b/caffe2/operators/dataset_ops.cc
index 87ed0433c2..b0a34f813b 100644
--- a/caffe2/operators/dataset_ops.cc
+++ b/caffe2/operators/dataset_ops.cc
@@ -1451,7 +1451,7 @@ class TreeCursorSerializer : public BlobSerializerBase {
}
blob_proto.set_content(os.str());
- acceptor(name, blob_proto.SerializeAsString());
+ acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};
@@ -1513,7 +1513,7 @@ void SharedTensorVectorPtrSerializer::Serialize(
blob_proto.set_name(name);
blob_proto.set_type("std::shared_ptr<std::vector<TensorCPU>>");
blob_proto.set_content("");
- acceptor(name, blob_proto.SerializeAsString());
+ acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
};
void SharedTensorVectorPtrDeserializer::Deserialize(
diff --git a/caffe2/operators/index_ops.cc b/caffe2/operators/index_ops.cc
index b6da99e99e..5c6488d59e 100644
--- a/caffe2/operators/index_ops.cc
+++ b/caffe2/operators/index_ops.cc
@@ -381,7 +381,7 @@ class IndexSerializer : public BlobSerializerBase {
os << base->maxElements() << " " << base->isFrozen();
blob_proto.set_content(os.str());
- acceptor(name, blob_proto.SerializeAsString());
+ acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
private:
diff --git a/caffe2/operators/map_ops.h b/caffe2/operators/map_ops.h
index 7b64808709..fa2c4a865b 100644
--- a/caffe2/operators/map_ops.h
+++ b/caffe2/operators/map_ops.h
@@ -225,8 +225,8 @@ class MapSerializer : public BlobSerializerBase {
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type(MapTypeTraits<KEY_T, VALUE_T>::MapTypeName());
- blob_proto.set_content(tensor_protos.SerializeAsString());
- acceptor(name, blob_proto.SerializeAsString());
+ blob_proto.set_content(SerializeAsString_EnforceCheck(tensor_protos));
+ acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
};
diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc
index 61474fe7f6..4f08acda9a 100644
--- a/caffe2/python/pybind_state.cc
+++ b/caffe2/python/pybind_state.cc
@@ -586,7 +586,8 @@ void addObjectMethods(py::module& m) {
const auto& meta = GetGradientForOp(def, output_gradients);
std::vector<py::bytes> grad_ops;
for (const auto& op : meta.ops_) {
- grad_ops.push_back(op.SerializeAsString());
+ grad_ops.push_back(
+ SerializeAsString_EnforceCheck(op, "addObjectMethods"));
}
return std::pair<std::vector<py::bytes>, std::vector<GradientWrapper>>{
grad_ops, meta.g_input_};
diff --git a/caffe2/sgd/iter_op.cc b/caffe2/sgd/iter_op.cc
index 222b20fb6f..8b851e0ba1 100644
--- a/caffe2/sgd/iter_op.cc
+++ b/caffe2/sgd/iter_op.cc
@@ -17,7 +17,7 @@ void MutexSerializer::Serialize(
blob_proto.set_name(name);
blob_proto.set_type("std::unique_ptr<std::mutex>");
blob_proto.set_content("");
- acceptor(name, blob_proto.SerializeAsString());
+ acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
void MutexDeserializer::Deserialize(const BlobProto& /* unused */, Blob* blob) {