Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

-

### Deprecated

- Deprecated `to_torchscript` method due to deprecation of TorchScript in PyTorch ([#21397](https://github.com/Lightning-AI/pytorch-lightning/pull/21397))

### Removed

-
Expand Down
12 changes: 11 additions & 1 deletion src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_6, _TORCHMETRICS_GREATER_EQUAL_0_9_1
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_deprecation, rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
from lightning.pytorch.utilities.types import (
_METRIC,
Expand Down Expand Up @@ -1498,6 +1498,11 @@ def to_torchscript(
scripted you should override this method. In case you want to return multiple modules, we recommend using a
dictionary.

.. deprecated::
``LightningModule.to_torchscript`` has been deprecated in v2.7 and will be removed in v2.8.
TorchScript is deprecated in PyTorch. Use ``torch.export.export()`` for model exporting instead.
See https://pytorch.org/docs/stable/export.html for more information.

Args:
file_path: Path where to save the torchscript. Default: None (no file saved).
method: Whether to use TorchScript's script or trace method. Default: 'script'
Expand Down Expand Up @@ -1536,6 +1541,11 @@ def forward(self, x):
defined or not.

"""
rank_zero_deprecation(
"`LightningModule.to_torchscript` has been deprecated in v2.7 and will be removed in v2.8. "
"TorchScript is deprecated in PyTorch. Use `torch.export.export()` for model exporting instead. "
"See https://pytorch.org/docs/stable/export.html for more information."
)
mode = self.training

if method == "script":
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/helpers/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def test_models(tmp_path, data_class, model_class):
if dm is not None:
trainer.test(model, datamodule=dm)

model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
model.to_torchscript()
if data_class:
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)

Expand Down
54 changes: 41 additions & 13 deletions tests/tests_pytorch/models/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.rank_zero import LightningDeprecationWarning
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.demos.boring_classes import BoringModel
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN
Expand All @@ -36,7 +37,8 @@ def test_torchscript_input_output(modelclass):
if isinstance(model, BoringModel):
model.example_input_array = torch.randn(5, 32)

script = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert isinstance(script, torch.jit.ScriptModule)

model.eval()
Expand All @@ -59,7 +61,8 @@ def test_torchscript_example_input_output_trace(modelclass):
if isinstance(model, BoringModel):
model.example_input_array = torch.randn(5, 32)

script = model.to_torchscript(method="trace")
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(method="trace")
assert isinstance(script, torch.jit.ScriptModule)

model.eval()
Expand All @@ -74,7 +77,8 @@ def test_torchscript_input_output_trace():
"""Test that traced LightningModule forward works with example_inputs."""
model = BoringModel()
example_inputs = torch.randn(1, 32)
script = model.to_torchscript(example_inputs=example_inputs, method="trace")
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(example_inputs=example_inputs, method="trace")
assert isinstance(script, torch.jit.ScriptModule)

model.eval()
Expand All @@ -99,7 +103,8 @@ def test_torchscript_device(device_str):
model = BoringModel().to(device)
model.example_input_array = torch.randn(5, 32)

script = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert next(script.parameters()).device == device
script_output = script(model.example_input_array.to(device))
assert script_output.device == device
Expand All @@ -121,19 +126,22 @@ def test_torchscript_device_with_check_inputs(device_str):

check_inputs = torch.rand(5, 32)

script = model.to_torchscript(method="trace", check_inputs=check_inputs)
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(method="trace", check_inputs=check_inputs)
assert isinstance(script, torch.jit.ScriptModule)


def test_torchscript_retain_training_state():
"""Test that torchscript export does not alter the training mode of original model."""
model = BoringModel()
model.train(True)
script = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert model.training
assert not script.training
model.train(False)
_ = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
_ = model.to_torchscript()
assert not model.training
assert not script.training

Expand All @@ -142,7 +150,8 @@ def test_torchscript_retain_training_state():
def test_torchscript_properties(modelclass):
"""Test that scripted LightningModule has unnecessary methods removed."""
model = modelclass()
script = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert not hasattr(model, "batch_size") or hasattr(script, "batch_size")
assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate")
assert not callable(getattr(script, "training_step", None))
Expand All @@ -153,7 +162,8 @@ def test_torchscript_save_load(tmp_path, modelclass):
"""Test that scripted LightningModule is correctly saved and can be loaded."""
model = modelclass()
output_file = str(tmp_path / "model.pt")
script = model.to_torchscript(file_path=output_file)
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(file_path=output_file)
loaded_script = torch.jit.load(output_file)
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))

Expand All @@ -170,7 +180,8 @@ class DummyFileSystem(LocalFileSystem): ...

model = modelclass()
output_file = os.path.join(_DUMMY_PRFEIX, _PREFIX_SEPARATOR, tmp_path, "model.pt")
script = model.to_torchscript(file_path=output_file)
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(file_path=output_file)

fs = get_filesystem(output_file)
with fs.open(output_file, "rb") as f:
Expand All @@ -184,7 +195,10 @@ def test_torchcript_invalid_method():
model = BoringModel()
model.train(True)

with pytest.raises(ValueError, match="only supports 'script' or 'trace'"):
with (
pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"),
pytest.raises(ValueError, match="only supports 'script' or 'trace'"),
):
model.to_torchscript(method="temp")


Expand All @@ -193,7 +207,10 @@ def test_torchscript_with_no_input():
model = BoringModel()
model.example_input_array = None

with pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"):
with (
pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"),
pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"),
):
model.to_torchscript(method="trace")


Expand Down Expand Up @@ -224,6 +241,17 @@ def forward(self, inputs):

lm = Parent()
assert not lm._jit_is_scripting
script = lm.to_torchscript(method="script")
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = lm.to_torchscript(method="script")
assert not lm._jit_is_scripting
assert isinstance(script, torch.jit.RecursiveScriptModule)


def test_to_torchscript_deprecation():
"""Test that to_torchscript raises a deprecation warning."""
model = BoringModel()
model.example_input_array = torch.randn(5, 32)

with pytest.warns(LightningDeprecationWarning, match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert isinstance(script, torch.jit.ScriptModule)
Loading