diff options
author | Yan Zhu <yzhu@fb.com> | 2018-11-14 17:19:14 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-11-14 17:21:08 -0800 |
commit | 2356c8d542f7f8fda2d3568372c82393cd42aa49 (patch) | |
tree | 8e1560afe241e68e1e7e77c3ea81005277012406 | |
parent | fed8d8975a3d5b99d963748ed6926e3837cd8098 (diff) | |
download | pytorch-2356c8d542f7f8fda2d3568372c82393cd42aa49.tar.gz pytorch-2356c8d542f7f8fda2d3568372c82393cd42aa49.tar.bz2 pytorch-2356c8d542f7f8fda2d3568372c82393cd42aa49.zip |
device inference for Adam (#13990)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13990
to make sure ITER blob lives on CPU.
Reviewed By: xianjiec
Differential Revision: D13056070
fbshipit-source-id: 148edbf745e50e886da3eb99d4e485d11c1924e2
-rw-r--r-- | caffe2/python/core_test.py | 14 | ||||
-rw-r--r-- | caffe2/sgd/adam_op.cc | 9 |
2 files changed, 23 insertions, 0 deletions
diff --git a/caffe2/python/core_test.py b/caffe2/python/core_test.py index 2f6dedbfd8..6c23d88543 100644 --- a/caffe2/python/core_test.py +++ b/caffe2/python/core_test.py @@ -707,6 +707,20 @@ class TestInferDevice(test_util.TestCase): outputs=["fc_1"] ) + def test_infer_device_adam(self): + in_options = [self.cuda_option] * 6 + in_options[5] = self.cpu_option + out_options = [self.cuda_option] * 4 + self._test_op( + "Adam", + in_options, + out_options, + op_option=self.cuda_option, + inputs=["param", "moment_1", "moment_2", "grad", "lr", "iter"], + outputs=["output_param", "output_moment_1", "output_moment_2", + "output_grad"] + ) + def test_infer_device_cross_device(self): self._test_op("CopyGPUToCPU", self.cuda_option, self.cpu_option) self._test_op("CopyCPUToGPU", self.cpu_option, self.cuda_option) diff --git a/caffe2/sgd/adam_op.cc b/caffe2/sgd/adam_op.cc index 623e93a07e..d12f6765cc 100644 --- a/caffe2/sgd/adam_op.cc +++ b/caffe2/sgd/adam_op.cc @@ -7,6 +7,15 @@ OPERATOR_SCHEMA(Adam) .NumInputs(6) .NumOutputs(3, 4) .AllowInplace({{0, 0}, {1, 1}, {2, 2}}) + .DeviceInferenceFunction([](const OperatorDef& def) { + auto op_device = + def.has_device_option() ? def.device_option() : DeviceOption(); + vector<DeviceOption> in_dev(def.input_size(), op_device); + vector<DeviceOption> out_dev(def.output_size(), op_device); + // ITER input lives on CPU + in_dev[5] = DeviceOption(); + return std::make_pair(in_dev, out_dev); + }) .SetDoc(R"DOC( Computes the Adam update (https://arxiv.org/abs/1412.6980) for an |