diff options
author | Jiyan Yang <chocjy@fb.com> | 2019-04-23 10:05:57 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-23 10:14:08 -0700 |
commit | 714344a976033e2c28595ac79252f67a81691766 (patch) | |
tree | 08b90165b53e4d8f1aee3784cacb89fe4471ccfe | |
parent | e3f150462185239f5a92d39916c3e3ca12fe80e0 (diff) | |
download | pytorch-714344a976033e2c28595ac79252f67a81691766.tar.gz pytorch-714344a976033e2c28595ac79252f67a81691766.tar.bz2 pytorch-714344a976033e2c28595ac79252f67a81691766.zip |
Specify to use Float16UniformFill if necessary in sparse lookup layer (#18499)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18499
If the init op is not fp16 compatible, it should throw.
However, in the special case where the original init op is UniformFill,
we replace it with Float16UniformFill
Reviewed By: kennyhorror
Differential Revision: D14627209
fbshipit-source-id: eb427772874a732ca8b3a25d06670d119ce8ac14
-rw-r--r-- | caffe2/python/layers/sparse_lookup.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/caffe2/python/layers/sparse_lookup.py b/caffe2/python/layers/sparse_lookup.py index d920bd1ba6..73c937358b 100644 --- a/caffe2/python/layers/sparse_lookup.py +++ b/caffe2/python/layers/sparse_lookup.py @@ -103,12 +103,17 @@ class SparseLookup(ModelLayer): self.weight_init = weight_init or default_init_op + # If fp16 is used, make sure fp16 init op is used if self.trainer_version == "fp16": - assert self.weight_init[0] in self._fp16_compatible_init_op_types,\ - "Fp16 training is enabled. Init op for weight parameter must be fp16"\ + # if init op is UniformFill, we replace it directly + if self.weight_init[0] == "UniformFill": + self.weight_init = ("Float16UniformFill", self.weight_init[1]) + assert self.weight_init[0] in self._fp16_compatible_init_op_types, ( + "Fp16 training is enabled. Init op for weight parameter must be fp16 " "compatibale. Got {}. Supported ops: {}".format( self.weight_init[0], self._fp16_compatible_init_op_types) + ) if _is_id_list(self.input_record): sparse_key = self.input_record.items() |