summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorChristian Puhrsch <cpuhrsch@fb.com>2018-09-24 10:39:10 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-09-24 10:40:10 -0700
commita9e6a673aec6c479447c61f3bcc5c10ddd1a099f (patch)
treea4f82c5af8650b06cece993cc13689a6a30dddb4 /test
parent117885128073c9c2b32f4b33c6c79df3895b7071 (diff)
downloadpytorch-a9e6a673aec6c479447c61f3bcc5c10ddd1a099f.tar.gz
pytorch-a9e6a673aec6c479447c61f3bcc5c10ddd1a099f.tar.bz2
pytorch-a9e6a673aec6c479447c61f3bcc5c10ddd1a099f.zip
Remove caffe2::Tensor::capacity_nbytes, at::Tensor::to##name##Data, (#11876)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11876 Modern C++ api instead of macros, item() is aligned with Python frontend. caffe2::Tensor::capacity_nbytes is effecitvely unused and confusing w.r.t. caffe2::Tensor::nbytes(). codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCByte "item<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCLong "item<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCInt "item<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCDouble "item<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toByteData "data<uint8_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toLongData "data<int64_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toIntData "data<int32_t>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toDoubleData "data<double>" codemod -d hphp --extensions cc,cpp,cu,cuh,h,py,hpp,mm toFloatData "data<float>" codemod -d caffe2 --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCComplexDouble "item<std::complex<double>>" codemod -d tc --extensions cc,cpp,cu,cuh,h,py,hpp,mm toCFloat "item<float>" Reviewed By: ezyang Differential Revision: D9948572 fbshipit-source-id: 70c9f5390d92b82c85fdd5f8a5aebca338ab413c
Diffstat (limited to 'test')
-rw-r--r--test/cpp/api/any.cpp6
-rw-r--r--test/cpp/api/integration.cpp14
-rw-r--r--test/cpp/api/jit.cpp6
-rw-r--r--test/cpp/api/misc.cpp2
-rw-r--r--test/cpp/api/module.cpp6
-rw-r--r--test/cpp/api/modules.cpp16
-rw-r--r--test/cpp/api/optim.cpp8
-rw-r--r--test/cpp/api/parallel.cpp8
-rw-r--r--test/cpp/api/rnn.cpp12
-rw-r--r--test/cpp/api/serialize.cpp18
-rw-r--r--test/cpp/api/tensor.cpp8
11 files changed, 52 insertions, 52 deletions
diff --git a/test/cpp/api/any.cpp b/test/cpp/api/any.cpp
index 0d8e98c415..22eda0d100 100644
--- a/test/cpp/api/any.cpp
+++ b/test/cpp/api/any.cpp
@@ -71,7 +71,7 @@ TEST_F(
ASSERT_TRUE(
any.forward(std::string("a"), std::string("ab"), std::string("abc"))
.sum()
- .toCInt() == 6);
+ .item<int32_t>() == 6);
}
TEST_F(AnyModuleTest, WrongArgumentType) {
@@ -232,10 +232,10 @@ TEST_F(AnyModuleTest, ConvertsVariableToTensorCorrectly) {
// mismatch).
AnyModule any(M{});
ASSERT_TRUE(
- any.forward(torch::autograd::Variable(torch::ones(5))).sum().toCFloat() ==
+ any.forward(torch::autograd::Variable(torch::ones(5))).sum().item<float>() ==
5);
// at::Tensors that are not variables work too.
- ASSERT_EQ(any.forward(at::ones(5)).sum().toCFloat(), 5);
+ ASSERT_EQ(any.forward(at::ones(5)).sum().item<float>(), 5);
}
namespace torch {
diff --git a/test/cpp/api/integration.cpp b/test/cpp/api/integration.cpp
index 131b0440a4..b2d10097b2 100644
--- a/test/cpp/api/integration.cpp
+++ b/test/cpp/api/integration.cpp
@@ -63,10 +63,10 @@ class CartPole {
}
void step(int action) {
- auto x = state[0].toCFloat();
- auto x_dot = state[1].toCFloat();
- auto theta = state[2].toCFloat();
- auto theta_dot = state[3].toCFloat();
+ auto x = state[0].item<float>();
+ auto x_dot = state[1].item<float>();
+ auto theta = state[2].item<float>();
+ auto theta_dot = state[3].item<float>();
auto force = (action == 1) ? force_mag : -force_mag;
auto costheta = std::cos(theta);
@@ -222,7 +222,7 @@ bool test_mnist(
torch::NoGradGuard guard;
auto result = std::get<1>(forward_op(tedata).max(1));
torch::Tensor correct = (result == telabel).toType(torch::kFloat32);
- return correct.sum().toCFloat() > telabel.size(0) * 0.8;
+ return correct.sum().item<float>() > telabel.size(0) * 0.8;
}
struct IntegrationTest : torch::test::SeedingFixture {};
@@ -251,7 +251,7 @@ TEST_F(IntegrationTest, CartPole) {
auto out = forward(state);
auto probs = torch::Tensor(std::get<0>(out));
auto value = torch::Tensor(std::get<1>(out));
- auto action = probs.multinomial(1)[0].toCInt();
+ auto action = probs.multinomial(1)[0].item<int32_t>();
// Compute the log prob of a multinomial distribution.
// This should probably be actually implemented in autogradpp...
auto p = probs / probs.sum(-1, true);
@@ -274,7 +274,7 @@ TEST_F(IntegrationTest, CartPole) {
std::vector<torch::Tensor> policy_loss;
std::vector<torch::Tensor> value_loss;
for (auto i = 0U; i < saved_log_probs.size(); i++) {
- auto r = rewards[i] - saved_values[i].toCFloat();
+ auto r = rewards[i] - saved_values[i].item<float>();
policy_loss.push_back(-r * saved_log_probs[i]);
value_loss.push_back(
torch::smooth_l1_loss(saved_values[i], torch::ones(1) * rewards[i]));
diff --git a/test/cpp/api/jit.cpp b/test/cpp/api/jit.cpp
index 34b3e8f630..9aa6968df7 100644
--- a/test/cpp/api/jit.cpp
+++ b/test/cpp/api/jit.cpp
@@ -20,10 +20,10 @@ TEST(TorchScriptTest, CanCompileMultipleFunctions) {
auto a = torch::ones(1);
auto b = torch::ones(1);
- ASSERT_EQ(1, module->run_method("test_mul", a, b).toTensor().toCLong());
+ ASSERT_EQ(1, module->run_method("test_mul", a, b).toTensor().item<int64_t>());
- ASSERT_EQ(2, module->run_method("test_relu", a, b).toTensor().toCLong());
+ ASSERT_EQ(2, module->run_method("test_relu", a, b).toTensor().item<int64_t>());
ASSERT_TRUE(
- 0x200 == module->run_method("test_while", a, b).toTensor().toCLong());
+ 0x200 == module->run_method("test_while", a, b).toTensor().item<int64_t>());
}
diff --git a/test/cpp/api/misc.cpp b/test/cpp/api/misc.cpp
index ca716d0ac0..b85cb9dcc1 100644
--- a/test/cpp/api/misc.cpp
+++ b/test/cpp/api/misc.cpp
@@ -49,5 +49,5 @@ TEST(NNInitTest, CanInitializeTensorThatRequiresGrad) {
tensor.fill_(1),
"a leaf Variable that requires grad "
"has been used in an in-place operation");
- ASSERT_EQ(torch::nn::init::ones_(tensor).sum().toCInt(), 12);
+ ASSERT_EQ(torch::nn::init::ones_(tensor).sum().item<int32_t>(), 12);
}
diff --git a/test/cpp/api/module.cpp b/test/cpp/api/module.cpp
index f2bca9501a..70d05d4240 100644
--- a/test/cpp/api/module.cpp
+++ b/test/cpp/api/module.cpp
@@ -41,13 +41,13 @@ TEST_F(ModuleTest, ZeroGrad) {
for (auto& parameter : module->parameters()) {
auto grad = parameter->grad();
ASSERT_TRUE(grad.defined());
- ASSERT_NE(grad.sum().toCFloat(), 0);
+ ASSERT_NE(grad.sum().item<float>(), 0);
}
module->zero_grad();
for (auto& parameter : module->parameters()) {
auto grad = parameter->grad();
ASSERT_TRUE(grad.defined());
- ASSERT_EQ(grad.sum().toCFloat(), 0);
+ ASSERT_EQ(grad.sum().item<float>(), 0);
}
}
@@ -72,7 +72,7 @@ TEST_F(ModuleTest, ZeroGradWithUndefined) {
ASSERT_TRUE(module.x.grad().defined());
ASSERT_FALSE(module.y.grad().defined());
- ASSERT_EQ(module.x.grad().sum().toCFloat(), 0);
+ ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
}
TEST_F(ModuleTest, CanGetName) {
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index 11e54a97a1..fd9416eb3b 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -134,7 +134,7 @@ TEST_F(ModulesTest, SimpleContainer) {
ASSERT_EQ(x.ndimension(), 2);
ASSERT_EQ(x.size(0), 1000);
ASSERT_EQ(x.size(1), 100);
- ASSERT_EQ(x.min().toCFloat(), 0);
+ ASSERT_EQ(x.min().item<float>(), 0);
}
TEST_F(ModulesTest, EmbeddingBasic) {
@@ -181,12 +181,12 @@ TEST_F(ModulesTest, Dropout) {
y.backward();
ASSERT_EQ(y.ndimension(), 1);
ASSERT_EQ(y.size(0), 100);
- ASSERT_LT(y.sum().toCFloat(), 130); // Probably
- ASSERT_GT(y.sum().toCFloat(), 70); // Probably
+ ASSERT_LT(y.sum().item<float>(), 130); // Probably
+ ASSERT_GT(y.sum().item<float>(), 70); // Probably
dropout->eval();
y = dropout->forward(x);
- ASSERT_EQ(y.sum().toCFloat(), 100);
+ ASSERT_EQ(y.sum().item<float>(), 100);
}
TEST_F(ModulesTest, Parameters) {
@@ -228,15 +228,15 @@ TEST_F(ModulesTest, FunctionalCallsSuppliedFunction) {
TEST_F(ModulesTest, FunctionalWithTorchFunction) {
auto functional = Functional(torch::relu);
- ASSERT_EQ(functional(torch::ones({})).toCFloat(), 1);
- ASSERT_EQ(functional(torch::ones({})).toCFloat(), 1);
- ASSERT_EQ(functional(torch::ones({}) * -1).toCFloat(), 0);
+ ASSERT_EQ(functional(torch::ones({})).item<float>(), 1);
+ ASSERT_EQ(functional(torch::ones({})).item<float>(), 1);
+ ASSERT_EQ(functional(torch::ones({}) * -1).item<float>(), 0);
}
TEST_F(ModulesTest, FunctionalArgumentBinding) {
auto functional =
Functional(torch::elu, /*alpha=*/1, /*scale=*/0, /*input_scale=*/1);
- ASSERT_EQ(functional(torch::ones({})).toCFloat(), 0);
+ ASSERT_EQ(functional(torch::ones({})).item<float>(), 0);
}
TEST_F(ModulesTest, BatchNormStateful) {
diff --git a/test/cpp/api/optim.cpp b/test/cpp/api/optim.cpp
index 03f7ed92a9..944a31ca7e 100644
--- a/test/cpp/api/optim.cpp
+++ b/test/cpp/api/optim.cpp
@@ -44,7 +44,7 @@ bool test_optimizer_xor(Options options) {
auto labels = torch::empty({kBatchSize});
for (size_t i = 0; i < kBatchSize; i++) {
inputs[i] = torch::randint(2, {2}, torch::kInt64);
- labels[i] = inputs[i][0].toCLong() ^ inputs[i][1].toCLong();
+ labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
}
inputs.set_requires_grad(true);
optimizer.zero_grad();
@@ -54,7 +54,7 @@ bool test_optimizer_xor(Options options) {
optimizer.step();
- running_loss = running_loss * 0.99 + loss.toCFloat() * 0.01;
+ running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
if (epoch > kMaximumNumberOfEpochs) {
std::cout << "Loss is too high after epoch " << epoch << ": "
<< running_loss << std::endl;
@@ -286,14 +286,14 @@ TEST(OptimTest, ZeroGrad) {
for (const auto& parameter : model->parameters()) {
ASSERT_TRUE(parameter->grad().defined());
- ASSERT_GT(parameter->grad().sum().toCFloat(), 0);
+ ASSERT_GT(parameter->grad().sum().item<float>(), 0);
}
optimizer.zero_grad();
for (const auto& parameter : model->parameters()) {
ASSERT_TRUE(parameter->grad().defined());
- ASSERT_EQ(parameter->grad().sum().toCFloat(), 0);
+ ASSERT_EQ(parameter->grad().sum().item<float>(), 0);
}
}
diff --git a/test/cpp/api/parallel.cpp b/test/cpp/api/parallel.cpp
index 71bcc542f8..a191078236 100644
--- a/test/cpp/api/parallel.cpp
+++ b/test/cpp/api/parallel.cpp
@@ -38,7 +38,7 @@ TEST_F(ParallelTest, DifferentiableScatter_MultiCUDA) {
ASSERT_TRUE(input.grad().defined());
ASSERT_TRUE(input.grad().device().is_cpu());
- ASSERT_EQ(input.grad().sum().toCInt(), 10);
+ ASSERT_EQ(input.grad().sum().item<int32_t>(), 10);
}
TEST_F(ParallelTest, DifferentiableGather_MultiCUDA) {
@@ -62,11 +62,11 @@ TEST_F(ParallelTest, DifferentiableGather_MultiCUDA) {
ASSERT_TRUE(a.grad().defined());
ASSERT_EQ(a.grad().device(), torch::Device(torch::kCUDA, 0));
- ASSERT_EQ(a.grad().sum().toCInt(), 5);
+ ASSERT_EQ(a.grad().sum().item<int32_t>(), 5);
ASSERT_TRUE(b.grad().defined());
ASSERT_EQ(b.grad().device(), torch::Device(torch::kCUDA, 1));
- ASSERT_EQ(b.grad().sum().toCInt(), 5);
+ ASSERT_EQ(b.grad().sum().item<int32_t>(), 5);
}
TEST_F(ParallelTest, Replicate_MultiCUDA) {
@@ -226,6 +226,6 @@ TEST_F(ParallelTest, DataParallelUsesAllAvailableCUDADevices_CUDA) {
const auto device_count = torch::cuda::device_count();
ASSERT_EQ(output.numel(), device_count);
for (size_t i = 0; i < device_count; ++i) {
- ASSERT_EQ(output[i].toCInt(), i);
+ ASSERT_EQ(output[i].item<int32_t>(), i);
}
}
diff --git a/test/cpp/api/rnn.cpp b/test/cpp/api/rnn.cpp
index 96ffd37eb0..e0d511fb09 100644
--- a/test/cpp/api/rnn.cpp
+++ b/test/cpp/api/rnn.cpp
@@ -56,7 +56,7 @@ bool test_RNN_xor(Func&& model_maker, bool cuda = false) {
loss.backward();
optimizer.step();
- running_loss = running_loss * 0.99 + loss.toCFloat() * 0.01;
+ running_loss = running_loss * 0.99 + loss.item<float>() * 0.01;
if (epoch > max_epoch) {
return false;
}
@@ -81,7 +81,7 @@ void check_lstm_sizes(RNNOutput output) {
ASSERT_EQ(output.state.size(3), 64); // 64 hidden dims
// Something is in the hiddens
- ASSERT_GT(output.state.norm().toCFloat(), 0);
+ ASSERT_GT(output.state.norm().item<float>(), 0);
}
struct RNNTest : torch::test::SeedingFixture {};
@@ -103,7 +103,7 @@ TEST_F(RNNTest, CheckOutputSizes) {
torch::Tensor diff = next.state - output.state;
// Hiddens changed
- ASSERT_GT(diff.abs().sum().toCFloat(), 1e-3);
+ ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
}
TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) {
@@ -137,7 +137,7 @@ TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) {
0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666};
for (size_t i = 0; i < 3 * 4 * 2; i++) {
- ASSERT_LT(std::abs(flat[i].toCFloat() - c_out[i]), 1e-3);
+ ASSERT_LT(std::abs(flat[i].item<float>() - c_out[i]), 1e-3);
}
ASSERT_EQ(out.state.ndimension(), 4); // (hx, cx) x layers x B x 2
@@ -163,7 +163,7 @@ TEST_F(RNNTest, CheckOutputValuesMatchPyTorch) {
1.0931,
1.4911};
for (size_t i = 0; i < 16; i++) {
- ASSERT_LT(std::abs(flat[i].toCFloat() - h_out[i]), 1e-3);
+ ASSERT_LT(std::abs(flat[i].item<float>() - h_out[i]), 1e-3);
}
}
@@ -206,7 +206,7 @@ TEST_F(RNNTest, Sizes_CUDA) {
torch::Tensor diff = next.state - output.state;
// Hiddens changed
- ASSERT_GT(diff.abs().sum().toCFloat(), 1e-3);
+ ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
}
TEST_F(RNNTest, EndToEndLSTM_CUDA) {
diff --git a/test/cpp/api/serialize.cpp b/test/cpp/api/serialize.cpp
index a37c00c2e3..0612029f53 100644
--- a/test/cpp/api/serialize.cpp
+++ b/test/cpp/api/serialize.cpp
@@ -90,7 +90,7 @@ TEST(Serialize, XOR) {
auto labels = torch::empty({batch_size});
for (size_t i = 0; i < batch_size; i++) {
inputs[i] = torch::randint(2, {2}, torch::kInt64);
- labels[i] = inputs[i][0].toCLong() ^ inputs[i][1].toCLong();
+ labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
}
auto x = model->forward<torch::Tensor>(inputs);
return torch::binary_cross_entropy(x, labels);
@@ -112,7 +112,7 @@ TEST(Serialize, XOR) {
loss.backward();
optimizer.step();
- running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
+ running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
ASSERT_LT(epoch, 3000);
epoch++;
}
@@ -122,7 +122,7 @@ TEST(Serialize, XOR) {
torch::load(model2, tempfile.str());
auto loss = getLoss(model2, 100);
- ASSERT_LT(loss.toCFloat(), 0.1);
+ ASSERT_LT(loss.item<float>(), 0.1);
}
TEST(Serialize, Optim) {
@@ -188,9 +188,9 @@ TEST(Serialize, Optim) {
const auto& name = p.key;
// Model 1 and 3 should be the same
ASSERT_TRUE(
- param1[name].norm().toCFloat() == param3[name].norm().toCFloat());
+ param1[name].norm().item<float>() == param3[name].norm().item<float>());
ASSERT_TRUE(
- param1[name].norm().toCFloat() != param2[name].norm().toCFloat());
+ param1[name].norm().item<float>() != param2[name].norm().item<float>());
}
}
@@ -202,7 +202,7 @@ TEST(Serialize, Optim) {
// auto labels = torch::empty({batch_size});
// for (size_t i = 0; i < batch_size; i++) {
// inputs[i] = torch::randint(2, {2}, torch::kInt64);
-// labels[i] = inputs[i][0].toCLong() ^ inputs[i][1].toCLong();
+// labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
// }
// auto x = model->forward<torch::Tensor>(inputs);
// return torch::binary_cross_entropy(x, labels);
@@ -224,7 +224,7 @@ TEST(Serialize, Optim) {
// loss.backward();
// optimizer.step();
//
-// running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
+// running_loss = running_loss * 0.99 + loss.sum().item<float>() * 0.01;
// ASSERT_LT(epoch, 3000);
// epoch++;
// }
@@ -234,7 +234,7 @@ TEST(Serialize, Optim) {
// torch::load(model2, tempfile.str());
//
// auto loss = getLoss(model2, 100);
-// ASSERT_LT(loss.toCFloat(), 0.1);
+// ASSERT_LT(loss.item<float>(), 0.1);
//
// model2->to(torch::kCUDA);
// torch::test::TempFile tempfile2;
@@ -242,5 +242,5 @@ TEST(Serialize, Optim) {
// torch::load(model3, tempfile2.str());
//
// loss = getLoss(model3, 100);
-// ASSERT_LT(loss.toCFloat(), 0.1);
+// ASSERT_LT(loss.item<float>(), 0.1);
// }
diff --git a/test/cpp/api/tensor.cpp b/test/cpp/api/tensor.cpp
index ad14298d86..3996132cc8 100644
--- a/test/cpp/api/tensor.cpp
+++ b/test/cpp/api/tensor.cpp
@@ -104,7 +104,7 @@ TEST(TensorTest, ContainsCorrectValueForSingleValue) {
auto tensor = at::tensor(123);
ASSERT_EQ(tensor.numel(), 1);
ASSERT_EQ(tensor.dtype(), at::kInt);
- ASSERT_EQ(tensor[0].toCInt(), 123);
+ ASSERT_EQ(tensor[0].item<int32_t>(), 123);
tensor = at::tensor(123.456f);
ASSERT_EQ(tensor.numel(), 1);
@@ -189,7 +189,7 @@ TEST(TensorTest, FromBlob) {
auto tensor = torch::from_blob(v.data(), v.size(), torch::kInt32);
ASSERT_TRUE(tensor.is_variable());
ASSERT_EQ(tensor.numel(), 3);
- ASSERT_EQ(tensor[0].toCInt(), 1);
- ASSERT_EQ(tensor[1].toCInt(), 2);
- ASSERT_EQ(tensor[2].toCInt(), 3);
+ ASSERT_EQ(tensor[0].item<int32_t>(), 1);
+ ASSERT_EQ(tensor[1].item<int32_t>(), 2);
+ ASSERT_EQ(tensor[2].item<int32_t>(), 3);
}