summaryrefslogtreecommitdiff
path: root/caffe2/operators/jsd_op.cc
diff options
context:
space:
mode:
authorDmytro Dzhulgakov <dzhulgakov@fb.com>2018-03-05 19:56:24 -0800
committerDmytro Dzhulgakov <dzhulgakov@users.noreply.github.com>2018-03-06 00:33:11 -0800
commitfe3c22cd243a60bd750995ede784f760eab8a2f7 (patch)
tree522b2b045c756270293500d29daaf12296050834 /caffe2/operators/jsd_op.cc
parent08dbd966429b0fad982baa4744b129ed762a83aa (diff)
downloadpytorch-fe3c22cd243a60bd750995ede784f760eab8a2f7.tar.gz
pytorch-fe3c22cd243a60bd750995ede784f760eab8a2f7.tar.bz2
pytorch-fe3c22cd243a60bd750995ede784f760eab8a2f7.zip
[GanH/Easy]Fix blob dim
as titled
Diffstat (limited to 'caffe2/operators/jsd_op.cc')
-rw-r--r--caffe2/operators/jsd_op.cc4
1 files changed, 2 insertions, 2 deletions
diff --git a/caffe2/operators/jsd_op.cc b/caffe2/operators/jsd_op.cc
index 69d23a31fb..1d7f49e09b 100644
--- a/caffe2/operators/jsd_op.cc
+++ b/caffe2/operators/jsd_op.cc
@@ -48,7 +48,7 @@ bool BernoulliJSDOp<float, CPUContext>::RunOnDevice() {
auto* L = Output(0); // JSD loss output
int N = X.size();
CAFFE_ENFORCE_EQ(T.size(), N);
- L->Resize(N);
+ L->ResizeLike(X);
auto* x_data = X.data<float>();
auto* t_data = T.data<float>();
auto* l_data = L->mutable_data<float>();
@@ -69,7 +69,7 @@ bool BernoulliJSDGradientOp<float, CPUContext>::RunOnDevice() {
auto& T = Input(2);
auto* gi = Output(0);
int N = X.size();
- gi->Resize(N);
+ gi->ResizeLike(X);
auto* go_data = go.data<float>();
auto* x_data = X.data<float>();
auto* t_data = T.data<float>();