summaryrefslogtreecommitdiff
path: root/torch/legacy/nn/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'torch/legacy/nn/utils.py')
-rw-r--r--torch/legacy/nn/utils.py32
1 files changed, 22 insertions, 10 deletions
diff --git a/torch/legacy/nn/utils.py b/torch/legacy/nn/utils.py
index 8a76117f03..0432a6e3a0 100644
--- a/torch/legacy/nn/utils.py
+++ b/torch/legacy/nn/utils.py
@@ -13,6 +13,8 @@ import torch
# > net1:type('torch.cuda.FloatTensor', tensorCache)
# > net2:type('torch.cuda.FloatTensor', tensorCache)
# > nn.utils.recursiveType(anotherTensor, 'torch.cuda.FloatTensor', tensorCache)
+
+
def recursiveType(param, type, tensorCache={}):
from .Criterion import Criterion
from .Module import Module
@@ -28,12 +30,13 @@ def recursiveType(param, type, tensorCache={}):
newparam = tensorCache[key]
else:
newparam = torch.Tensor().type(type)
- storageType = type.replace('Tensor','Storage')
+ storageType = type.replace('Tensor', 'Storage')
param_storage = param.storage()
if param_storage:
storage_key = param_storage._cdata
if storage_key not in tensorCache:
- tensorCache[storage_key] = torch._import_dotted_name(storageType)(param_storage.size()).copy_(param_storage)
+ tensorCache[storage_key] = torch._import_dotted_name(
+ storageType)(param_storage.size()).copy_(param_storage)
newparam.set_(
tensorCache[storage_key],
param.storage_offset(),
@@ -44,6 +47,7 @@ def recursiveType(param, type, tensorCache={}):
param = newparam
return param
+
def recursiveResizeAs(t1, t2):
if isinstance(t2, list):
t1 = t1 if isinstance(t1, list) else [t1]
@@ -56,20 +60,22 @@ def recursiveResizeAs(t1, t2):
t1 = t1 if torch.is_tensor(t1) else t2.new()
t1.resize_as_(t2)
else:
- raise RuntimeError("Expecting nested tensors or tables. Got " + \
- type(t1).__name__ + " and " + type(t2).__name__ + "instead")
+ raise RuntimeError("Expecting nested tensors or tables. Got " +
+ type(t1).__name__ + " and " + type(t2).__name__ + "instead")
return t1, t2
+
def recursiveFill(t2, val):
if isinstance(t2, list):
t2 = [recursiveFill(x, val) for x in t2]
elif torch.is_tensor(t2):
t2.fill_(val)
else:
- raise RuntimeError("expecting tensor or table thereof. Got " + \
- type(t2).__name__ + " instead")
+ raise RuntimeError("expecting tensor or table thereof. Got " +
+ type(t2).__name__ + " instead")
return t2
+
def recursiveAdd(t1, val=1, t2=None):
if t2 is None:
t2 = val
@@ -81,10 +87,11 @@ def recursiveAdd(t1, val=1, t2=None):
elif torch.is_tensor(t1) and torch.is_tensor(t2):
t1.add_(val, t2)
else:
- raise RuntimeError("expecting nested tensors or tables. Got " + \
- type(t1).__name__ + " and " + type(t2).__name__ + " instead")
+ raise RuntimeError("expecting nested tensors or tables. Got " +
+ type(t1).__name__ + " and " + type(t2).__name__ + " instead")
return t1, t2
+
def recursiveCopy(t1, t2):
if isinstance(t2, list):
t1 = t1 if isinstance(t1, list) else [t1]
@@ -94,10 +101,11 @@ def recursiveCopy(t1, t2):
t1 = t1 if torch.is_tensor(t1) else t2.new()
t1.resize_as_(t2).copy_(t2)
else:
- raise RuntimeError("expecting nested tensors or tables. Got " + \
- type(t1).__name__ + " and " + type(t2).__name__ + " instead")
+ raise RuntimeError("expecting nested tensors or tables. Got " +
+ type(t1).__name__ + " and " + type(t2).__name__ + " instead")
return t1, t2
+
def addSingletondimension(*args):
view = None
if len(args) < 3:
@@ -109,6 +117,7 @@ def addSingletondimension(*args):
view.set_(t)
return view.unsqueeze_(dim)
+
def contiguousView(output, input, *args):
if output is None:
output = input.new()
@@ -123,9 +132,12 @@ def contiguousView(output, input, *args):
# go over specified fields and clear them. accepts
# nn.clearState(self, ['_buffer', '_buffer2']) and
# nn.clearState(self, '_buffer', '_buffer2')
+
+
def clear(self, *args):
if len(args) == 1 and isinstance(args[0], list):
args = args[0]
+
def _clear(f):
if not hasattr(self, f):
return