1- from typing import Optional , Union
1+ from typing import List , Optional , Union
22
33import torch
44from torch import Tensor
55
6+ import torch_cluster .typing
7+
8+
9+ @torch .jit ._overload # noqa
10+ def fps (src , batch , ratio , random_start , batch_size , ptr ): # noqa
11+ # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
12+ pass # pragma: no cover
13+
14+
15+ @torch .jit ._overload # noqa
16+ def fps (src , batch , ratio , random_start , batch_size , ptr ): # noqa
17+ # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
18+ pass # pragma: no cover
19+
620
721@torch .jit ._overload # noqa
8- def fps (src , batch , ratio , random_start , batch_size ): # noqa
9- # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int]) -> Tensor # noqa
22+ def fps (src , batch , ratio , random_start , batch_size , ptr ): # noqa
23+ # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]] ) -> Tensor # noqa
1024 pass # pragma: no cover
1125
1226
1327@torch .jit ._overload # noqa
14- def fps (src , batch , ratio , random_start , batch_size ): # noqa
15- # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int]) -> Tensor # noqa
28+ def fps (src , batch , ratio , random_start , batch_size , ptr ): # noqa
29+ # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]] ) -> Tensor # noqa
1630 pass # pragma: no cover
1731
1832
1933def fps ( # noqa
2034 src : torch .Tensor ,
2135 batch : Optional [Tensor ] = None ,
22- ratio : Optional [Union [torch . Tensor , float ]] = None ,
36+ ratio : Optional [Union [Tensor , float ]] = None ,
2337 random_start : bool = True ,
2438 batch_size : Optional [int ] = None ,
39+ ptr : Optional [Union [Tensor , List [int ]]] = None ,
2540):
2641 r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
2742 Learning on Point Sets in a Metric Space"
@@ -40,6 +55,10 @@ def fps( # noqa
4055 node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
4156 batch_size (int, optional): The number of examples :math:`B`.
4257 Automatically calculated if not given. (default: :obj:`None`)
58+ ptr (torch.Tensor or [int], optional): If given, batch assignment will
59+ be determined based on boundaries in CSR representation, *e.g.*,
60+ :obj:`batch=[0,0,1,1,1,2]` translates to :obj:`ptr=[0,2,5,6]`.
61+ (default: :obj:`None`)
4362
4463 :rtype: :class:`LongTensor`
4564
@@ -52,7 +71,6 @@ def fps( # noqa
5271 batch = torch.tensor([0, 0, 0, 0])
5372 index = fps(src, batch, ratio=0.5)
5473 """
55-
5674 r : Optional [Tensor ] = None
5775 if ratio is None :
5876 r = torch .tensor (0.5 , dtype = src .dtype , device = src .device )
@@ -62,6 +80,17 @@ def fps( # noqa
6280 r = ratio
6381 assert r is not None
6482
83+ if ptr is not None :
84+ if isinstance (ptr , list ) and torch_cluster .typing .WITH_PTR_LIST :
85+ return torch .ops .torch_cluster .fps_ptr_list (
86+ src , ptr , r , random_start )
87+
88+ if isinstance (ptr , list ):
89+ return torch .ops .torch_cluster .fps (
90+ src , torch .tensor (ptr , device = src .device ), r , random_start )
91+ else :
92+ return torch .ops .torch_cluster .fps (src , ptr , r , random_start )
93+
6594 if batch is not None :
6695 assert src .size (0 ) == batch .numel ()
6796 if batch_size is None :
@@ -70,9 +99,9 @@ def fps( # noqa
7099 deg = src .new_zeros (batch_size , dtype = torch .long )
71100 deg .scatter_add_ (0 , batch , torch .ones_like (batch ))
72101
73- ptr = deg .new_zeros (batch_size + 1 )
74- torch .cumsum (deg , 0 , out = ptr [1 :])
102+ ptr_vec = deg .new_zeros (batch_size + 1 )
103+ torch .cumsum (deg , 0 , out = ptr_vec [1 :])
75104 else :
76- ptr = torch .tensor ([0 , src .size (0 )], device = src .device )
105+ ptr_vec = torch .tensor ([0 , src .size (0 )], device = src .device )
77106
78- return torch .ops .torch_cluster .fps (src , ptr , r , random_start )
107+ return torch .ops .torch_cluster .fps (src , ptr_vec , r , random_start )
0 commit comments