Skip to content

Commit b3d70bd

Browse files
dscieburusty1s
andauthored
Extend FPS with an extra ptr argument (#180)
* Extend FPS with an extra ptr argument * update * update * update --------- Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
1 parent b8fd424 commit b3d70bd

File tree

4 files changed

+56
-15
lines changed

4 files changed

+56
-15
lines changed

.github/workflows/testing.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ jobs:
3131
3232
- name: Install main package
3333
run: |
34+
pip install scipy==1.10.1 # Python 3.8 support
3435
python setup.py develop
3536
3637
- name: Run test-suite

test/test_fps.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,27 @@ def test_fps(dtype, device):
2525
[+2, -2],
2626
], dtype, device)
2727
batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
28+
ptr_list = [0, 4, 8]
29+
ptr = torch.tensor(ptr_list, device=device)
2830

2931
out = fps(x, batch, random_start=False)
3032
assert out.tolist() == [0, 2, 4, 6]
3133

3234
out = fps(x, batch, ratio=0.5, random_start=False)
3335
assert out.tolist() == [0, 2, 4, 6]
3436

35-
out = fps(x, batch, ratio=torch.tensor(0.5, device=device),
36-
random_start=False)
37+
ratio = torch.tensor(0.5, device=device)
38+
out = fps(x, batch, ratio=ratio, random_start=False)
3739
assert out.tolist() == [0, 2, 4, 6]
3840

39-
out = fps(x, batch, ratio=torch.tensor([0.5, 0.5], device=device),
40-
random_start=False)
41+
out = fps(x, ptr=ptr_list, ratio=0.5, random_start=False)
42+
assert out.tolist() == [0, 2, 4, 6]
43+
44+
out = fps(x, ptr=ptr, ratio=0.5, random_start=False)
45+
assert out.tolist() == [0, 2, 4, 6]
46+
47+
ratio = torch.tensor([0.5, 0.5], device=device)
48+
out = fps(x, batch, ratio=ratio, random_start=False)
4149
assert out.tolist() == [0, 2, 4, 6]
4250

4351
out = fps(x, random_start=False)

torch_cluster/fps.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,42 @@
1-
from typing import Optional, Union
1+
from typing import List, Optional, Union
22

33
import torch
44
from torch import Tensor
55

6+
import torch_cluster.typing
7+
8+
9+
@torch.jit._overload # noqa
10+
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
11+
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
12+
pass # pragma: no cover
13+
14+
15+
@torch.jit._overload # noqa
16+
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
17+
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
18+
pass # pragma: no cover
19+
620

721
@torch.jit._overload # noqa
8-
def fps(src, batch, ratio, random_start, batch_size): # noqa
9-
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int]) -> Tensor # noqa
22+
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
23+
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
1024
pass # pragma: no cover
1125

1226

1327
@torch.jit._overload # noqa
14-
def fps(src, batch, ratio, random_start, batch_size): # noqa
15-
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int]) -> Tensor # noqa
28+
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
29+
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
1630
pass # pragma: no cover
1731

1832

1933
def fps( # noqa
2034
src: torch.Tensor,
2135
batch: Optional[Tensor] = None,
22-
ratio: Optional[Union[torch.Tensor, float]] = None,
36+
ratio: Optional[Union[Tensor, float]] = None,
2337
random_start: bool = True,
2438
batch_size: Optional[int] = None,
39+
ptr: Optional[Union[Tensor, List[int]]] = None,
2540
):
2641
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
2742
Learning on Point Sets in a Metric Space"
@@ -40,6 +55,10 @@ def fps( # noqa
4055
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
4156
batch_size (int, optional): The number of examples :math:`B`.
4257
Automatically calculated if not given. (default: :obj:`None`)
58+
ptr (torch.Tensor or [int], optional): If given, batch assignment will
59+
be determined based on boundaries in CSR representation, *e.g.*,
60+
:obj:`batch=[0,0,1,1,1,2]` translates to :obj:`ptr=[0,2,5,6]`.
61+
(default: :obj:`None`)
4362
4463
:rtype: :class:`LongTensor`
4564
@@ -52,7 +71,6 @@ def fps( # noqa
5271
batch = torch.tensor([0, 0, 0, 0])
5372
index = fps(src, batch, ratio=0.5)
5473
"""
55-
5674
r: Optional[Tensor] = None
5775
if ratio is None:
5876
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
@@ -62,6 +80,17 @@ def fps( # noqa
6280
r = ratio
6381
assert r is not None
6482

83+
if ptr is not None:
84+
if isinstance(ptr, list) and torch_cluster.typing.WITH_PTR_LIST:
85+
return torch.ops.torch_cluster.fps_ptr_list(
86+
src, ptr, r, random_start)
87+
88+
if isinstance(ptr, list):
89+
return torch.ops.torch_cluster.fps(
90+
src, torch.tensor(ptr, device=src.device), r, random_start)
91+
else:
92+
return torch.ops.torch_cluster.fps(src, ptr, r, random_start)
93+
6594
if batch is not None:
6695
assert src.size(0) == batch.numel()
6796
if batch_size is None:
@@ -70,9 +99,9 @@ def fps( # noqa
7099
deg = src.new_zeros(batch_size, dtype=torch.long)
71100
deg.scatter_add_(0, batch, torch.ones_like(batch))
72101

73-
ptr = deg.new_zeros(batch_size + 1)
74-
torch.cumsum(deg, 0, out=ptr[1:])
102+
ptr_vec = deg.new_zeros(batch_size + 1)
103+
torch.cumsum(deg, 0, out=ptr_vec[1:])
75104
else:
76-
ptr = torch.tensor([0, src.size(0)], device=src.device)
105+
ptr_vec = torch.tensor([0, src.size(0)], device=src.device)
77106

78-
return torch.ops.torch_cluster.fps(src, ptr, r, random_start)
107+
return torch.ops.torch_cluster.fps(src, ptr_vec, r, random_start)

torch_cluster/typing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import torch
2+
3+
WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list')

0 commit comments

Comments
 (0)