Skip to content

Commit 4cfbf9b

Browse files
committed
add necessary tests and update the documentation
1 parent 99662c5 commit 4cfbf9b

File tree

7 files changed

+236
-2
lines changed

7 files changed

+236
-2
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
:orphan:
2+
3+
MUSA training (Advanced)
4+
========================
5+
**Audience:** Users looking to train models on MooreThreads device using MUSA accelerator.
6+
7+
.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
8+
9+
----
10+
11+
MUSAAccelerator Overview
12+
--------------------
13+
torch_musa is an extended Python package based on PyTorch that enables full utilization of MooreThreads graphics cards'
14+
super computing power. Combined with PyTorch, users can take advantage of the strong power of MooreThreads graphics cards
15+
through torch_musa.
16+
17+
PyTorch Lightning automatically finds these weights and ties them after the modules are moved to the
18+
MUSA device under the hood. It will ensure that the weights among the modules are shared but not copied
19+
independently.
20+
21+
22+
Example:
23+
24+
.. code-block:: python
25+
import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F
26+
import pytorch_lightning as L
27+
28+
# Step 1: Define a LightningModule
29+
class LitAutoEncoder(L.LightningModule):
30+
def __init__(self):
31+
super().__init__()
32+
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
33+
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
34+
35+
def forward(self, x):
36+
# in lightning, forward defines the prediction/inference actions
37+
embedding = self.encoder(x)
38+
return embedding
39+
40+
def training_step(self, batch, batch_idx):
41+
# training_step defines the train loop. It is independent of forward
42+
x, _ = batch
43+
x = x.view(x.size(0), -1)
44+
z = self.encoder(x)
45+
x_hat = self.decoder(z)
46+
loss = F.mse_loss(x_hat, x)
47+
self.log("train_loss", loss)
48+
return loss
49+
50+
def configure_optimizers(self):
51+
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
52+
return optimizer
53+
54+
def main():
55+
# -------------------
56+
# Step 2: Define data
57+
# -------------------
58+
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
59+
train, val = data.random_split(dataset, [55000, 5000])
60+
61+
# -------------------
62+
# Step 3: Train
63+
# -------------------
64+
autoencoder = LitAutoEncoder()
65+
# we also support accelerator="auto" or accelerator="musa"
66+
trainer = L.Trainer(accelerator="gpu")
67+
trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val))
68+
69+
if __name__ == '__main__':
70+
71+
main()
72+
----
73+
74+
MUSA
75+
----
76+
MUSA is the library that interfaces PyTorch with the MooreThreads graphics cards.
77+
For more information check out `MUSA <https://github.com/MooreThreads/torch_musa>`_.

src/lightning/fabric/utilities/testing/_runif.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from lightning.fabric.accelerators import XLAAccelerator
2424
from lightning.fabric.accelerators.cuda import num_cuda_devices
2525
from lightning.fabric.accelerators.mps import MPSAccelerator
26+
from lightning.fabric.accelerators.musa import MUSAAccelerator
2627
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
2728
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
2829

@@ -36,6 +37,7 @@ def _runif_reasons(
3637
bf16_cuda: bool = False,
3738
tpu: bool = False,
3839
mps: Optional[bool] = None,
40+
musa: Optional[bool] = None,
3941
skip_windows: bool = False,
4042
standalone: bool = False,
4143
deepspeed: bool = False,
@@ -53,6 +55,8 @@ def _runif_reasons(
5355
tpu: Require that TPU is available.
5456
mps: If True: Require that MPS (Apple Silicon) is available,
5557
if False: Explicitly Require that MPS is not available
58+
musa: If True: Require that MUSA (Device) is available,
59+
if False: Explicitly Require that MUSA is not available
5660
skip_windows: Skip for Windows platform.
5761
standalone: Mark the test as standalone, our CI will run it in a separate process.
5862
This requires that the ``PL_RUN_STANDALONE_TESTS=1`` environment variable is set.
@@ -107,6 +111,12 @@ def _runif_reasons(
107111
reasons.append("MPS")
108112
elif not mps and MPSAccelerator.is_available():
109113
reasons.append("not MPS")
114+
115+
if musa is not None:
116+
if musa and not MUSAAccelerator.is_available():
117+
reasons.append("MUSA")
118+
elif not musa and MUSAAccelerator.is_available():
119+
reasons.append("not MUSA")
110120

111121
if standalone:
112122
if os.getenv("PL_RUN_STANDALONE_TESTS", "0") != "1":

src/lightning/pytorch/utilities/testing/_runif.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def _runif_reasons(
3232
bf16_cuda: bool = False,
3333
tpu: bool = False,
3434
mps: Optional[bool] = None,
35+
musa: Optional[bool] = None,
3536
skip_windows: bool = False,
3637
standalone: bool = False,
3738
deepspeed: bool = False,
@@ -56,6 +57,8 @@ def _runif_reasons(
5657
tpu: Require that TPU is available.
5758
mps: If True: Require that MPS (Apple Silicon) is available,
5859
if False: Explicitly Require that MPS is not available
60+
musa: If True: Require that MUSA (Device) is available,
61+
if False: Explicitly Require that MUSA is not available
5962
skip_windows: Skip for Windows platform.
6063
standalone: Mark the test as standalone, our CI will run it in a separate process.
6164
This requires that the ``PL_RUN_STANDALONE_TESTS=1`` environment variable is set.
@@ -79,6 +82,7 @@ def _runif_reasons(
7982
bf16_cuda=bf16_cuda,
8083
tpu=tpu,
8184
mps=mps,
85+
musa=musa,
8286
skip_windows=skip_windows,
8387
standalone=standalone,
8488
deepspeed=deepspeed,
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
import torch
16+
17+
from lightning.fabric.accelerators.musa import MUSAAccelerator
18+
from lightning.fabric.utilities.exceptions import MisconfigurationException
19+
from tests_fabric.helpers.runif import RunIf
20+
from unittest import mock
21+
from unittest.mock import Mock
22+
23+
_MAYBE_MUSA = "musa" if MUSAAccelerator.is_available() else "cpu"
24+
25+
26+
@mock.patch("lightning.fabric.accelerators.musa.num_musa_devices", return_value=2)
27+
@RunIf(musa=True)
28+
def test_auto_device_count(_):
29+
assert MUSAAccelerator.auto_device_count() == 2
30+
31+
32+
@RunIf(musa=True)
33+
def test_musa_availability():
34+
assert MUSAAccelerator.is_available()
35+
36+
37+
def test_init_device_with_wrong_device_type():
38+
with pytest.raises(ValueError, match="Device should be MUSA"):
39+
MUSAAccelerator().setup_device(torch.device("cpu"))
40+
41+
42+
@RunIf(musa=True)
43+
@pytest.mark.parametrize(
44+
("devices", "expected"),
45+
[
46+
([], []),
47+
([1], [torch.device(_MAYBE_MUSA, 1)]),
48+
([3, 1], [torch.device(_MAYBE_MUSA, 3), torch.device(_MAYBE_MUSA, 1)]),
49+
],
50+
)
51+
def test_get_parallel_devices(devices, expected):
52+
assert MUSAAccelerator.get_parallel_devices(devices) == expected
53+
54+
@mock.patch("torch.musa.set_device")
55+
@mock.patch("torch.musa.get_device_capability", return_value=(7, 0))
56+
def test_set_cuda_device(_, set_device_mock):
57+
device = torch.device(_MAYBE_MUSA, 1)
58+
MUSAAccelerator().setup_device(device)
59+
set_device_mock.assert_called_once_with(device)

tests/tests_fabric/utilities/test_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from lightning_utilities.core.imports import RequirementCache
1111

1212
import lightning.fabric
13-
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
13+
from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator, MUSAAccelerator
1414
from lightning.fabric.plugins.environments import LightningEnvironment
1515
from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy
1616
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
@@ -40,7 +40,7 @@ def spawn_launch(fn, parallel_devices):
4040
"""Copied from ``tests_pytorch.core.test_results.spawn_launch``"""
4141
# TODO: the accelerator and cluster_environment should be optional to just launch processes, but this requires lazy
4242
# initialization to be implemented
43-
device_to_accelerator = {"cuda": CUDAAccelerator, "mps": MPSAccelerator, "cpu": CPUAccelerator}
43+
device_to_accelerator = {"cuda": CUDAAccelerator, "mps": MPSAccelerator, "cpu": CPUAccelerator, "musa": MUSAAccelerator}
4444
accelerator_cls = device_to_accelerator[parallel_devices[0].type]
4545
strategy = DDPStrategy(
4646
accelerator=accelerator_cls(),
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections import namedtuple
16+
from unittest import mock
17+
18+
import pytest
19+
import torch
20+
21+
import tests_pytorch.helpers.pipelines as tpipes
22+
from lightning.pytorch import Trainer
23+
from lightning.pytorch.accelerators import MUSAAccelerator
24+
from lightning.pytorch.demos.boring_classes import BoringModel
25+
from tests_pytorch.helpers.runif import RunIf
26+
27+
28+
@RunIf(musa=True)
29+
def test_musa_availability():
30+
assert MUSAAccelerator.is_available()
31+
32+
33+
def test_warning_if_musa_not_used(musa_count_1):
34+
with pytest.warns(UserWarning, match="GPU available but not used"):
35+
Trainer(accelerator="cpu")
36+
37+
38+
@RunIf(musa=True)
39+
@pytest.mark.parametrize("accelerator_value", ["musa", MUSAAccelerator()])
40+
def test_trainer_musa_accelerator(accelerator_value):
41+
trainer = Trainer(accelerator=accelerator_value, devices=1)
42+
assert isinstance(trainer.accelerator, MUSAAccelerator)
43+
assert trainer.num_devices == 1
44+
45+
@RunIf(musa=True)
46+
@mock.patch("torch.musa.set_device")
47+
def test_set_musa_device(set_device_mock, tmp_path, monkeypatch):
48+
monkeypatch.setenv("MUSA_DEVICE_ORDER", "PCI_BUS_ID") # 或其他需要的值
49+
model = BoringModel()
50+
trainer = Trainer(
51+
default_root_dir=tmp_path,
52+
fast_dev_run=True,
53+
accelerator="gpu",
54+
devices=1,
55+
enable_checkpointing=False,
56+
enable_model_summary=False,
57+
enable_progress_bar=False,
58+
)
59+
trainer.fit(model)
60+
set_device_mock.assert_called_once()

tests/tests_pytorch/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,30 @@ def cuda_count_2(monkeypatch):
203203
def cuda_count_4(monkeypatch):
204204
mock_cuda_count(monkeypatch, 4)
205205

206+
def mock_musa_count(monkeypatch, n: int) -> None:
207+
monkeypatch.setattr(lightning.fabric.accelerators.musa, "num_musa_devices", lambda: n)
208+
monkeypatch.setattr(lightning.pytorch.accelerators.musa, "num_musa_devices", lambda: n)
209+
210+
211+
@pytest.fixture
212+
def musa_count_0(monkeypatch):
213+
mock_musa_count(monkeypatch, 0)
214+
215+
216+
@pytest.fixture
217+
def musa_count_1(monkeypatch):
218+
mock_musa_count(monkeypatch, 1)
219+
220+
221+
@pytest.fixture
222+
def musa_count_2(monkeypatch):
223+
mock_musa_count(monkeypatch, 2)
224+
225+
226+
@pytest.fixture
227+
def musa_count_4(monkeypatch):
228+
mock_musa_count(monkeypatch, 4)
229+
206230

207231
def mock_mps_count(monkeypatch, n: int) -> None:
208232
monkeypatch.setattr(lightning.fabric.accelerators.mps, "_get_all_available_mps_gpus", lambda: [0] if n > 0 else [])

0 commit comments

Comments
 (0)