@@ -30,3 +30,28 @@ TEST(Converters, ATenTopKConvertsCorrectly) {
3030 ASSERT_TRUE (
3131 torch_tensorrt::tests::util::almostEqual (jit_results[1 ], trt_results[1 ].reshape_as (jit_results[1 ]), 2e-6 ));
3232}
33+
34+ TEST (Converters, ATenMaxDimConvertsCorrectly) {
35+ const auto graph = R"IR(
36+ graph(%x.1 : Tensor):
37+ %2 : int = prim::Constant[value=0]()
38+ %3 : bool = prim::Constant[value=0]()
39+ %4 : Tensor, %5 : Tensor = aten::max(%x.1, %2, %3)
40+ return (%4, %5))IR" ;
41+
42+ auto g = std::make_shared<torch::jit::Graph>();
43+ torch::jit::parseIR (graph, g.get ());
44+
45+ auto in = at::rand ({2 , 3 , 5 , 5 }, {at::kCUDA });
46+
47+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
48+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
49+
50+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
51+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
52+
53+ ASSERT_TRUE (
54+ torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
55+ ASSERT_TRUE (
56+ torch_tensorrt::tests::util::almostEqual (jit_results[1 ], trt_results[1 ].reshape_as (jit_results[1 ]), 2e-6 ));
57+ }
0 commit comments