diff options
author | Jonathan L Long <jonlong@cs.berkeley.edu> | 2015-07-23 20:35:42 -0700 |
---|---|---|
committer | Jonathan L Long <jonlong@cs.berkeley.edu> | 2015-07-23 20:35:42 -0700 |
commit | 96c2fe1de80c9752b992c4578a3ce46028d21fc5 (patch) | |
tree | 67a38c71c407bd8e7669b9ba8dd3e16d4b0a4ecd | |
parent | f16195aa8b1569e0260c48d4159b7d2ce0ea2fab (diff) | |
download | caffeonacl-96c2fe1de80c9752b992c4578a3ce46028d21fc5.tar.gz caffeonacl-96c2fe1de80c9752b992c4578a3ce46028d21fc5.tar.bz2 caffeonacl-96c2fe1de80c9752b992c4578a3ce46028d21fc5.zip |
[pycaffe] allow layers to have names different from their first tops
Previously, net spec only allowed names to be assigned to Tops, giving
layers the names of their first tops. Now, names can be assigned to
Functions, which become layer names in serialization. Unnamed Functions
still get named after their first top, if present, or autogenerated, if
not. (This will allow top-less layers in a natural way.)
-rw-r--r-- | python/caffe/net_spec.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/python/caffe/net_spec.py b/python/caffe/net_spec.py index 5fb26ed4..16f30008 100644 --- a/python/caffe/net_spec.py +++ b/python/caffe/net_spec.py @@ -108,7 +108,15 @@ class Function(object): del self.params['in_place'] self.tops = tuple(Top(self, n) for n in range(self.ntop)) - def _get_name(self, top, names, autonames): + def _get_name(self, names, autonames): + if self not in names and self.ntop > 0: + names[self] = self._get_top_name(self.tops[0], names, autonames) + elif self not in names: + autonames[self.type_name] += 1 + names[self] = self.type_name + str(autonames[self.type_name]) + return names[self] + + def _get_top_name(self, top, names, autonames): if top not in names: autonames[top.fn.type_name] += 1 names[top] = top.fn.type_name + str(autonames[top.fn.type_name]) @@ -129,8 +137,8 @@ class Function(object): layer.top.extend(layer.bottom) else: for top in self.tops: - layer.top.append(self._get_name(top, names, autonames)) - layer.name = self._get_name(self.tops[0], names, autonames) + layer.top.append(self._get_top_name(top, names, autonames)) + layer.name = self._get_name(names, autonames) for k, v in six.iteritems(self.params): # special case to handle generic *params |