diff options
author | Trevor Killeen <killeentm@gmail.com> | 2017-09-27 17:37:40 -0700 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-09-27 21:32:36 -0400 |
commit | 095805036c51cf304c3b284b377db800a7d8d5a8 (patch) | |
tree | 08c5eae0dd4d8629694d252a6a4427a1fab99664 /tools/cwrap | |
parent | e9fe0d8e6c6f996dd64adb6b45d6cb83eda90053 (diff) | |
download | pytorch-095805036c51cf304c3b284b377db800a7d8d5a8.tar.gz pytorch-095805036c51cf304c3b284b377db800a7d8d5a8.tar.bz2 pytorch-095805036c51cf304c3b284b377db800a7d8d5a8.zip |
re-enable out-of-place bernoulli for cuda tensors
Diffstat (limited to 'tools/cwrap')
-rw-r--r-- | tools/cwrap/plugins/BeforeAfterCall.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/tools/cwrap/plugins/BeforeAfterCall.py b/tools/cwrap/plugins/BeforeAfterCall.py index 8685790c73..28ba1a266d 100644 --- a/tools/cwrap/plugins/BeforeAfterCall.py +++ b/tools/cwrap/plugins/BeforeAfterCall.py @@ -9,10 +9,14 @@ class BeforeAfterCall(CWrapPlugin): def insert_snippet(self, template, option, offset, name): prepend_str = option.get(name) + if isinstance(prepend_str, dict): + backend = option['backends'][0] + prepend_str = prepend_str.get(backend, None) + if prepend_str is None: return if '$' in prepend_str: - before_call_template = Template(option[name]) + before_call_template = Template(prepend_str) args = {'arg' + str(i): self.cwrap.get_arg_accessor(arg, option) for i, arg in enumerate(option['arguments'])} prepend_str = before_call_template.substitute(args) |