Skip to content

Commit c0bcffb

Browse files
committed
type conversion to int for equality check in bitwise test
1 parent 3e27790 commit c0bcffb

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

tests/core/conversion/converters/test_bitwise.cpp

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,29 @@ std::string gen_test_graph() {
1111
return (%3))IR";
1212
}
1313

14-
#define test_bitwise_not(dtype) \
15-
TEST(Converters, ATenBitwiseNot##dtype##ConvertsCorrectly) { \
16-
const auto graph = gen_test_graph(); \
17-
\
18-
auto g = std::make_shared<torch::jit::Graph>(); \
19-
torch::jit::parseIR(graph, g.get()); \
20-
\
21-
at::Tensor in; \
22-
if (strcmp(#dtype, "Integer") == 0) \
23-
in = at::randint(-128, 128, {10}, {at::kCUDA}).toType(at::kInt); \
24-
if (strcmp(#dtype, "Boolean") == 0) \
25-
in = at::randint(0, 1, {10}, {at::kCUDA}).toType(at::kBool); \
26-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); \
27-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); \
28-
\
29-
in = at::clone(in); \
30-
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); \
31-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); \
32-
\
33-
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0]));\
14+
#define test_bitwise_not(dtype) \
15+
TEST(Converters, ATenBitwiseNot##dtype##ConvertsCorrectly) { \
16+
const auto graph = gen_test_graph(); \
17+
\
18+
auto g = std::make_shared<torch::jit::Graph>(); \
19+
torch::jit::parseIR(graph, g.get()); \
20+
\
21+
at::Tensor in; \
22+
if (strcmp(#dtype, "Integer") == 0) \
23+
in = at::randint(-128, 128, {10}, {at::kCUDA}).toType(at::kInt); \
24+
if (strcmp(#dtype, "Boolean") == 0) \
25+
in = at::randint(0, 1, {10}, {at::kCUDA}).toType(at::kBool); \
26+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); \
27+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); \
28+
\
29+
in = at::clone(in); \
30+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); \
31+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); \
32+
\
33+
auto jit_int = jit_results[0].toType(at::kInt); \
34+
auto trt_int = trt_results[0].toType(at::kInt); \
35+
\
36+
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_int, trt_int)); \
3437
}
3538

3639
test_bitwise_not(Integer);

0 commit comments

Comments
 (0)