diff options
author | Jeff Donahue <jeff.donahue@gmail.com> | 2015-08-21 17:29:06 -0700 |
---|---|---|
committer | Jeff Donahue <jeff.donahue@gmail.com> | 2015-09-03 16:31:28 -0700 |
commit | c2484747d813d616bcb504d97f93071b26bb372d (patch) | |
tree | cc48a3a82a52f6789df9a502fd65e0a24c066490 /python | |
parent | 66823b59d70097f4ccbe3631b102ef238c08535b (diff) | |
download | caffeonacl-c2484747d813d616bcb504d97f93071b26bb372d.tar.gz caffeonacl-c2484747d813d616bcb504d97f93071b26bb372d.tar.bz2 caffeonacl-c2484747d813d616bcb504d97f93071b26bb372d.zip |
NetSpec: don't require lists to specify single-element repeated fields
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/net_spec.py | 10 | ||||
-rw-r--r-- | python/caffe/test/test_net_spec.py | 3 |
2 files changed, 9 insertions, 4 deletions
diff --git a/python/caffe/net_spec.py b/python/caffe/net_spec.py index 77a0e007..93fc0192 100644 --- a/python/caffe/net_spec.py +++ b/python/caffe/net_spec.py @@ -56,8 +56,14 @@ def to_proto(*tops): def assign_proto(proto, name, val): """Assign a Python object to a protobuf message, based on the Python type (in recursive fashion). Lists become repeated fields/messages, dicts - become messages, and other types are assigned directly.""" - + become messages, and other types are assigned directly. For convenience, + repeated fields whose values are not lists are converted to single-element + lists; e.g., `my_repeated_int_field=3` is converted to + `my_repeated_int_field=[3]`.""" + + is_repeated_field = hasattr(getattr(proto, name), 'extend') + if is_repeated_field and not isinstance(val, list): + val = [val] if isinstance(val, list): if isinstance(val[0], dict): for item in val: diff --git a/python/caffe/test/test_net_spec.py b/python/caffe/test/test_net_spec.py index b4595e65..fee3c0aa 100644 --- a/python/caffe/test/test_net_spec.py +++ b/python/caffe/test/test_net_spec.py @@ -43,8 +43,7 @@ def anon_lenet(batch_size): def silent_net(): n = caffe.NetSpec() - n.data, n.data2 = L.DummyData(shape=[dict(dim=[3]), dict(dim=[4, 2])], - ntop=2) + n.data, n.data2 = L.DummyData(shape=dict(dim=3), ntop=2) n.silence_data = L.Silence(n.data, ntop=0) n.silence_data2 = L.Silence(n.data2, ntop=0) return n.to_proto() |