1414__version__ = '1.6.0'
1515URL = 'https://github.com/rusty1s/pytorch_cluster'
1616
17- WITH_CUDA = torch .cuda .is_available () and CUDA_HOME is not None
17+ WITH_CUDA = False
18+ if torch .cuda .is_available ():
19+ WITH_CUDA = CUDA_HOME is not None or torch .version .hip
20+
1821suffices = ['cpu' , 'cuda' ] if WITH_CUDA else ['cpu' ]
1922if os .getenv ('FORCE_CUDA' , '0' ) == '1' :
2023 suffices = ['cuda' , 'cpu' ]
@@ -31,9 +34,12 @@ def get_extensions():
3134
3235 extensions_dir = osp .join ('csrc' )
3336 main_files = glob .glob (osp .join (extensions_dir , '*.cpp' ))
37+ # remove generated 'hip' files, in case of rebuilds
38+ main_files = [path for path in main_files if 'hip' not in path ]
3439
3540 for main , suffix in product (main_files , suffices ):
3641 define_macros = [('WITH_PYTHON' , None )]
42+ undef_macros = []
3743
3844 if sys .platform == 'win32' :
3945 define_macros += [('torchcluster_EXPORTS' , None )]
@@ -63,9 +69,17 @@ def get_extensions():
6369 define_macros += [('WITH_CUDA' , None )]
6470 nvcc_flags = os .getenv ('NVCC_FLAGS' , '' )
6571 nvcc_flags = [] if nvcc_flags == '' else nvcc_flags .split (' ' )
66- nvcc_flags += ['--expt-relaxed-constexpr' , '- O2' ]
72+ nvcc_flags += ['-O2' ]
6773 extra_compile_args ['nvcc' ] = nvcc_flags
6874
75+ if torch .version .hip :
76+ # USE_ROCM was added to later versions of PyTorch
77+ # Define here to support older PyTorch versions as well:
78+ define_macros += [('USE_ROCM' , None )]
79+ undef_macros += ['__HIP_NO_HALF_CONVERSIONS__' ]
80+ else :
81+ nvcc_flags += ['--expt-relaxed-constexpr' ]
82+
6983 name = main .split (os .sep )[- 1 ][:- 4 ]
7084 sources = [main ]
7185
@@ -83,6 +97,7 @@ def get_extensions():
8397 sources ,
8498 include_dirs = [extensions_dir ],
8599 define_macros = define_macros ,
100+ undef_macros = undef_macros ,
86101 extra_compile_args = extra_compile_args ,
87102 extra_link_args = extra_link_args ,
88103 )
@@ -99,6 +114,11 @@ def get_extensions():
99114 'scipy' ,
100115]
101116
117+ # work-around hipify abs paths
118+ include_package_data = True
119+ if torch .cuda .is_available () and torch .version .hip :
120+ include_package_data = False
121+
102122setup (
103123 name = 'torch_cluster' ,
104124 version = __version__ ,
@@ -125,4 +145,5 @@ def get_extensions():
125145 BuildExtension .with_options (no_python_abi_suffix = True , use_ninja = False )
126146 },
127147 packages = find_packages (),
148+ include_package_data = include_package_data ,
128149)
0 commit comments