diff options
Diffstat (limited to 'torch/legacy/nn/utils.py')
-rw-r--r-- | torch/legacy/nn/utils.py | 32 |
1 files changed, 22 insertions, 10 deletions
diff --git a/torch/legacy/nn/utils.py b/torch/legacy/nn/utils.py index 8a76117f03..0432a6e3a0 100644 --- a/torch/legacy/nn/utils.py +++ b/torch/legacy/nn/utils.py @@ -13,6 +13,8 @@ import torch # > net1:type('torch.cuda.FloatTensor', tensorCache) # > net2:type('torch.cuda.FloatTensor', tensorCache) # > nn.utils.recursiveType(anotherTensor, 'torch.cuda.FloatTensor', tensorCache) + + def recursiveType(param, type, tensorCache={}): from .Criterion import Criterion from .Module import Module @@ -28,12 +30,13 @@ def recursiveType(param, type, tensorCache={}): newparam = tensorCache[key] else: newparam = torch.Tensor().type(type) - storageType = type.replace('Tensor','Storage') + storageType = type.replace('Tensor', 'Storage') param_storage = param.storage() if param_storage: storage_key = param_storage._cdata if storage_key not in tensorCache: - tensorCache[storage_key] = torch._import_dotted_name(storageType)(param_storage.size()).copy_(param_storage) + tensorCache[storage_key] = torch._import_dotted_name( + storageType)(param_storage.size()).copy_(param_storage) newparam.set_( tensorCache[storage_key], param.storage_offset(), @@ -44,6 +47,7 @@ def recursiveType(param, type, tensorCache={}): param = newparam return param + def recursiveResizeAs(t1, t2): if isinstance(t2, list): t1 = t1 if isinstance(t1, list) else [t1] @@ -56,20 +60,22 @@ def recursiveResizeAs(t1, t2): t1 = t1 if torch.is_tensor(t1) else t2.new() t1.resize_as_(t2) else: - raise RuntimeError("Expecting nested tensors or tables. Got " + \ - type(t1).__name__ + " and " + type(t2).__name__ + "instead") + raise RuntimeError("Expecting nested tensors or tables. Got " + + type(t1).__name__ + " and " + type(t2).__name__ + "instead") return t1, t2 + def recursiveFill(t2, val): if isinstance(t2, list): t2 = [recursiveFill(x, val) for x in t2] elif torch.is_tensor(t2): t2.fill_(val) else: - raise RuntimeError("expecting tensor or table thereof. Got " + \ - type(t2).__name__ + " instead") + raise RuntimeError("expecting tensor or table thereof. Got " + + type(t2).__name__ + " instead") return t2 + def recursiveAdd(t1, val=1, t2=None): if t2 is None: t2 = val @@ -81,10 +87,11 @@ def recursiveAdd(t1, val=1, t2=None): elif torch.is_tensor(t1) and torch.is_tensor(t2): t1.add_(val, t2) else: - raise RuntimeError("expecting nested tensors or tables. Got " + \ - type(t1).__name__ + " and " + type(t2).__name__ + " instead") + raise RuntimeError("expecting nested tensors or tables. Got " + + type(t1).__name__ + " and " + type(t2).__name__ + " instead") return t1, t2 + def recursiveCopy(t1, t2): if isinstance(t2, list): t1 = t1 if isinstance(t1, list) else [t1] @@ -94,10 +101,11 @@ def recursiveCopy(t1, t2): t1 = t1 if torch.is_tensor(t1) else t2.new() t1.resize_as_(t2).copy_(t2) else: - raise RuntimeError("expecting nested tensors or tables. Got " + \ - type(t1).__name__ + " and " + type(t2).__name__ + " instead") + raise RuntimeError("expecting nested tensors or tables. Got " + + type(t1).__name__ + " and " + type(t2).__name__ + " instead") return t1, t2 + def addSingletondimension(*args): view = None if len(args) < 3: @@ -109,6 +117,7 @@ def addSingletondimension(*args): view.set_(t) return view.unsqueeze_(dim) + def contiguousView(output, input, *args): if output is None: output = input.new() @@ -123,9 +132,12 @@ def contiguousView(output, input, *args): # go over specified fields and clear them. accepts # nn.clearState(self, ['_buffer', '_buffer2']) and # nn.clearState(self, '_buffer', '_buffer2') + + def clear(self, *args): if len(args) == 1 and isinstance(args[0], list): args = args[0] + def _clear(f): if not hasattr(self, f): return |