@@ -11,26 +11,29 @@ std::string gen_test_graph() {
1111 return (%3))IR" ;
1212}
1313
14- #define test_bitwise_not (dtype ) \
15- TEST (Converters, ATenBitwiseNot##dtype##ConvertsCorrectly) { \
16- const auto graph = gen_test_graph (); \
17- \
18- auto g = std::make_shared<torch::jit::Graph>(); \
19- torch::jit::parseIR (graph, g.get ()); \
20- \
21- at::Tensor in; \
22- if (strcmp (#dtype, " Integer" ) == 0 ) \
23- in = at::randint (-128 , 128 , {10 }, {at::kCUDA }).toType (at::kInt ); \
24- if (strcmp (#dtype, " Boolean" ) == 0 ) \
25- in = at::randint (0 , 1 , {10 }, {at::kCUDA }).toType (at::kBool ); \
26- auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {}); \
27- auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in}); \
28- \
29- in = at::clone (in); \
30- params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {}); \
31- auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in}); \
32- \
33- ASSERT_TRUE (torch_tensorrt::tests::util::exactlyEqual (jit_results[0 ], trt_results[0 ]));\
14+ #define test_bitwise_not (dtype ) \
15+ TEST (Converters, ATenBitwiseNot##dtype##ConvertsCorrectly) { \
16+ const auto graph = gen_test_graph (); \
17+ \
18+ auto g = std::make_shared<torch::jit::Graph>(); \
19+ torch::jit::parseIR (graph, g.get ()); \
20+ \
21+ at::Tensor in; \
22+ if (strcmp (#dtype, " Integer" ) == 0 ) \
23+ in = at::randint (-128 , 128 , {10 }, {at::kCUDA }).toType (at::kInt ); \
24+ if (strcmp (#dtype, " Boolean" ) == 0 ) \
25+ in = at::randint (0 , 1 , {10 }, {at::kCUDA }).toType (at::kBool ); \
26+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {}); \
27+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in}); \
28+ \
29+ in = at::clone (in); \
30+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {}); \
31+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in}); \
32+ \
33+ auto jit_int = jit_results[0 ].toType (at::kInt ); \
34+ auto trt_int = trt_results[0 ].toType (at::kInt ); \
35+ \
36+ ASSERT_TRUE (torch_tensorrt::tests::util::exactlyEqual (jit_int, trt_int)); \
3437 }
3538
3639test_bitwise_not (Integer);
0 commit comments