diff options
author | Tristan Rice <tristanr@fb.com> | 2018-12-17 15:59:45 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-12-17 16:02:16 -0800 |
commit | e650a848721a39707dcdde3e36a5ca05f3da0510 (patch) | |
tree | 1eaddf5e7e45a2ef2aceabe58bef805b6c5c0091 | |
parent | e0b261a35b1483e15c7b81a476851abfdb69602a (diff) | |
download | pytorch-e650a848721a39707dcdde3e36a5ca05f3da0510.tar.gz pytorch-e650a848721a39707dcdde3e36a5ca05f3da0510.tar.bz2 pytorch-e650a848721a39707dcdde3e36a5ca05f3da0510.zip |
caffe2/python/task: added __repr__ methods to all task definitions (#15250)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15250
This adds `__repr__` methods to all of the classes under task.py. This makes the objects much easier to interact with when using them in an interactive manner, such as in a Jupyter notebook.
The default `__repr__` method just returns the object ID which is very unhelpful.
Reviewed By: hanli0612
Differential Revision: D13475758
fbshipit-source-id: 6e1b166ec35163b9776c797b6a2e0d002560cd29
-rw-r--r-- | caffe2/python/task.py | 25 | ||||
-rw-r--r-- | caffe2/python/task_test.py | 24 |
2 files changed, 49 insertions, 0 deletions
diff --git a/caffe2/python/task.py b/caffe2/python/task.py index 161ba4f5bb..eb7ad4edfa 100644 --- a/caffe2/python/task.py +++ b/caffe2/python/task.py @@ -52,6 +52,10 @@ class Cluster(object): def node_kwargs(self): return self._node_kwargs + def __repr__(self): + return "Cluster(nodes={}, node_kwargs={})".format( + self.nodes(), self.node_kwargs()) + @context.define_context(allow_default=True) class Node(object): @@ -85,6 +89,9 @@ class Node(object): def __str__(self): return self._name + def __repr__(self): + return "Node(name={}, kwargs={})".format(self._name, self._kwargs) + def kwargs(self): return self._kwargs @@ -345,6 +352,10 @@ class TaskGroup(object): def workspace_type(self): return self._workspace_type + def __repr__(self): + return "TaskGroup(tasks={}, workspace_type={}, remote_nets={})".format( + self.tasks(), self.workspace_type(), self.remote_nets()) + class TaskOutput(object): """ @@ -389,6 +400,9 @@ class TaskOutput(object): else: return fetched_vals + def __repr__(self): + return "TaskOutput(names={}, values={})".format(self.names, self._values) + def final_output(blob_or_record): """ @@ -424,6 +438,9 @@ class TaskOutputList(object): offset += num assert offset == len(values), 'Wrong number of output values.' + def __repr__(self): + return "TaskOutputList(outputs={})".format(self.outputs) + @context.define_context() class Task(object): @@ -625,6 +642,10 @@ class Task(object): self.get_step() self._already_used = True + def __repr__(self): + return "Task(name={}, node={}, outputs={})".format( + self.name, self.node, self.outputs()) + class SetupNets(object): """ @@ -668,3 +689,7 @@ class SetupNets(object): def exit(self, exit_net): return self.exit_nets + + def __repr__(self): + return "SetupNets(init_nets={}, exit_nets={})".format( + self.init_nets, self.exit_nets) diff --git a/caffe2/python/task_test.py b/caffe2/python/task_test.py new file mode 100644 index 0000000000..f1c51bc5b4 --- /dev/null +++ b/caffe2/python/task_test.py @@ -0,0 +1,24 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest +from caffe2.python import task + + +class TestTask(unittest.TestCase): + def testRepr(self): + cases = [ + (task.Cluster(), "Cluster(nodes=[], node_kwargs={})"), + (task.Node(), "Node(name=local, kwargs={})"), + ( + task.TaskGroup(), + "TaskGroup(tasks=[], workspace_type=None, remote_nets=[])", + ), + (task.TaskOutput([]), "TaskOutput(names=[], values=None)"), + (task.Task(), "Task(name=local/task, node=local, outputs=[])"), + (task.SetupNets(), "SetupNets(init_nets=None, exit_nets=None)"), + ] + for obj, want in cases: + self.assertEqual(obj.__repr__(), want) |