summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJiyan Yang <chocjy@fb.com>2019-04-23 10:05:57 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-23 10:14:08 -0700
commit714344a976033e2c28595ac79252f67a81691766 (patch)
tree08b90165b53e4d8f1aee3784cacb89fe4471ccfe
parente3f150462185239f5a92d39916c3e3ca12fe80e0 (diff)
downloadpytorch-714344a976033e2c28595ac79252f67a81691766.tar.gz
pytorch-714344a976033e2c28595ac79252f67a81691766.tar.bz2
pytorch-714344a976033e2c28595ac79252f67a81691766.zip
Specify to use Float16UniformFill if necessary in sparse lookup layer (#18499)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18499 If the init op is not fp16 compatible, it should throw. However, in the special case where the original init op is UniformFill, we replace it with Float16UniformFill Reviewed By: kennyhorror Differential Revision: D14627209 fbshipit-source-id: eb427772874a732ca8b3a25d06670d119ce8ac14
-rw-r--r--caffe2/python/layers/sparse_lookup.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/caffe2/python/layers/sparse_lookup.py b/caffe2/python/layers/sparse_lookup.py
index d920bd1ba6..73c937358b 100644
--- a/caffe2/python/layers/sparse_lookup.py
+++ b/caffe2/python/layers/sparse_lookup.py
@@ -103,12 +103,17 @@ class SparseLookup(ModelLayer):
self.weight_init = weight_init or default_init_op
+ # If fp16 is used, make sure fp16 init op is used
if self.trainer_version == "fp16":
- assert self.weight_init[0] in self._fp16_compatible_init_op_types,\
- "Fp16 training is enabled. Init op for weight parameter must be fp16"\
+ # if init op is UniformFill, we replace it directly
+ if self.weight_init[0] == "UniformFill":
+ self.weight_init = ("Float16UniformFill", self.weight_init[1])
+ assert self.weight_init[0] in self._fp16_compatible_init_op_types, (
+ "Fp16 training is enabled. Init op for weight parameter must be fp16 "
"compatibale. Got {}. Supported ops: {}".format(
self.weight_init[0],
self._fp16_compatible_init_op_types)
+ )
if _is_id_list(self.input_record):
sparse_key = self.input_record.items()