summaryrefslogtreecommitdiff
path: root/tools/cwrap
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2017-09-27 17:37:40 -0700
committerSoumith Chintala <soumith@gmail.com>2017-09-27 21:32:36 -0400
commit095805036c51cf304c3b284b377db800a7d8d5a8 (patch)
tree08c5eae0dd4d8629694d252a6a4427a1fab99664 /tools/cwrap
parente9fe0d8e6c6f996dd64adb6b45d6cb83eda90053 (diff)
downloadpytorch-095805036c51cf304c3b284b377db800a7d8d5a8.tar.gz
pytorch-095805036c51cf304c3b284b377db800a7d8d5a8.tar.bz2
pytorch-095805036c51cf304c3b284b377db800a7d8d5a8.zip
re-enable out-of-place bernoulli for cuda tensors
Diffstat (limited to 'tools/cwrap')
-rw-r--r--tools/cwrap/plugins/BeforeAfterCall.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/tools/cwrap/plugins/BeforeAfterCall.py b/tools/cwrap/plugins/BeforeAfterCall.py
index 8685790c73..28ba1a266d 100644
--- a/tools/cwrap/plugins/BeforeAfterCall.py
+++ b/tools/cwrap/plugins/BeforeAfterCall.py
@@ -9,10 +9,14 @@ class BeforeAfterCall(CWrapPlugin):
def insert_snippet(self, template, option, offset, name):
prepend_str = option.get(name)
+ if isinstance(prepend_str, dict):
+ backend = option['backends'][0]
+ prepend_str = prepend_str.get(backend, None)
+
if prepend_str is None:
return
if '$' in prepend_str:
- before_call_template = Template(option[name])
+ before_call_template = Template(prepend_str)
args = {'arg' + str(i): self.cwrap.get_arg_accessor(arg, option) for i, arg
in enumerate(option['arguments'])}
prepend_str = before_call_template.substitute(args)