Skip to content

Commit c275dd0

Browse files
committed
tests: Update to new exception operator
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 359d07a commit c275dd0

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/core/lowering/test_exception_elimination_pass.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,17 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block0) {
4242
g->insertNode(bool_node);
4343
auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0);
4444
auto if_block0 = if_node->addBlock();
45-
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val}, 0);
45+
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
4646
if_block0->appendNode(exception_node);
4747
auto if_block1 = if_node->addBlock();
4848
g->insertNode(if_node);
4949
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
5050
g->insertNode(cat_node);
5151
g->registerOutput(cat_node->output());
5252

53+
std::cout << "Source Graph: " << *g << std::endl;
5354
torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g);
55+
std::cout << "Modified Graph: " << *g << std::endl;
5456
for (auto node : g->nodes()) {
5557
EXPECT_NE(node, if_node);
5658
}
@@ -97,14 +99,16 @@ TEST(LoweringPasses, EliminateExceptionOrPassPattern_Block1) {
9799
auto if_node = g->create(torch::jit::prim::If, {bool_node->output()}, 0);
98100
auto if_block0 = if_node->addBlock();
99101
auto if_block1 = if_node->addBlock();
100-
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val}, 0);
102+
auto exception_node = g->create(torch::jit::prim::RaiseException, {except_val, none_const_val}, 0);
101103
if_block1->appendNode(exception_node);
102104
g->insertNode(if_node);
103105
auto cat_node = g->create(torch::jit::aten::cat, {list_node->output(), zero_const_val});
104106
g->insertNode(cat_node);
105107
g->registerOutput(cat_node->output());
106108

109+
std::cout << "Source Graph: " << *g << std::endl;
107110
torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern(g);
111+
std::cout << "Modified Graph: " << *g << std::endl;
108112
for (auto node : g->nodes()) {
109113
EXPECT_NE(node, if_node);
110114
}

0 commit comments

Comments
 (0)