summaryrefslogtreecommitdiff
path: root/docs/source/hub.rst
blob: 298aa298945c202566565831db3a9ab122fc075e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
torch.hub
===================================
Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.

Publishing models
-----------------

Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights)
to a github repository by adding a simple ``hubconf.py`` file;

``hubconf.py`` can have multiple entrypoints. Each entrypoint is defined as a python function with
the following signature.

::

    def entrypoint_name(pretrained=False, *args, **kwargs):
        ...

How to implement an entrypoint?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Here is a code snippet from pytorch/vision repository, which specifies an entrypoint
for ``resnet18`` model. You can see a full script in
`pytorch/vision repo <https://github.com/pytorch/vision/blob/master/hubconf.py>`_

::

    dependencies = ['torch', 'math']

    def resnet18(pretrained=False, *args, **kwargs):
        """
        Resnet18 model
        pretrained (bool): a recommended kwargs for all entrypoints
        args & kwargs are arguments for the function
        """
        ######## Call the model in the repo ###############
        from torchvision.models.resnet import resnet18 as _resnet18
        model = _resnet18(*args, **kwargs)
        ######## End of call ##############################
        # The following logic is REQUIRED
        if pretrained:
            # For weights saved in local repo
            # model.load_state_dict(<path_to_saved_file>)

            # For weights saved elsewhere
            checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
            model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
        return model

- ``dependencies`` variable is a **list** of package names required to to run the model.
- Pretrained weights can either be stored local in the github repo, or loadable by
  ``model_zoo.load()``.
- ``pretrained`` controls whether to load the pre-trained weights provided by repo owners.
- ``args`` and ``kwargs`` are passed along to the real callable function.
- Docstring of the function works as a help message, explaining what does the model do and what
  are the allowed arguments.
- Entrypoint function should **ALWAYS** return a model(nn.module).

Important Notice
^^^^^^^^^^^^^^^^

- The published models should be at least in a branch/tag. It can't be a random commit.

Loading models from Hub
-----------------------

Users can load the pre-trained models using ``torch.hub.load()`` API.


.. automodule:: torch.hub
.. autofunction:: load

Here's an example loading ``resnet18`` entrypoint from ``pytorch/vision`` repo.

::

    hub_model = hub.load(
        'pytorch/vision:master', # repo_owner/repo_name:branch
        'resnet18', # entrypoint
        1234, # args for callable [not applicable to resnet]
        pretrained=True) # kwargs for callable

Where are my downloaded model & weights saved?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The locations are used in the order of

- hub_dir: user specified path. It can be set in the following ways:
  - Setting the environment variable ``TORCH_HUB_DIR``
  - Calling ``hub.set_dir(<PATH_TO_HUB_DIR>)``
- ``~/.torch/hub``

.. autofunction:: set_dir

Caching logic
^^^^^^^^^^^^^

By default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in ``hub_dir``.

Users can force a reload by calling ``hub.load(..., force_reload=True)``. This will delete
the existing github folder and downloaded weights, reinitialize a fresh download. This is useful
when updates are published to the same branch, users can keep up with the latest release.