summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTristan Rice <tristanr@fb.com>2018-12-17 15:59:45 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-12-17 16:02:16 -0800
commite650a848721a39707dcdde3e36a5ca05f3da0510 (patch)
tree1eaddf5e7e45a2ef2aceabe58bef805b6c5c0091
parente0b261a35b1483e15c7b81a476851abfdb69602a (diff)
downloadpytorch-e650a848721a39707dcdde3e36a5ca05f3da0510.tar.gz
pytorch-e650a848721a39707dcdde3e36a5ca05f3da0510.tar.bz2
pytorch-e650a848721a39707dcdde3e36a5ca05f3da0510.zip
caffe2/python/task: added __repr__ methods to all task definitions (#15250)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15250 This adds `__repr__` methods to all of the classes under task.py. This makes the objects much easier to interact with when using them in an interactive manner, such as in a Jupyter notebook. The default `__repr__` method just returns the object ID which is very unhelpful. Reviewed By: hanli0612 Differential Revision: D13475758 fbshipit-source-id: 6e1b166ec35163b9776c797b6a2e0d002560cd29
-rw-r--r--caffe2/python/task.py25
-rw-r--r--caffe2/python/task_test.py24
2 files changed, 49 insertions, 0 deletions
diff --git a/caffe2/python/task.py b/caffe2/python/task.py
index 161ba4f5bb..eb7ad4edfa 100644
--- a/caffe2/python/task.py
+++ b/caffe2/python/task.py
@@ -52,6 +52,10 @@ class Cluster(object):
def node_kwargs(self):
return self._node_kwargs
+ def __repr__(self):
+ return "Cluster(nodes={}, node_kwargs={})".format(
+ self.nodes(), self.node_kwargs())
+
@context.define_context(allow_default=True)
class Node(object):
@@ -85,6 +89,9 @@ class Node(object):
def __str__(self):
return self._name
+ def __repr__(self):
+ return "Node(name={}, kwargs={})".format(self._name, self._kwargs)
+
def kwargs(self):
return self._kwargs
@@ -345,6 +352,10 @@ class TaskGroup(object):
def workspace_type(self):
return self._workspace_type
+ def __repr__(self):
+ return "TaskGroup(tasks={}, workspace_type={}, remote_nets={})".format(
+ self.tasks(), self.workspace_type(), self.remote_nets())
+
class TaskOutput(object):
"""
@@ -389,6 +400,9 @@ class TaskOutput(object):
else:
return fetched_vals
+ def __repr__(self):
+ return "TaskOutput(names={}, values={})".format(self.names, self._values)
+
def final_output(blob_or_record):
"""
@@ -424,6 +438,9 @@ class TaskOutputList(object):
offset += num
assert offset == len(values), 'Wrong number of output values.'
+ def __repr__(self):
+ return "TaskOutputList(outputs={})".format(self.outputs)
+
@context.define_context()
class Task(object):
@@ -625,6 +642,10 @@ class Task(object):
self.get_step()
self._already_used = True
+ def __repr__(self):
+ return "Task(name={}, node={}, outputs={})".format(
+ self.name, self.node, self.outputs())
+
class SetupNets(object):
"""
@@ -668,3 +689,7 @@ class SetupNets(object):
def exit(self, exit_net):
return self.exit_nets
+
+ def __repr__(self):
+ return "SetupNets(init_nets={}, exit_nets={})".format(
+ self.init_nets, self.exit_nets)
diff --git a/caffe2/python/task_test.py b/caffe2/python/task_test.py
new file mode 100644
index 0000000000..f1c51bc5b4
--- /dev/null
+++ b/caffe2/python/task_test.py
@@ -0,0 +1,24 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+from caffe2.python import task
+
+
+class TestTask(unittest.TestCase):
+ def testRepr(self):
+ cases = [
+ (task.Cluster(), "Cluster(nodes=[], node_kwargs={})"),
+ (task.Node(), "Node(name=local, kwargs={})"),
+ (
+ task.TaskGroup(),
+ "TaskGroup(tasks=[], workspace_type=None, remote_nets=[])",
+ ),
+ (task.TaskOutput([]), "TaskOutput(names=[], values=None)"),
+ (task.Task(), "Task(name=local/task, node=local, outputs=[])"),
+ (task.SetupNets(), "SetupNets(init_nets=None, exit_nets=None)"),
+ ]
+ for obj, want in cases:
+ self.assertEqual(obj.__repr__(), want)