Skip to content

Commit 7c5a6b7

Browse files
authored
Enable ROCm build (#149)
1 parent 2738738 commit 7c5a6b7

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

csrc/version.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
#include "macros.h"
77

88
#ifdef WITH_CUDA
9+
#ifdef USE_ROCM
10+
#include <hip/hip_version.h>
11+
#else
912
#include <cuda.h>
1013
#endif
14+
#endif
1115

1216
#ifdef _WIN32
1317
#ifdef WITH_PYTHON
@@ -23,7 +27,11 @@ PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
2327
namespace cluster {
2428
CLUSTER_API int64_t cuda_version() noexcept {
2529
#ifdef WITH_CUDA
30+
#ifdef USE_ROCM
31+
return HIP_VERSION;
32+
#else
2633
return CUDA_VERSION;
34+
#endif
2735
#else
2836
return -1;
2937
#endif

setup.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
__version__ = '1.6.0'
1515
URL = 'https://github.com/rusty1s/pytorch_cluster'
1616

17-
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
17+
WITH_CUDA = False
18+
if torch.cuda.is_available():
19+
WITH_CUDA = CUDA_HOME is not None or torch.version.hip
20+
1821
suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu']
1922
if os.getenv('FORCE_CUDA', '0') == '1':
2023
suffices = ['cuda', 'cpu']
@@ -31,9 +34,12 @@ def get_extensions():
3134

3235
extensions_dir = osp.join('csrc')
3336
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
37+
# remove generated 'hip' files, in case of rebuilds
38+
main_files = [path for path in main_files if 'hip' not in path]
3439

3540
for main, suffix in product(main_files, suffices):
3641
define_macros = [('WITH_PYTHON', None)]
42+
undef_macros = []
3743

3844
if sys.platform == 'win32':
3945
define_macros += [('torchcluster_EXPORTS', None)]
@@ -63,9 +69,17 @@ def get_extensions():
6369
define_macros += [('WITH_CUDA', None)]
6470
nvcc_flags = os.getenv('NVCC_FLAGS', '')
6571
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
66-
nvcc_flags += ['--expt-relaxed-constexpr', '-O2']
72+
nvcc_flags += ['-O2']
6773
extra_compile_args['nvcc'] = nvcc_flags
6874

75+
if torch.version.hip:
76+
# USE_ROCM was added to later versions of PyTorch
77+
# Define here to support older PyTorch versions as well:
78+
define_macros += [('USE_ROCM', None)]
79+
undef_macros += ['__HIP_NO_HALF_CONVERSIONS__']
80+
else:
81+
nvcc_flags += ['--expt-relaxed-constexpr']
82+
6983
name = main.split(os.sep)[-1][:-4]
7084
sources = [main]
7185

@@ -83,6 +97,7 @@ def get_extensions():
8397
sources,
8498
include_dirs=[extensions_dir],
8599
define_macros=define_macros,
100+
undef_macros=undef_macros,
86101
extra_compile_args=extra_compile_args,
87102
extra_link_args=extra_link_args,
88103
)
@@ -99,6 +114,11 @@ def get_extensions():
99114
'scipy',
100115
]
101116

117+
# work-around hipify abs paths
118+
include_package_data = True
119+
if torch.cuda.is_available() and torch.version.hip:
120+
include_package_data = False
121+
102122
setup(
103123
name='torch_cluster',
104124
version=__version__,
@@ -125,4 +145,5 @@ def get_extensions():
125145
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
126146
},
127147
packages=find_packages(),
148+
include_package_data=include_package_data,
128149
)

0 commit comments

Comments
 (0)