summaryrefslogtreecommitdiff
path: root/torch/storage.py
blob: 104f9eec59667d653509b898c4ea6648b519a6a0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from ._utils import _type, _cuda, _range


class _StorageBase(object):
    is_cuda = False
    is_sparse = False

    def __str__(self):
        content = ' ' + '\n '.join(str(self[i]) for i in _range(len(self)))
        return content + '\n[{} of size {}]'.format(torch.typename(self), len(self))

    def __repr__(self):
        return str(self)

    def __iter__(self):
        return iter(map(lambda i: self[i], _range(self.size())))

    def __copy__(self):
        return self.clone()

    def __deepcopy__(self, memo):
        memo = memo.setdefault('torch', {})
        if self._cdata in memo:
            return memo[self._cdata]
        new_storage = self.clone()
        memo[self._cdata] = new_storage
        return new_storage

    def __reduce__(self):
        return type(self), (self.tolist(),)

    def clone(self):
        """Returns a copy of this storage"""
        return type(self)(self.size()).copy_(self)

    def tolist(self):
        """Returns a list containing the elements of this storage"""
        return [v for v in self]

    def cpu(self):
        """Returns a CPU copy of this storage if it's not already on the CPU"""
        return self.type(getattr(torch, self.__class__.__name__))

    def double(self):
        """Casts this storage to double type"""
        return self.type(type(self).__module__ + '.DoubleStorage')

    def float(self):
        """Casts this storage to float type"""
        return self.type(type(self).__module__ + '.FloatStorage')

    def half(self):
        """Casts this storage to half type"""
        return self.type(type(self).__module__ + '.HalfStorage')

    def long(self):
        """Casts this storage to long type"""
        return self.type(type(self).__module__ + '.LongStorage')

    def int(self):
        """Casts this storage to int type"""
        return self.type(type(self).__module__ + '.IntStorage')

    def short(self):
        """Casts this storage to short type"""
        return self.type(type(self).__module__ + '.ShortStorage')

    def char(self):
        """Casts this storage to char type"""
        return self.type(type(self).__module__ + '.CharStorage')

    def byte(self):
        """Casts this storage to byte type"""
        return self.type(type(self).__module__ + '.ByteStorage')

    def pin_memory(self):
        """Copies the storage to pinned memory, if it's not already pinned."""
        if self.is_cuda:
            raise TypeError("cannot pin '{0}' only CPU memory can be pinned"
                            .format(self.type()))
        import torch.cuda
        allocator = torch.cuda._host_allocator()
        return type(self)(self.size(), allocator=allocator).copy_(self)

    def share_memory_(self):
        """Moves the storage to shared memory.

        This is a no-op for storages already in shared memory and for CUDA
        storages, which do not need to be moved for sharing across processes.
        Storages in shared memory cannot be resized.

        Returns: self
        """
        from torch.multiprocessing import get_sharing_strategy
        if self.is_cuda:
            pass  # CUDA doesn't use POSIX shared memory
        elif get_sharing_strategy() == 'file_system':
            self._share_filename_()
        else:
            self._share_fd_()
        return self

    @classmethod
    def _new_shared(cls, size):
        """Creates a new storage in shared memory with the same data type"""
        from torch.multiprocessing import get_sharing_strategy
        if cls.is_cuda:
            return cls(size)
        elif get_sharing_strategy() == 'file_system':
            return cls._new_using_filename(size)
        else:
            return cls._new_using_fd(size)


_StorageBase.type = _type
_StorageBase.cuda = _cuda