diff options
author | Jerry Zhang <jerryzh@fb.com> | 2018-07-27 10:50:54 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-07-27 10:56:39 -0700 |
commit | aebf3b47aef7c7bbdfe30f74dc35ed0ba727bdc1 (patch) | |
tree | 19aa553ff72cb2ff7cc922ba5bdc0a2679b83ccf /caffe2/image | |
parent | 94439d7df4d158023ea964db9afacaa2e1370074 (diff) | |
download | pytorch-aebf3b47aef7c7bbdfe30f74dc35ed0ba727bdc1.tar.gz pytorch-aebf3b47aef7c7bbdfe30f74dc35ed0ba727bdc1.tar.bz2 pytorch-aebf3b47aef7c7bbdfe30f74dc35ed0ba727bdc1.zip |
Remove template parameter from Tensor (#9939)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9939
Pull Request resolved: https://github.com/facebookresearch/weakly-supervised-action-detection/pull/13
Pull Request resolved: https://github.com/pytorch/translate/pull/166
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9125
Closes https://github.com/pytorch/pytorch/pull/9125
Use inheritance for polymorphism, and remove template parameter
This is to change the templating in call sites, the core implementations will change later
Before Caffe2 Tensor class was compile-time fixed to bind to a particular device/context. With this change, we're making it a runtime property (stored inside the tensor), but preserve the same semantics. For example, one has to specify device type in order to create a Tensor - there are no uninitialized tensors. More specifically the changes are:
1. We added an extra argument *DeviceType* to most of the constructors of the tensor, e.g. (Tensor(DeviceType type)),
2. Semantics of constructor Tensor(const Tensor<SrcContext>& src, ContextForCopy* context); is changed, in this constructor, the second context is passed in to enable us to call the templated Copy function, it could be in a different context as source and target previously, now we'll enforce that the context should have same device type as src, if it is provided.
3. To preserve 'get-or-construct' semantics of Blob, we added specialized getter Blob::GetMutableTensor that verifies both that Blob contains a Tensor and that it's of a correct type
4. Specifically, Tensor type is not default-constructible any more (as we don't have unknown device tensors) and thus some of the code handling STL containers needs to change
Note: Some changes are postponed just to keep this diff a bit smaller. Please see `TODO`s.
Reviewed By: ezyang, houseroad
Differential Revision: D9024330
fbshipit-source-id: e0b8295d2dc6ebe2963383ded5af799ad17164ba
Diffstat (limited to 'caffe2/image')
-rw-r--r-- | caffe2/image/image_input_op.h | 45 | ||||
-rw-r--r-- | caffe2/image/transform_gpu.cu | 33 | ||||
-rw-r--r-- | caffe2/image/transform_gpu.h | 9 |
3 files changed, 49 insertions, 38 deletions
diff --git a/caffe2/image/image_input_op.h b/caffe2/image/image_input_op.h index 6bf232977d..97b942d924 100644 --- a/caffe2/image/image_input_op.h +++ b/caffe2/image/image_input_op.h @@ -87,12 +87,12 @@ class ImageInputOp final unique_ptr<db::DBReader> owned_reader_; const db::DBReader* reader_; CPUContext cpu_context_; - TensorCPU prefetched_image_; - TensorCPU prefetched_label_; + Tensor prefetched_image_{CPU}; + Tensor prefetched_label_{CPU}; vector<TensorCPU> prefetched_additional_outputs_; - Tensor<Context> prefetched_image_on_device_; - Tensor<Context> prefetched_label_on_device_; - vector<Tensor<Context>> prefetched_additional_outputs_on_device_; + Tensor prefetched_image_on_device_{Context::GetDeviceType()}; + Tensor prefetched_label_on_device_{Context::GetDeviceType()}; + vector<Tensor> prefetched_additional_outputs_on_device_; // Default parameters for images PerImageArg default_arg_; int batch_size_; @@ -118,8 +118,8 @@ class ImageInputOp final int crop_; std::vector<float> mean_; std::vector<float> std_; - Tensor<Context> mean_gpu_; - Tensor<Context> std_gpu_; + Tensor mean_gpu_{Context::GetDeviceType()}; + Tensor std_gpu_{Context::GetDeviceType()}; bool mirror_; bool is_test_; bool use_caffe_datum_; @@ -154,8 +154,6 @@ ImageInputOp<Context>::ImageInputOp( Workspace* ws) : PrefetchOperator<Context>(operator_def, ws), reader_(nullptr), - prefetched_additional_outputs_(OutputSize() - 2), - prefetched_additional_outputs_on_device_(OutputSize() - 2), batch_size_( OperatorBase::template GetSingleArgument<int>("batch_size", 0)), label_type_(static_cast<LABEL_TYPE>( @@ -385,6 +383,9 @@ ImageInputOp<Context>::ImageInputOp( } for (int i = 0; i < additional_output_sizes.size(); ++i) { + prefetched_additional_outputs_on_device_.emplace_back( + Context::GetDeviceType()); + prefetched_additional_outputs_.emplace_back(CPU); prefetched_additional_outputs_[i].Resize( TIndex(batch_size_), TIndex(additional_output_sizes[i])); } @@ -1207,12 +1208,12 @@ bool ImageInputOp<Context>::Prefetch() { // If the context is not CPUContext, we will need to do a copy in the // prefetch function as well. if (!std::is_same<Context, CPUContext>::value) { - prefetched_image_on_device_.CopyFrom(prefetched_image_, &context_); - prefetched_label_on_device_.CopyFrom(prefetched_label_, &context_); + prefetched_image_on_device_.CopyFrom(prefetched_image_, &cpu_context_); + prefetched_label_on_device_.CopyFrom(prefetched_label_, &cpu_context_); for (int i = 0; i < prefetched_additional_outputs_on_device_.size(); ++i) { prefetched_additional_outputs_on_device_[i].CopyFrom( - prefetched_additional_outputs_[i], &context_); + prefetched_additional_outputs_[i], &cpu_context_); } } @@ -1223,13 +1224,13 @@ bool ImageInputOp<Context>::Prefetch() { template <class Context> bool ImageInputOp<Context>::CopyPrefetched() { - auto* image_output = OperatorBase::Output<Tensor<Context> >(0); - auto* label_output = OperatorBase::Output<Tensor<Context> >(1); - vector<Tensor<Context>*> additional_outputs_output; + auto type = Context::GetDeviceType(); + auto* image_output = OperatorBase::Output<Tensor>(0, type); + auto* label_output = OperatorBase::Output<Tensor>(1, type); + vector<Tensor*> additional_outputs_output; for (int i = 2; i < OutputSize(); ++i) { - additional_outputs_output.push_back( - OperatorBase::Output<Tensor<Context>>(i)); + additional_outputs_output.push_back(OperatorBase::Output<Tensor>(i, type)); } // Note(jiayq): The if statement below should be optimized away by the @@ -1249,10 +1250,12 @@ bool ImageInputOp<Context>::CopyPrefetched() { mean_gpu_.Resize(mean_.size()); std_gpu_.Resize(std_.size()); - context_.template Copy<float, CPUContext, Context>( - mean_.size(), mean_.data(), mean_gpu_.template mutable_data<float>()); - context_.template Copy<float, CPUContext, Context>( - std_.size(), std_.data(), std_gpu_.template mutable_data<float>()); + context_.template CopyFromCPU<float>( + mean_.size(), + mean_.data(), + mean_gpu_.template mutable_data<float>()); + context_.template CopyFromCPU<float>( + std_.size(), std_.data(), std_gpu_.template mutable_data<float>()); mean_std_copied_ = true; } // GPU transform kernel allows explicitly setting output type diff --git a/caffe2/image/transform_gpu.cu b/caffe2/image/transform_gpu.cu index c6d8d77533..bb557429f5 100644 --- a/caffe2/image/transform_gpu.cu +++ b/caffe2/image/transform_gpu.cu @@ -50,9 +50,12 @@ __global__ void transform_kernel( template <typename T_IN, typename T_OUT, class Context> -bool TransformOnGPU(Tensor<Context>& X, Tensor<Context> *Y, - Tensor<Context>& mean, Tensor<Context>& std, - Context *context) { +bool TransformOnGPU( + Tensor& X, + Tensor* Y, + Tensor& mean, + Tensor& std, + Context* context) { // data comes in as NHWC const int N = X.dim32(0), C = X.dim32(3), H = X.dim32(1), W = X.dim32(2); // data goes out as NCHW @@ -68,16 +71,18 @@ bool TransformOnGPU(Tensor<Context>& X, Tensor<Context> *Y, return true; }; -template bool TransformOnGPU<uint8_t, float, CUDAContext>(Tensor<CUDAContext>& X, - Tensor<CUDAContext> *Y, - Tensor<CUDAContext>& mean, - Tensor<CUDAContext>& std, - CUDAContext *context); - -template bool TransformOnGPU<uint8_t, float16, CUDAContext>(Tensor<CUDAContext>& X, - Tensor<CUDAContext> *Y, - Tensor<CUDAContext>& mean, - Tensor<CUDAContext>& std, - CUDAContext *context); +template bool TransformOnGPU<uint8_t, float, CUDAContext>( + Tensor& X, + Tensor* Y, + Tensor& mean, + Tensor& std, + CUDAContext* context); + +template bool TransformOnGPU<uint8_t, float16, CUDAContext>( + Tensor& X, + Tensor* Y, + Tensor& mean, + Tensor& std, + CUDAContext* context); } // namespace caffe2 diff --git a/caffe2/image/transform_gpu.h b/caffe2/image/transform_gpu.h index a19b5251f5..3ca11ce159 100644 --- a/caffe2/image/transform_gpu.h +++ b/caffe2/image/transform_gpu.h @@ -31,9 +31,12 @@ namespace caffe2 { template <typename T_IN, typename T_OUT, class Context> -bool TransformOnGPU(Tensor<Context>& X, Tensor<Context>* Y, - Tensor<Context>& mean, Tensor<Context>& std, - Context* context); +bool TransformOnGPU( + Tensor& X, + Tensor* Y, + Tensor& mean, + Tensor& std, + Context* context); } // namespace caffe2 |