diff --git a/.github/workflows/build-test-linux-x86_64.yml b/.github/workflows/build-test-linux-x86_64.yml index 4c0f31b256..3918b0f839 100644 --- a/.github/workflows/build-test-linux-x86_64.yml +++ b/.github/workflows/build-test-linux-x86_64.yml @@ -136,7 +136,7 @@ jobs: cd tests/py cd dynamo python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_* - python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/ + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_* python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/ popd @@ -229,6 +229,8 @@ jobs: pushd . cd tests/py/dynamo python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_* + python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/test_001_* + popd L1-dynamo-compile-tests: diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index 0b7a76cbd7..980884d7a7 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -135,7 +135,7 @@ jobs: pushd . cd tests/py/dynamo ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_* - ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/ + ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_* ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/ popd @@ -219,6 +219,7 @@ jobs: pushd . cd tests/py/dynamo ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_* + ../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/test_001_* popd L1-dynamo-compile-tests: diff --git a/examples/dynamo/low_cpu_memory_compilation.py b/examples/dynamo/low_cpu_memory_compilation.py new file mode 100644 index 0000000000..c508d3f0b4 --- /dev/null +++ b/examples/dynamo/low_cpu_memory_compilation.py @@ -0,0 +1,130 @@ +""" + +.. _low_cpu_memory_compilation: + +Low CPU Memory Compilation Example +================================== + +This example demonstrates compiling a model with a bounded CPU (host) memory +budget using Torch-TensorRT Dynamo. Limiting host RAM use is helpful on +memory-constrained machines or when compiling very large models. + +Key notes: +- The toy model below has roughly 430 MB of parameters. We set the CPU + memory budget to 2 GiB. At compile time, only about 900 MB of host RAM + may remain available. We expect at most 403 * 4 = 1612 MB of memory to be used by the model. + So the model is partitioned into two subgraphs to fit the memory budget. + +- Performance impact varies by model. When the number of TensorRT engines + created is small, the impact is typically minimal. + +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt as torchtrt +from torch_tensorrt.dynamo.conversion import CompilationSettings + + +class net(nn.Module): + def __init__(self): + super().__init__() + # Intentionally large layers to stress host memory during compilation. + self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1) + self.bn1 = nn.BatchNorm2d(4096) + self.conv2 = nn.Conv2d(4096, 1024, 3, padding=1) + self.bn2 = nn.BatchNorm2d(1024) + self.fc1 = nn.Linear(1024 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + +model = net().eval() +model.to("cuda") +inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")] + +enabled_precisions = {torch.float} +use_python_runtime = False + +compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + "enable_resource_partitioning": True, + "cpu_memory_budget": 2 * 1024 * 1024 * 1024, # 2 GiB in bytes +} + +settings = CompilationSettings(**compilation_options) +with torchtrt.dynamo.Debugger( + log_level="debug", + logging_dir="/home/profile/logging/moe", + engine_builder_monitor=False, +): + + exp_program = torch.export.export(model, tuple(inputs)) + trt_gm = torchtrt.dynamo.compile( + exp_program, + inputs=inputs, + **compilation_options, + ) + + # Expect two back-to-back TensorRT engines due to partitioning under the memory budget. + print(trt_gm) + + +""" +You should be able to see two back-to-back TensorRT engines in the graph + +Graph Structure: + + Inputs: List[Tensor: (1, 1024, 224, 224)@float32] + ... + TRT Engine #1 - Submodule name: _run_on_acc_0_resource_split_0 + Engine Inputs: List[Tensor: (1, 1024, 224, 224)@float32] + Number of Operators in Engine: 9 + Engine Outputs: List[Tensor: (1, 1024, 112, 112)@float32] + ... + TRT Engine #2 - Submodule name: _run_on_acc_0_resource_split_1 + Engine Inputs: List[Tensor: (1, 1024, 112, 112)@float32] + Number of Operators in Engine: 3 + Engine Outputs: List[Tensor: (1, 10)@float32] + ... + Outputs: List[Tensor: (1, 10)@float32] + + ------------------------- Aggregate Stats ------------------------- + + Average Number of Operators per TRT Engine: 6.0 + Most Operators in a TRT Engine: 9 + + ********** Recommendations ********** + + - For minimal graph segmentation, select min_block_size=9 which would generate 1 TRT engine(s) + - For moderate graph segmentation, select min_block_size=6 which would generate 1 TRT engine(s) + - The current level of graph segmentation is equivalent to selecting min_block_size=3 which generates 2 TRT engine(s) +GraphModule( + (_run_on_acc_0_resource_split_0): TorchTensorRTModule() + (_run_on_acc_0_resource_split_1): TorchTensorRTModule() +) + + + +def forward(self, x): + x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) + _run_on_acc_0_resource_split_0 = self._run_on_acc_0_resource_split_0(x); x = None + _run_on_acc_0_resource_split_1 = self._run_on_acc_0_resource_split_1(_run_on_acc_0_resource_split_0); _run_on_acc_0_resource_split_0 = None + return pytree.tree_unflatten((_run_on_acc_0_resource_split_1,), self._out_spec) +) +""" diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 129f9e3d38..c7aec9a684 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -40,6 +40,9 @@ post_lowering, pre_export_lowering, ) +from torch_tensorrt.dynamo.partitioning._resource_partitioner import ( + resource_partition, +) from torch_tensorrt.dynamo.utils import ( deallocate_module, get_cpu_memory_usage, @@ -105,6 +108,8 @@ def cross_compile_for_windows( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, + cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows @@ -179,6 +184,8 @@ def cross_compile_for_windows( tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited. + cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -334,6 +341,8 @@ def cross_compile_for_windows( "tiling_optimization_level": tiling_optimization_level, "l2_limit_for_tiling": l2_limit_for_tiling, "use_distributed_mode_trace": use_distributed_mode_trace, + "enable_resource_partitioning": enable_resource_partitioning, + "cpu_memory_budget": cpu_memory_budget, } # disable the following settings is not supported for cross compilation for windows feature @@ -448,6 +457,8 @@ def compile( autocast_calibration_dataloader: Optional[ torch.utils.data.DataLoader ] = _defaults.AUTOCAST_CALIBRATION_DATALOADER, + cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET, + enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -532,6 +543,8 @@ def compile( autocast_max_output_threshold (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats. If not provided, infinity will be used. Default is None. autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None. + enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited. + cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -732,6 +745,8 @@ def compile( "autocast_max_output_threshold": autocast_max_output_threshold, "autocast_max_depth_of_reduction": autocast_max_depth_of_reduction, "autocast_calibration_dataloader": autocast_calibration_dataloader, + "enable_resource_partitioning": enable_resource_partitioning, + "cpu_memory_budget": cpu_memory_budget, } logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB") settings = CompilationSettings(**compilation_options) @@ -905,6 +920,12 @@ def preserve_module_specs( require_full_compilation=settings.require_full_compilation, ) + if settings.enable_resource_partitioning: + partitioned_module = resource_partition( + partitioned_module, + cpu_memory_budget=settings.cpu_memory_budget, + ) + dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators # The global partitioner leaves non-TRT nodes as-is @@ -928,6 +949,7 @@ def preserve_module_specs( for attr in dir(gm): if attr.startswith("_frozen_param"): delattr(gm, attr) + for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -1390,7 +1412,7 @@ def convert_exported_program_to_serialized_trt_engine( ) flattened_input_list = get_flat_args_with_check( - exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore + exported_program, list(trt_arg_inputs), trt_kwarg_inputs )[0] try: diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 7bdeaa0382..fd093d402f 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -64,6 +64,8 @@ AUTOCAST_MAX_OUTPUT_THRESHOLD = 512 AUTOCAST_MAX_DEPTH_OF_REDUCTION = None AUTOCAST_CALIBRATION_DATALOADER = None +ENABLE_RESOURCE_PARTITIONING = False +CPU_MEMORY_BUDGET = None if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 6027219e5d..8bc3e06faa 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -14,6 +14,7 @@ AUTOCAST_MAX_DEPTH_OF_REDUCTION, AUTOCAST_MAX_OUTPUT_THRESHOLD, CACHE_BUILT_ENGINES, + CPU_MEMORY_BUDGET, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, @@ -22,6 +23,7 @@ ENABLE_AUTOCAST, ENABLE_CROSS_COMPILE_FOR_WINDOWS, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + ENABLE_RESOURCE_PARTITIONING, ENABLE_WEIGHT_STREAMING, ENABLED_PRECISIONS, ENGINE_CAPABILITY, @@ -168,6 +170,8 @@ class CompilationSettings: autocast_calibration_dataloader: Optional[torch.utils.data.DataLoader] = ( AUTOCAST_CALIBRATION_DATALOADER ) + enable_resource_partitioning: bool = ENABLE_RESOURCE_PARTITIONING + cpu_memory_budget: int = CPU_MEMORY_BUDGET def __getstate__(self) -> dict[str, Any]: from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index e2f544c2a7..72d0be42c7 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -230,7 +230,7 @@ def partition_graph(self) -> torch.fx.GraphModule: # Tag the accelerated nodes and split the graph accordingly self.tag(subgraphs) - return self.split() + return self.split(remove_tag=True) def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: """Generates starter nodes for partitioning + segmentation""" diff --git a/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py new file mode 100644 index 0000000000..b55fc0d873 --- /dev/null +++ b/py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py @@ -0,0 +1,184 @@ +from collections import defaultdict +from functools import lru_cache +from typing import Any, Callable, Dict, List, Set, Tuple + +import torch +from torch.fx.passes.utils.matcher_utils import SubgraphMatcher +from torch.ops import aten + +ATOMIC_SUBGRAPHS = [] + + +def register_atomic_subgraph( + init_args: Tuple[Any, ...] = tuple(), + is_core_aten: bool = False, +) -> Callable[[torch.nn.Module], torch.nn.Module]: + + def decorator(subgraph: torch.nn.Module) -> torch.nn.Module: + ATOMIC_SUBGRAPHS.append((subgraph, init_args, is_core_aten)) + return subgraph + + return decorator + + +@register_atomic_subgraph(init_args=(aten.silu.default,), is_core_aten=True) +@register_atomic_subgraph(init_args=(aten.gelu.default,), is_core_aten=True) +@register_atomic_subgraph(init_args=(aten.relu.default,), is_core_aten=True) +@register_atomic_subgraph(init_args=(aten.sigmoid.default,), is_core_aten=True) +class ConvBNActivation(torch.nn.Module): # type: ignore[misc] + def __init__(self, activation: torch._ops.OpOverload) -> None: + super().__init__() + self.activation = activation + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + bn_weight: torch.Tensor, + bn_bias: torch.Tensor, + running_mean: torch.Tensor, + running_var: torch.Tensor, + momentum: float, + eps: float, + ) -> torch.Tensor: + x = aten.convolution.default( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + x = aten._native_batch_norm_legit_no_training.default( + x, bn_weight, bn_bias, running_mean, running_var, momentum, eps + )[0] + x = self.activation(x) + return x + + +@register_atomic_subgraph(init_args=(aten.silu.default,), is_core_aten=True) +@register_atomic_subgraph(init_args=(aten.gelu.default,), is_core_aten=True) +@register_atomic_subgraph(init_args=(aten.relu.default,), is_core_aten=True) +@register_atomic_subgraph(init_args=(aten.sigmoid.default,), is_core_aten=True) +class ConvActivation(torch.nn.Module): # type: ignore[misc] + def __init__(self, activation: torch._ops.OpOverload) -> None: + super().__init__() + self.activation = activation + + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + ) -> torch.Tensor: + x = aten.convolution.default( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + x = self.activation(x) + return x + + +@register_atomic_subgraph(init_args=(), is_core_aten=True) +class MulAdd(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor + ) -> torch.Tensor: + x = aten.mul.Tensor(x, weight) + x = aten.add.Tensor(x, bias) + return x + + +@register_atomic_subgraph(init_args=(), is_core_aten=True) +class MulMul(torch.nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + + def forward( + self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor + ) -> torch.Tensor: + x = aten.mul.Tensor(x, y) + x = aten.mul.Tensor(x, z) + return x + + +def get_node_in_fusion_pattern( + graph: torch.fx.Graph, +) -> Dict[torch.fx.Node, Set[torch.fx.Node]]: + """ + This function gets the nodes map of the fusion pattern from the graph. + Key: node that appears in the fusion pattern + Value: the list of nodes that should be fused together + """ + fusion_nodes = defaultdict(set) + for compiled_pattern_graph in get_compiled_atomic_subgraphs(): + subgraph_matcher = SubgraphMatcher(compiled_pattern_graph.graph) + match_result = subgraph_matcher.match(graph) + for match in match_result: + fusion_group = { + node + for node in match.nodes_map.values() + if node + and type(node) == torch.fx.Node + and node.op == "call_function" + and node not in match.placeholder_nodes + } + for node in fusion_group: + fusion_nodes[node].update(fusion_group) + + return fusion_nodes + + +def get_compiled_atomic_subgraphs() -> List[torch.fx.GraphModule]: + """ + This function gets the compiled atomic subgraphs from the graph. + LRU cache the result to avoid recompiling the same pattern multiple times. + """ + compiled_atomic_subgraphs = [] + for pattern, init_args, is_core_aten in ATOMIC_SUBGRAPHS: + pattern_graph = trace_atomic_graph(pattern, init_args, is_core_aten) + if not is_core_aten: + # TODO: Add decomposition and lowering if is_core_aten is False + raise NotImplementedError( + "Atomic subgraphs are not supported for non-aten subgraphs yet." + ) + compiled_atomic_subgraphs.append(pattern_graph) + return compiled_atomic_subgraphs + + +@lru_cache(maxsize=None) +def trace_atomic_graph( + graph: torch.nn.Module, init_args: Any, is_core_aten: bool = True +) -> torch.fx.GraphModule: + if is_core_aten: + return torch.fx.symbolic_trace(graph(*init_args)) + else: + raise NotImplementedError( + "Resource partitioner currently does not support unlowered atomic subgraphs" + ) diff --git a/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py new file mode 100644 index 0000000000..25f7f7169d --- /dev/null +++ b/py/torch_tensorrt/dynamo/partitioning/_resource_partitioner.py @@ -0,0 +1,570 @@ +"""Resource-aware graph partitioner for TensorRT compilation. + +This module refines an existing capability-based partitioning (accelerated vs +non-accelerated subgraphs) by further splitting accelerated subgraphs to meet +host CPU memory constraints during TensorRT engine building. + +High-level algorithm +-------------------- +Given an original `torch.fx.GraphModule` and a capability-partitioned +`GraphModule` (produced earlier in the pipeline), we: + +1) Reconstruct subgraphs on the original graph + - Iterate over the capability-partitioned module to determine which original + nodes belong to which subgraph (accelerated or not). + - Preserve fusion groups discovered in each subgraph so that all nodes in a fusion + group remain in the same subgraph and not be split across subgraphs. + - Verify subgraphs respect topological order. This is to ensure the validity of the subgraphs. + - Reconstruting subgraphs from partitioned module is easier than building nasted partitioned graph modules and flattening them later. + +2) Estimate memory cost of each possible subgraphs + - Compute a per-subgraph "size" by traversing the graph to find weights + (get_attr) reachable from its nodes and summing tensor bytes. + - Use a set to record the visited nodes and avoid double counting shared parameters across subgraphs. + + +4) Split large accelerated subgraphs + - While a subgraph exceeds the per-engine budget, split it into two or more subgraphs. + - Move nodes incrementally from the front of the original subgraph into a + new left subgraph, repeatedly validating/correcting topological, partitioning, and + dependency constraints. + - Ensure we never split across a fusion group; when a split would break a + fusion, we backtrack dependencies and move the entire fusion and related nodes into the left + side. + - Continue until the left subgraph fits the budget + - Repeat the process for the right subgraph until all subgraphs fit the budget. + +5) Finalize + - After splitting, assert all fusion groups reside in a single subgraph. + - Tag nodes and produce a `GraphModule` where each subgraph becomes either a + TRT engine (accelerated) or runs in Torch (non-accelerated). + +Notes +----- +- The budget is a heuristic bound. If the total model size exceeds 40x the + per-engine budget, we fail early with a clear error suggesting remedies. +""" + +import logging +from typing import Dict, List, Optional, Set, Tuple + +import psutil +import torch +from torch.fx.experimental.const_fold import _inline_module +from torch.fx.passes.splitter_base import Subgraph, _SplitterBase +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS +from torch_tensorrt.dynamo.partitioning._atomic_subgraphs import ( + get_node_in_fusion_pattern, +) + +logger = logging.getLogger(__name__) + +MAX_NUM_OF_ENGINES = 50 +ENGINE_COMPILATION_MEMORY_USAGE_MULTIPLIER = 4 + + +class ResourcePartitioner(_SplitterBase): # type: ignore + """Refine capability-based subgraphs to meet host CPU memory constraints. + + This partitioner takes: + - an original `torch.fx.GraphModule` (`module`) + - a capability-partitioned `GraphModule` (`partitioned_module`) containing + submodules that delineate accelerated vs non-accelerated regions + - a CPU memory budget in bytes (`cpu_memory_budget`) + + It maps nodes from `module` into subgraphs according to `partitioned_module` + and then splits oversized accelerated subgraphs so that each resulting TRT + engine's estimated size fits within a conservative budget derived from + available CPU memory or predefined CPU budget. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + cpu_memory_budget: Optional[int], + submodule_name: str, + ): + + assert isinstance(module, torch.fx.GraphModule) + + self.module = module + used_rss: int = psutil.Process().memory_info().rss + self.remaining_memory_budget = ( + cpu_memory_budget - used_rss + if cpu_memory_budget is not None + else psutil.virtual_memory().available + ) + self.not_set_limit = cpu_memory_budget is None + self.resource_split_count = 0 + self.submodule_name = submodule_name + self.deps = self.find_deps() + + self._node_submodule_map: Dict[str, str] = {} + self._return_tuple = False + self.fusion_patterns: Dict[torch.fx.Node, Set[torch.fx.Node]] = {} + + def partition_graph(self) -> torch.fx.GraphModule: + """Build the final partitioned `GraphModule` honoring memory constraints. + + Steps: + - Build subgraph assignments from the capability-partitioned module + - Split oversized accelerated subgraphs based on memory budget + - Tag nodes and construct the final split graph + + Returns: + torch.fx.GraphModule: A graph split into subgraphs based on capability partitioning and memory constraints. + """ + # Delegate nodes based on operator coverage + subgraphs = self.put_nodes_into_subgraphs() + sizes = self.size_of_subgraphs(subgraphs) + if ( + sum(sizes) * ENGINE_COMPILATION_MEMORY_USAGE_MULTIPLIER + < self.remaining_memory_budget + ): + return self.module + + subgraphs = self.break_subgraphs( + subgraphs, subgraph_size_budget=self.calculate_size_budget() + ) + + if len(subgraphs) == 1: + return self.module + + # Set the number of TRT engines to be generated + self.num_trt_accelerated_subgraphs = len([s for s in subgraphs if s.is_acc]) + + # Tag the accelerated nodes and split the graph accordingly + self.tag(subgraphs) + + gm = self.split(remove_tag=True) + + return gm + + def tag(self, subgraphs: list[Subgraph]) -> None: + self.tags = [] + for subgraph in subgraphs: + tag = f"{self.submodule_name}_resource_split_{self.resource_split_count}" + self.resource_split_count += 1 + self.tags.append(tag) + for node in subgraph.nodes: + node.tag = tag + self._node_submodule_map[node.name] = tag + + def put_nodes_into_subgraphs(self) -> list[Subgraph]: + """ + Put the nodes into the subgraphs and erase the tag from previous partitioner if it exists. + Returns: + list[Subgraph]: Ordered subgraphs consisting of nodes in `module`. + """ + + nodes = [] + for node in self.module.graph.nodes: + if hasattr(node, "tag"): + del node.tag + if node.op in CALLABLE_NODE_OPS: + nodes.append(node) + subgraphs = [Subgraph(is_acc=True, nodes=nodes)] + self.fusion_patterns = get_node_in_fusion_pattern(self.module.graph) + + return subgraphs + + def check_topological_order(self, subgraphs: List[Subgraph]) -> bool: + """Return True if subgraphs are in a valid topological order. + + Each node's dependencies must appear in earlier subgraphs or earlier + positions within the same subgraph. Subgraphs should be topologically ordered to ensure the validity of the subgraphs. + """ + visited_nodes: set[torch.fx.Node] = set() + for subgraph in subgraphs: + for node in subgraph.nodes: + if self.deps[node] > visited_nodes: + return False + visited_nodes.add(node) + return True + + def calculate_size_budget( + self, + engine_compilation_memory_usage_multiplier: int = ENGINE_COMPILATION_MEMORY_USAGE_MULTIPLIER, + ) -> int: + """Compute the per-engine size budget in bytes. + + Uses explicit `cpu_memory_budget` minus used RSS + divided by a safety multiplier. + + Args: + engine_compilation_memory_usage_multiplier: Safety divisor applied to + available memory to approximate a per-engine budget. By default we assume TensorRT + compilation requires up to 4x the model's size. + + Returns: + int: Budget in bytes for a single accelerated subgraph. + """ + + return ( + self.remaining_memory_budget // engine_compilation_memory_usage_multiplier + ) + + def break_subgraphs( + self, subgraphs: List[Subgraph], subgraph_size_budget: int + ) -> List[Subgraph]: + """Split oversized accelerated subgraphs until they fit within budget. + + - Compute sizes for each subgraph (in bytes of parameters reachable from + that subgraph). + - If the sum of all sizes is catastrophically larger than budget + (threshold 40x), raise a ValueError with guidance. + - For any subgraph whose size exceeds `subgraph_size_budget`, iteratively + split it using `break_subgraph_by_size` and append resulting segments. + - Validate that fusion groups remain intact post splitting. + + Args: + subgraphs: Ordered list of subgraphs from capability partitioning. + subgraph_size_budget: Target maximum size per accelerated subgraph. + + Returns: + List[Subgraph]: New list of subgraphs after resource-aware splitting. + """ + + new_subgraphs = [] + # We throw an error if the remaining memory is almost empty compared to the model size. + # i.e. if the remaining memory is 4G (budget is 1G) the model size is greater than 40G, we stop the compilation. + sizes = self.size_of_subgraphs(subgraphs) + if sum(sizes) > subgraph_size_budget * MAX_NUM_OF_ENGINES: + if self.not_set_limit: + raise ValueError( + "The system memory is too constrained to compile the model without severe perf degradation. Consider setting offload_module_to_cpu=False to save more CPU memory." + ) + else: + raise ValueError( + "CPU memory budget is too small to compile the model. " + + f"CPU memory budget: {self.remaining_memory_budget // (1024 * 1024)} MB, Model size: {sum(sizes) // (1024 * 1024)} MB. " + + "Consider setting cpu_memory_budget to a larger value." + ) + for subgraph, size in zip(subgraphs, sizes): + + while size > subgraph_size_budget: + broken_subgraphs, size_0, size_1 = self.break_subgraph_by_size( + subgraph, subgraph_size_budget + ) + size = size_1 + new_subgraphs.append(broken_subgraphs[0]) + subgraph = broken_subgraphs[1] + + if len(subgraph.nodes) != 0: + new_subgraphs.append(subgraph) + + self._verify_all_fusion_nodes_in_same_subgraph(new_subgraphs) + + return new_subgraphs + + def _verify_all_fusion_nodes_in_same_subgraph( + self, subgraphs: List[Subgraph] + ) -> None: + """Assert that every fusion group is contained in exactly one subgraph.""" + node_to_subgraph = {} + for i, s in enumerate(subgraphs): + for n in s.nodes: + node_to_subgraph[n] = i + + fusion_nodes_map_list = [ + len({node_to_subgraph[n] for n in ns}) == 1 + for ns in self.fusion_patterns.values() + ] # fusion nodes must be in the same subgraph + + assert all( + fusion_nodes_map_list + ), "All fusion nodes must be in the same subgraph" + logger.info("All fusion nodes are in the same subgraph.") + + def break_subgraph_by_size( + self, subgraph: Subgraph, size_to_break: int + ) -> Tuple[List[Subgraph], int, int]: + """Split a single oversized subgraph into two valid subgraphs. + + Moves nodes from the head of `subgraph` into a new left segment until + the left segment's estimated size exceeds `size_to_break`. During the + process we: + - Repeatedly validate/correct topological placement + - Detect and avoid splitting fusion groups by moving all fused nodes + (and their producer chain) into the left segment + + Returns: + (segments, size_left, size_right): + segments[0] is the new left subgraph, segments[1] is the residual + right subgraph. Sizes are estimated parameter bytes of each. + """ + all_nodes = subgraph.nodes + device_ordinal = subgraph.device_ordinal + new_subgraphs = [ + Subgraph( + is_acc=True, + nodes=[], + device_ordinal=device_ordinal, + ), + Subgraph( + is_acc=True, + nodes=all_nodes, + device_ordinal=device_ordinal, + ), + ] + + # We break the subgraph until the left subgraph fits the budget. + while True: + # Set a step size proportional to the size of the subgraph to make the algorithm more efficient. + # This reduce the time complexity from O(N**2) to O(N). The max number of steps is 50. + # Note: we want the first step size to be 1. + step_size = ( + 1 if not new_subgraphs[0].nodes else max(1, len(all_nodes) // 50) + ) + new_subgraphs = self.step_and_validate(new_subgraphs, step_size) + size_0, size_1 = self.size_of_subgraphs(new_subgraphs) + if size_0 > size_to_break or size_0 > size_1: + break + + return new_subgraphs, size_0, size_1 + + def step_and_validate( + self, new_subgraphs: List[Subgraph], step_size: int = 1 + ) -> List[Subgraph]: + """Advance the split by `step_size` nodes, then add more nodes to the left subgraph if rules are broken. + There are two rules to check: + 1. The subgraphs should be ordered in a way that is safely to partition. + This is checked by validate_and_correct_subgraphs. Check that function for more details. + 2. The subgraphs should not break any fusion groups. + - Move `step_size` nodes from the right to the left subgraph. + - Run validation/correction to ensure a legal partitioning placement. + - Get all leaf nodes in the left subgraph and check whether any of them are in a fusion group. + - If the move splits a fusion group, migrate the entire fusion into the left subgraph. + + Returns: + List[Subgraph]: Updated pair of subgraphs after stabilization. + """ + + for _ in range(step_size): + new_subgraphs[0].nodes.append(new_subgraphs[1].nodes.pop(0)) + + while True: + new_subgraphs = self.validate_and_correct_subgraphs(new_subgraphs) + nodes_in_first_subgraph = set(new_subgraphs[0].nodes) + nodes_in_second_subgraph = set(new_subgraphs[1].nodes) + leaf_node = self.get_leaf_node(nodes_in_first_subgraph) + broken_fusion = self.step_if_break_fusion( + new_subgraphs, + leaf_node, + nodes_in_first_subgraph, + nodes_in_second_subgraph, + ) + if not broken_fusion or len(new_subgraphs[1].nodes) == 0: + break + + return new_subgraphs + + def step_if_break_fusion( + self, + subgraphs: List[Subgraph], + leaf_nodes: set[torch.fx.Node], + nodes_in_first_subgraph: set[torch.fx.Node], + nodes_in_second_subgraph: set[torch.fx.Node], + ) -> bool: + """Detect a fusion split and migrate fused nodes to the left subgraph. + + Given the current split boundary (captured by `leaf_nodes` of the left + subgraph), check all recorded fusion groups. If any fused node remains + on the right while its peer is on the left, pull the node and all of its + producer chain into the left subgraph to keep fusions intact. + + Returns: + bool: True if any fusion was migrated (i.e., a split would have + broken a fusion), otherwise False. + """ + + def add_nodes(node: torch.fx.Node) -> None: + """ + This function adds a node and all its previous nodes to the first subgraph and removes it from the second subgraph in post order. + """ + if ( + node.op in CALLABLE_NODE_OPS + and node not in nodes_in_first_subgraph + and node in nodes_in_second_subgraph + ): + # Exclude all nodes already in the first subgraph + nodes_in_first_subgraph.add(node) + nodes_in_second_subgraph.remove(node) + for input_node in node._input_nodes: + add_nodes(input_node) + subgraphs[0].nodes.append(node) + subgraphs[1].nodes.remove(node) + + fusion_broken = False + for leaf in leaf_nodes: + for node in self.fusion_patterns.get(leaf, []): + if ( + node not in nodes_in_first_subgraph + and node in nodes_in_second_subgraph + ): + fusion_broken = True + add_nodes(node) + + return fusion_broken + + def get_leaf_node( + self, nodes_in_first_subgraph: set[torch.fx.Node] + ) -> set[torch.fx.Node]: + """Return nodes in the left subgraph that feed any node on the right. + + A node is considered a leaf if at least one of its users is not in the + left subgraph. + """ + leaf_node = set() + + for node in nodes_in_first_subgraph: + for user in node.users: + if user not in nodes_in_first_subgraph: + leaf_node.add(node) + break + return leaf_node + + def size_of_subgraphs(self, subgraphs: List[Subgraph]) -> List[int]: + """Estimate parameter footprint (bytes) for each subgraph. + + Traverses each subgraph's nodes and their producer chains to find + parameters referenced via `get_attr`, summing tensor bytes. Shared + parameters are counted only once globally. + + Returns: + List[int]: Size per subgraph in bytes. + """ + state_dict = self.module.state_dict(keep_vars=True) + sizes = [] + weight_visited_nodes = set() + for subgraph in subgraphs: + nodes_in_subgraph = set(subgraph.nodes) + stack = subgraph.nodes.copy() + size = 0 + while stack: + node = stack.pop() + if node in weight_visited_nodes: + continue + weight_visited_nodes.add(node) + if node.op == "get_attr": + weight = state_dict.get(node.target, None) + if weight is None: + logger.warning(f"Weight {node.target} not found in state_dict") + continue + size += weight.numel() * weight.element_size() + continue + if node not in nodes_in_subgraph: + # Trace to other subgraphs + continue + for input_node in node._input_nodes: + if input_node not in weight_visited_nodes: + stack.append(input_node) + sizes.append(size) + return sizes + + def validate_and_correct_subgraphs( + self, subgraphs: List[Subgraph] + ) -> List[Subgraph]: + """This is very important for the correctness of the partitioning. Torch gives undefined behavior if the subgraphs are not ordered correctly. + + The principle is: nodes that have all dependencies resolved in previous subgraphs should also be moved to the previous subgraph. + For example, given a breakpoint node n resulting in two subgraphs S1 [..., n] and S2 [n+1, ...], all nodes in S2 that is not directly or indirectly depend on n should be moved to S1. + + We use a map to record the index of the subgraph that a node's users should belong to. If the node N is in subgraph S1 and is not the breakpoint node (subgraph.nodes[-1]), + then the users that only depend on N should also be moved to S1. However, N is a breakpoint node, then the users that only depend on N should also be moved to S2. + + With the map, we can determine with subgraph a later node should be moved to according to all its inputs. We take max indices of all inputs nodes to determine the subgraph index. + + Returns: + List[Subgraph]: Corrected subgraphs. + """ + # a map from a node to the index of the subgraph it's user should belong to + visited_nodes = {} + + for i, subgraph in enumerate(subgraphs): + if i == 0: + for node in subgraph.nodes: + visited_nodes[node] = i + # breakpoint node's users should belong to the next subgraph + visited_nodes[subgraph.nodes[-1]] = i + 1 + continue + + else: + to_remove_nodes = [] + for j, node in enumerate(subgraph.nodes): + if j == len(subgraph.nodes) - 1: + # breakpoint node's users should belong to the next subgraph + visited_nodes[node] = i + 1 + continue + subgraph_idx = 0 + for dep in self.deps[node]: + if dep in visited_nodes: + # We take max indices of all inputs nodes to determine the subgraph index. + subgraph_idx = max(subgraph_idx, visited_nodes[dep]) + + if subgraph_idx != i: + # If the node should be moved to a different subgraph, we move it and remove it from the current subgraph. + subgraphs[subgraph_idx].nodes.append(node) + to_remove_nodes.append(node) + # Record the the subgraph that the users of this node should belong to + visited_nodes[node] = subgraph_idx + + # Remove the nodes that are moved to other subgraphs + for node in to_remove_nodes: + subgraph.nodes.remove(node) + + return subgraphs + + +def resource_partition( + gm: torch.fx.GraphModule, + cpu_memory_budget: int, +) -> torch.fx.GraphModule: + """Resource-aware partitioning entry point. + + Takes an original FX graph (`gm`) and a capability-partitioned module + (`partitioned_module`) and returns a final graph where accelerated segments + are split further, if necessary, to satisfy CPU memory limits for TRT + engine compilation. + + Args: + gm: Original FX `GraphModule`. + partitioned_module: Capability-partitioned `GraphModule` indicating + accelerated vs non-accelerated regions. + cpu_memory_budget: CPU memory budget in bytes for engine compilation. + Use -1 to base the budget on currently available system memory. + + Returns: + torch.fx.GraphModule: Final graph with resource-constrained subgraphs. + """ + + # Construct + for name, _ in gm.named_children(): + submodule = getattr(gm, name) + if ( + not isinstance(submodule, torch.fx.graph_module.GraphModule) + or "_run_on_acc" not in name + ): + continue + partitioner = ResourcePartitioner( + submodule, + submodule_name=name, + cpu_memory_budget=cpu_memory_budget, + ) + + partitioned_graph = partitioner.partition_graph() + setattr(gm, name, partitioned_graph) + + for name, module in list(gm.named_children()): + split = False + if "_run_on_acc" in name: + for subname, submodule in module.named_children(): + if "resource_split" in subname: + split = True + setattr(gm, subname, submodule) + if split: + _inline_module(gm, name) + delattr(gm, name) + + gm.recompile() + return gm diff --git a/tests/py/dynamo/partitioning/test_fast_partitioning.py b/tests/py/dynamo/partitioning/test_000_fast_partitioning.py similarity index 100% rename from tests/py/dynamo/partitioning/test_fast_partitioning.py rename to tests/py/dynamo/partitioning/test_000_fast_partitioning.py diff --git a/tests/py/dynamo/partitioning/test_flaky_global_partitioning.py b/tests/py/dynamo/partitioning/test_000_flaky_global_partitioning.py similarity index 100% rename from tests/py/dynamo/partitioning/test_flaky_global_partitioning.py rename to tests/py/dynamo/partitioning/test_000_flaky_global_partitioning.py diff --git a/tests/py/dynamo/partitioning/test_global_partitioning.py b/tests/py/dynamo/partitioning/test_000_global_partitioning.py similarity index 100% rename from tests/py/dynamo/partitioning/test_global_partitioning.py rename to tests/py/dynamo/partitioning/test_000_global_partitioning.py diff --git a/tests/py/dynamo/partitioning/test_hierarchical_partitioning.py b/tests/py/dynamo/partitioning/test_000_hierarchical_partitioning.py similarity index 100% rename from tests/py/dynamo/partitioning/test_hierarchical_partitioning.py rename to tests/py/dynamo/partitioning/test_000_hierarchical_partitioning.py diff --git a/tests/py/dynamo/partitioning/test_000_resource_partitioning.py b/tests/py/dynamo/partitioning/test_000_resource_partitioning.py new file mode 100644 index 0000000000..2014eea8fe --- /dev/null +++ b/tests/py/dynamo/partitioning/test_000_resource_partitioning.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +from torch.fx.passes.splitter_base import Subgraph +from torch.ops import aten +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo import partitioning +from torch_tensorrt.dynamo.conversion import CompilationSettings +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) +from torch_tensorrt.dynamo.lowering.passes import post_lowering, pre_export_lowering +from torch_tensorrt.dynamo.partitioning._resource_partitioner import ( + ResourcePartitioner, +) + + +class TestResourcePartitioning(TestCase): + def test_atomic_subgraph_correction(self): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 3, 3, padding=1) + self.bn1 = nn.BatchNorm2d(3) + self.relu = nn.ReLU() + self.fc = nn.Linear(3 * 224 * 224, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = torch.flatten(x, 1) + x = self.fc(x) + return x + + model = net().eval() + model.to("cuda") + inputs = [torch.randn((1, 3, 224, 224)).to("cuda")] + + enabled_precisions = {torch.float} + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + + compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + "enable_resource_partitioning": True, + } + settings = CompilationSettings(**compilation_options) + + exported_program = pre_export_lowering(exp_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + partitioned_module, supported_ops = partitioning.fast_partition( + gm, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=True, + ) + + for name, _ in partitioned_module.named_children(): + submodule = getattr(partitioned_module, name) + if ( + not isinstance(submodule, torch.fx.graph_module.GraphModule) + or "_run_on_acc" not in name + ): + continue + partitioner = ResourcePartitioner( + submodule, + submodule_name=name, + cpu_memory_budget=2 * 1024 * 1024 * 1024, + ) + subgraphs = partitioner.put_nodes_into_subgraphs() + new_subgraphs = [] + current_subgraph = [] + # Split the subgraph into two subgraphs by the ReLU node, which breaks the fusion group. + for node in subgraphs[0].nodes: + if node.op == "call_function" and node.target == aten.relu.default: + new_subgraphs.append(Subgraph(is_acc=True, nodes=current_subgraph)) + current_subgraph = [] + current_subgraph.append(node) + if current_subgraph: + new_subgraphs.append(Subgraph(is_acc=True, nodes=current_subgraph)) + + leaf_node = partitioner.get_leaf_node(new_subgraphs[0].nodes) + broken_fusion = partitioner.step_if_break_fusion( + new_subgraphs, + leaf_node, + set(new_subgraphs[0].nodes), + set(new_subgraphs[1].nodes), + ) + # The fusion was broken + assert broken_fusion + + # The fusion should be fixed after the step + partitioner._verify_all_fusion_nodes_in_same_subgraph(new_subgraphs) + + break + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/partitioning/test_001_resource_partitioning.py b/tests/py/dynamo/partitioning/test_001_resource_partitioning.py new file mode 100644 index 0000000000..b8c5a68276 --- /dev/null +++ b/tests/py/dynamo/partitioning/test_001_resource_partitioning.py @@ -0,0 +1,417 @@ +from typing import Any, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_tensorrt as torchtrt +from torch.fx.passes.splitter_base import Subgraph +from torch.ops import aten +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo import partitioning +from torch_tensorrt.dynamo.conversion import CompilationSettings +from torch_tensorrt.dynamo.lowering import ( + get_decompositions, + post_lowering, + pre_export_lowering, +) +from torch_tensorrt.dynamo.lowering.passes import post_lowering, pre_export_lowering +from torch_tensorrt.dynamo.partitioning._atomic_subgraphs import ( + ATOMIC_SUBGRAPHS, + register_atomic_subgraph, +) +from torch_tensorrt.dynamo.partitioning._resource_partitioner import ( + ResourcePartitioner, + resource_partition, +) + + +class TestResourcePartitioning(TestCase): + def test_resource_partitioning(self): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1) + self.bn1 = nn.BatchNorm2d(4096) + self.conv2 = nn.Conv2d(4096, 1024, 3, padding=1) + self.bn2 = nn.BatchNorm2d(1024) + self.fc1 = nn.Linear(1024 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + model = net().eval() + model.to("cuda") + inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")] + + enabled_precisions = {torch.float} + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + + compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + "enable_resource_partitioning": True, + } + settings = CompilationSettings(**compilation_options) + + exported_program = pre_export_lowering(exp_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + partitioned_module, supported_ops = partitioning.fast_partition( + gm, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=True, + ) + + partitioned_module = resource_partition( + partitioned_module, cpu_memory_budget=2 * 1024 * 1024 * 1024 # 2GB, + ) + + self.assertEqual( + len(list[Any](partitioned_module.named_children())), + 2, + "The graph should have 2 subgraphs", + ) + + torch._dynamo.reset() + + def test_resource_partitioning_with_capability_partitioning(self): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1) + self.bn1 = nn.BatchNorm2d(4096) + self.conv2 = nn.Conv2d(4096, 4096, 3, padding=1) + self.bn2 = nn.BatchNorm2d(4096) + + self.conv3 = nn.Conv2d(4096, 4096, 3, padding=1) + self.bn3 = nn.BatchNorm2d(4096) + self.conv4 = nn.Conv2d(4096, 1024, 3, padding=1) + self.bn4 = nn.BatchNorm2d(1024) + + self.fc1 = nn.Linear(1024 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv3(x) + x = self.bn3(x) + x = F.relu(x) + x = self.conv4(x) + x = self.bn4(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + model = net().eval() + model.to("cuda") + inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")] + + enabled_precisions = {torch.float} + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + + compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + "torch_executed_ops": {"torch.ops.aten.max_pool2d.default"}, + "enable_resource_partitioning": True, + } + settings = CompilationSettings(**compilation_options) + + exported_program = pre_export_lowering(exp_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + partitioned_module, supported_ops = partitioning.fast_partition( + gm, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=True, + ) + + partitioned_module = resource_partition( + partitioned_module, cpu_memory_budget=1.4 * 1024 * 1024 * 1024 # 1.4GB, + ) + + assert ( + len( + [ + name + for name, _ in partitioned_module.named_children() + if "_run_on_acc" in name + ] + ) + == 5 + ), "The graph should have 5 accelerated subgraphs" + assert ( + len( + [ + name + for name, _ in partitioned_module.named_children() + if "_run_on_gpu" in name + ] + ) + == 2 + ), "The graph should have 2 non-accelerated subgraphs" + + torch._dynamo.reset() + + def test_resource_partitioning_with_capability_partitioning_and_atomic_subgraphs( + self, + ): + """ + After defining the atomic subgraphs, the resource partitioner will not be able to find valid partition in the subgraph. + So there should only be 3 accelerated subgraphs and 2 non-accelerated subgraphs. + """ + + @register_atomic_subgraph(init_args=(), is_core_aten=True) + class ReLUConv(nn.Module): + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + transposed: bool, + output_padding: List[int], + groups: int, + ) -> torch.Tensor: + x = aten.relu.default(x) + x = aten.convolution.default( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) + return x + + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1) + self.bn1 = nn.BatchNorm2d(4096) + self.conv2 = nn.Conv2d(4096, 4096, 3, padding=1) + self.bn2 = nn.BatchNorm2d(4096) + + self.conv3 = nn.Conv2d(4096, 4096, 3, padding=1) + self.bn3 = nn.BatchNorm2d(4096) + self.conv4 = nn.Conv2d(4096, 1024, 3, padding=1) + self.bn4 = nn.BatchNorm2d(1024) + + self.fc1 = nn.Linear(1024 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv3(x) + x = self.bn3(x) + x = F.relu(x) + x = self.conv4(x) + x = self.bn4(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + model = net().eval() + model.to("cuda") + inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")] + + enabled_precisions = {torch.float} + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + + compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + "torch_executed_ops": {"torch.ops.aten.max_pool2d.default"}, + "enable_resource_partitioning": True, + } + settings = CompilationSettings(**compilation_options) + + exported_program = pre_export_lowering(exp_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + partitioned_module, supported_ops = partitioning.fast_partition( + gm, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + skip_fusion=True, + ) + + partitioned_module = resource_partition( + partitioned_module, cpu_memory_budget=1.4 * 1024 * 1024 * 1024 # 1.4GB, + ) + + assert ( + len( + [ + name + for name, _ in partitioned_module.named_children() + if "_run_on_acc" in name + ] + ) + == 3 + ), "The graph should have 3 accelerated subgraphs" + assert ( + len( + [ + name + for name, _ in partitioned_module.named_children() + if "_run_on_gpu" in name + ] + ) + == 2 + ), "The graph should have 2 non-accelerated subgraphs" + + ATOMIC_SUBGRAPHS.remove((ReLUConv, (), True)) + + torch._dynamo.reset() + + def test_resource_partitioning_with_global_capability_partitioning(self): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1) + self.bn1 = nn.BatchNorm2d(4096) + self.conv2 = nn.Conv2d(4096, 4096, 3, padding=1) + self.bn2 = nn.BatchNorm2d(4096) + + self.conv3 = nn.Conv2d(4096, 4096, 3, padding=1) + self.bn3 = nn.BatchNorm2d(4096) + self.conv4 = nn.Conv2d(4096, 1024, 3, padding=1) + self.bn4 = nn.BatchNorm2d(1024) + + self.fc1 = nn.Linear(1024 * 56 * 56, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv3(x) + x = self.bn3(x) + x = F.relu(x) + x = self.conv4(x) + x = self.bn4(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + return self.fc1(x) + + model = net().eval() + model.to("cuda") + inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")] + + enabled_precisions = {torch.float} + use_python_runtime = False + + exp_program = torch.export.export(model, tuple(inputs)) + + compilation_options = { + "use_python_runtime": use_python_runtime, + "enabled_precisions": enabled_precisions, + "min_block_size": 1, + "immutable_weights": True, + "reuse_cached_engines": False, + "torch_executed_ops": {"torch.ops.aten.max_pool2d.default"}, + "enable_resource_partitioning": True, + } + settings = CompilationSettings(**compilation_options) + + exported_program = pre_export_lowering(exp_program, settings) + exported_program = exported_program.run_decompositions( + get_decompositions(False) + ) + + gm = exported_program.module() + gm = post_lowering(gm, settings) + + partitioned_module, supported_ops = partitioning.global_partition( + gm, + min_block_size=settings.min_block_size, + torch_executed_ops=settings.torch_executed_ops, + require_full_compilation=settings.require_full_compilation, + ) + + partitioned_module = resource_partition( + partitioned_module, cpu_memory_budget=1.4 * 1024 * 1024 * 1024 # 1.4GB, + ) + + assert ( + len( + [ + name + for name, _ in partitioned_module.named_children() + if "_run_on_acc" in name + ] + ) + == 5 + ), "The graph should have 5 accelerated subgraphs" + + torch._dynamo.reset() + + +if __name__ == "__main__": + run_tests()