summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan L Long <jonlong@cs.berkeley.edu>2015-07-23 20:35:42 -0700
committerJonathan L Long <jonlong@cs.berkeley.edu>2015-07-23 20:35:42 -0700
commit96c2fe1de80c9752b992c4578a3ce46028d21fc5 (patch)
tree67a38c71c407bd8e7669b9ba8dd3e16d4b0a4ecd
parentf16195aa8b1569e0260c48d4159b7d2ce0ea2fab (diff)
downloadcaffeonacl-96c2fe1de80c9752b992c4578a3ce46028d21fc5.tar.gz
caffeonacl-96c2fe1de80c9752b992c4578a3ce46028d21fc5.tar.bz2
caffeonacl-96c2fe1de80c9752b992c4578a3ce46028d21fc5.zip
[pycaffe] allow layers to have names different from their first tops
Previously, net spec only allowed names to be assigned to Tops, giving layers the names of their first tops. Now, names can be assigned to Functions, which become layer names in serialization. Unnamed Functions still get named after their first top, if present, or autogenerated, if not. (This will allow top-less layers in a natural way.)
-rw-r--r--python/caffe/net_spec.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/python/caffe/net_spec.py b/python/caffe/net_spec.py
index 5fb26ed4..16f30008 100644
--- a/python/caffe/net_spec.py
+++ b/python/caffe/net_spec.py
@@ -108,7 +108,15 @@ class Function(object):
del self.params['in_place']
self.tops = tuple(Top(self, n) for n in range(self.ntop))
- def _get_name(self, top, names, autonames):
+ def _get_name(self, names, autonames):
+ if self not in names and self.ntop > 0:
+ names[self] = self._get_top_name(self.tops[0], names, autonames)
+ elif self not in names:
+ autonames[self.type_name] += 1
+ names[self] = self.type_name + str(autonames[self.type_name])
+ return names[self]
+
+ def _get_top_name(self, top, names, autonames):
if top not in names:
autonames[top.fn.type_name] += 1
names[top] = top.fn.type_name + str(autonames[top.fn.type_name])
@@ -129,8 +137,8 @@ class Function(object):
layer.top.extend(layer.bottom)
else:
for top in self.tops:
- layer.top.append(self._get_name(top, names, autonames))
- layer.name = self._get_name(self.tops[0], names, autonames)
+ layer.top.append(self._get_top_name(top, names, autonames))
+ layer.name = self._get_name(names, autonames)
for k, v in six.iteritems(self.params):
# special case to handle generic *params