@@ -15,26 +15,32 @@ namespace converters {
1515namespace impl {
1616namespace {
1717
18- bool add_split (ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list) {
18+ bool add_split (ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list, bool unbind ) {
1919 auto in = args[0 ].ITensor ();
20- auto axis = args[2 ].unwrapToInt ();
21- auto inDimSize = in->getDimensions ().d [axis];
22- auto numOutputs = 1 , numRemainder = 0 ;
20+ auto numOutputs = 1 , numRemainder = 0 , axis = 0 ;
2321 std::vector<int64_t > sizes;
2422
25- if (split_list) {
26- sizes = args[1 ].unwrapToIntList ().vec ();
27- numOutputs = sizes.size ();
23+ if (unbind) {
24+ axis = args[1 ].unwrapToInt ();
25+ numOutputs = in->getDimensions ().d [axis];
26+ sizes.insert (sizes.end (), numOutputs, 1 );
2827 } else {
29- auto split_size = args[1 ].unwrapToInt ();
30- numOutputs = inDimSize / split_size;
31- numRemainder = inDimSize % split_size;
32- for (int64_t i = 0 ; i < numOutputs; i++) {
33- sizes.push_back (split_size);
34- }
35- if (numRemainder) {
36- numOutputs += 1 ;
37- sizes.push_back (numRemainder);
28+ axis = args[2 ].unwrapToInt ();
29+ auto inDimSize = in->getDimensions ().d [axis];
30+ if (split_list) {
31+ sizes = args[1 ].unwrapToIntList ().vec ();
32+ numOutputs = sizes.size ();
33+ } else {
34+ auto split_size = args[1 ].unwrapToInt ();
35+ numOutputs = inDimSize / split_size;
36+ numRemainder = inDimSize % split_size;
37+ for (int64_t i = 0 ; i < numOutputs; i++) {
38+ sizes.push_back (split_size);
39+ }
40+ if (numRemainder) {
41+ numOutputs += 1 ;
42+ sizes.push_back (numRemainder);
43+ }
3844 }
3945 }
4046
@@ -340,19 +346,25 @@ auto select_registrations TORCHTRT_UNUSED =
340346 }})
341347 .pattern({" aten::split(Tensor self, int[] split_sizes, int dim=0) -> (Tensor[])" ,
342348 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
343- add_split (ctx, n, args, true );
349+ add_split (ctx, n, args, true , false );
344350 LOG_DEBUG (" Converted split op into a list of IValues" );
345351 return true ;
346352 }})
347353 .pattern({" aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> (Tensor[])" ,
348354 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
349- add_split (ctx, n, args, false );
355+ add_split (ctx, n, args, false , false );
350356 LOG_DEBUG (" Converted split op into a list of IValues" );
351357 return true ;
352358 }})
353359 .pattern({" aten::split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> (Tensor[])" ,
354360 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
355- add_split (ctx, n, args, true );
361+ add_split (ctx, n, args, true , false );
362+ LOG_DEBUG (" Converted split op into a list of IValues" );
363+ return true ;
364+ }})
365+ .pattern({" aten::unbind.int(Tensor(a -> *) self, int dim=0) -> (Tensor[])" ,
366+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
367+ add_split (ctx, n, args, false , true );
356368 LOG_DEBUG (" Converted split op into a list of IValues" );
357369 return true ;
358370 }})
0 commit comments