Skip to content

Commit 42d0d66

Browse files
committed
Define more linker requirements
1 parent 1963f39 commit 42d0d66

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

pytensor/link/basic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,9 @@ class PerformLinker(LocalLinker):
283283
284284
"""
285285

286+
required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only")
287+
incompatible_rewrites: tuple[str, ...] = ("cxx",)
288+
286289
def __init__(
287290
self, allow_gc: bool | None = None, schedule: Callable | None = None
288291
) -> None:
@@ -584,6 +587,9 @@ class JITLinker(PerformLinker):
584587
585588
"""
586589

590+
required_rewrites: tuple[str, ...] = ("minimum_compile",)
591+
incompatible_rewrites: tuple[str, ...] = ()
592+
587593
@abstractmethod
588594
def fgraph_convert(
589595
self, fgraph, order, input_storage, output_storage, storage_map, **kwargs

pytensor/link/vm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,10 @@ class VMLinker(LocalLinker):
812812
813813
"""
814814

815+
# We can only set these correctly after `__init__`, as it depends on `c_thunks`
816+
required_rewrites: tuple[str, ...] = ("minimum_compile",)
817+
incompatible_rewrites: tuple[str, ...] = ()
818+
815819
def __init__(
816820
self,
817821
allow_gc=None,
@@ -834,6 +838,9 @@ def __init__(
834838
self.lazy = lazy
835839
if c_thunks is None:
836840
c_thunks = bool(config.cxx)
841+
if not c_thunks:
842+
self.required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only")
843+
self.incompatible_rewrites: tuple[str, ...] = ("cxx",)
837844
self.c_thunks = c_thunks
838845
self.allow_partial_eval = allow_partial_eval
839846
self.updated_vars = {}

tests/compile/test_mode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,12 @@ def test_NoOutputFromInplace():
5656

5757
def test_including():
5858
mode = Mode(linker="py", optimizer="merge")
59-
assert set(mode._optimizer.include) == {"minimum_compile", "merge"}
59+
assert set(mode._optimizer.include) == {"minimum_compile", "py_only", "merge"}
6060

6161
new_mode = mode.including("fast_compile")
6262
assert set(new_mode._optimizer.include) == {
6363
"minimum_compile",
64+
"py_only",
6465
"merge",
6566
"fast_compile",
6667
}

0 commit comments

Comments
 (0)