Skip to content

Commit 303e889

Browse files
committed
fixed flake8 errors
1 parent 9016cdb commit 303e889

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

torch_cluster/radius.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from typing import Optional
22

33
import torch
4-
import scipy.spatial
5-
64

75
@torch.jit.script
86
def sample(col: torch.Tensor, count: int) -> torch.Tensor:
@@ -87,7 +85,6 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
8785
assert x.size(0) == batch_x.size(0)
8886
assert y.size(0) == batch_y.size(0)
8987

90-
9188
x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
9289
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
9390

@@ -104,6 +101,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
104101
return torch.stack([row[mask], col[mask]], dim=0)
105102
"""
106103

104+
107105
def radius_graph(x: torch.Tensor, r: float,
108106
batch: Optional[torch.Tensor] = None, loop: bool = False,
109107
max_num_neighbors: int = 32,
@@ -144,7 +142,7 @@ def radius_graph(x: torch.Tensor, r: float,
144142
row, col = (col, row) if flow == 'source_to_target' else (row, col)
145143
else:
146144
row, col = (col, row) if flow == 'target_to_source' else (row, col)
147-
145+
148146
if not loop:
149147
mask = row != col
150148
row, col = row[mask], col[mask]

0 commit comments

Comments
 (0)