|
6 | 6 | #include "torch/csrc/jit/ir/irparser.h" |
7 | 7 | #include "torch/csrc/jit/ir/subgraph_matcher.h" |
8 | 8 |
|
9 | | -TEST(LoweringPasses, ReduceToCorrectly) { |
10 | | - std::string source_graph = R"IR( |
11 | | - graph(%x, %device, %dtype, %nb, %copy, %format): |
12 | | - %out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format) |
13 | | - return (%out))IR"; |
14 | | - std::string target_graph = R"IR( |
15 | | - graph(%x, %device, %dtype, %nb, %copy, %format): |
16 | | - %out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format) |
17 | | - return (%out))IR"; |
18 | | - |
19 | | - torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( |
20 | | - torch_tensorrt::core::util::logging::LogLevel::kGRAPH); |
21 | | - auto sg = std::make_shared<torch::jit::Graph>(); |
22 | | - torch::jit::parseIR(source_graph, &*sg); |
23 | | - torch_tensorrt::core::lowering::passes::ReduceToOperation(sg); |
24 | | - |
25 | | - auto tg = std::make_shared<torch::jit::Graph>(); |
26 | | - torch::jit::parseIR(target_graph, &*tg); |
27 | | - |
28 | | - ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); |
29 | | -} |
30 | | - |
31 | 9 | TEST(LoweringPasses, ReduceToDtypeLayoutCorrectly) { |
32 | 10 | std::string source_graph = R"IR( |
33 | | - graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format): |
34 | | - %out : Tensor = aten::to(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format) |
| 11 | + graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format): |
| 12 | + %out : Tensor = aten::to(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format) |
35 | 13 | return (%out))IR"; |
36 | 14 | std::string target_graph = R"IR( |
37 | | - graph(%x, %device, %dtype, %layout, %pm, %nb, %copy, %format): |
38 | | - %out : Tensor = aten::to(%x, %dtype, %nb, %copy, %format) |
| 15 | + graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format): |
| 16 | + %out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format) |
39 | 17 | return (%out))IR"; |
40 | 18 |
|
41 | 19 | torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( |
|
0 commit comments