summaryrefslogtreecommitdiff
path: root/test/test_optim.py
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2017-01-21 12:18:36 +0100
committerSoumith Chintala <soumith@gmail.com>2017-01-22 18:02:40 -0500
commitf8ae34706e14fe7bb7826d9723dda4bb6a960a4a (patch)
treeb90051e4229b73a69e9be6fd4c839f4417a938e2 /test/test_optim.py
parentf8e89fbe1123f6788992b70361f13ad498665327 (diff)
downloadpytorch-f8ae34706e14fe7bb7826d9723dda4bb6a960a4a.tar.gz
pytorch-f8ae34706e14fe7bb7826d9723dda4bb6a960a4a.tar.bz2
pytorch-f8ae34706e14fe7bb7826d9723dda4bb6a960a4a.zip
Port L-BFGS from Lua optim
Diffstat (limited to 'test/test_optim.py')
-rw-r--r--test/test_optim.py45
1 files changed, 33 insertions, 12 deletions
diff --git a/test/test_optim.py b/test/test_optim.py
index 6172887ac7..ddd33c9a58 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -36,12 +36,19 @@ class TestOptim(TestCase):
initial_dist = params.data.dist(solution)
def eval():
+ optimizer.zero_grad()
loss = rosenbrock(params)
loss.backward()
+ # loss.backward() will give **slightly** different
+ # gradients, than drosenbtock, because of a different ordering
+ # of floating point operations. In most cases it doesn't matter,
+ # but some optimizers are so sensitive that they can temporarily
+ # diverge up to 1e-4, just to converge again. This makes the
+ # comparison more stable.
+ params.grad.data.copy_(drosenbrock(params.data))
return loss
for i in range(2000):
- optimizer.zero_grad()
optimizer.step(eval)
old_fn(lambda _: (rosenbrock(params_t), drosenbrock(params_t)),
params_t, state)
@@ -56,21 +63,21 @@ class TestOptim(TestCase):
optimizer = constructor(weight, bias)
def fn():
+ optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
- return (y + bias).abs().sum()
+ loss = (y + bias).pow(2).sum()
+ loss.backward()
+ return loss
initial_value = fn().data[0]
for i in range(200):
- weight.grad.data.zero_()
- bias.grad.data.zero_()
- fn().backward()
- optimizer.step()
+ optimizer.step(fn)
- self.assertLessEqual(fn().data[0], initial_value)
+ self.assertLess(fn().data[0], initial_value)
- def _test_basic_cases(self, constructor):
+ def _test_basic_cases(self, constructor, ignore_multidevice=False):
self._test_basic_cases_template(
torch.randn(10, 5),
torch.randn(10),
@@ -94,12 +101,12 @@ class TestOptim(TestCase):
constructor
)
# Multi-GPU
- if not torch.cuda.device_count() > 1:
+ if not torch.cuda.device_count() > 1 or ignore_multidevice:
return
self._test_basic_cases_template(
- torch.randn(10, 5).cuda(),
- torch.randn(10).cuda(),
- torch.randn(5).cuda(),
+ torch.randn(10, 5).cuda(0),
+ torch.randn(10).cuda(1),
+ torch.randn(5).cuda(0),
constructor
)
@@ -275,6 +282,20 @@ class TestOptim(TestCase):
lr=1e-3)
)
+ def test_lbfgs(self):
+ self._test_rosenbrock(
+ lambda params: optim.LBFGS(params),
+ wrap_old_fn(old_optim.lbfgs)
+ )
+ self._test_rosenbrock(
+ lambda params: optim.LBFGS(params, lr=5e-2, max_iter=5),
+ wrap_old_fn(old_optim.lbfgs, learningRate=5e-2, maxIter=5)
+ )
+ self._test_basic_cases(
+ lambda weight, bias: optim.LBFGS([weight, bias]),
+ ignore_multidevice=True
+ )
+
def test_invalid_param_type(self):
with self.assertRaises(TypeError):
optim.SGD(Variable(torch.randn(5, 5)), lr=3)