diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2015-08-07 21:06:29 -0700 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2015-08-07 21:06:29 -0700 |
commit | eb3e1149a2fcc9c48d268ffe2319d872081e4c3b (patch) | |
tree | a38d46ade530575dd447a0b52471709f4a98fd09 /python | |
parent | 311e47bc1a6dabea8bcad7cf23b5332be8850b39 (diff) | |
parent | c80f6f585549334022996395015cd490cc2644c8 (diff) | |
download | caffeonacl-eb3e1149a2fcc9c48d268ffe2319d872081e4c3b.tar.gz caffeonacl-eb3e1149a2fcc9c48d268ffe2319d872081e4c3b.tar.bz2 caffeonacl-eb3e1149a2fcc9c48d268ffe2319d872081e4c3b.zip |
Merge pull request #2813 from longjon/net-spec-imp
Python net spec cleanups and top-less layers
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/net_spec.py | 36 | ||||
-rw-r--r-- | python/caffe/test/test_net_spec.py | 15 |
2 files changed, 38 insertions, 13 deletions
diff --git a/python/caffe/net_spec.py b/python/caffe/net_spec.py index 1b4814a4..31cde7ad 100644 --- a/python/caffe/net_spec.py +++ b/python/caffe/net_spec.py @@ -18,7 +18,7 @@ for specifying nets. In particular, the automatically generated layer names are not guaranteed to be forward-compatible. """ -from collections import OrderedDict +from collections import OrderedDict, Counter from .proto import caffe_pb2 from google import protobuf @@ -44,10 +44,8 @@ def to_proto(*tops): """Generate a NetParameter that contains all layers needed to compute all arguments.""" - if not isinstance(tops, tuple): - tops = (tops,) layers = OrderedDict() - autonames = {} + autonames = Counter() for top in tops: top.fn._to_proto(layers, {}, autonames) net = caffe_pb2.NetParameter() @@ -89,6 +87,9 @@ class Top(object): return to_proto(self) + def _to_proto(self, layers, names, autonames): + return self.fn._to_proto(layers, names, autonames) + class Function(object): """A Function specifies a layer, its parameters, and its inputs (which @@ -107,11 +108,18 @@ class Function(object): del self.params['in_place'] self.tops = tuple(Top(self, n) for n in range(self.ntop)) - def _get_name(self, top, names, autonames): + def _get_name(self, names, autonames): + if self not in names and self.ntop > 0: + names[self] = self._get_top_name(self.tops[0], names, autonames) + elif self not in names: + autonames[self.type_name] += 1 + names[self] = self.type_name + str(autonames[self.type_name]) + return names[self] + + def _get_top_name(self, top, names, autonames): if top not in names: - n = autonames.setdefault(top.fn.type_name, 1) autonames[top.fn.type_name] += 1 - names[top] = top.fn.type_name + str(n) + names[top] = top.fn.type_name + str(autonames[top.fn.type_name]) return names[top] def _to_proto(self, layers, names, autonames): @@ -119,7 +127,7 @@ class Function(object): return bottom_names = [] for inp in self.inputs: - inp.fn._to_proto(layers, names, autonames) + inp._to_proto(layers, names, autonames) bottom_names.append(layers[inp.fn].top[inp.n]) layer = caffe_pb2.LayerParameter() layer.type = self.type_name @@ -129,8 +137,8 @@ class Function(object): layer.top.extend(layer.bottom) else: for top in self.tops: - layer.top.append(self._get_name(top, names, autonames)) - layer.name = self._get_name(self.tops[0], names, autonames) + layer.top.append(self._get_top_name(top, names, autonames)) + layer.name = self._get_name(names, autonames) for k, v in six.iteritems(self.params): # special case to handle generic *params @@ -163,10 +171,10 @@ class NetSpec(object): def to_proto(self): names = {v: k for k, v in six.iteritems(self.tops)} - autonames = {} + autonames = Counter() layers = OrderedDict() for name, top in six.iteritems(self.tops): - top.fn._to_proto(layers, names, autonames) + top._to_proto(layers, names, autonames) net = caffe_pb2.NetParameter() net.layer.extend(layers.values()) return net @@ -180,7 +188,9 @@ class Layers(object): def __getattr__(self, name): def layer_fn(*args, **kwargs): fn = Function(name, args, kwargs) - if fn.ntop == 1: + if fn.ntop == 0: + return fn + elif fn.ntop == 1: return fn.tops[0] else: return fn.tops diff --git a/python/caffe/test/test_net_spec.py b/python/caffe/test/test_net_spec.py index 909a101b..b4595e65 100644 --- a/python/caffe/test/test_net_spec.py +++ b/python/caffe/test/test_net_spec.py @@ -41,6 +41,14 @@ def anon_lenet(batch_size): loss = L.SoftmaxWithLoss(ip2, label) return loss.to_proto() +def silent_net(): + n = caffe.NetSpec() + n.data, n.data2 = L.DummyData(shape=[dict(dim=[3]), dict(dim=[4, 2])], + ntop=2) + n.silence_data = L.Silence(n.data, ntop=0) + n.silence_data2 = L.Silence(n.data2, ntop=0) + return n.to_proto() + class TestNetSpec(unittest.TestCase): def load_net(self, net_proto): f = tempfile.NamedTemporaryFile(mode='w+', delete=False) @@ -65,3 +73,10 @@ class TestNetSpec(unittest.TestCase): net_proto.layer[6].top) net = self.load_net(net_proto) self.assertEqual(len(net.layers), 9) + + def test_zero_tops(self): + """Test net construction for top-less layers.""" + + net_proto = silent_net() + net = self.load_net(net_proto) + self.assertEqual(len(net.forward()), 0) |