summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorQinqing Zheng <enosair@users.noreply.github.com>2018-02-23 10:41:52 -0800
committerGitHub <noreply@github.com>2018-02-23 10:41:52 -0800
commitb3fdfa7bd666716981d71347b713d6f4dda0c0c4 (patch)
treecf1332f992735a7408b19917d1c453adeef5c35a
parent232837a75e5aebe81c51783ab679bc388df33a78 (diff)
downloadpytorch-b3fdfa7bd666716981d71347b713d6f4dda0c0c4.tar.gz
pytorch-b3fdfa7bd666716981d71347b713d6f4dda0c0c4.tar.bz2
pytorch-b3fdfa7bd666716981d71347b713d6f4dda0c0c4.zip
[DT] [4/n] Make epoch_group explicit for JobRunner (#2018)
-rw-r--r--caffe2/python/checkpoint.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/caffe2/python/checkpoint.py b/caffe2/python/checkpoint.py
index dc3f325230..8cd82537cb 100644
--- a/caffe2/python/checkpoint.py
+++ b/caffe2/python/checkpoint.py
@@ -745,7 +745,9 @@ def epoch_limiter(num_epochs):
init_net = core.Net('epoch_counter_init')
counter = init_net.CreateCounter([], init_count=num_epochs - 1)
Task(step=init_net)
- epoch_net = core.Net('epoch_countdown')
- finished = epoch_net.CountDown(counter)
- output = Task(step=epoch_net, outputs=finished).outputs()[0]
+
+ with Job.current().epoch_group:
+ epoch_net = core.Net('epoch_countdown')
+ finished = epoch_net.CountDown(counter)
+ output = Task(step=epoch_net, outputs=finished).outputs()[0]
Job.current().add_stop_signal(output)