diff options
author | Ailing Zhang <ailzhang@fb.com> | 2018-10-29 18:41:04 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-10-29 18:43:14 -0700 |
commit | 4a3baec96158c14f0157ea5a5a8a3454b291ed34 (patch) | |
tree | 54a17c25d1bac57ed9e157dbe743be5748371e24 /docs | |
parent | 955a01562dd67d84dd9b22800fc6604fab0a40ff (diff) | |
download | pytorch-4a3baec96158c14f0157ea5a5a8a3454b291ed34.tar.gz pytorch-4a3baec96158c14f0157ea5a5a8a3454b291ed34.tar.bz2 pytorch-4a3baec96158c14f0157ea5a5a8a3454b291ed34.zip |
Hub Implementation (#12228)
Summary:
[Edit: after applied colesbury 's suggestions]
* Hub module enable users to share code + pretrained weights through github repos.
Example usage:
```
hub_model = hub.load(
'ailzhang/vision:hub', # repo_owner/repo_name:branch
'wrapper1', # entrypoint
1234, # args for callable [not applicable to resnet18]
pretrained=True) # kwargs for callable
```
* Protocol on repo owner side: example https://github.com/ailzhang/vision/tree/hub
* The "published" models should be at least in a branch/tag. It can't be a random commit.
* Repo owner should have the following field defined in `hubconf.py`
* function/entrypoint with function signature `def wrapper1(pretrained=False, *args, **kwargs):`
* `pretrained` allows users to load pretrained weights from repo owner.
* `args` and `kwargs` are passed to the callable `resnet18`, repo owner should clearly specify their help message in the docstring
```
def wrapper1(pretrained=False, *args, **kwargs):
"""
pretrained (bool): a recommended kwargs for all entrypoints
args & kwargs are arguments for the function
"""
from torchvision.models.resnet import resnet18
model = resnet18(*args, **kwargs)
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
if pretrained:
model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
return model
```
* Hub_dir
* `hub_dir` specifies where the intermediate files/folders will be saved. By default this is `~/.torch/hub`.
* Users can change it by either setting the environment variable `TORCH_HUB_DIR` or calling `hub.set_dir(PATH_TO_HUB_DIR)`.
* By default, we don't cleanup files after loading so that users can use cache next time.
* Cache logic :
* We used the cache by default if it exists in `hub_dir`.
* Users can force a fresh reload by calling `hub.load(..., force_reload=True)`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12228
Differential Revision: D10511470
Pulled By: ailzhang
fbshipit-source-id: 12ac27f01d33653f06b2483655546492f82cce38
Diffstat (limited to 'docs')
-rw-r--r-- | docs/source/hub.rst | 6 | ||||
-rw-r--r-- | docs/source/index.rst | 1 |
2 files changed, 7 insertions, 0 deletions
diff --git a/docs/source/hub.rst b/docs/source/hub.rst new file mode 100644 index 0000000000..2966d0dff9 --- /dev/null +++ b/docs/source/hub.rst @@ -0,0 +1,6 @@ +torch.hub +=================================== + +.. automodule:: torch.hub +.. autofunction:: load +.. autofunction:: set_dir diff --git a/docs/source/index.rst b/docs/source/index.rst index 3e26aa8f5d..e74578a4fd 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,6 +42,7 @@ PyTorch is an optimized tensor library for deep learning using GPUs and CPUs. data dlpack ffi + hub model_zoo onnx torch.distributed.deprecated <distributed_deprecated> |