diff options
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/core/blob_serialization.cc | 23 | ||||
-rw-r--r-- | caffe2/core/blob_serialization.h | 6 | ||||
-rw-r--r-- | caffe2/core/blob_serializer_base.h | 12 | ||||
-rw-r--r-- | caffe2/core/blob_test.cc | 15 | ||||
-rw-r--r-- | caffe2/core/db.cc | 7 | ||||
-rw-r--r-- | caffe2/core/db.h | 3 | ||||
-rw-r--r-- | caffe2/core/int8_serialization.cc | 6 | ||||
-rw-r--r-- | caffe2/core/qtensor_serialization.h | 9 | ||||
-rw-r--r-- | caffe2/operators/counter_ops.cc | 8 | ||||
-rw-r--r-- | caffe2/operators/dataset_ops.cc | 12 | ||||
-rw-r--r-- | caffe2/operators/dataset_ops.h | 3 | ||||
-rw-r--r-- | caffe2/operators/index_ops.cc | 6 | ||||
-rw-r--r-- | caffe2/operators/map_ops.h | 7 | ||||
-rw-r--r-- | caffe2/sgd/iter_op.cc | 5 | ||||
-rw-r--r-- | caffe2/sgd/iter_op.h | 3 |
15 files changed, 79 insertions, 46 deletions
diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc index cc0939387d..96a32ad8d1 100644 --- a/caffe2/core/blob_serialization.cc +++ b/caffe2/core/blob_serialization.cc @@ -37,15 +37,16 @@ class StringSerializer : public BlobSerializerBase { * otherwise this function produces a fatal error. */ void Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, SerializationAcceptor acceptor) override { - CAFFE_ENFORCE(blob.IsType<std::string>()); + CAFFE_ENFORCE(typeMeta.Match<std::string>()); BlobProto blob_proto; blob_proto.set_name(name); blob_proto.set_type("std::string"); - blob_proto.set_content(blob.template Get<std::string>()); + blob_proto.set_content(*static_cast<const std::string*>(pointer)); acceptor(name, blob_proto.SerializeAsString()); } }; @@ -70,7 +71,8 @@ void SerializeBlob( std::unique_ptr<BlobSerializerBase> serializer( CreateSerializer(blob.meta().id())); CAFFE_ENFORCE(serializer, "No known serializer for ", blob.meta().name()); - serializer->SerializeWithChunkSize(blob, name, acceptor, chunk_size); + serializer->SerializeWithChunkSize( + blob.GetRaw(), blob.meta(), name, acceptor, chunk_size); } // The blob serialization member function implementation. @@ -86,19 +88,22 @@ std::string SerializeBlob(const Blob& blob, const string& name) { } void TensorSerializer::Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, BlobSerializerBase::SerializationAcceptor acceptor) { - this->SerializeWithChunkSize(blob, name, acceptor, kDefaultChunkSize); + this->SerializeWithChunkSize( + pointer, typeMeta, name, acceptor, kDefaultChunkSize); } void TensorSerializer::SerializeWithChunkSize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, BlobSerializerBase::SerializationAcceptor acceptor, int chunk_size) { - CAFFE_ENFORCE(blob.IsType<Tensor>()); - const auto& tensor = blob.template Get<Tensor>(); + CAFFE_ENFORCE(typeMeta.Match<Tensor>()); + const auto& tensor = *static_cast<const Tensor*>(pointer); if (chunk_size == kNoChunking) { chunk_size = tensor.size() + 1; // to account for empty tensors } else if (chunk_size == kDefaultChunkSize) { diff --git a/caffe2/core/blob_serialization.h b/caffe2/core/blob_serialization.h index 597cc49df0..2c2d590711 100644 --- a/caffe2/core/blob_serialization.h +++ b/caffe2/core/blob_serialization.h @@ -70,11 +70,13 @@ class CAFFE2_API TensorSerializer : public BlobSerializerBase { * otherwise this function produces a fatal error. */ void Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, SerializationAcceptor acceptor) override; void SerializeWithChunkSize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, SerializationAcceptor acceptor, int chunk_size) override; diff --git a/caffe2/core/blob_serializer_base.h b/caffe2/core/blob_serializer_base.h index 4e0e3e4d6d..ad282f31fe 100644 --- a/caffe2/core/blob_serializer_base.h +++ b/caffe2/core/blob_serializer_base.h @@ -43,16 +43,20 @@ class BlobSerializerBase { * serailizer can use it to save blob in several chunks * acceptor should be thread-safe */ - virtual void Serialize(const Blob& blob, const std::string& name, - SerializationAcceptor acceptor) = 0; + virtual void Serialize( + const void* pointer, + TypeMeta typeMeta, + const std::string& name, + SerializationAcceptor acceptor) = 0; virtual void SerializeWithChunkSize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const std::string& name, SerializationAcceptor acceptor, int /*chunk_size*/) { // Base implementation. - Serialize(blob, name, acceptor); + Serialize(pointer, typeMeta, name, acceptor); } }; diff --git a/caffe2/core/blob_test.cc b/caffe2/core/blob_test.cc index b8dfe82abc..30e6dbb69b 100644 --- a/caffe2/core/blob_test.cc +++ b/caffe2/core/blob_test.cc @@ -51,17 +51,19 @@ class BlobTestFooSerializer : public BlobSerializerBase { * otherwise this function produces a fatal error. */ void Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, SerializationAcceptor acceptor) override { - CAFFE_ENFORCE(blob.IsType<BlobTestFoo>()); + CAFFE_ENFORCE(typeMeta.Match<BlobTestFoo>()); BlobProto blob_proto; blob_proto.set_name(name); blob_proto.set_type("BlobTestFoo"); // For simplicity we will just serialize the 4-byte content as a string. blob_proto.set_content(std::string( - reinterpret_cast<const char*>(&(blob.Get<BlobTestFoo>().val)), + reinterpret_cast<const char*>( + &static_cast<const BlobTestFoo*>(pointer)->val), sizeof(int32_t))); acceptor(name, blob_proto.SerializeAsString()); } @@ -942,11 +944,12 @@ class DummyTypeSerializer : public BlobSerializerBase { DummyTypeSerializer() {} ~DummyTypeSerializer() {} void Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, SerializationAcceptor acceptor) override { - CAFFE_ENFORCE(blob.IsType<DummyType>()); - const auto& container = blob.template Get<DummyType>(); + CAFFE_ENFORCE(typeMeta.Match<DummyType>()); + const auto& container = *static_cast<const DummyType*>(pointer); for (int k = 0; k < container.n_chunks; ++k) { std::string serialized_chunk = container.serialize(name, k); acceptor(c10::str(name, kChunkIdSeparator, k), serialized_chunk); diff --git a/caffe2/core/db.cc b/caffe2/core/db.cc index c0031cb066..67b0f1ffe2 100644 --- a/caffe2/core/db.cc +++ b/caffe2/core/db.cc @@ -170,11 +170,12 @@ REGISTER_CAFFE2_DB(MiniDB, MiniDB); REGISTER_CAFFE2_DB(minidb, MiniDB); void DBReaderSerializer::Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, BlobSerializerBase::SerializationAcceptor acceptor) { - CAFFE_ENFORCE(blob.IsType<DBReader>()); - auto& reader = blob.Get<DBReader>(); + CAFFE_ENFORCE(typeMeta.Match<DBReader>()); + const auto& reader = *static_cast<const DBReader*>(pointer); DBReaderProto proto; proto.set_name(name); proto.set_source(reader.source_); diff --git a/caffe2/core/db.h b/caffe2/core/db.h index f6044ff35f..ff7461ae0c 100644 --- a/caffe2/core/db.h +++ b/caffe2/core/db.h @@ -295,7 +295,8 @@ class CAFFE2_API DBReaderSerializer : public BlobSerializerBase { * otherwise this function produces a fatal error. */ void Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, BlobSerializerBase::SerializationAcceptor acceptor) override; }; diff --git a/caffe2/core/int8_serialization.cc b/caffe2/core/int8_serialization.cc index 190cf5797f..7a18e16a2b 100644 --- a/caffe2/core/int8_serialization.cc +++ b/caffe2/core/int8_serialization.cc @@ -11,10 +11,12 @@ namespace int8 { class Int8TensorCPUSerializer : public BlobSerializerBase { public: void Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, SerializationAcceptor acceptor) override { - const auto& tensor = blob.template Get<Int8TensorCPU>(); + CAFFE_ENFORCE(typeMeta.Match<Int8TensorCPU>()); + const auto& tensor = *static_cast<const Int8TensorCPU*>(pointer); BlobProto blob_proto; blob_proto.set_name(name); blob_proto.set_type("Int8TensorCPU"); diff --git a/caffe2/core/qtensor_serialization.h b/caffe2/core/qtensor_serialization.h index 8efac029ee..d9881030f1 100644 --- a/caffe2/core/qtensor_serialization.h +++ b/caffe2/core/qtensor_serialization.h @@ -17,7 +17,8 @@ class QTensorSerializer : public BlobSerializerBase { * Serializes a Blob. Note that this blob has to contain QTensor<Context>. */ void Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, SerializationAcceptor acceptor) override; @@ -34,10 +35,12 @@ class QTensorDeserializer : public BlobDeserializerBase { template <class Context> void QTensorSerializer<Context>::Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, BlobSerializerBase::SerializationAcceptor acceptor) { - const auto& qtensor = blob.template Get<QTensor<Context>>(); + CAFFE_ENFORCE(typeMeta.Match<QTensor<Context>>()); + const auto& qtensor = *static_cast<const QTensor<Context>*>(pointer); BlobProto blob_proto; blob_proto.set_name(name); blob_proto.set_type(kQTensorBlobQType); diff --git a/caffe2/operators/counter_ops.cc b/caffe2/operators/counter_ops.cc index 50e4b9448a..79a6b51057 100644 --- a/caffe2/operators/counter_ops.cc +++ b/caffe2/operators/counter_ops.cc @@ -139,10 +139,11 @@ class CounterSerializer : public BlobSerializerBase { ~CounterSerializer() {} void Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, SerializationAcceptor acceptor) override { - CAFFE_ENFORCE(blob.IsType<std::unique_ptr<Counter<int64_t>>>()); + CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<Counter<int64_t>>>()); BlobProto blob_proto; blob_proto.set_name(name); @@ -152,7 +153,8 @@ class CounterSerializer : public BlobSerializerBase { proto.set_data_type(TensorProto_DataType_INT64); proto.add_dims(1); proto.add_int64_data( - blob.template Get<std::unique_ptr<Counter<int64_t>>>()->retrieve()); + (*static_cast<const std::unique_ptr<Counter<int64_t>>*>(pointer)) + ->retrieve()); acceptor(name, blob_proto.SerializeAsString()); } }; diff --git a/caffe2/operators/dataset_ops.cc b/caffe2/operators/dataset_ops.cc index 28952650b7..87ed0433c2 100644 --- a/caffe2/operators/dataset_ops.cc +++ b/caffe2/operators/dataset_ops.cc @@ -1419,10 +1419,13 @@ class TreeCursorSerializer : public BlobSerializerBase { ~TreeCursorSerializer() {} void Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, SerializationAcceptor acceptor) override { - auto& cursor = blob.template Get<std::unique_ptr<TreeCursor>>(); + CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<TreeCursor>>()); + const auto& cursor = + *static_cast<const std::unique_ptr<TreeCursor>*>(pointer); BlobProto blob_proto; // serialize offsets as a tensor @@ -1495,7 +1498,8 @@ REGISTER_BLOB_DESERIALIZER(std::unique_ptr<TreeCursor>, TreeCursorDeserializer); } // namespace void SharedTensorVectorPtrSerializer::Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, BlobSerializerBase::SerializationAcceptor acceptor) { /* This is dummy serialize that doesn't save anything. If saving the content @@ -1504,7 +1508,7 @@ void SharedTensorVectorPtrSerializer::Serialize( LastNWindowCollectorOp and ReservoirSamplingOp if this serializer actually saves the content. */ - CAFFE_ENFORCE(blob.IsType<std::shared_ptr<std::vector<TensorCPU>>>()); + CAFFE_ENFORCE(typeMeta.Match<std::shared_ptr<std::vector<TensorCPU>>>()); BlobProto blob_proto; blob_proto.set_name(name); blob_proto.set_type("std::shared_ptr<std::vector<TensorCPU>>"); diff --git a/caffe2/operators/dataset_ops.h b/caffe2/operators/dataset_ops.h index 47a5260c83..08b095379e 100644 --- a/caffe2/operators/dataset_ops.h +++ b/caffe2/operators/dataset_ops.h @@ -196,7 +196,8 @@ using TensorVectorPtr = std::unique_ptr<std::vector<Tensor>>; class SharedTensorVectorPtrSerializer : public BlobSerializerBase { public: void Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, BlobSerializerBase::SerializationAcceptor acceptor) override; }; diff --git a/caffe2/operators/index_ops.cc b/caffe2/operators/index_ops.cc index 2fb8f3b338..b6da99e99e 100644 --- a/caffe2/operators/index_ops.cc +++ b/caffe2/operators/index_ops.cc @@ -348,10 +348,12 @@ class IndexSerializer : public BlobSerializerBase { ~IndexSerializer() {} void Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, SerializationAcceptor acceptor) override { - auto& base = blob.template Get<std::unique_ptr<IndexBase>>(); + CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<IndexBase>>()); + const auto& base = *static_cast<const std::unique_ptr<IndexBase>*>(pointer); Blob tensor_blob; auto* tensor_out = BlobGetMutableTensor(&tensor_blob, CPU); diff --git a/caffe2/operators/map_ops.h b/caffe2/operators/map_ops.h index 52cf8d1a8a..7b64808709 100644 --- a/caffe2/operators/map_ops.h +++ b/caffe2/operators/map_ops.h @@ -195,11 +195,12 @@ class MapSerializer : public BlobSerializerBase { using MapType = typename MapTypeTraits<KEY_T, VALUE_T>::MapType; void Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, BlobSerializerBase::SerializationAcceptor acceptor) override { - CAFFE_ENFORCE(blob.IsType<MapType>()); - const MapType& map_data = blob.template Get<MapType>(); + CAFFE_ENFORCE(typeMeta.Match<MapType>()); + const MapType& map_data = *static_cast<const MapType*>(pointer); int64_t sz = map_data.size(); Tensor key_tensor(CPU); key_tensor.Resize(sz); diff --git a/caffe2/sgd/iter_op.cc b/caffe2/sgd/iter_op.cc index ac964018b9..222b20fb6f 100644 --- a/caffe2/sgd/iter_op.cc +++ b/caffe2/sgd/iter_op.cc @@ -8,10 +8,11 @@ namespace caffe2 { void MutexSerializer::Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, BlobSerializerBase::SerializationAcceptor acceptor) { - CAFFE_ENFORCE(blob.IsType<std::unique_ptr<std::mutex>>()); + CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<std::mutex>>()); BlobProto blob_proto; blob_proto.set_name(name); blob_proto.set_type("std::unique_ptr<std::mutex>"); diff --git a/caffe2/sgd/iter_op.h b/caffe2/sgd/iter_op.h index 22ec8d252c..86a4c676bc 100644 --- a/caffe2/sgd/iter_op.h +++ b/caffe2/sgd/iter_op.h @@ -88,7 +88,8 @@ class MutexSerializer : public BlobSerializerBase { * fatal error. */ void Serialize( - const Blob& blob, + const void* pointer, + TypeMeta typeMeta, const string& name, BlobSerializerBase::SerializationAcceptor acceptor) override; }; |