diff options
author | Qinqing Zheng <enosair@users.noreply.github.com> | 2018-02-23 10:41:52 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-23 10:41:52 -0800 |
commit | b3fdfa7bd666716981d71347b713d6f4dda0c0c4 (patch) | |
tree | cf1332f992735a7408b19917d1c453adeef5c35a | |
parent | 232837a75e5aebe81c51783ab679bc388df33a78 (diff) | |
download | pytorch-b3fdfa7bd666716981d71347b713d6f4dda0c0c4.tar.gz pytorch-b3fdfa7bd666716981d71347b713d6f4dda0c0c4.tar.bz2 pytorch-b3fdfa7bd666716981d71347b713d6f4dda0c0c4.zip |
[DT] [4/n] Make epoch_group explicit for JobRunner (#2018)
-rw-r--r-- | caffe2/python/checkpoint.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/caffe2/python/checkpoint.py b/caffe2/python/checkpoint.py index dc3f325230..8cd82537cb 100644 --- a/caffe2/python/checkpoint.py +++ b/caffe2/python/checkpoint.py @@ -745,7 +745,9 @@ def epoch_limiter(num_epochs): init_net = core.Net('epoch_counter_init') counter = init_net.CreateCounter([], init_count=num_epochs - 1) Task(step=init_net) - epoch_net = core.Net('epoch_countdown') - finished = epoch_net.CountDown(counter) - output = Task(step=epoch_net, outputs=finished).outputs()[0] + + with Job.current().epoch_group: + epoch_net = core.Net('epoch_countdown') + finished = epoch_net.CountDown(counter) + output = Task(step=epoch_net, outputs=finished).outputs()[0] Job.current().add_stop_signal(output) |