11from typing import Optional
22
33import torch
4- import scipy .spatial
5-
64
75@torch .jit .script
86def sample (col : torch .Tensor , count : int ) -> torch .Tensor :
@@ -87,7 +85,6 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
8785 assert x .size (0 ) == batch_x .size (0 )
8886 assert y .size (0 ) == batch_y .size (0 )
8987
90-
9188 x = torch .cat ([x , 2 * r * batch_x .view (- 1 , 1 ).to (x .dtype )], dim = - 1 )
9289 y = torch .cat ([y , 2 * r * batch_y .view (- 1 , 1 ).to (y .dtype )], dim = - 1 )
9390
@@ -104,6 +101,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
104101 return torch.stack([row[mask], col[mask]], dim=0)
105102 """
106103
104+
107105def radius_graph (x : torch .Tensor , r : float ,
108106 batch : Optional [torch .Tensor ] = None , loop : bool = False ,
109107 max_num_neighbors : int = 32 ,
@@ -144,7 +142,7 @@ def radius_graph(x: torch.Tensor, r: float,
144142 row , col = (col , row ) if flow == 'source_to_target' else (row , col )
145143 else :
146144 row , col = (col , row ) if flow == 'target_to_source' else (row , col )
147-
145+
148146 if not loop :
149147 mask = row != col
150148 row , col = row [mask ], col [mask ]
0 commit comments