summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/core/blob_serialization.cc23
-rw-r--r--caffe2/core/blob_serialization.h6
-rw-r--r--caffe2/core/blob_serializer_base.h12
-rw-r--r--caffe2/core/blob_test.cc15
-rw-r--r--caffe2/core/db.cc7
-rw-r--r--caffe2/core/db.h3
-rw-r--r--caffe2/core/int8_serialization.cc6
-rw-r--r--caffe2/core/qtensor_serialization.h9
-rw-r--r--caffe2/operators/counter_ops.cc8
-rw-r--r--caffe2/operators/dataset_ops.cc12
-rw-r--r--caffe2/operators/dataset_ops.h3
-rw-r--r--caffe2/operators/index_ops.cc6
-rw-r--r--caffe2/operators/map_ops.h7
-rw-r--r--caffe2/sgd/iter_op.cc5
-rw-r--r--caffe2/sgd/iter_op.h3
15 files changed, 79 insertions, 46 deletions
diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc
index cc0939387d..96a32ad8d1 100644
--- a/caffe2/core/blob_serialization.cc
+++ b/caffe2/core/blob_serialization.cc
@@ -37,15 +37,16 @@ class StringSerializer : public BlobSerializerBase {
* otherwise this function produces a fatal error.
*/
void Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override {
- CAFFE_ENFORCE(blob.IsType<std::string>());
+ CAFFE_ENFORCE(typeMeta.Match<std::string>());
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type("std::string");
- blob_proto.set_content(blob.template Get<std::string>());
+ blob_proto.set_content(*static_cast<const std::string*>(pointer));
acceptor(name, blob_proto.SerializeAsString());
}
};
@@ -70,7 +71,8 @@ void SerializeBlob(
std::unique_ptr<BlobSerializerBase> serializer(
CreateSerializer(blob.meta().id()));
CAFFE_ENFORCE(serializer, "No known serializer for ", blob.meta().name());
- serializer->SerializeWithChunkSize(blob, name, acceptor, chunk_size);
+ serializer->SerializeWithChunkSize(
+ blob.GetRaw(), blob.meta(), name, acceptor, chunk_size);
}
// The blob serialization member function implementation.
@@ -86,19 +88,22 @@ std::string SerializeBlob(const Blob& blob, const string& name) {
}
void TensorSerializer::Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor) {
- this->SerializeWithChunkSize(blob, name, acceptor, kDefaultChunkSize);
+ this->SerializeWithChunkSize(
+ pointer, typeMeta, name, acceptor, kDefaultChunkSize);
}
void TensorSerializer::SerializeWithChunkSize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor,
int chunk_size) {
- CAFFE_ENFORCE(blob.IsType<Tensor>());
- const auto& tensor = blob.template Get<Tensor>();
+ CAFFE_ENFORCE(typeMeta.Match<Tensor>());
+ const auto& tensor = *static_cast<const Tensor*>(pointer);
if (chunk_size == kNoChunking) {
chunk_size = tensor.size() + 1; // to account for empty tensors
} else if (chunk_size == kDefaultChunkSize) {
diff --git a/caffe2/core/blob_serialization.h b/caffe2/core/blob_serialization.h
index 597cc49df0..2c2d590711 100644
--- a/caffe2/core/blob_serialization.h
+++ b/caffe2/core/blob_serialization.h
@@ -70,11 +70,13 @@ class CAFFE2_API TensorSerializer : public BlobSerializerBase {
* otherwise this function produces a fatal error.
*/
void Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override;
void SerializeWithChunkSize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor,
int chunk_size) override;
diff --git a/caffe2/core/blob_serializer_base.h b/caffe2/core/blob_serializer_base.h
index 4e0e3e4d6d..ad282f31fe 100644
--- a/caffe2/core/blob_serializer_base.h
+++ b/caffe2/core/blob_serializer_base.h
@@ -43,16 +43,20 @@ class BlobSerializerBase {
* serailizer can use it to save blob in several chunks
* acceptor should be thread-safe
*/
- virtual void Serialize(const Blob& blob, const std::string& name,
- SerializationAcceptor acceptor) = 0;
+ virtual void Serialize(
+ const void* pointer,
+ TypeMeta typeMeta,
+ const std::string& name,
+ SerializationAcceptor acceptor) = 0;
virtual void SerializeWithChunkSize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const std::string& name,
SerializationAcceptor acceptor,
int /*chunk_size*/) {
// Base implementation.
- Serialize(blob, name, acceptor);
+ Serialize(pointer, typeMeta, name, acceptor);
}
};
diff --git a/caffe2/core/blob_test.cc b/caffe2/core/blob_test.cc
index b8dfe82abc..30e6dbb69b 100644
--- a/caffe2/core/blob_test.cc
+++ b/caffe2/core/blob_test.cc
@@ -51,17 +51,19 @@ class BlobTestFooSerializer : public BlobSerializerBase {
* otherwise this function produces a fatal error.
*/
void Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override {
- CAFFE_ENFORCE(blob.IsType<BlobTestFoo>());
+ CAFFE_ENFORCE(typeMeta.Match<BlobTestFoo>());
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type("BlobTestFoo");
// For simplicity we will just serialize the 4-byte content as a string.
blob_proto.set_content(std::string(
- reinterpret_cast<const char*>(&(blob.Get<BlobTestFoo>().val)),
+ reinterpret_cast<const char*>(
+ &static_cast<const BlobTestFoo*>(pointer)->val),
sizeof(int32_t)));
acceptor(name, blob_proto.SerializeAsString());
}
@@ -942,11 +944,12 @@ class DummyTypeSerializer : public BlobSerializerBase {
DummyTypeSerializer() {}
~DummyTypeSerializer() {}
void Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override {
- CAFFE_ENFORCE(blob.IsType<DummyType>());
- const auto& container = blob.template Get<DummyType>();
+ CAFFE_ENFORCE(typeMeta.Match<DummyType>());
+ const auto& container = *static_cast<const DummyType*>(pointer);
for (int k = 0; k < container.n_chunks; ++k) {
std::string serialized_chunk = container.serialize(name, k);
acceptor(c10::str(name, kChunkIdSeparator, k), serialized_chunk);
diff --git a/caffe2/core/db.cc b/caffe2/core/db.cc
index c0031cb066..67b0f1ffe2 100644
--- a/caffe2/core/db.cc
+++ b/caffe2/core/db.cc
@@ -170,11 +170,12 @@ REGISTER_CAFFE2_DB(MiniDB, MiniDB);
REGISTER_CAFFE2_DB(minidb, MiniDB);
void DBReaderSerializer::Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor) {
- CAFFE_ENFORCE(blob.IsType<DBReader>());
- auto& reader = blob.Get<DBReader>();
+ CAFFE_ENFORCE(typeMeta.Match<DBReader>());
+ const auto& reader = *static_cast<const DBReader*>(pointer);
DBReaderProto proto;
proto.set_name(name);
proto.set_source(reader.source_);
diff --git a/caffe2/core/db.h b/caffe2/core/db.h
index f6044ff35f..ff7461ae0c 100644
--- a/caffe2/core/db.h
+++ b/caffe2/core/db.h
@@ -295,7 +295,8 @@ class CAFFE2_API DBReaderSerializer : public BlobSerializerBase {
* otherwise this function produces a fatal error.
*/
void Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor) override;
};
diff --git a/caffe2/core/int8_serialization.cc b/caffe2/core/int8_serialization.cc
index 190cf5797f..7a18e16a2b 100644
--- a/caffe2/core/int8_serialization.cc
+++ b/caffe2/core/int8_serialization.cc
@@ -11,10 +11,12 @@ namespace int8 {
class Int8TensorCPUSerializer : public BlobSerializerBase {
public:
void Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override {
- const auto& tensor = blob.template Get<Int8TensorCPU>();
+ CAFFE_ENFORCE(typeMeta.Match<Int8TensorCPU>());
+ const auto& tensor = *static_cast<const Int8TensorCPU*>(pointer);
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type("Int8TensorCPU");
diff --git a/caffe2/core/qtensor_serialization.h b/caffe2/core/qtensor_serialization.h
index 8efac029ee..d9881030f1 100644
--- a/caffe2/core/qtensor_serialization.h
+++ b/caffe2/core/qtensor_serialization.h
@@ -17,7 +17,8 @@ class QTensorSerializer : public BlobSerializerBase {
* Serializes a Blob. Note that this blob has to contain QTensor<Context>.
*/
void Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override;
@@ -34,10 +35,12 @@ class QTensorDeserializer : public BlobDeserializerBase {
template <class Context>
void QTensorSerializer<Context>::Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor) {
- const auto& qtensor = blob.template Get<QTensor<Context>>();
+ CAFFE_ENFORCE(typeMeta.Match<QTensor<Context>>());
+ const auto& qtensor = *static_cast<const QTensor<Context>*>(pointer);
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type(kQTensorBlobQType);
diff --git a/caffe2/operators/counter_ops.cc b/caffe2/operators/counter_ops.cc
index 50e4b9448a..79a6b51057 100644
--- a/caffe2/operators/counter_ops.cc
+++ b/caffe2/operators/counter_ops.cc
@@ -139,10 +139,11 @@ class CounterSerializer : public BlobSerializerBase {
~CounterSerializer() {}
void Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override {
- CAFFE_ENFORCE(blob.IsType<std::unique_ptr<Counter<int64_t>>>());
+ CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<Counter<int64_t>>>());
BlobProto blob_proto;
blob_proto.set_name(name);
@@ -152,7 +153,8 @@ class CounterSerializer : public BlobSerializerBase {
proto.set_data_type(TensorProto_DataType_INT64);
proto.add_dims(1);
proto.add_int64_data(
- blob.template Get<std::unique_ptr<Counter<int64_t>>>()->retrieve());
+ (*static_cast<const std::unique_ptr<Counter<int64_t>>*>(pointer))
+ ->retrieve());
acceptor(name, blob_proto.SerializeAsString());
}
};
diff --git a/caffe2/operators/dataset_ops.cc b/caffe2/operators/dataset_ops.cc
index 28952650b7..87ed0433c2 100644
--- a/caffe2/operators/dataset_ops.cc
+++ b/caffe2/operators/dataset_ops.cc
@@ -1419,10 +1419,13 @@ class TreeCursorSerializer : public BlobSerializerBase {
~TreeCursorSerializer() {}
void Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override {
- auto& cursor = blob.template Get<std::unique_ptr<TreeCursor>>();
+ CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<TreeCursor>>());
+ const auto& cursor =
+ *static_cast<const std::unique_ptr<TreeCursor>*>(pointer);
BlobProto blob_proto;
// serialize offsets as a tensor
@@ -1495,7 +1498,8 @@ REGISTER_BLOB_DESERIALIZER(std::unique_ptr<TreeCursor>, TreeCursorDeserializer);
} // namespace
void SharedTensorVectorPtrSerializer::Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor) {
/* This is dummy serialize that doesn't save anything. If saving the content
@@ -1504,7 +1508,7 @@ void SharedTensorVectorPtrSerializer::Serialize(
LastNWindowCollectorOp and ReservoirSamplingOp if this serializer actually
saves the content.
*/
- CAFFE_ENFORCE(blob.IsType<std::shared_ptr<std::vector<TensorCPU>>>());
+ CAFFE_ENFORCE(typeMeta.Match<std::shared_ptr<std::vector<TensorCPU>>>());
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type("std::shared_ptr<std::vector<TensorCPU>>");
diff --git a/caffe2/operators/dataset_ops.h b/caffe2/operators/dataset_ops.h
index 47a5260c83..08b095379e 100644
--- a/caffe2/operators/dataset_ops.h
+++ b/caffe2/operators/dataset_ops.h
@@ -196,7 +196,8 @@ using TensorVectorPtr = std::unique_ptr<std::vector<Tensor>>;
class SharedTensorVectorPtrSerializer : public BlobSerializerBase {
public:
void Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor) override;
};
diff --git a/caffe2/operators/index_ops.cc b/caffe2/operators/index_ops.cc
index 2fb8f3b338..b6da99e99e 100644
--- a/caffe2/operators/index_ops.cc
+++ b/caffe2/operators/index_ops.cc
@@ -348,10 +348,12 @@ class IndexSerializer : public BlobSerializerBase {
~IndexSerializer() {}
void Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override {
- auto& base = blob.template Get<std::unique_ptr<IndexBase>>();
+ CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<IndexBase>>());
+ const auto& base = *static_cast<const std::unique_ptr<IndexBase>*>(pointer);
Blob tensor_blob;
auto* tensor_out = BlobGetMutableTensor(&tensor_blob, CPU);
diff --git a/caffe2/operators/map_ops.h b/caffe2/operators/map_ops.h
index 52cf8d1a8a..7b64808709 100644
--- a/caffe2/operators/map_ops.h
+++ b/caffe2/operators/map_ops.h
@@ -195,11 +195,12 @@ class MapSerializer : public BlobSerializerBase {
using MapType = typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
void Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor) override {
- CAFFE_ENFORCE(blob.IsType<MapType>());
- const MapType& map_data = blob.template Get<MapType>();
+ CAFFE_ENFORCE(typeMeta.Match<MapType>());
+ const MapType& map_data = *static_cast<const MapType*>(pointer);
int64_t sz = map_data.size();
Tensor key_tensor(CPU);
key_tensor.Resize(sz);
diff --git a/caffe2/sgd/iter_op.cc b/caffe2/sgd/iter_op.cc
index ac964018b9..222b20fb6f 100644
--- a/caffe2/sgd/iter_op.cc
+++ b/caffe2/sgd/iter_op.cc
@@ -8,10 +8,11 @@
namespace caffe2 {
void MutexSerializer::Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor) {
- CAFFE_ENFORCE(blob.IsType<std::unique_ptr<std::mutex>>());
+ CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<std::mutex>>());
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type("std::unique_ptr<std::mutex>");
diff --git a/caffe2/sgd/iter_op.h b/caffe2/sgd/iter_op.h
index 22ec8d252c..86a4c676bc 100644
--- a/caffe2/sgd/iter_op.h
+++ b/caffe2/sgd/iter_op.h
@@ -88,7 +88,8 @@ class MutexSerializer : public BlobSerializerBase {
* fatal error.
*/
void Serialize(
- const Blob& blob,
+ const void* pointer,
+ TypeMeta typeMeta,
const string& name,
BlobSerializerBase::SerializationAcceptor acceptor) override;
};