diff --git a/devito/passes/iet/langbase.py b/devito/passes/iet/langbase.py index a99fcdcb41..56672552a4 100644 --- a/devito/passes/iet/langbase.py +++ b/devito/passes/iet/langbase.py @@ -12,6 +12,7 @@ from devito.passes import is_on_device from devito.passes.iet.engine import iet_pass from devito.symbolics import Byref, CondNe, SizeOf +from sympy import Ge from devito.tools import as_list, is_integer, prod from devito.types import Symbol, QueueID, Wildcard @@ -56,11 +57,16 @@ class LangBB(metaclass=LangMeta): """ @classmethod - def _get_num_devices(cls): + def _get_num_devices(cls, platform): """ Get the number of accessible devices. + Returns a tuple of (ngpus_symbol, call_to_get_num_devices). """ - raise NotImplementedError + from devito.types import Symbol + ngpus = Symbol(name='ngpus', dtype='int32') + devicetype = as_list(cls[platform]) + call_ngpus = cls['num-devices'](devicetype, retobj=ngpus) + return ngpus, call_ngpus @classmethod def _map_to(cls, f, imask=None, qid=None): @@ -426,9 +432,27 @@ def _make_setdevice_seq(iet, nodes=()): devicetype = as_list(self.langbb[self.platform]) deviceid = self.deviceid + # Add device validation check + ngpus, call_ngpus = self.langbb._get_num_devices(self.platform) + + validation = Conditional( + Ge(deviceid, ngpus), + List(body=[ + Call('printf', ['"%s: Error - device %d >= %d devices\\n"', + self.langbb['name'], deviceid, ngpus]), + Call('exit', [1]) + ]) + ) + + device_setup = List(body=[ + call_ngpus, + validation, + self.langbb['set-device']([deviceid] + devicetype) + ]) + return list(nodes) + [Conditional( CondNe(deviceid, -1), - self.langbb['set-device']([deviceid] + devicetype) + device_setup )] def _make_setdevice_mpi(iet, objcomm, nodes=()): @@ -441,7 +465,21 @@ def _make_setdevice_mpi(iet, objcomm, nodes=()): ngpus, call_ngpus = self.langbb._get_num_devices(self.platform) - osdd_then = self.langbb['set-device']([deviceid] + devicetype) + # Add device validation for explicit device ID + validation = Conditional( + Ge(deviceid, ngpus), + List(body=[ + Call('printf', ['"%s: Error - device %d >= %d devices\\n"', + self.langbb['name'], deviceid, ngpus]), + Call('exit', [1]) + ]) + ) + + osdd_then = List(body=[ + call_ngpus, + validation, + self.langbb['set-device']([deviceid] + devicetype) + ]) osdd_else = self.langbb['set-device']([rank % ngpus] + devicetype) return list(nodes) + [Conditional( diff --git a/tests/test_gpu_openacc.py b/tests/test_gpu_openacc.py index 8c4813db0b..2b5fc7d6fa 100644 --- a/tests/test_gpu_openacc.py +++ b/tests/test_gpu_openacc.py @@ -200,6 +200,30 @@ def test_op_apply(self): assert np.all(np.array(u.data[0, :, :, :]) == time_steps) + def test_device_validation_error_message(self): + """Test that OpenACC device validation includes helpful error messages.""" + grid = Grid(shape=(3, 3, 3)) + + u = TimeFunction(name='u', grid=grid, dtype=np.int32) + + op = Operator(Eq(u.forward, u + 1), platform='nvidiaX', language='openacc') + + # Check that the generated code contains device validation + code = str(op) + + # Should contain device count check + assert 'acc_get_num_devices' in code, "Missing OpenACC device count check" + + # Should contain validation condition + assert 'deviceid >= ngpus' in code, "Missing OpenACC device ID " + \ + "validation condition" + + # Should contain error message + assert 'Error - device' in code, "Missing error message" + + # Should contain exit call to prevent undefined behavior + assert 'exit(1)' in code, "Missing exit call on validation failure" + def iso_acoustic(self, opt): shape = (101, 101) extent = (1000, 1000) diff --git a/tests/test_gpu_openmp.py b/tests/test_gpu_openmp.py index 7150d66eb2..3ed106859d 100644 --- a/tests/test_gpu_openmp.py +++ b/tests/test_gpu_openmp.py @@ -20,8 +20,13 @@ def test_init_omp_env(self): op = Operator(Eq(u.forward, u.dx+1), language='openmp') - assert str(op.body.init[0].body[0]) ==\ - 'if (deviceid != -1)\n{\n omp_set_default_device(deviceid);\n}' + # With device validation, the generated code now includes validation logic + init_code = str(op.body.init[0].body[0]) + assert 'if (deviceid != -1)' in init_code + assert 'int ngpus = omp_get_num_devices()' in init_code + assert 'if (deviceid >= ngpus)' in init_code + assert 'Error - device' in init_code + assert 'omp_set_default_device(deviceid)' in init_code @pytest.mark.parallel(mode=1) def test_init_omp_env_w_mpi(self, mode): @@ -31,14 +36,41 @@ def test_init_omp_env_w_mpi(self, mode): op = Operator(Eq(u.forward, u.dx+1), language='openmp') - assert str(op.body.init[0].body[0]) ==\ - ('if (deviceid != -1)\n' - '{\n omp_set_default_device(deviceid);\n}\n' - 'else\n' - '{\n int rank = 0;\n' - ' MPI_Comm_rank(comm,&rank);\n' - ' int ngpus = omp_get_num_devices();\n' - ' omp_set_default_device((rank)%(ngpus));\n}') + # With device validation, the MPI case also includes validation for explicit + # deviceid + init_code = str(op.body.init[0].body[0]) + assert 'if (deviceid != -1)' in init_code + assert 'int ngpus = omp_get_num_devices()' in init_code + # For MPI case with explicit deviceid, should have validation + assert 'if (deviceid >= ngpus)' in init_code + assert 'Error - device' in init_code + # Should still have MPI rank-based assignment in else clause + assert 'int rank = 0' in init_code + assert 'MPI_Comm_rank(comm,&rank)' in init_code + assert '(rank)%(ngpus)' in init_code + + def test_device_validation_error_message(self): + """Test that device validation includes helpful error messages.""" + grid = Grid(shape=(3, 3, 3)) + + u = TimeFunction(name='u', grid=grid) + + op = Operator(Eq(u.forward, u.dx+1), language='openmp') + + # Check that the generated code contains device validation + code = str(op) + + # Should contain device count check + assert 'omp_get_num_devices()' in code, "Missing device count check" + + # Should contain validation condition + assert 'deviceid >= ngpus' in code, "Missing device ID validation condition" + + # Should contain error message + assert 'Error - device' in code, "Missing error message" + + # Should contain exit call to prevent undefined behavior + assert 'exit(1)' in code, "Missing exit call on validation failure" def test_basic(self): grid = Grid(shape=(3, 3, 3))