summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2015-08-07 21:06:29 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2015-08-07 21:06:29 -0700
commiteb3e1149a2fcc9c48d268ffe2319d872081e4c3b (patch)
treea38d46ade530575dd447a0b52471709f4a98fd09 /python
parent311e47bc1a6dabea8bcad7cf23b5332be8850b39 (diff)
parentc80f6f585549334022996395015cd490cc2644c8 (diff)
downloadcaffeonacl-eb3e1149a2fcc9c48d268ffe2319d872081e4c3b.tar.gz
caffeonacl-eb3e1149a2fcc9c48d268ffe2319d872081e4c3b.tar.bz2
caffeonacl-eb3e1149a2fcc9c48d268ffe2319d872081e4c3b.zip
Merge pull request #2813 from longjon/net-spec-imp
Python net spec cleanups and top-less layers
Diffstat (limited to 'python')
-rw-r--r--python/caffe/net_spec.py36
-rw-r--r--python/caffe/test/test_net_spec.py15
2 files changed, 38 insertions, 13 deletions
diff --git a/python/caffe/net_spec.py b/python/caffe/net_spec.py
index 1b4814a4..31cde7ad 100644
--- a/python/caffe/net_spec.py
+++ b/python/caffe/net_spec.py
@@ -18,7 +18,7 @@ for specifying nets. In particular, the automatically generated layer names
are not guaranteed to be forward-compatible.
"""
-from collections import OrderedDict
+from collections import OrderedDict, Counter
from .proto import caffe_pb2
from google import protobuf
@@ -44,10 +44,8 @@ def to_proto(*tops):
"""Generate a NetParameter that contains all layers needed to compute
all arguments."""
- if not isinstance(tops, tuple):
- tops = (tops,)
layers = OrderedDict()
- autonames = {}
+ autonames = Counter()
for top in tops:
top.fn._to_proto(layers, {}, autonames)
net = caffe_pb2.NetParameter()
@@ -89,6 +87,9 @@ class Top(object):
return to_proto(self)
+ def _to_proto(self, layers, names, autonames):
+ return self.fn._to_proto(layers, names, autonames)
+
class Function(object):
"""A Function specifies a layer, its parameters, and its inputs (which
@@ -107,11 +108,18 @@ class Function(object):
del self.params['in_place']
self.tops = tuple(Top(self, n) for n in range(self.ntop))
- def _get_name(self, top, names, autonames):
+ def _get_name(self, names, autonames):
+ if self not in names and self.ntop > 0:
+ names[self] = self._get_top_name(self.tops[0], names, autonames)
+ elif self not in names:
+ autonames[self.type_name] += 1
+ names[self] = self.type_name + str(autonames[self.type_name])
+ return names[self]
+
+ def _get_top_name(self, top, names, autonames):
if top not in names:
- n = autonames.setdefault(top.fn.type_name, 1)
autonames[top.fn.type_name] += 1
- names[top] = top.fn.type_name + str(n)
+ names[top] = top.fn.type_name + str(autonames[top.fn.type_name])
return names[top]
def _to_proto(self, layers, names, autonames):
@@ -119,7 +127,7 @@ class Function(object):
return
bottom_names = []
for inp in self.inputs:
- inp.fn._to_proto(layers, names, autonames)
+ inp._to_proto(layers, names, autonames)
bottom_names.append(layers[inp.fn].top[inp.n])
layer = caffe_pb2.LayerParameter()
layer.type = self.type_name
@@ -129,8 +137,8 @@ class Function(object):
layer.top.extend(layer.bottom)
else:
for top in self.tops:
- layer.top.append(self._get_name(top, names, autonames))
- layer.name = self._get_name(self.tops[0], names, autonames)
+ layer.top.append(self._get_top_name(top, names, autonames))
+ layer.name = self._get_name(names, autonames)
for k, v in six.iteritems(self.params):
# special case to handle generic *params
@@ -163,10 +171,10 @@ class NetSpec(object):
def to_proto(self):
names = {v: k for k, v in six.iteritems(self.tops)}
- autonames = {}
+ autonames = Counter()
layers = OrderedDict()
for name, top in six.iteritems(self.tops):
- top.fn._to_proto(layers, names, autonames)
+ top._to_proto(layers, names, autonames)
net = caffe_pb2.NetParameter()
net.layer.extend(layers.values())
return net
@@ -180,7 +188,9 @@ class Layers(object):
def __getattr__(self, name):
def layer_fn(*args, **kwargs):
fn = Function(name, args, kwargs)
- if fn.ntop == 1:
+ if fn.ntop == 0:
+ return fn
+ elif fn.ntop == 1:
return fn.tops[0]
else:
return fn.tops
diff --git a/python/caffe/test/test_net_spec.py b/python/caffe/test/test_net_spec.py
index 909a101b..b4595e65 100644
--- a/python/caffe/test/test_net_spec.py
+++ b/python/caffe/test/test_net_spec.py
@@ -41,6 +41,14 @@ def anon_lenet(batch_size):
loss = L.SoftmaxWithLoss(ip2, label)
return loss.to_proto()
+def silent_net():
+ n = caffe.NetSpec()
+ n.data, n.data2 = L.DummyData(shape=[dict(dim=[3]), dict(dim=[4, 2])],
+ ntop=2)
+ n.silence_data = L.Silence(n.data, ntop=0)
+ n.silence_data2 = L.Silence(n.data2, ntop=0)
+ return n.to_proto()
+
class TestNetSpec(unittest.TestCase):
def load_net(self, net_proto):
f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
@@ -65,3 +73,10 @@ class TestNetSpec(unittest.TestCase):
net_proto.layer[6].top)
net = self.load_net(net_proto)
self.assertEqual(len(net.layers), 9)
+
+ def test_zero_tops(self):
+ """Test net construction for top-less layers."""
+
+ net_proto = silent_net()
+ net = self.load_net(net_proto)
+ self.assertEqual(len(net.forward()), 0)