summaryrefslogtreecommitdiff
path: root/binaries
diff options
context:
space:
mode:
authorSebastian Messmer <messmer@fb.com>2018-09-24 22:52:14 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-09-24 22:57:05 -0700
commit17a65bf9b680ae04522c11462f5fc243e525f07c (patch)
treea39193987ed7a6e39857ecc2a8e5a02d7ce1c0e9 /binaries
parentdfa03e94ebf24b12e889f749c481ed687441cf75 (diff)
downloadpytorch-17a65bf9b680ae04522c11462f5fc243e525f07c.tar.gz
pytorch-17a65bf9b680ae04522c11462f5fc243e525f07c.tar.bz2
pytorch-17a65bf9b680ae04522c11462f5fc243e525f07c.zip
Removing some dependency edges from Blob to other caffe2 (#11923)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11923 This is pre-work to allow moving Blob to ATen/core, which cannot depend on caffe2 anymore. (1) Removing the Blob -> Tensor dependency allows us to move Blob to ATen/core and use it inside IValue without having to wait for the Tensor merge to be complete. (2) In the final Blob design, we want it to be a very small class that doesn't have any special treatment for Tensor (or to be more correct, doesn't allow storing Tensor anymore), so this is anyhow the direction we want to go. This changes call sites that will have to be moved to IValue later, but they cannot be moved to IValue directly, because for that, IValue first needs to be able to store Blob, which in turn first needs this diff and some other changes coming up in future diffs. Codemods: $ codemod --extensions h,hpp,c,cpp,cc "([a-zA-Z0-9_]+)\\.IsTensorType\\(" "BlobIsTensorType(\\1, " $ codemod --extensions h,hpp,c,cpp,cc "([a-zA-Z0-9_]+)->IsTensorType\\(" "BlobIsTensorType(*\\1, " $ codemod --extensions h,hpp,c,cpp,cc "([a-zA-Z0-9_]+)\\.GetMutableTensor\\(" "BlobGetMutableTensor(\\1, " $ codemod --extensions h,hpp,c,cpp,cc "([a-zA-Z0-9_]+)->GetMutableTensor\\(" "BlobGetMutableTensor(*\\1, " It is, however, not only these codemods because regex based refactoring was only able to match a small amount of the call sites. To catch more, I wouldn've needed a AST aware tool like clangr, which I didn't figure out how to use. Reviewed By: ezyang Differential Revision: D9979976 fbshipit-source-id: 2ea17724e223b5b73b44f99362727759ca689e61
Diffstat (limited to 'binaries')
-rw-r--r--binaries/benchmark_helper.cc6
-rw-r--r--binaries/speed_benchmark.cc2
2 files changed, 4 insertions, 4 deletions
diff --git a/binaries/benchmark_helper.cc b/binaries/benchmark_helper.cc
index 001c8e965f..f481a6292c 100644
--- a/binaries/benchmark_helper.cc
+++ b/binaries/benchmark_helper.cc
@@ -163,7 +163,7 @@ void loadInput(
CAFFE_THROW("Not support GPU on mobile.");
#endif
} else {
- caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
+ caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
if (input_type_list[i] == "uint8_t") {
@@ -200,7 +200,7 @@ void fillInputBlob(
int protos_size = tensor_kv.second.protos_size();
caffe2::TensorProto* tensor_proto =
tensor_kv.second.mutable_protos(iteration % protos_size);
- caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
+ caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
if (tensor_proto->data_type() == caffe2::TensorProto::STRING) {
int total_size = tensor_proto->string_data_size();
for (size_t i = 0; i < total_size; i++) {
@@ -298,7 +298,7 @@ void writeOutput(
#endif
} else {
writeTextOutput<caffe2::CPUContext, caffe2::TensorCPU>(
- workspace->GetBlob(name)->GetMutableTensor(caffe2::CPU),
+ BlobGetMutableTensor(workspace->GetBlob(name), caffe2::CPU),
output_prefix,
name);
}
diff --git a/binaries/speed_benchmark.cc b/binaries/speed_benchmark.cc
index 5914e3f58b..fd502cf3c0 100644
--- a/binaries/speed_benchmark.cc
+++ b/binaries/speed_benchmark.cc
@@ -137,7 +137,7 @@ int main(int argc, char** argv) {
if (blob == nullptr) {
blob = workspace->CreateBlob(input_names[i]);
}
- caffe2::TensorCPU* tensor = blob->GetMutableTensor(caffe2::CPU);
+ caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
if (input_type_list[i] == "uint8_t") {