summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeff Donahue <jeff.donahue@gmail.com>2015-08-21 17:29:06 -0700
committerJeff Donahue <jeff.donahue@gmail.com>2015-09-03 16:31:28 -0700
commitc2484747d813d616bcb504d97f93071b26bb372d (patch)
treecc48a3a82a52f6789df9a502fd65e0a24c066490
parent66823b59d70097f4ccbe3631b102ef238c08535b (diff)
downloadcaffeonacl-c2484747d813d616bcb504d97f93071b26bb372d.tar.gz
caffeonacl-c2484747d813d616bcb504d97f93071b26bb372d.tar.bz2
caffeonacl-c2484747d813d616bcb504d97f93071b26bb372d.zip
NetSpec: don't require lists to specify single-element repeated fields
-rw-r--r--python/caffe/net_spec.py10
-rw-r--r--python/caffe/test/test_net_spec.py3
2 files changed, 9 insertions, 4 deletions
diff --git a/python/caffe/net_spec.py b/python/caffe/net_spec.py
index 77a0e007..93fc0192 100644
--- a/python/caffe/net_spec.py
+++ b/python/caffe/net_spec.py
@@ -56,8 +56,14 @@ def to_proto(*tops):
def assign_proto(proto, name, val):
"""Assign a Python object to a protobuf message, based on the Python
type (in recursive fashion). Lists become repeated fields/messages, dicts
- become messages, and other types are assigned directly."""
-
+ become messages, and other types are assigned directly. For convenience,
+ repeated fields whose values are not lists are converted to single-element
+ lists; e.g., `my_repeated_int_field=3` is converted to
+ `my_repeated_int_field=[3]`."""
+
+ is_repeated_field = hasattr(getattr(proto, name), 'extend')
+ if is_repeated_field and not isinstance(val, list):
+ val = [val]
if isinstance(val, list):
if isinstance(val[0], dict):
for item in val:
diff --git a/python/caffe/test/test_net_spec.py b/python/caffe/test/test_net_spec.py
index b4595e65..fee3c0aa 100644
--- a/python/caffe/test/test_net_spec.py
+++ b/python/caffe/test/test_net_spec.py
@@ -43,8 +43,7 @@ def anon_lenet(batch_size):
def silent_net():
n = caffe.NetSpec()
- n.data, n.data2 = L.DummyData(shape=[dict(dim=[3]), dict(dim=[4, 2])],
- ntop=2)
+ n.data, n.data2 = L.DummyData(shape=dict(dim=3), ntop=2)
n.silence_data = L.Silence(n.data, ntop=0)
n.silence_data2 = L.Silence(n.data2, ntop=0)
return n.to_proto()