@@ -293,33 +293,68 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
293293 false_const_val->setType (c10::BoolType::get ());
294294 torch::jit::IValue neg_one (-1 );
295295 auto neg_one_const_val = g->insertConstant (neg_one);
296- auto dict_node = g->createDict (ins_key_val->type (), x->type (), torch::jit::ArrayRef<torch::jit::Value*>(), torch::jit::ArrayRef<torch::jit::Value*>());
296+ auto dict_node = g->createDict (
297+ ins_key_val->type (),
298+ x->type (),
299+ torch::jit::ArrayRef<torch::jit::Value*>(),
300+ torch::jit::ArrayRef<torch::jit::Value*>());
297301 g->insertNode (dict_node);
298- auto set_node = g->create (torch::jit::Symbol::fromQualString (" aten::_set_item" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val, x}, 0 );
302+ auto set_node = g->create (
303+ torch::jit::Symbol::fromQualString (" aten::_set_item" ),
304+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val, x},
305+ 0 );
299306 g->insertNode (set_node);
300- auto get_node = g->create (torch::jit::Symbol::fromQualString (" aten::__getitem__" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val}, 1 );
307+ auto get_node = g->create (
308+ torch::jit::Symbol::fromQualString (" aten::__getitem__" ),
309+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val},
310+ 1 );
301311 g->insertNode (get_node);
302- auto lt_node = g->create (torch::jit::Symbol::fromQualString (" aten::lt" ), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output (), y}, 1 );
312+ auto lt_node = g->create (
313+ torch::jit::Symbol::fromQualString (" aten::lt" ),
314+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output (), y},
315+ 1 );
303316 g->insertNode (lt_node);
304- auto list_node = g->createList (at::OptionalType::create (lt_node->output ()->type ()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output ()});
317+ auto list_node = g->createList (
318+ at::OptionalType::create (lt_node->output ()->type ()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output ()});
305319 g->insertNode (list_node);
306- auto dtype_node = g->create (torch::jit::Symbol::fromQualString (" prim::dtype" ), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()}, 1 );
320+ auto dtype_node = g->create (
321+ torch::jit::Symbol::fromQualString (" prim::dtype" ),
322+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()},
323+ 1 );
307324 dtype_node->output ()->setType (neg_one_const_val->type ());
308325 g->insertNode (dtype_node);
309- auto device_node = g->create (torch::jit::Symbol::fromQualString (" prim::device" ), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()}, 1 );
326+ auto device_node = g->create (
327+ torch::jit::Symbol::fromQualString (" prim::device" ),
328+ torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()},
329+ 1 );
310330 device_node->output ()->setType (c10::DeviceObjType::get ());
311331 g->insertNode (device_node);
312- auto tensor_node = g->create (torch::jit::Symbol::fromQualString (" aten::tensor" ), torch::jit::ArrayRef<torch::jit::Value*>{neg_one_const_val, dtype_node->output (), device_node->output (), false_const_val}, 1 );
332+ auto tensor_node = g->create (
333+ torch::jit::Symbol::fromQualString (" aten::tensor" ),
334+ torch::jit::ArrayRef<torch::jit::Value*>{
335+ neg_one_const_val, dtype_node->output (), device_node->output (), false_const_val},
336+ 1 );
313337 g->insertNode (tensor_node);
314- auto index_put_node = g->create (torch::jit::Symbol::fromQualString (" aten::index_put_" ),
315- torch::jit::ArrayRef<torch::jit::Value*>{get_node->output (), list_node->output (), tensor_node->output (), false_const_val}, 1 );
338+ auto index_put_node = g->create (
339+ torch::jit::Symbol::fromQualString (" aten::index_put_" ),
340+ torch::jit::ArrayRef<torch::jit::Value*>{
341+ get_node->output (), list_node->output (), tensor_node->output (), false_const_val},
342+ 1 );
316343 g->insertNode (index_put_node);
317- auto out_set_node = g->create (torch::jit::Symbol::fromQualString (" aten::_set_item" ),
318- torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val, get_node->output ()}, 0 );
344+ auto out_set_node = g->create (
345+ torch::jit::Symbol::fromQualString (" aten::_set_item" ),
346+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val, get_node->output ()},
347+ 0 );
319348 g->insertNode (out_set_node);
320- auto get_ins_node = g->create (torch::jit::Symbol::fromQualString (" aten::__getitem__" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val}, 1 );
349+ auto get_ins_node = g->create (
350+ torch::jit::Symbol::fromQualString (" aten::__getitem__" ),
351+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val},
352+ 1 );
321353 g->insertNode (get_ins_node);
322- auto get_outs_node = g->create (torch::jit::Symbol::fromQualString (" aten::__getitem__" ), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val}, 1 );
354+ auto get_outs_node = g->create (
355+ torch::jit::Symbol::fromQualString (" aten::__getitem__" ),
356+ torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val},
357+ 1 );
323358 g->insertNode (get_outs_node);
324359 g->registerOutput (get_ins_node->output ());
325360 g->registerOutput (get_outs_node->output ());
@@ -337,10 +372,9 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
337372 input_types.insert ({g->inputs ()[i], {at::kFloat }});
338373 }
339374 auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs (inputs_map, input_types);
340- auto segmented_blocks =
341- torch_tensorrt::core::partitioning::Partition (g->block (), input_ivalues_map, partition_info);
375+ auto segmented_blocks = torch_tensorrt::core::partitioning::Partition (g->block (), input_ivalues_map, partition_info);
342376
343- int torch_block_cnt = 0 , trt_block_cnt = 0 ;
377+ int torch_block_cnt = 0 , trt_block_cnt = 0 ;
344378 for (const auto & segmented_block : segmented_blocks) {
345379 if (segmented_block.target () == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT ) {
346380 ++trt_block_cnt;
@@ -353,12 +387,12 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
353387 bool input_dict = false ;
354388 auto dict_type = dict_node->output ()->type ();
355389 for (auto in : segmented_block.raw_inputs ()) {
356- if (in->type ()->isSubtypeOf (dict_type)){
390+ if (in->type ()->isSubtypeOf (dict_type)) {
357391 input_dict = true ;
358392 }
359393 }
360394 for (auto out : segmented_block.raw_outputs ()) {
361- if (out->type ()->isSubtypeOf (dict_type)){
395+ if (out->type ()->isSubtypeOf (dict_type)) {
362396 output_dict = true ;
363397 }
364398 }
0 commit comments