summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYan Zhu <yzhu@fb.com>2018-11-14 17:19:14 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-11-14 17:21:08 -0800
commit2356c8d542f7f8fda2d3568372c82393cd42aa49 (patch)
tree8e1560afe241e68e1e7e77c3ea81005277012406
parentfed8d8975a3d5b99d963748ed6926e3837cd8098 (diff)
downloadpytorch-2356c8d542f7f8fda2d3568372c82393cd42aa49.tar.gz
pytorch-2356c8d542f7f8fda2d3568372c82393cd42aa49.tar.bz2
pytorch-2356c8d542f7f8fda2d3568372c82393cd42aa49.zip
device inference for Adam (#13990)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/13990 to make sure ITER blob lives on CPU. Reviewed By: xianjiec Differential Revision: D13056070 fbshipit-source-id: 148edbf745e50e886da3eb99d4e485d11c1924e2
-rw-r--r--caffe2/python/core_test.py14
-rw-r--r--caffe2/sgd/adam_op.cc9
2 files changed, 23 insertions, 0 deletions
diff --git a/caffe2/python/core_test.py b/caffe2/python/core_test.py
index 2f6dedbfd8..6c23d88543 100644
--- a/caffe2/python/core_test.py
+++ b/caffe2/python/core_test.py
@@ -707,6 +707,20 @@ class TestInferDevice(test_util.TestCase):
outputs=["fc_1"]
)
+ def test_infer_device_adam(self):
+ in_options = [self.cuda_option] * 6
+ in_options[5] = self.cpu_option
+ out_options = [self.cuda_option] * 4
+ self._test_op(
+ "Adam",
+ in_options,
+ out_options,
+ op_option=self.cuda_option,
+ inputs=["param", "moment_1", "moment_2", "grad", "lr", "iter"],
+ outputs=["output_param", "output_moment_1", "output_moment_2",
+ "output_grad"]
+ )
+
def test_infer_device_cross_device(self):
self._test_op("CopyGPUToCPU", self.cuda_option, self.cpu_option)
self._test_op("CopyCPUToGPU", self.cpu_option, self.cuda_option)
diff --git a/caffe2/sgd/adam_op.cc b/caffe2/sgd/adam_op.cc
index 623e93a07e..d12f6765cc 100644
--- a/caffe2/sgd/adam_op.cc
+++ b/caffe2/sgd/adam_op.cc
@@ -7,6 +7,15 @@ OPERATOR_SCHEMA(Adam)
.NumInputs(6)
.NumOutputs(3, 4)
.AllowInplace({{0, 0}, {1, 1}, {2, 2}})
+ .DeviceInferenceFunction([](const OperatorDef& def) {
+ auto op_device =
+ def.has_device_option() ? def.device_option() : DeviceOption();
+ vector<DeviceOption> in_dev(def.input_size(), op_device);
+ vector<DeviceOption> out_dev(def.output_size(), op_device);
+ // ITER input lives on CPU
+ in_dev[5] = DeviceOption();
+ return std::make_pair(in_dev, out_dev);
+ })
.SetDoc(R"DOC(
Computes the Adam update (https://arxiv.org/abs/1412.6980) for an