diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index 5f7b8e8da1b..94bb8a98e09 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -88,7 +88,7 @@ export SDL_VIDEODRIVER=dummy # legacy from bash scripts: remove? conda env config vars set \ MAX_IDLE_COUNT=1000 \ - MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:99 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=DEBUG TOKENIZERS_PARALLELISM=true + MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:99 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=INFO TOKENIZERS_PARALLELISM=true pip3 install pip --upgrade pip install virtualenv diff --git a/.github/unittest/linux_sota/scripts/run_all.sh b/.github/unittest/linux_sota/scripts/run_all.sh index 40da5075f4c..cf1a9be33a8 100755 --- a/.github/unittest/linux_sota/scripts/run_all.sh +++ b/.github/unittest/linux_sota/scripts/run_all.sh @@ -52,6 +52,16 @@ if [ ! -d "${env_dir}" ]; then fi conda activate "${env_dir}" +# Verify we have CPython, not PyPy +python_impl=$(python -c "import platform; print(platform.python_implementation())") +if [ "$python_impl" != "CPython" ]; then + echo "ERROR: Expected CPython but got $python_impl" + echo "Python executable: $(which python)" + echo "Python version: $(python --version)" + exit 1 +fi +printf "* Verified Python implementation: %s\n" "$python_impl" + # 3. Install mujoco printf "* Installing mujoco and related\n" mkdir -p $root_dir/.mujoco @@ -64,7 +74,10 @@ cd "${root_dir}" # 4. Install Conda dependencies printf "* Installing dependencies (except PyTorch)\n" -echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +# Add python version to environment.yml if not already present (idempotent) +if ! grep -q "python=${PYTHON_VERSION}" "${this_dir}/environment.yml"; then + echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +fi cat "${this_dir}/environment.yml" export MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 @@ -100,11 +113,27 @@ pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl # TODO: move this down -- will break torchrl installation conda install -y -c conda-forge libstdcxx-ng=12 -## find libstdc -STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1) -conda env config vars set \ - MAX_IDLE_COUNT=1000 \ - LD_PRELOAD=${root_dir}/$STDC_LOC TOKENIZERS_PARALLELISM=true +## find libstdc - search in the env's lib directory first, then fall back to conda packages +STDC_LOC=$(find "${env_dir}/lib" -name "libstdc++.so.6" 2>/dev/null | head -1) +if [ -z "$STDC_LOC" ]; then + # Fall back to searching in conda packages for libstdcxx-ng specifically + STDC_LOC=$(find conda/pkgs -path "*libstdcxx*" -name "libstdc++.so.6" 2>/dev/null | head -1) +fi +if [ -z "$STDC_LOC" ]; then + echo "WARNING: Could not find libstdc++.so.6, skipping LD_PRELOAD" + conda env config vars set \ + MAX_IDLE_COUNT=1000 \ + TOKENIZERS_PARALLELISM=true +else + echo "Found libstdc++ at: $STDC_LOC" + conda env config vars set \ + MAX_IDLE_COUNT=1000 \ + LD_PRELOAD=${STDC_LOC} TOKENIZERS_PARALLELISM=true +fi + +# Reactivate environment to apply the new env vars +conda deactivate +conda activate "${env_dir}" # compile mujoco-py (bc it's done at runtime for whatever reason someone thought it was a good idea) python -c """import gym;import d4rl""" diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 50c45220942..34f79429417 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -27,7 +27,7 @@ ) from torchrl.envs import EnvCreator, GymEnv, ParallelEnv from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy if __name__ == "__main__": avail_devices = ("cpu",) diff --git a/benchmarks/storage/benchmark_sample_latency_over_rpc.py b/benchmarks/storage/benchmark_sample_latency_over_rpc.py index 4af76440290..bf92deb1284 100644 --- a/benchmarks/storage/benchmark_sample_latency_over_rpc.py +++ b/benchmarks/storage/benchmark_sample_latency_over_rpc.py @@ -144,7 +144,7 @@ def __init__(self, capacity: int): rank = args.rank storage_type = args.storage - torchrl_logger.info(f"Rank: {rank}; Storage: {storage_type}") + torchrl_logger.debug(f"RANK: {rank}; Storage: {storage_type}") os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" diff --git a/benchmarks/test_collectors_benchmark.py b/benchmarks/test_collectors_benchmark.py index ccbcaea7055..c3887352b7d 100644 --- a/benchmarks/test_collectors_benchmark.py +++ b/benchmarks/test_collectors_benchmark.py @@ -18,7 +18,7 @@ from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, StepCounter, TransformedEnv from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy def single_collector_setup(): diff --git a/docs/source/reference/collectors_weightsync.rst b/docs/source/reference/collectors_weightsync.rst index 0fcf174f3c1..0bb49669664 100644 --- a/docs/source/reference/collectors_weightsync.rst +++ b/docs/source/reference/collectors_weightsync.rst @@ -11,7 +11,7 @@ used in both instances. From there, anything can happen: - In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named `DataCollectors` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights - for his instance of the policy. + for their instance of the policy. - In some cases, the environment or the postprocessing hooks can rely on the usage of a model which itself needs synchronization. This means that there can be multiple ends in the data transfer API and one needs to think beyond policy-to-policy weight synchronization strategies. @@ -23,106 +23,354 @@ used in both instances. From there, anything can happen: asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach is to store the weights on some intermediary server and let the workers fetch them when necessary. -TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight +TorchRL tries to account for each of these problems in a flexible manner. We identify three basic components in a weight transfer: -- A `Sender` class that somehow gets the weights (or a reference to them) and initializes the transfer; -- A `Receiver` class that casts the weights to the destination module (policy or other utility module); -- A `Transport` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else). -- A Scheme that defines what sender, receiver and transport have to be used and how to initialize them. +- A **Scheme** class that orchestrates the entire weight synchronization lifecycle, including initialization, + connection setup, and weight transfer coordination. +- A **Transport** class that handles the actual transfer of weights (through shared memory, queues, torch.distributed, + Ray, etc.). Each scheme creates one or more transports for communication with workers. +- A **Strategy** class that determines the weight format (TensorDict or state_dict) and how weights are + extracted from and applied to models. Each of these classes is detailed below. -Usage Examples --------------- +.. note:: + **For most users, weight synchronization happens automatically.** When using TorchRL collectors + with the ``weight_sync_schemes`` argument, the collector handles all initialization, connection, + and synchronization calls internally. You simply call ``collector.update_policy_weights_()`` and + the weights are propagated to all workers. + + The ``update_policy_weights_`` method supports multiple calling conventions:: + + # No arguments - uses registered policy + collector.update_policy_weights_() + + # Positional argument - policy module or TensorDict + collector.update_policy_weights_(policy_module) + collector.update_policy_weights_(weights_tensordict) + + # Keyword arguments for clarity + collector.update_policy_weights_(policy=actor_module) + collector.update_policy_weights_(weights=weights_td, model_id="actor") + + # Multiple models atomically + collector.update_policy_weights_(weights_dict={"actor": actor_td, "critic": critic_td}) + + The detailed lifecycle documentation below is primarily intended for developers who want to: + + - Understand the internals of weight synchronization + - Implement custom weight sync schemes for specialized use cases (e.g., new distributed backends, custom serialization) + - Debug synchronization issues in complex distributed setups + - Use weight sync schemes outside of collectors for custom multiprocessing scenarios + +Lifecycle of Weight Synchronization +----------------------------------- + +Weight synchronization follows a **two-phase initialization pattern** with a clear separation between +local setup and inter-process communication: + +.. code-block:: text + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ SENDER (Main Process) │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ 1. scheme.init_on_sender(model_id, context, ...) │ + │ └─ Sets up local state, creates transports, NO communication │ + │ │ + │ 2. Send scheme to receiver (via multiprocessing/pickle) │ + │ └─ Scheme object is passed to worker processes │ + │ │ + │ 3. scheme.connect() ◄──── BLOCKING RENDEZ-VOUS ────► │ + │ └─ Sends initial weights (if model is stateful) │ + │ │ + │ 4. scheme.send(weights) [ready for ongoing updates] │ + └─────────────────────────────────────────────────────────────────────────┘ + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ RECEIVER (Worker Process) │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ 1. scheme.init_on_receiver(model_id, context, ...) │ + │ └─ Sets up local state, resolves model, NO communication │ + │ │ + │ 2. scheme.connect() ◄──── BLOCKING RENDEZ-VOUS ────► │ + │ └─ Receives initial weights, applies to model │ + │ └─ (May be no-op if sender handles via remote call) │ + │ │ + │ 3. scheme.receive() [for ongoing updates] │ + └─────────────────────────────────────────────────────────────────────────┘ + +Phase 1: Initialization (No Communication) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``init_on_sender()`` and ``init_on_receiver()`` methods prepare local state without any +inter-process communication: + +- Set up local attributes and references (model, context, worker indices) +- Create transport objects and register them +- Prepare queues, buffers, or other communication primitives +- **Do NOT perform any inter-worker communication** + +This separation allows the scheme to be pickled and sent to worker processes after sender +initialization but before any actual communication occurs. + +.. code-block:: python + + # === SENDER (main process) === + scheme = SharedMemWeightSyncScheme() + scheme.init_on_sender( + model_id="policy", + context=collector, # or explicit params like weights, devices, num_workers + ) + + # === Scheme is passed to workers via multiprocessing === + # (The scheme object is pickled and sent to worker processes) + + # === RECEIVER (worker process) === + scheme.init_on_receiver( + model_id="policy", + context=inner_collector, # or explicit params like model, worker_idx + ) + +Phase 2: Connection and Initial Weights (Rendez-vous) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``connect()`` method performs the actual inter-process communication. **Both sender and receiver +must call this method** (simultaneously or in the expected order for the scheme): + +1. **Connection rendez-vous**: Sender and receiver synchronize (e.g., torch.distributed process group + initialization, shared memory buffer exchange via queues) +2. **Initial weight transfer**: If the sender has a stateful model, weights are sent to receivers + so they start with the correct parameters + +.. code-block:: python + + # === Called simultaneously on both ends === + + # Sender side (main process): + scheme.connect() # Blocks until receivers are ready, sends initial weights + + # Receiver side (worker process): + scheme.connect(worker_idx=0) # Blocks until sender sends, receives initial weights .. note:: - **Runnable versions** of these examples are available in the repository: - - - `examples/collectors/weight_sync_standalone.py `_: Standalone weight synchronization - - `examples/collectors/weight_sync_collectors.py `_: Collector integration + The ``connect()`` method is a **blocking rendez-vous** for most schemes. The exact behavior + depends on the scheme: -Using Weight Update Schemes Independently -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + - **Queue-based schemes** (SharedMem, MultiProcess): Sender puts to queue, receiver blocks reading + - **Distributed schemes** (Distributed, Ray): Both sides block on ``torch.distributed.send/recv`` + - **RPC/Ray with remote calls**: Receiver's ``connect()`` may be a no-op if the sender triggers + the receiver via a remote call (e.g., ``RayModuleTransformScheme``) -Weight update schemes can be used outside of collectors for custom synchronization scenarios. -The new simplified API provides four core methods for weight synchronization: +Phase 3: Ongoing Weight Updates +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side -- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side -- ``get_sender()`` - Get the configured sender instance -- ``get_receiver()`` - Get the configured receiver instance +After ``connect()`` completes, the scheme is ready for ongoing weight synchronization: -Here's a basic example: +- ``send()`` / ``send_async()`` on the sender side pushes new weights +- ``receive()`` on the receiver side (or automatic for shared memory schemes) .. code-block:: python - import torch - import torch.nn as nn - from torch import multiprocessing as mp - from tensordict import TensorDict - from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, - ) + # Training loop + for batch in dataloader: + loss = train_step(batch) + + # Push updated weights to workers + scheme.send(new_weights) + +For some schemes (Ray, RPC), the sender's ``send()`` makes a remote call that triggers the receiver +automatically, so the user doesn't need to explicitly poll ``receive()``. + +Scheme-Specific Behavior +------------------------ + +SharedMemWeightSyncScheme +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Uses shared memory for zero-copy weight updates. After initial setup, weight updates are instantaneous +since all processes share the same memory buffers. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates shared buffers + per-worker queues + - Stores model reference + - None + * - ``connect`` + - Puts buffer references into queues + - Reads from queue, applies to model + - mp.Queue (blocking) + * - ``send`` + - Updates shared memory in-place + - N/A (sees updates automatically) + - Zero-copy shared memory + +MultiProcessWeightSyncScheme +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Sends weight copies through multiprocessing queues. More flexible than shared memory but requires +explicit data transfer for each update. Supports timeout for non-blocking receives. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates per-worker queues + - Gets queue reference + - None + * - ``connect`` + - Sends weights via queue + - Reads from queue, applies to model via strategy + - mp.Queue (blocking) + * - ``send`` + - Puts weights into queues + - Must call ``receive()``, transport applies weights + - mp.Queue (supports timeout) + +DistributedWeightSyncScheme +~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Create a simple policy - policy = nn.Linear(4, 2) +Uses ``torch.distributed`` primitives with a TCPStore for signaling. Suitable for distributed +training scenarios where processes are already part of a process group. Supports timeout via +``irecv(return_premature=True)`` for non-blocking receives. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates transports with TCPStore + rank + - Creates transport with store + rank + - None + * - ``connect`` + - Sends initial weights via ``torch.distributed.send()`` + - Receives initial weights via ``torch.distributed.recv()``, applies via strategy + - **Rendez-vous**: torch.distributed send/recv + * - ``send`` + - Sets TCPStore flag + ``torch.distributed.send()`` + - Must poll TCPStore, then call ``receive()``, transport applies weights + - TCPStore + torch.distributed (supports timeout) + +RPCWeightSyncScheme +~~~~~~~~~~~~~~~~~~~ + +Uses ``torch.distributed.rpc`` for signaling with ``torch.distributed`` for data transfer. +The sender's ``send()`` triggers the receiver via RPC, so no explicit receiver polling is needed. +Supports timeout via ``irecv(return_premature=True)`` for non-blocking receives. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates transports with RPC refs + - Stores model reference + - None + * - ``connect`` + - No-op + - No-op + - None + * - ``send`` + - **RPC call** triggers receiver + ``send()`` + - Triggered by RPC, does ``recv()``, transport applies weights + - RPC + torch.distributed (supports timeout) + +RayWeightSyncScheme +~~~~~~~~~~~~~~~~~~~ + +Uses Ray actors for coordination with ``torch.distributed`` for efficient weight transfer. +Suitable for Ray-based distributed RL setups. Supports timeout via ``irecv(return_premature=True)`` +for non-blocking receives. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates transports with Ray actor handles + - Creates transport, stores model + - None + * - ``connect`` + - Creates ConnectionInfo Ray actor, ``init_process_group(rank=0)``, sends initial weights + - Waits for ConnectionInfo, ``init_process_group(rank=N)``, receives weights via strategy + - **Rendez-vous**: Ray actor + torch.distributed + * - ``send`` + - **Ray remote call** triggers receiver + ``isend()`` + - Triggered by Ray, does ``irecv()``, transport applies weights + - Ray + torch.distributed (supports timeout) + +RayModuleTransformScheme +~~~~~~~~~~~~~~~~~~~~~~~~ + +Specialized scheme for synchronizing weights to a module running inside a ``RayModuleTransform``. +The sender triggers all receiver operations via Ray remote calls. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates transport for transform actor + - Creates transport, stores module + - None + * - ``connect`` + - **Ray call** triggers receiver init, then rendez-vous + weight send + - **Triggered by Ray**: joins process group, receives weights + - Ray + torch.distributed + * - ``send`` + - **Ray remote call** triggers receiver + ``isend()`` + - Triggered by Ray, does ``irecv()`` + - Ray + torch.distributed - # Example 1: Multiprocess weight synchronization with state_dict - # -------------------------------------------------------------- - # On the main process side (trainer): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - - # Initialize scheme with pipes - parent_pipe, child_pipe = mp.Pipe() - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - - # Get the sender and send weights - sender = scheme.get_sender() - weights = policy.state_dict() - sender.send(weights) # Synchronous send - # or sender.send_async(weights); sender.wait_async() # Asynchronous send - - # On the worker process side: - # scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy) - # receiver = scheme.get_receiver() - # # Non-blocking check for new weights - # if receiver.receive(timeout=0.001): - # # Weights were received and applied - - # Example 2: Shared memory weight synchronization - # ------------------------------------------------ - # Create shared memory scheme with auto-registration - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - - # Initialize with pipes for lazy registration - parent_pipe2, child_pipe2 = mp.Pipe() - shared_scheme.init_on_sender(model_id="policy", pipes=[parent_pipe2]) - - # Get sender and send weights (automatically creates shared buffer on first send) - shared_sender = shared_scheme.get_sender() - weights_td = TensorDict.from_module(policy) - shared_sender.send(weights_td) +.. note:: + ``RayModuleTransformScheme`` is unique in that even ``connect`` on the sender + triggers the receiver initialization via a Ray remote call. The user only needs to call + ``connect()`` on the sender side. - # Workers automatically see updates via shared memory! +Usage Examples +-------------- -Using Weight Update Schemes with Collectors -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. note:: + **Runnable versions** of these examples are available in the repository: + + - `examples/collectors/weight_sync_standalone.py `_: Standalone weight synchronization + - `examples/collectors/weight_sync_collectors.py `_: Collector integration + +Using Weight Sync Schemes with Collectors +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Weight update schemes integrate seamlessly with TorchRL collectors, enabling efficient weight synchronization -across multiple inference workers: +Weight sync schemes integrate seamlessly with TorchRL collectors. The collector handles calling +``init_on_sender()``, ``init_on_receiver()``, and ``connect()`` automatically: .. code-block:: python import torch.nn as nn from tensordict.nn import TensorDictModule - from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector + from torchrl.collectors import MultiSyncDataCollector from torchrl.envs import GymEnv - from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, - ) + from torchrl.weight_update import SharedMemWeightSyncScheme # Create environment and policy env = GymEnv("CartPole-v1") @@ -133,85 +381,171 @@ across multiple inference workers: out_keys=["action"], ) - # Example 1: Single collector with multiprocess scheme - # ----------------------------------------------------- - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + # Create scheme - collector handles initialization + scheme = SharedMemWeightSyncScheme(strategy="tensordict") - collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), + collector = MultiSyncDataCollector( + create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3, policy=policy, - frames_per_batch=64, - total_frames=1000, + frames_per_batch=192, + total_frames=10000, weight_sync_schemes={"policy": scheme}, ) - # Collect data and update weights periodically + # Collect data and update weights for i, data in enumerate(collector): - # ... training step with data ... - - # Update policy weights every N iterations + # ... training step ... + + # Update weights - multiple calling conventions supported: if i % 10 == 0: - new_weights = policy.state_dict() - collector.update_policy_weights_(new_weights) + # Option 1: No arguments (uses registered policy) + collector.update_policy_weights_() + + # Option 2: Pass policy module (positional) + collector.update_policy_weights_(policy) + + # Option 3: Pass weights TensorDict (positional) + # collector.update_policy_weights_(weights_tensordict) + + # Option 4: Use keyword arguments for clarity + # collector.update_policy_weights_(policy=policy) + # collector.update_policy_weights_(weights=weights_td, model_id="policy") collector.shutdown() - # Example 2: Multiple collectors with shared memory - # -------------------------------------------------- - # Shared memory is more efficient for frequent updates - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) +Using Weight Sync Schemes Standalone +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - collector = MultiSyncDataCollector( - create_env_fn=[ - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), - ], - policy=policy, - frames_per_batch=192, - total_frames=10000, - weight_sync_schemes={"policy": shared_scheme}, +For custom multiprocessing scenarios, you can use schemes directly. The key is to follow the +two-phase pattern: initialize first (no communication), then connect (blocking rendez-vous): + +.. code-block:: python + + import torch + import torch.nn as nn + from torch import multiprocessing as mp + from tensordict import TensorDict + from torchrl.weight_update import SharedMemWeightSyncScheme + + def worker_fn(scheme, worker_idx): + """Worker process - receives scheme via pickle.""" + # Create local model (weights will be overwritten by sender's weights) + model = nn.Linear(4, 2) + + # PHASE 1: Initialize on receiver (no communication yet) + scheme.init_on_receiver(model_id="policy", model=model, worker_idx=worker_idx) + + # PHASE 2: Blocking rendez-vous - receive initial weights from sender + scheme.connect(worker_idx=worker_idx) + # model now has the sender's weights! + + # Ready to work - for SharedMem, weight updates are automatic + while True: + # ... use model for inference ... + # model.parameters() automatically reflect sender's updates + + # === MAIN PROCESS (Sender) === + policy = nn.Linear(4, 2) + scheme = SharedMemWeightSyncScheme() + + # PHASE 1: Initialize on sender (no communication yet) + scheme.init_on_sender( + model_id="policy", + weights=TensorDict.from_module(policy), + devices=[torch.device("cpu")] * 2, + num_workers=2, ) - # Workers automatically see weight updates via shared memory - for data in collector: - # ... training ... - collector.update_policy_weights_(TensorDict.from_module(policy)) + # Spawn workers - scheme is pickled and sent to each worker + workers = [mp.Process(target=worker_fn, args=(scheme, i)) for i in range(2)] + for w in workers: + w.start() - collector.shutdown() + # PHASE 2: Blocking rendez-vous - send initial weights to workers + scheme.connect() + # Workers now have copies of policy's weights! + + # PHASE 3: Ongoing updates (zero-copy for shared memory) + for epoch in range(10): + # ... training step updates policy weights ... + scheme.send() # Workers automatically see the new weights + + for w in workers: + w.join() .. note:: - When using ``SharedMemWeightSyncScheme``, weight updates are zero-copy and extremely fast since all - processes share the same memory buffers. This is ideal for frequent weight updates but requires all - processes to be on the same machine. + When using ``SharedMemWeightSyncScheme``, weight updates after initialization are zero-copy and extremely + fast since all processes share the same memory buffers. Workers don't need to call ``receive()`` - they + automatically see updates. .. note:: The ``strategy`` parameter determines the weight format: ``"state_dict"`` uses PyTorch's native state - dictionaries, while ``"tensordict"`` uses TensorDict format which can be more efficient for structured - models and supports advanced features like lazy initialization. + dictionaries, while ``"tensordict"`` (default) uses TensorDict format which is more efficient for + structured models and supports features like device mapping. -Weight Senders --------------- +Transports +---------- -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst +Transports handle the low-level communication between sender and receiver. Each scheme creates +appropriate transport instances for its workers. - WeightSender - RayModuleTransformSender +Transport Interface +~~~~~~~~~~~~~~~~~~~ -Weight Receivers ----------------- +All transports implement the ``TransportBackend`` protocol with a stateless design. The key methods +accept ``weights``, ``model``, and ``strategy`` as keyword arguments rather than storing them as +instance attributes: -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst +.. code-block:: python - WeightReceiver - RayModuleTransformReceiver + # Transport methods accept model/weights/strategy as kwargs + transport.receive_weights( + timeout=None, # Optional timeout in seconds (None = blocking) + weights=buffer, # Pre-allocated weight buffer + model=policy, # Model to apply weights to + strategy=strategy, # WeightStrategy for weight application + ) -Transports ----------- + transport.setup_connection_and_weights_on_receiver( + worker_idx=0, + weights=buffer, + model=policy, + strategy=strategy, + ) + +Timeout Support +~~~~~~~~~~~~~~~ + +Transports support timeout for non-blocking weight reception: + +.. list-table:: + :header-rows: 1 + + * - Transport + - Timeout Support + - Notes + * - ``MPTransport`` + - ✅ Yes + - Uses ``queue.get(timeout=...)`` + * - ``RPCTransport`` + - ✅ Yes + - Uses ``irecv(return_premature=True)`` with polling + * - ``RayTransport`` + - ✅ Yes + - Uses ``irecv(return_premature=True)`` with polling + * - ``DistributedTransport`` + - ✅ Yes + - Uses ``irecv(return_premature=True)`` with polling + * - ``SharedMemTransport`` + - N/A + - Shared memory is instant (no waiting) + +When ``timeout=None`` (default), the receive operation blocks until weights arrive. +When a timeout is specified, the method returns ``None`` if the timeout expires before +weights are received. + +Available Transports +~~~~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated/ @@ -221,18 +555,21 @@ Transports MPTransport SharedMemTransport RayTransport - RayActorTransport RPCTransport DistributedTransport Schemes ------- +Schemes orchestrate the weight synchronization lifecycle, managing initialization, connection setup, +and ongoing weight transfers. + .. autosummary:: :toctree: generated/ :template: rl_template.rst WeightSyncScheme + WeightStrategy MultiProcessWeightSyncScheme SharedMemWeightSyncScheme NoWeightSyncScheme @@ -245,38 +582,9 @@ Legacy: Weight Updaters ----------------------- .. warning:: - The `WeightUpdater` is considered legacy as per the 0.11 release and will be deprecated soon. - The Weight update schemes, which provides more flexibility and a better compatibility with heavy - weight transfers (e.g., LLMs) is to be preferred. - -In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the -latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible -mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios. - -Sending and receiving model weights with WeightUpdaters -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The weight synchronization process is facilitated by one dedicated extension point: -:class:`~torchrl.collectors.WeightUpdaterBase`. These base class provides a structured interface for -implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs. - -:class:`~torchrl.collectors.WeightUpdaterBase` handles the distribution of policy weights to -the policy or to remote inference workers, as well as formatting / gathering the weights from a server if necessary. -Every collector -- server or worker -- should have a `WeightUpdaterBase` instance to handle the -weight synchronization with the policy. -Even the simplest collectors use a :class:`~torchrl.collectors.VanillaWeightUpdater` instance to update the policy -state-dict (assuming it is a :class:`~torch.nn.Module` instance). - -Extending the Updater Class -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations. -The goal is to be able to customize the weight sync strategy while leaving the collector and policy implementation -untouched. -This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware -setups. -By implementing the abstract methods in these base classes, users can define how weights are retrieved, -transformed, and applied, ensuring seamless integration with their existing infrastructure. + The `WeightUpdater` API is deprecated as of the 0.11 release. + The Weight Sync Schemes API provides more flexibility and better compatibility with heavy + weight transfers (e.g., LLMs) and should be preferred for all new code. .. currentmodule:: torchrl.collectors diff --git a/docs/source/reference/envs_api.rst b/docs/source/reference/envs_api.rst index 03cb747ece0..db2d5d4ccfe 100644 --- a/docs/source/reference/envs_api.rst +++ b/docs/source/reference/envs_api.rst @@ -198,7 +198,6 @@ Helpers :toctree: generated/ :template: rl_template_fun.rst - RandomPolicy check_env_specs exploration_type get_available_libraries diff --git a/docs/source/reference/modules_actors.rst b/docs/source/reference/modules_actors.rst index afbf90ba702..6543b7512ed 100644 --- a/docs/source/reference/modules_actors.rst +++ b/docs/source/reference/modules_actors.rst @@ -20,6 +20,7 @@ TensorDictModules and SafeModules SafeModule SafeSequential TanhModule + RandomPolicy Probabilistic actors -------------------- diff --git a/examples/collectors/multi_weight_updates.py b/examples/collectors/multi_weight_updates.py index 7011e7f4879..6533eda3975 100644 --- a/examples/collectors/multi_weight_updates.py +++ b/examples/collectors/multi_weight_updates.py @@ -25,7 +25,7 @@ from torchrl.data import LazyTensorStorage, ReplayBuffer from torchrl.envs.libs.gym import GymEnv from torchrl.envs.transforms.module import ModuleTransform -from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme +from torchrl.weight_update import MultiProcessWeightSyncScheme def make_module(): diff --git a/examples/collectors/weight_sync_collectors.py b/examples/collectors/weight_sync_collectors.py index a3962966c8c..020ad0b8a61 100644 --- a/examples/collectors/weight_sync_collectors.py +++ b/examples/collectors/weight_sync_collectors.py @@ -90,7 +90,7 @@ def example_multi_collector_shared_memory(): env.close() # Shared memory is more efficient for frequent updates - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + scheme = SharedMemWeightSyncScheme(strategy="tensordict") print("Creating multi-collector with shared memory...") collector = MultiSyncDataCollector( diff --git a/examples/collectors/weight_sync_standalone.py b/examples/collectors/weight_sync_standalone.py deleted file mode 100644 index 2d918cb10a2..00000000000 --- a/examples/collectors/weight_sync_standalone.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -""" -Weight Synchronization Schemes - Standalone Usage -================================================== - -This example demonstrates how to use weight synchronization schemes independently -of collectors for custom synchronization scenarios. - -The weight synchronization infrastructure provides flexible sender/receiver patterns -that can be used for various multiprocessing scenarios. -""" - -import torch -import torch.nn as nn -from tensordict import TensorDict -from torch import multiprocessing as mp -from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, -) - - -def worker_process_mp(child_pipe, model_state): - """Worker process that receives weights via multiprocessing pipe.""" - print("Worker: Starting...") - - # Create a policy on the worker side - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(0.0) - policy.bias.fill_(0.0) - - # Create receiver and register the policy - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - receiver = scheme.create_receiver() - receiver.register_model(policy) - receiver.register_worker_transport(child_pipe) - - print(f"Worker: Before update - weight sum: {policy.weight.sum().item():.4f}") - - # Receive and apply weights - result = receiver._transport.receive_weights(timeout=5.0) - if result is not None: - model_id, weights = result - receiver.apply_weights(weights) - print(f"Worker: After update - weight sum: {policy.weight.sum().item():.4f}") - else: - print("Worker: No weights received") - - # Store final state for verification - model_state["weight_sum"] = policy.weight.sum().item() - model_state["bias_sum"] = policy.bias.sum().item() - - -def worker_process_shared_mem(child_pipe, model_state): - """Worker process that receives shared memory buffer reference.""" - print("SharedMem Worker: Starting...") - - # Create a policy on the worker side - policy = nn.Linear(4, 2) - - # Wait for shared memory buffer registration - if child_pipe.poll(timeout=10.0): - data, msg = child_pipe.recv() - if msg == "register_shared_weights": - model_id, shared_weights = data - print(f"SharedMem Worker: Received shared buffer for model '{model_id}'") - # Apply shared weights to policy - shared_weights.to_module(policy) - # Send acknowledgment - child_pipe.send((None, "registered")) - - # Small delay to ensure main process updates shared memory - import time - - time.sleep(0.5) - - print(f"SharedMem Worker: weight sum: {policy.weight.sum().item():.4f}") - - # Store final state for verification - model_state["weight_sum"] = policy.weight.sum().item() - model_state["bias_sum"] = policy.bias.sum().item() - - -def example_multiprocess_sync(): - """Example 1: Multiprocess weight synchronization with state_dict.""" - print("\n" + "=" * 70) - print("Example 1: Multiprocess Weight Synchronization") - print("=" * 70) - - # Create a simple policy on main process - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(1.0) - policy.bias.fill_(0.5) - - print(f"Main: Policy weight sum: {policy.weight.sum().item():.4f}") - - # Create scheme and sender - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - sender = scheme.create_sender() - - # Create pipe for communication - parent_pipe, child_pipe = mp.Pipe() - sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) - - # Start worker process - manager = mp.Manager() - model_state = manager.dict() - process = mp.Process(target=worker_process_mp, args=(child_pipe, model_state)) - process.start() - - # Send weights to worker - weights = policy.state_dict() - print("Main: Sending weights to worker...") - sender.update_weights(weights) - - # Wait for worker to complete - process.join(timeout=10.0) - - if process.is_alive(): - print("Warning: Worker process did not terminate in time") - process.terminate() - else: - print( - f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" - ) - print("Weight synchronization successful!") - - -def example_shared_memory_sync(): - """Example 2: Shared memory weight synchronization.""" - print("\n" + "=" * 70) - print("Example 2: Shared Memory Weight Synchronization") - print("=" * 70) - - # Create a simple policy - policy = nn.Linear(4, 2) - - # Create shared memory scheme with auto-registration - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - sender = scheme.create_sender() - - # Create pipe for lazy registration - parent_pipe, child_pipe = mp.Pipe() - sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) - - # Start worker process - manager = mp.Manager() - model_state = manager.dict() - process = mp.Process( - target=worker_process_shared_mem, args=(child_pipe, model_state) - ) - process.start() - - # Send weights (automatically creates shared buffer on first send) - weights_td = TensorDict.from_module(policy) - with torch.no_grad(): - weights_td["weight"].fill_(2.0) - weights_td["bias"].fill_(1.0) - - print("Main: Sending weights via shared memory...") - sender.update_weights(weights_td) - - # Workers automatically see updates via shared memory! - print("Main: Weights are now in shared memory, workers can access them") - - # Wait for worker to complete - process.join(timeout=10.0) - - if process.is_alive(): - print("Warning: Worker process did not terminate in time") - process.terminate() - else: - print( - f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" - ) - print("Shared memory synchronization successful!") - - -def main(): - """Run all examples.""" - print("\n" + "=" * 70) - print("Weight Synchronization Schemes - Standalone Usage Examples") - print("=" * 70) - - # Set multiprocessing start method - try: - mp.set_start_method("spawn") - except RuntimeError: - pass # Already set - - # Run examples - example_multiprocess_sync() - example_shared_memory_sync() - - print("\n" + "=" * 70) - print("All examples completed successfully!") - print("=" * 70 + "\n") - - -if __name__ == "__main__": - main() diff --git a/examples/distributed/collectors/multi_nodes/delayed_dist.py b/examples/distributed/collectors/multi_nodes/delayed_dist.py index 0061a895578..5139e811a65 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_dist.py +++ b/examples/distributed/collectors/multi_nodes/delayed_dist.py @@ -116,7 +116,7 @@ def main(): from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.data import Bounded from torchrl.envs.libs.gym import GymEnv, set_gym_backend - from torchrl.envs.utils import RandomPolicy + from torchrl.modules import RandomPolicy collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector device_str = "device" if num_workers == 1 else "devices" diff --git a/examples/distributed/collectors/multi_nodes/delayed_rpc.py b/examples/distributed/collectors/multi_nodes/delayed_rpc.py index a684a1b724c..e2aab24753a 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_rpc.py +++ b/examples/distributed/collectors/multi_nodes/delayed_rpc.py @@ -115,7 +115,7 @@ def main(): from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.data import Bounded from torchrl.envs.libs.gym import GymEnv, set_gym_backend - from torchrl.envs.utils import RandomPolicy + from torchrl.modules import RandomPolicy collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector device_str = "device" if num_workers == 1 else "devices" diff --git a/examples/distributed/collectors/multi_nodes/generic.py b/examples/distributed/collectors/multi_nodes/generic.py index 795660fc683..29144a9f796 100644 --- a/examples/distributed/collectors/multi_nodes/generic.py +++ b/examples/distributed/collectors/multi_nodes/generic.py @@ -14,7 +14,7 @@ from torchrl.collectors.distributed import DistributedDataCollector from torchrl.envs import EnvCreator from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/multi_nodes/rpc.py b/examples/distributed/collectors/multi_nodes/rpc.py index 151879a5423..208c6abdaec 100644 --- a/examples/distributed/collectors/multi_nodes/rpc.py +++ b/examples/distributed/collectors/multi_nodes/rpc.py @@ -15,7 +15,7 @@ from torchrl.collectors.distributed import RPCDataCollector from torchrl.envs import EnvCreator from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/multi_nodes/sync.py b/examples/distributed/collectors/multi_nodes/sync.py index 10a37d47a87..100c598602b 100644 --- a/examples/distributed/collectors/multi_nodes/sync.py +++ b/examples/distributed/collectors/multi_nodes/sync.py @@ -14,7 +14,7 @@ from torchrl.collectors.distributed import DistributedSyncDataCollector from torchrl.envs import EnvCreator from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/single_machine/generic.py b/examples/distributed/collectors/single_machine/generic.py index 2c52c84321a..21d9dc375db 100644 --- a/examples/distributed/collectors/single_machine/generic.py +++ b/examples/distributed/collectors/single_machine/generic.py @@ -34,7 +34,7 @@ from torchrl.collectors.distributed import DistributedDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/single_machine/rpc.py b/examples/distributed/collectors/single_machine/rpc.py index 5c9ef50b08a..009eb39ad53 100644 --- a/examples/distributed/collectors/single_machine/rpc.py +++ b/examples/distributed/collectors/single_machine/rpc.py @@ -30,7 +30,7 @@ from torchrl.collectors.distributed import RPCDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/single_machine/sync.py b/examples/distributed/collectors/single_machine/sync.py index 84cc1b1de99..51bc62af4af 100644 --- a/examples/distributed/collectors/single_machine/sync.py +++ b/examples/distributed/collectors/single_machine/sync.py @@ -31,7 +31,7 @@ from torchrl.collectors.distributed import DistributedSyncDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/replay_buffers/distributed_replay_buffer.py b/examples/distributed/replay_buffers/distributed_replay_buffer.py index f92f78de7e1..df522443c06 100644 --- a/examples/distributed/replay_buffers/distributed_replay_buffer.py +++ b/examples/distributed/replay_buffers/distributed_replay_buffer.py @@ -172,7 +172,7 @@ def __init__(self, capacity: int): if __name__ == "__main__": args = parser.parse_args() rank = args.rank - torchrl_logger.info(f"Rank: {rank}") + torchrl_logger.debug(f"RANK: {rank}") os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" diff --git a/sota-implementations/expert-iteration/ei_utils.py b/sota-implementations/expert-iteration/ei_utils.py index c6732c1763e..6448a86d374 100644 --- a/sota-implementations/expert-iteration/ei_utils.py +++ b/sota-implementations/expert-iteration/ei_utils.py @@ -5,7 +5,6 @@ from __future__ import annotations import time - from typing import Any, Literal import torch @@ -612,7 +611,6 @@ def get_wandb_run_id(wandb_logger): """ try: # Wait a bit for wandb to initialize - import time max_attempts = 10 for attempt in range(max_attempts): diff --git a/test/llm/test_wrapper.py b/test/llm/test_wrapper.py index 30e2d7d7129..b496e749c78 100644 --- a/test/llm/test_wrapper.py +++ b/test/llm/test_wrapper.py @@ -7,8 +7,10 @@ import argparse import gc import importlib.util +import threading import time +from concurrent.futures import ThreadPoolExecutor, wait from functools import partial import pytest @@ -412,8 +414,6 @@ def slow_forward(self, td_input, **kwargs): @pytest.fixture def monkey_patch_forward_for_instrumentation(): """Fixture to monkey patch the forward method to add detailed processing event tracking.""" - import threading - import time # Track processing events processing_events = [] @@ -2706,8 +2706,6 @@ def test_batching_min_batch_size_one_immediate_processing( monkey_patch_forward_for_timing, ): """Test that with min_batch_size=1, first request is processed immediately and subsequent ones are grouped.""" - import time - from concurrent.futures import ThreadPoolExecutor, wait # Create wrapper using helper function wrapper = create_batching_test_wrapper( diff --git a/test/services/test_python_executor_service.py b/test/services/test_python_executor_service.py index cb55c0a6a10..b18181c573f 100644 --- a/test/services/test_python_executor_service.py +++ b/test/services/test_python_executor_service.py @@ -73,7 +73,7 @@ def test_service_execution(self, ray_init): result = x + y print(f"Result: {result}") """ - result = ray.get(executor.execute.remote(code), timeout=2) + result = ray.get(executor.execute.remote(code), timeout=10) assert result["success"] is True assert "Result: 30" in result["stdout"] @@ -101,7 +101,7 @@ def test_service_execution_error(self, ray_init): # Execute code with an error code = "raise ValueError('Test error')" - result = ray.get(executor.execute.remote(code), timeout=2) + result = ray.get(executor.execute.remote(code), timeout=10) assert result["success"] is False assert "ValueError: Test error" in result["stderr"] @@ -119,7 +119,7 @@ def test_multiple_executions(self, ray_init): "python_executor", PythonExecutorService, pool_size=4, - timeout=5.0, + timeout=10.0, num_cpus=4, max_concurrency=4, ) @@ -132,14 +132,16 @@ def test_multiple_executions(self, ray_init): code = f"print('Execution {i}')" futures.append(executor.execute.remote(code)) - # Wait for all to complete - results = ray.get(futures, timeout=5) + # Wait for all to complete with longer timeout + results = ray.get(futures, timeout=30) # All should succeed assert len(results) == 8 for i, result in enumerate(results): - assert result["success"] is True - assert f"Execution {i}" in result["stdout"] + assert result["success"] is True, f"Execution {i} failed: {result}" + assert ( + f"Execution {i}" in result["stdout"] + ), f"Expected 'Execution {i}' in stdout, got: {result['stdout']!r}" finally: services.reset() diff --git a/test/test_collector.py b/test/test_collector.py index 73c6e5c3d21..1be9bc9ed15 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -5,7 +5,6 @@ from __future__ import annotations import argparse - import contextlib import functools import gc @@ -13,11 +12,16 @@ import subprocess import sys import time +import traceback +from contextlib import nullcontext from unittest.mock import patch import numpy as np import pytest import torch + +import torchrl.collectors._multi_base +import torchrl.collectors._runner from packaging import version from tensordict import ( assert_allclose_td, @@ -33,7 +37,6 @@ TensorDictSequential, ) from torch import nn - from torchrl._utils import ( _make_ordinal_device, _replace_last, @@ -48,7 +51,8 @@ SyncDataCollector, WeightUpdaterBase, ) -from torchrl.collectors.collectors import _Interruptor +from torchrl.collectors._constants import _Interruptor +from torchrl.collectors._multi_base import _MultiDataCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import ( @@ -76,9 +80,14 @@ _aggregate_end_of_traj, check_env_specs, PARTIAL_MISSING_ERR, +) +from torchrl.modules import ( + Actor, + OrnsteinUhlenbeckProcessModule, RandomPolicy, + SafeModule, ) -from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule +from torchrl.testing.modules import BiasModule, NonSerializableBiasModule from torchrl.weight_update import ( MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme, @@ -1130,40 +1139,20 @@ def make_and_test_policy( policy, policy_device=original_device, env_device=original_device ) - # a deepcopy must occur when the policy_device differs from the actual device - with pytest.raises(RuntimeError, match="deepcopy not allowed"): + # Test that we DON'T raise deepcopy errors anymore even when policy_device differs + # These scenarios previously would have triggered deepcopy, but now use meta device context manager + if collector_type is not SyncDataCollector: + # policy_device differs from the actual device - previously required deepcopy, now works! policy = make_policy(device=original_device) make_and_test_policy( policy, policy_device=shared_device, env_device=shared_device ) - # a deepcopy must occur when device differs from the actual device - with pytest.raises(RuntimeError, match="deepcopy not allowed"): + if collector_type is not SyncDataCollector: + # device differs from the actual device - previously required deepcopy, now works! policy = make_policy(device=original_device) make_and_test_policy(policy, device=shared_device) - # If the policy is not an nn.Module, we can't cast it to device, so we assume that the policy device - # is there to inform us - substitute_device = ( - original_device if torch.cuda.is_available() else torch.device("cpu") - ) - policy = make_policy(substitute_device, nn_module=False) - with pytest.warns(UserWarning): - make_and_test_policy( - policy, policy_device=substitute_device, env_device=substitute_device - ) - # For instance, if the env is on CPU, knowing the policy device helps with casting stuff on the right device - with pytest.warns(UserWarning): - make_and_test_policy( - policy, policy_device=substitute_device, env_device=shared_device - ) - make_and_test_policy( - policy, - policy_device=substitute_device, - env_device=shared_device, - trust_policy=True, - ) - # If there is no policy_device, we assume that the user is doing things right too but don't warn if collector_type is SyncDataCollector or original_device.type != "mps": policy = make_policy(original_device, nn_module=False) @@ -1487,12 +1476,14 @@ def env_fn(seed): assert_allclose_td(data10, data20) @pytest.mark.parametrize("use_async", [False, True]) - @pytest.mark.parametrize("cudagraph", [False, True]) + @pytest.mark.parametrize( + "cudagraph", [False, True] if torch.cuda.is_available() else [False] + ) @pytest.mark.parametrize( "weight_sync_scheme", [None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme], ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") + # @pytest.mark.skipif(not torch.cuda.is_available() and not torch.mps.is_available(), reason="no cuda/mps device found") def test_update_weights(self, use_async, cudagraph, weight_sync_scheme): def create_env(): return ContinuousActionVecMockEnv() @@ -1509,17 +1500,17 @@ def create_env(): kwargs = {} if weight_sync_scheme is not None: kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()} + device = "cuda:0" if torch.cuda.is_available() else "cpu" collector = collector_class( [create_env] * 3, policy=policy, - device=[torch.device("cuda:0")] * 3, - storing_device=[torch.device("cuda:0")] * 3, + device=[torch.device(device)] * 3, + storing_device=[torch.device(device)] * 3, frames_per_batch=20, cat_results="stack", cudagraph_policy=cudagraph, **kwargs, ) - assert "policy" in collector._weight_senders, collector._weight_senders.keys() try: # collect state_dict state_dict = collector.state_dict() @@ -1530,7 +1521,7 @@ def create_env(): ].keys() for k in state_dict[f"worker{worker}"]["policy_state_dict"]: torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], + state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(), policy_state_dict[k].cpu(), ) @@ -1544,9 +1535,11 @@ def create_env(): # check they don't match for worker in range(3): for k in state_dict[f"worker{worker}"]["policy_state_dict"]: - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError + ) if torch.cuda.is_available() else nullcontext(): torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], + state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(), policy_state_dict[k].cpu(), ) @@ -1559,7 +1552,7 @@ def create_env(): for worker in range(3): for k in state_dict[f"worker{worker}"]["policy_state_dict"]: torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], + state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(), policy_state_dict[k].cpu(), ) finally: @@ -1571,8 +1564,6 @@ def create_env(): ) # MultiSync has known indexing issues with SharedMem def test_update_weights_shared_mem(self, use_async): """Test shared memory weight synchronization scheme.""" - from tensordict import TensorDict - from torchrl.weight_update.weight_sync_schemes import SharedMemWeightSyncScheme def create_env(): return ContinuousActionVecMockEnv() @@ -1589,7 +1580,11 @@ def create_env(): # Create shared memory weight sync scheme weight_sync_scheme = SharedMemWeightSyncScheme() - weight_sync_scheme.register_shared_weights("policy", policy_weights) + # Use the new init_on_sender API with params_map + # All 3 workers share the same CPU weights in shared memory + weight_sync_scheme.init_on_sender( + params_map={0: policy_weights, 1: policy_weights, 2: policy_weights}, + ) collector_class = ( MultiSyncDataCollector if not use_async else MultiaSyncDataCollector @@ -1841,8 +1836,14 @@ def forward(self, tensordict): class PolicyWithDevice(TensorDictModuleBase): in_keys = ["observation"] out_keys = ["action"] - # receives and sends data on gpu - default_device = "cuda:0" if torch.cuda.device_count() else "cpu" + + def __init__(self, default_device=None): + super().__init__() + self.default_device = ( + default_device + if default_device is not None + else ("cuda:0" if torch.cuda.device_count() else "cpu") + ) def forward(self, tensordict): assert tensordict.device == _make_ordinal_device( @@ -1859,7 +1860,7 @@ def test_output_device(self, main_device, storing_device): env_device = None policy_device = main_device env = self.DeviceLessEnv(main_device) - policy = self.PolicyWithDevice() + policy = self.PolicyWithDevice(main_device) collector = SyncDataCollector( env, policy, @@ -1900,7 +1901,7 @@ def test_output_device(self, main_device, storing_device): env_device = None policy_device = None env = self.EnvWithDevice(main_device) - policy = self.PolicyWithDevice() + policy = self.PolicyWithDevice(main_device) collector = SyncDataCollector( env, policy, @@ -1913,14 +1914,16 @@ def test_output_device(self, main_device, storing_device): ) for data in collector: # noqa: B007 break - assert data.device == main_device + # When storing_device is None, it falls back to device + expected_device = storing_device if storing_device is not None else main_device + assert data.device == expected_device # same but more specific device = None env_device = main_device policy_device = main_device env = self.EnvWithDevice(main_device) - policy = self.PolicyWithDevice() + policy = self.PolicyWithDevice(main_device) collector = SyncDataCollector( env, policy, @@ -1933,7 +1936,9 @@ def test_output_device(self, main_device, storing_device): ) for data in collector: # noqa: B007 break - assert data.device == main_device + # When storing_device is None, and env_device == policy_device, it falls back to env_device + expected_device = storing_device if storing_device is not None else main_device + assert data.device == expected_device # none has a device device = None @@ -2334,6 +2339,9 @@ def test_auto_wrap_modules( ), device=device, ) + if isinstance(collector, _MultiDataCollector): + assert collector._weight_sync_schemes is not None + assert "policy" in collector._weight_sync_schemes try: out_keys = ["action"] @@ -2354,6 +2362,7 @@ def test_auto_wrap_modules( p.data.zero_() assert p.device == torch.device("cpu") # Debug: updating policy weights + torchrl_logger.debug("Calling update_policy_weights_") collector.update_policy_weights_() # Debug: updated policy weights elif i == 4: @@ -2401,7 +2410,9 @@ def test_auto_wrap_error(self, collector_class, env_maker, num_envs): policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1]) with pytest.raises( TypeError, - match=("Arguments to policy.forward are incompatible with entries in"), + match=( + "Arguments to policy.forward are incompatible with entries in|Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True." + ), ): collector_class( **self._create_collector_kwargs( @@ -2980,6 +2991,93 @@ def test_param_sync_mixed_device( col.shutdown() del col + @pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 3, + reason="requires at least 3 CUDA devices", + ) + def test_shared_device_weight_update(self): + """Test that weight updates work correctly when multiple workers share the same device. + + This test specifically validates the per-worker queue implementation in SharedMemWeightSyncScheme. + When workers 0 and 2 share cuda:2, each should receive its own copy of the weights through + dedicated queues, preventing race conditions that could occur with a single shared queue. + + Note: This test only uses SharedMemWeightSyncScheme (not MultiProcessWeightSyncScheme) because + the latter sends tensors through pipes, which we want to avoid. + """ + # Create policy on cuda:0 + policy = TensorDictModule( + nn.Linear(7, 7, device="cuda:0"), + in_keys=["observation"], + out_keys=["action"], + ) + + def make_env(): + return ContinuousActionVecMockEnv() + + # Create collector with workers on cuda:2, cuda:1, cuda:2 + # Workers 0 and 2 share cuda:2 - this is the key test case + collector = MultiaSyncDataCollector( + [make_env, make_env, make_env], + policy=policy, + frames_per_batch=30, + total_frames=300, + device=["cuda:2", "cuda:1", "cuda:2"], + storing_device=["cuda:2", "cuda:1", "cuda:2"], + weight_sync_schemes={"policy": SharedMemWeightSyncScheme()}, + ) + + try: + # Collect first batch to initialize workers + for _ in collector: + break + + # Get initial weights + old_weight = policy.module.weight.data.clone() + + # Modify policy weights on cuda:0 + for p in policy.parameters(): + p.data += torch.randn_like(p) + + new_weight = policy.module.weight.data.clone() + assert not torch.allclose( + old_weight, new_weight + ), "Weights should have changed" + + # Update weights - this should propagate to all workers via their dedicated queues + collector.update_policy_weights_() + + # Collect more batches to ensure weights are propagated + for i, _ in enumerate(collector): + if i >= 2: + break + + # Get state dict from all workers + state_dict = collector.state_dict() + + # Verify all workers have the new weights, including both workers on cuda:2 + for worker_idx in range(3): + worker_key = f"worker{worker_idx}" + assert ( + "policy_state_dict" in state_dict[worker_key] + ), f"Worker {worker_idx} should have policy_state_dict" + worker_weight = state_dict[worker_key]["policy_state_dict"][ + "module.weight" + ] + torch.testing.assert_close( + worker_weight.cpu(), + new_weight.cpu(), + msg=( + f"Worker {worker_idx} weights don't match expected weights. " + f"Workers 0 and 2 share device cuda:2, worker 1 is on cuda:1. " + f"This test validates that the per-worker queue system correctly " + f"distributes weights even when multiple workers share a device." + ), + ) + finally: + collector.shutdown() + del collector + class TestAggregateReset: def test_aggregate_reset_to_root(self): @@ -3170,82 +3268,164 @@ def test_aggregate_reset_to_root_errors(self): ) +def _subprocess_test_worker(func, error_queue): + """Worker function that runs a test function and reports errors via queue.""" + try: + func() + except Exception as e: + error_queue.put((type(e).__name__, str(e), traceback.format_exc())) + else: + error_queue.put(None) + + +def _run_test_in_subprocess(func, timeout=120): + """Run a test function in a fresh subprocess to avoid thread pool initialization issues. + + This is necessary because torch.set_num_threads() may not work correctly + if the thread pool has already been initialized in the parent process. + Running in a fresh subprocess ensures a clean PyTorch state. + + Args: + func: The test function to run. Must be picklable (module-level function). + timeout: Timeout in seconds for the subprocess. + + Raises: + AssertionError: If the test function raises an exception in the subprocess. + """ + ctx = torch.multiprocessing.get_context("spawn") + error_queue = ctx.Queue() + + proc = ctx.Process(target=_subprocess_test_worker, args=(func, error_queue)) + proc.start() + proc.join(timeout=timeout) + + if proc.is_alive(): + proc.terminate() + proc.join() + raise AssertionError(f"Test timed out after {timeout} seconds") + + if proc.exitcode != 0: + try: + result = error_queue.get_nowait() + except Exception: + result = None + + if result is not None: + exc_type, exc_msg, tb = result + raise AssertionError(f"Test failed with {exc_type}: {exc_msg}\n{tb}") + else: + raise AssertionError(f"Test subprocess exited with code {proc.exitcode}") + + # Check if there was an exception even with exitcode 0 + try: + result = error_queue.get_nowait() + if result is not None: + exc_type, exc_msg, tb = result + raise AssertionError(f"Test failed with {exc_type}: {exc_msg}\n{tb}") + except Exception: + pass + + +def _test_num_threads_impl(): + """Implementation of test_num_threads that runs in a subprocess.""" + env = ContinuousActionVecMockEnv() + _main_async_collector_saved = torchrl.collectors._multi_base._main_async_collector + torchrl.collectors._multi_base._main_async_collector = decorate_thread_sub_func( + torchrl.collectors._multi_base._main_async_collector, num_threads=3 + ) + num_threads = torch.get_num_threads() + try: + c = MultiSyncDataCollector( + [env], + policy=RandomPolicy(env.action_spec), + num_threads=7, + num_sub_threads=3, + total_frames=200, + frames_per_batch=200, + cat_results="stack", + ) + assert ( + torch.get_num_threads() == 7 + ), f"Expected 7 threads, got {torch.get_num_threads()}" + for _ in c: + pass + finally: + try: + c.shutdown() + del c + except Exception: + pass + torchrl.collectors._multi_base._main_async_collector = ( + _main_async_collector_saved + ) + torch.set_num_threads(num_threads) + + +def _test_auto_num_threads_impl(): + """Implementation of test_auto_num_threads that runs in a subprocess.""" + init_threads = torch.get_num_threads() + + # Test 1: Single env + try: + collector = MultiSyncDataCollector( + [ContinuousActionVecMockEnv], + RandomPolicy(ContinuousActionVecMockEnv().full_action_spec), + frames_per_batch=3, + cat_results="stack", + ) + for _ in collector: + current = torch.get_num_threads() + expected = init_threads - 1 + assert current == expected, f"Expected {expected} threads, got {current}" + break + collector.shutdown() + current = torch.get_num_threads() + assert ( + current == init_threads + ), f"After shutdown: expected {init_threads} threads, got {current}" + del collector + gc.collect() + finally: + torch.set_num_threads(init_threads) + + # Test 2: ParallelEnv with 2 workers + try: + collector = MultiSyncDataCollector( + [ParallelEnv(2, ContinuousActionVecMockEnv)], + RandomPolicy(ContinuousActionVecMockEnv().full_action_spec.expand(2)), + frames_per_batch=3, + cat_results="stack", + ) + for _ in collector: + current = torch.get_num_threads() + expected = init_threads - 2 + assert current == expected, f"Expected {expected} threads, got {current}" + break + collector.shutdown() + current = torch.get_num_threads() + assert ( + current == init_threads + ), f"After shutdown: expected {init_threads} threads, got {current}" + del collector + gc.collect() + finally: + torch.set_num_threads(init_threads) + + class TestLibThreading: @pytest.mark.skipif( IS_OSX, reason="setting different threads across workers can randomly fail on OSX.", ) def test_num_threads(self): - from torchrl.collectors import collectors - - _main_async_collector_saved = collectors._main_async_collector - collectors._main_async_collector = decorate_thread_sub_func( - collectors._main_async_collector, num_threads=3 - ) - num_threads = torch.get_num_threads() - try: - env = ContinuousActionVecMockEnv() - c = MultiSyncDataCollector( - [env], - policy=RandomPolicy(env.action_spec), - num_threads=7, - num_sub_threads=3, - total_frames=200, - frames_per_batch=200, - cat_results="stack", - ) - assert torch.get_num_threads() == 7 - for _ in c: - pass - finally: - try: - c.shutdown() - del c - except Exception: - torchrl_logger.info("Failed to shut down collector") - # reset vals - collectors._main_async_collector = _main_async_collector_saved - torch.set_num_threads(num_threads) + _run_test_in_subprocess(_test_num_threads_impl) @pytest.mark.skipif( IS_OSX or IS_WINDOWS, reason="setting different threads across workers can randomly fail on OSX.", ) def test_auto_num_threads(self): - init_threads = torch.get_num_threads() - try: - collector = MultiSyncDataCollector( - [ContinuousActionVecMockEnv], - RandomPolicy(ContinuousActionVecMockEnv().full_action_spec), - frames_per_batch=3, - cat_results="stack", - ) - for _ in collector: - assert torch.get_num_threads() == init_threads - 1 - break - collector.shutdown() - assert torch.get_num_threads() == init_threads - del collector - gc.collect() - finally: - torch.set_num_threads(init_threads) - - try: - collector = MultiSyncDataCollector( - [ParallelEnv(2, ContinuousActionVecMockEnv)], - RandomPolicy(ContinuousActionVecMockEnv().full_action_spec.expand(2)), - frames_per_batch=3, - cat_results="stack", - ) - for _ in collector: - assert torch.get_num_threads() == init_threads - 2 - break - collector.shutdown() - assert torch.get_num_threads() == init_threads - del collector - gc.collect() - finally: - torch.set_num_threads(init_threads) + _run_test_in_subprocess(_test_auto_num_threads_impl) class TestUniqueTraj: @@ -3848,13 +4028,12 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase: def all_worker_ids(self) -> list[int] | list[torch.device]: return list(range(self.num_workers)) - @pytest.mark.skipif(not _has_cuda, reason="requires cuda another device than CPU.") @pytest.mark.skipif(not _has_gym, reason="requires gym") @pytest.mark.parametrize( - "weight_updater", ["scheme_shared", "scheme_pipe", "weight_updater"] + "weight_updater", ["scheme_shared", "scheme_mp", "weight_updater"] ) - def test_weight_update(self, weight_updater): - device = "cuda:0" + def test_update_weights(self, weight_updater): + device = "cuda:0" if torch.cuda.is_available() else "cpu" env_maker = lambda: GymEnv(PENDULUM_VERSIONED(), device="cpu") policy_factory = lambda: TensorDictModule( nn.Linear(3, 1, device=device), in_keys=["observation"], out_keys=["action"] @@ -3863,14 +4042,22 @@ def test_weight_update(self, weight_updater): policy_weights = TensorDict.from_module(policy) kwargs = {} if weight_updater == "scheme_shared": - kwargs = {"weight_sync_schemes": {"policy": SharedMemWeightSyncScheme()}} - elif weight_updater == "scheme_pipe": - kwargs = {"weight_sync_schemes": {"policy": MultiProcessWeightSyncScheme()}} + scheme = SharedMemWeightSyncScheme() + kwargs = {"weight_sync_schemes": {"policy": scheme}} + elif weight_updater == "scheme_mp": + scheme = MultiProcessWeightSyncScheme() + kwargs = {"weight_sync_schemes": {"policy": scheme}} elif weight_updater == "weight_updater": + scheme = None kwargs = {"weight_updater": self.MPSWeightUpdaterBase(policy_weights, 2)} else: raise NotImplementedError + if scheme is not None: + scheme.init_on_sender( + model=policy_factory(), devices=[device] * 2, model_id="policy" + ) + collector = MultiSyncDataCollector( create_env_fn=[env_maker, env_maker], policy_factory=policy_factory, @@ -3883,10 +4070,13 @@ def test_weight_update(self, weight_updater): storing_device="cpu", **kwargs, ) - - # When using policy_factory, must pass weights explicitly - collector.update_policy_weights_(policy_weights) try: + if weight_updater == "weight_updater": + assert collector._legacy_weight_updater + + # When using policy_factory, must pass weights explicitly + collector.update_policy_weights_(policy_weights) + for i, data in enumerate(collector): if i == 2: assert (data["action"] != 0).any() @@ -3900,6 +4090,79 @@ def test_weight_update(self, weight_updater): finally: collector.shutdown() + @pytest.mark.parametrize( + "collector_cls", + [ + functools.partial(MultiSyncDataCollector, cat_results="stack"), + MultiaSyncDataCollector, + ], + ) + @pytest.mark.parametrize( + "weight_sync_scheme_cls", + [MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme], + ) + def test_nonserializable_policy_with_factory_and_weight_sync( + self, collector_cls, weight_sync_scheme_cls + ): + """Test that a non-serializable policy can be used on the main node alongside a policy_factory. + + The policy instance is used only for weight extraction on the main node, while + the policy_factory is what gets sent to and instantiated on workers. + """ + + # Simple continuous-control env + def create_env(): + return ContinuousActionVecMockEnv() + + # Non-serializable policy instance on main node + base_module = NonSerializableBiasModule(0.0) + policy = TensorDictModule( + base_module, in_keys=["observation"], out_keys=["action"] + ) + + # Serializable factory used to build worker policies + def policy_factory(): + return TensorDictModule( + BiasModule(0.0), in_keys=["observation"], out_keys=["action"] + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # Weight sync scheme will be initialized on the sender side by the collector, + # using the policy instance passed above as the source of weights. + weight_sync_scheme = weight_sync_scheme_cls() + + collector = collector_cls( + [create_env, create_env], + policy=policy, + policy_factory=policy_factory, + frames_per_batch=16, + total_frames=64, + device=device, + storing_device="cpu", + weight_sync_schemes={"policy": weight_sync_scheme}, + ) + + try: + # Ensure we can collect at least one batch without serialization issues + iterator = iter(collector) + _ = next(iterator) + + # Change the main-node policy weights and update workers without passing weights explicitly + with torch.no_grad(): + base_module.bias.add_(1.0) + + # This call should: + # - Use the (non-serializable) policy to extract weights via TensorDict.from_module() + # - Send those weights through the weight sync scheme + # - NOT attempt to serialize the policy itself + collector.update_policy_weights_() + + # Collect again to exercise the updated weights path and ensure workers didn't crash + _ = next(iterator) + finally: + collector.shutdown() + class TestAsyncCollection: @pytest.mark.parametrize("total_frames", [-1, 1_000_000_000]) @@ -3993,6 +4256,7 @@ def test_start_multi(self, total_frames, cls): "weight_sync_scheme", [None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme], ) + @pytest.mark.flaky(reruns=3, reruns_delay=0.5) def test_start_update_policy(self, total_frames, cls, weight_sync_scheme): rb = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000)) env = CountingEnv() @@ -4025,16 +4289,17 @@ def test_start_update_policy(self, total_frames, cls, weight_sync_scheme): frames_per_batch=16, **kwargs, ) - if not isinstance(collector, SyncDataCollector): - if weight_sync_scheme is not None: - assert isinstance( - collector._weight_sync_schemes["policy"], weight_sync_scheme - ) - else: - assert isinstance( - collector._weight_sync_schemes["policy"], SharedMemWeightSyncScheme - ) try: + if not isinstance(collector, SyncDataCollector): + if weight_sync_scheme is not None: + assert isinstance( + collector._weight_sync_schemes["policy"], weight_sync_scheme + ) + else: + assert isinstance( + collector._weight_sync_schemes["policy"], + SharedMemWeightSyncScheme, + ) collector.start() for _ in range(10): time.sleep(0.1) @@ -4050,7 +4315,7 @@ def test_start_update_policy(self, total_frames, cls, weight_sync_scheme): if (rb[-16:]["action"] == 1).all(): break else: - raise RuntimeError + raise RuntimeError("Failed to update policy weights") finally: collector.async_shutdown(timeout=10) del collector diff --git a/test/test_cost.py b/test/test_cost.py index 53a3966495a..5906100e5e5 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -1121,6 +1121,7 @@ def test_dqn_prioritized_weights(self): value_network=value, action_space="categorical", reduction="mean" ) loss_fn.make_value_estimator() + softupdate = SoftUpdate(loss_fn, eps=0.5) # Create prioritized replay buffer rb = TensorDictPrioritizedReplayBuffer( @@ -1174,6 +1175,7 @@ def test_dqn_prioritized_weights(self): reduction="none", use_prioritized_weights=False, ) + softupdate = SoftUpdate(loss_fn_no_reduction, eps=0.5) loss_fn_no_reduction.make_value_estimator() loss_fn_no_reduction.target_value_network_params = ( loss_fn.target_value_network_params @@ -1673,6 +1675,7 @@ def test_dqn_prioritized_weights(self): loss_fn = DQNLoss( value_network=value, action_space="categorical", reduction="mean" ) + softupdate = SoftUpdate(loss_fn, eps=0.5) loss_fn.make_value_estimator() # Create prioritized replay buffer @@ -1727,6 +1730,7 @@ def test_dqn_prioritized_weights(self): reduction="none", use_prioritized_weights=False, ) + softupdate = SoftUpdate(loss_fn_no_reduction, eps=0.5) loss_fn_no_reduction.make_value_estimator() loss_fn_no_reduction.target_value_network_params = ( loss_fn.target_value_network_params @@ -2396,6 +2400,7 @@ def test_ddpg_prioritized_weights(self): # Create DDPG loss loss_fn = DDPGLoss(actor_network=actor, value_network=qvalue) + softupdate = SoftUpdate(loss_fn, eps=0.5) loss_fn.make_value_estimator() # Create prioritized replay buffer @@ -2449,6 +2454,7 @@ def test_ddpg_prioritized_weights(self): value_network=qvalue, use_prioritized_weights=False, ) + softupdate = SoftUpdate(loss_fn_no_weights, eps=0.5) loss_fn_no_weights.make_value_estimator() loss_fn_no_weights.value_network_params = loss_fn.value_network_params loss_fn_no_weights.target_value_network_params = ( @@ -3303,6 +3309,7 @@ def test_td3_prioritized_weights(self): low=-torch.ones(n_act), high=torch.ones(n_act), shape=(n_act,) ), ) + softupdate = SoftUpdate(loss_fn, eps=0.5) loss_fn.make_value_estimator() # Create prioritized replay buffer @@ -3360,6 +3367,7 @@ def test_td3_prioritized_weights(self): ), use_prioritized_weights=False, ) + softupdate = SoftUpdate(loss_fn_no_weights, eps=0.5) loss_fn_no_weights.make_value_estimator() loss_fn_no_weights.qvalue_network_params = loss_fn.qvalue_network_params loss_fn_no_weights.target_qvalue_network_params = ( @@ -5288,6 +5296,122 @@ def test_sac_reduction(self, reduction, version, composite_action_dist): continue assert loss[key].shape == torch.Size([]) + def test_sac_prioritized_weights(self, version): + """Test SAC with prioritized replay buffer weighted loss reduction.""" + if version != 2: + pytest.skip("Test not implemented for version 1.") + n_obs = 4 + n_act = 2 + batch_size = 32 + buffer_size = 100 + + # Actor network + actor_net = nn.Sequential( + nn.Linear(n_obs, 64), + nn.ReLU(), + nn.Linear(64, 2 * n_act), + NormalParamExtractor(), + ) + actor_module = TensorDictModule( + actor_net, in_keys=["observation"], out_keys=["loc", "scale"] + ) + actor = ProbabilisticActor( + module=actor_module, + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + return_log_prob=True, + spec=Bounded( + low=-torch.ones(n_act), high=torch.ones(n_act), shape=(n_act,) + ), + ) + + # Q-value network + qvalue_net = MLP(in_features=n_obs + n_act, out_features=1, num_cells=[64, 64]) + qvalue = ValueOperator(module=qvalue_net, in_keys=["observation", "action"]) + + # Value network for SAC v1 + value_net = MLP(in_features=n_obs, out_features=1, num_cells=[64, 64]) + value = ValueOperator(module=value_net, in_keys=["observation"]) + + # Create SAC loss + loss_fn = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=2, + ) + SoftUpdate(loss_fn, eps=0.5) + loss_fn.make_value_estimator() + + # Create prioritized replay buffer + rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.9, + storage=LazyTensorStorage(buffer_size), + batch_size=batch_size, + priority_key="td_error", + ) + + # Create initial data + initial_data = TensorDict( + { + "observation": torch.randn(buffer_size, n_obs), + "action": torch.randn(buffer_size, n_act).clamp(-1, 1), + ("next", "observation"): torch.randn(buffer_size, n_obs), + ("next", "reward"): torch.randn(buffer_size, 1), + ("next", "done"): torch.zeros(buffer_size, 1, dtype=torch.bool), + ("next", "terminated"): torch.zeros(buffer_size, 1, dtype=torch.bool), + }, + batch_size=[buffer_size], + ) + rb.extend(initial_data) + + # Sample (weights should all be identical initially) + sample1 = rb.sample() + assert "priority_weight" in sample1.keys() + weights1 = sample1["priority_weight"] + assert torch.allclose(weights1, weights1[0], atol=1e-5) + + # Run loss to get priorities + loss_fn(sample1) + assert "td_error" in sample1.keys() + + # Update replay buffer with new priorities + rb.update_tensordict_priority(sample1) + + # Sample again - weights should now be non-equal + sample2 = rb.sample() + weights2 = sample2["priority_weight"] + assert weights2.std() > 1e-5 + + # Run loss again with varied weights + loss_out2 = loss_fn(sample2) + assert torch.isfinite(loss_out2["loss_qvalue"]) + + # Verify weighted vs unweighted differ + loss_fn_no_weights = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=2, + use_prioritized_weights=False, + ) + SoftUpdate(loss_fn_no_weights, eps=0.5) + loss_fn_no_weights.make_value_estimator() + loss_fn_no_weights.qvalue_network_params = loss_fn.qvalue_network_params + loss_fn_no_weights.target_qvalue_network_params = ( + loss_fn.target_qvalue_network_params + ) + loss_fn_no_weights.actor_network_params = loss_fn.actor_network_params + loss_fn_no_weights.value_network_params = loss_fn.value_network_params + loss_fn_no_weights.target_value_network_params = ( + loss_fn.target_value_network_params + ) + + loss_out_no_weights = loss_fn_no_weights(sample2) + # Weighted and unweighted should differ (in general) + assert torch.isfinite(loss_out_no_weights["loss_qvalue"]) + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" @@ -7767,118 +7891,6 @@ def test_redq_reduction(self, reduction, deprecated_loss): continue assert loss[key].shape == torch.Size([]) - def test_sac_prioritized_weights(self): - """Test SAC with prioritized replay buffer weighted loss reduction.""" - n_obs = 4 - n_act = 2 - batch_size = 32 - buffer_size = 100 - - # Actor network - actor_net = nn.Sequential( - nn.Linear(n_obs, 64), - nn.ReLU(), - nn.Linear(64, 2 * n_act), - NormalParamExtractor(), - ) - actor_module = TensorDictModule( - actor_net, in_keys=["observation"], out_keys=["loc", "scale"] - ) - actor = ProbabilisticActor( - module=actor_module, - in_keys=["loc", "scale"], - distribution_class=TanhNormal, - return_log_prob=True, - spec=Bounded( - low=-torch.ones(n_act), high=torch.ones(n_act), shape=(n_act,) - ), - ) - - # Q-value network - qvalue_net = MLP(in_features=n_obs + n_act, out_features=1, num_cells=[64, 64]) - qvalue = ValueOperator(module=qvalue_net, in_keys=["observation", "action"]) - - # Value network for SAC v1 - value_net = MLP(in_features=n_obs, out_features=1, num_cells=[64, 64]) - value = ValueOperator(module=value_net, in_keys=["observation"]) - - # Create SAC loss - loss_fn = SACLoss( - actor_network=actor, - qvalue_network=qvalue, - value_network=value, - num_qvalue_nets=2, - ) - loss_fn.make_value_estimator() - - # Create prioritized replay buffer - rb = TensorDictPrioritizedReplayBuffer( - alpha=0.7, - beta=0.9, - storage=LazyTensorStorage(buffer_size), - batch_size=batch_size, - priority_key="td_error", - ) - - # Create initial data - initial_data = TensorDict( - { - "observation": torch.randn(buffer_size, n_obs), - "action": torch.randn(buffer_size, n_act).clamp(-1, 1), - ("next", "observation"): torch.randn(buffer_size, n_obs), - ("next", "reward"): torch.randn(buffer_size, 1), - ("next", "done"): torch.zeros(buffer_size, 1, dtype=torch.bool), - ("next", "terminated"): torch.zeros(buffer_size, 1, dtype=torch.bool), - }, - batch_size=[buffer_size], - ) - rb.extend(initial_data) - - # Sample (weights should all be identical initially) - sample1 = rb.sample() - assert "priority_weight" in sample1.keys() - weights1 = sample1["priority_weight"] - assert torch.allclose(weights1, weights1[0], atol=1e-5) - - # Run loss to get priorities - loss_fn(sample1) - assert "td_error" in sample1.keys() - - # Update replay buffer with new priorities - rb.update_tensordict_priority(sample1) - - # Sample again - weights should now be non-equal - sample2 = rb.sample() - weights2 = sample2["priority_weight"] - assert weights2.std() > 1e-5 - - # Run loss again with varied weights - loss_out2 = loss_fn(sample2) - assert torch.isfinite(loss_out2["loss_qvalue"]) - - # Verify weighted vs unweighted differ - loss_fn_no_weights = SACLoss( - actor_network=actor, - qvalue_network=qvalue, - value_network=value, - num_qvalue_nets=2, - use_prioritized_weights=False, - ) - loss_fn_no_weights.make_value_estimator() - loss_fn_no_weights.qvalue_network_params = loss_fn.qvalue_network_params - loss_fn_no_weights.target_qvalue_network_params = ( - loss_fn.target_qvalue_network_params - ) - loss_fn_no_weights.actor_network_params = loss_fn.actor_network_params - loss_fn_no_weights.value_network_params = loss_fn.value_network_params - loss_fn_no_weights.target_value_network_params = ( - loss_fn.target_value_network_params - ) - - loss_out_no_weights = loss_fn_no_weights(sample2) - # Weighted and unweighted should differ (in general) - assert torch.isfinite(loss_out_no_weights["loss_qvalue"]) - class TestCQL(LossModuleTestBase): seed = 0 diff --git a/test/test_distributed.py b/test/test_distributed.py index 6183132394e..48d2ea5181c 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -10,35 +10,22 @@ import abc import argparse +import importlib import os +import socket import sys import time +import traceback from functools import partial import pytest -from tensordict import TensorDict -from tensordict.nn import TensorDictModuleBase -from torchrl._utils import logger as torchrl_logger -from torchrl.data import ( - LazyTensorStorage, - RandomSampler, - RayReplayBuffer, - RoundRobinWriter, - SamplerWithoutReplacement, -) - -try: - import ray - - _has_ray = True - RAY_ERR = None -except ModuleNotFoundError as err: - _has_ray = False - RAY_ERR = err import torch +from tensordict import TensorDict +from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential from torch import multiprocessing as mp, nn +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import ( MultiaSyncDataCollector, @@ -52,7 +39,16 @@ RPCDataCollector, ) from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG -from torchrl.envs.utils import RandomPolicy +from torchrl.data import ( + LazyTensorStorage, + RandomSampler, + RayReplayBuffer, + RoundRobinWriter, + SamplerWithoutReplacement, +) +from torchrl.modules import RandomPolicy + +_has_ray = importlib.util.find_spec("ray") is not None if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test.mocking_classes import ContinuousActionVecMockEnv, CountingEnv @@ -115,16 +111,17 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch): **cls.distributed_kwargs(), ) total = 0 - torchrl_logger.info("getting data...") for data in collector: total += data.numel() assert data.numel() == frames_per_batch assert data.names[-1] == "time" - collector.shutdown() assert total == 1000 - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {str(e)}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) + finally: + collector.shutdown() @pytest.mark.parametrize("frames_per_batch", [50, 100]) def test_distributed_collector_basic(self, frames_per_batch): @@ -136,8 +133,9 @@ def test_distributed_collector_basic(self, frames_per_batch): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -163,9 +161,10 @@ def _test_distributed_collector_mult(cls, queue, frames_per_batch): assert data.numel() == frames_per_batch collector.shutdown() assert total == -frames_per_batch * (1000 // -frames_per_batch) - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {e}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) def test_distributed_collector_mult(self, frames_per_batch=200): """Testing multiple nodes.""" @@ -177,8 +176,9 @@ def test_distributed_collector_mult(self, frames_per_batch=200): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -205,9 +205,10 @@ def _test_distributed_collector_sync(cls, queue, sync): assert data.numel() == frames_per_batch collector.shutdown() assert total == 200 - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {str(e)}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) @pytest.mark.parametrize("sync", [False, True]) def test_distributed_collector_sync(self, sync): @@ -219,8 +220,9 @@ def test_distributed_collector_sync(self, sync): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -247,9 +249,10 @@ def _test_distributed_collector_class(cls, queue, collector_class): assert data.numel() == frames_per_batch collector.shutdown() assert total == 200 - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {str(e)}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) @pytest.mark.parametrize( "collector_class", @@ -268,8 +271,9 @@ def test_distributed_collector_class(self, collector_class): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -277,21 +281,36 @@ def test_distributed_collector_class(self, collector_class): queue.close() @classmethod - def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): + def _test_distributed_collector_updatepolicy( + cls, queue, collector_class, sync, pfactory + ): try: frames_per_batch = 50 total_frames = 300 env = CountingEnv - policy = CountingPolicy() + if pfactory: + policy_factory = CountingPolicy + policy = None + else: + policy = CountingPolicy() + policy_factory = None if collector_class is MultiaSyncDataCollector: # otherwise we may collect data from a collector that has not yet been # updated n_collectors = 1 else: n_collectors = 2 - collector = cls.distributed_class()( + weights = None + if policy is None and policy_factory is not None: + policy_stateful = policy_factory() + weights = TensorDict.from_module(policy_stateful).lock_() + dcls = cls.distributed_class() + torchrl_logger.info(f"Using distributed collector {dcls}") + collector = None + collector = dcls( [env] * n_collectors, policy, + policy_factory=policy_factory, collector_class=collector_class, total_frames=total_frames, frames_per_batch=frames_per_batch, @@ -306,17 +325,23 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): assert data.numel() == frames_per_batch if i == 0: first_batch = data - policy.weight.data += 1 - collector.update_policy_weights_() + if policy is not None: + policy.weight.data += 1 + else: + weights.data += 1 + torchrl_logger.info("TEST -- Calling update_policy_weights_()") + collector.update_policy_weights_(weights) + torchrl_logger.info("TEST -- Done calling update_policy_weights_()") elif total == total_frames - frames_per_batch: last_batch = data assert (first_batch["action"] == 1).all(), first_batch["action"] assert (last_batch["action"] == 2).all(), last_batch["action"] collector.shutdown() assert total == total_frames - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {str(e)}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) @pytest.mark.parametrize( "collector_class", @@ -327,18 +352,20 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): ], ) @pytest.mark.parametrize("sync", [False, True]) - def test_distributed_collector_updatepolicy(self, collector_class, sync): + @pytest.mark.parametrize("pfactory", [False, True]) + def test_distributed_collector_updatepolicy(self, collector_class, sync, pfactory): """Testing various collector classes to be used in nodes.""" queue = mp.Queue(1) proc = mp.Process( target=self._test_distributed_collector_updatepolicy, - args=(queue, collector_class, sync), + args=(queue, collector_class, sync, pfactory), ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -353,7 +380,13 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: - return {"launcher": "mp", "tcp_port": "4324"} + # Pick an ephemeral free TCP port on localhost for each test process to + # avoid address-in-use errors when tests are run repeatedly or in quick + # succession. + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + port = s.getsockname()[1] + return {"launcher": "mp", "tcp_port": str(port)} @classmethod def _start_worker(cls): @@ -367,7 +400,10 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: - return {"launcher": "mp", "tcp_port": "4324"} + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + port = s.getsockname()[1] + return {"launcher": "mp", "tcp_port": str(port)} @classmethod def _start_worker(cls): @@ -381,7 +417,10 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: - return {"launcher": "mp", "tcp_port": "4324"} + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + port = s.getsockname()[1] + return {"launcher": "mp", "tcp_port": str(port)} @classmethod def _start_worker(cls): @@ -392,15 +431,24 @@ def test_distributed_collector_sync(self, *args): @classmethod def _test_distributed_collector_updatepolicy( - cls, queue, collector_class, update_interval + cls, + queue, + collector_class, + update_interval, + pfactory, ): frames_per_batch = 50 total_frames = 300 env = CountingEnv + if pfactory: + policy_factory = CountingPolicy + else: + policy_factory = None policy = CountingPolicy() collector = cls.distributed_class()( [env] * 2, policy, + policy_factory=policy_factory, collector_class=collector_class, total_frames=total_frames, frames_per_batch=frames_per_batch, @@ -408,7 +456,6 @@ def _test_distributed_collector_updatepolicy( **cls.distributed_kwargs(), ) try: - total = 0 first_batch = None last_batch = None @@ -426,10 +473,12 @@ def _test_distributed_collector_updatepolicy( else: assert (last_batch["action"] == 1).all(), last_batch["action"] assert total == total_frames - queue.put("passed") + queue.put(("passed", None)) + except Exception as e: + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) finally: collector.shutdown() - queue.put("not passed") @pytest.mark.parametrize( "collector_class", @@ -440,18 +489,22 @@ def _test_distributed_collector_updatepolicy( ], ) @pytest.mark.parametrize("update_interval", [1]) - def test_distributed_collector_updatepolicy(self, collector_class, update_interval): + @pytest.mark.parametrize("pfactory", [False, True]) + def test_distributed_collector_updatepolicy( + self, collector_class, update_interval, pfactory + ): """Testing various collector classes to be used in nodes.""" queue = mp.Queue(1) proc = mp.Process( target=self._test_distributed_collector_updatepolicy, - args=(queue, collector_class, update_interval), + args=(queue, collector_class, update_interval, pfactory), ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -459,7 +512,9 @@ def test_distributed_collector_updatepolicy(self, collector_class, update_interv queue.close() -@pytest.mark.skipif(not _has_ray, reason=f"Ray not found (error: {RAY_ERR})") +@pytest.mark.skipif( + not _has_ray, reason="Ray not found. Ray may be badly configured or not installed." +) class TestRayCollector(DistributedCollectorBase): """A testing distributed data collector class that runs tests without using a Queue, to avoid potential deadlocks when combining Ray and multiprocessing. @@ -467,6 +522,7 @@ class TestRayCollector(DistributedCollectorBase): @pytest.fixture(autouse=True, scope="class") def start_ray(self): + import ray from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG ray.init(**DEFAULT_RAY_INIT_CONFIG) @@ -474,12 +530,24 @@ def start_ray(self): yield ray.shutdown() + @pytest.fixture(autouse=True, scope="function") + def reset_process_group(self): + import torch.distributed as dist + + try: + dist.destroy_process_group() + except Exception: + pass + yield + @classmethod def distributed_class(cls) -> type: return RayCollector @classmethod def distributed_kwargs(cls) -> dict: + import ray + ray.shutdown() # make sure ray is not running ray_init_config = DEFAULT_RAY_INIT_CONFIG ray_init_config["runtime_env"] = { @@ -557,20 +625,31 @@ def test_distributed_collector_class(self, collector_class): ], ) @pytest.mark.parametrize("sync", [False, True]) - def test_distributed_collector_updatepolicy(self, collector_class, sync): + @pytest.mark.parametrize("pfactory", [False, True]) + def test_distributed_collector_updatepolicy(self, collector_class, sync, pfactory): frames_per_batch = 50 total_frames = 300 env = CountingEnv - policy = CountingPolicy() + if pfactory: + policy_factory = CountingPolicy + policy = None + else: + policy = CountingPolicy() + policy_factory = None if collector_class is MultiaSyncDataCollector: # otherwise we may collect data from a collector that has not yet been # updated n_collectors = 1 else: n_collectors = 2 + weights = None + if policy is None and policy_factory is not None: + policy_stateful = policy_factory() + weights = TensorDict.from_module(policy_stateful) collector = self.distributed_class()( [env] * n_collectors, policy, + policy_factory=policy_factory, collector_class=collector_class, total_frames=total_frames, frames_per_batch=frames_per_batch, @@ -586,8 +665,11 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync): assert data.numel() == frames_per_batch if i == 0: first_batch = data - policy.weight.data += 1 - collector.update_policy_weights_() + if policy is not None: + policy.weight.data += 1 + else: + weights.data += 1 + collector.update_policy_weights_(weights) elif total == total_frames - frames_per_batch: last_batch = data assert (first_batch["action"] == 1).all(), first_batch["action"] @@ -631,7 +713,19 @@ def test_ray_collector_policy_constructor(self): env = CountingEnv def policy_constructor(): - return lambda td: td.set("action", torch.full(td.shape, 2)) + return TensorDictSequential( + TensorDictModule( + lambda x: x.float(), + in_keys=["observation"], + out_keys=["_obs_float"], + ), + TensorDictModule( + nn.Linear(1, 1), out_keys=["action"], in_keys=["_obs_float"] + ), + TensorDictModule( + lambda x: x.int(), in_keys=["action"], out_keys=["action"] + ), + ) collector = self.distributed_class()( [env] * n_collectors, @@ -641,9 +735,16 @@ def policy_constructor(): frames_per_batch=frames_per_batch, **self.distributed_kwargs(), ) + p = policy_constructor() + # p(env().reset()) + weights = TensorDict.from_module(p) + weights["module", "1", "module", "weight"].data.fill_(0) + weights["module", "1", "module", "bias"].data.fill_(2) + collector.update_policy_weights_(weights) try: for data in collector: assert (data["action"] == 2).all() + collector.update_policy_weights_(weights) finally: collector.shutdown() diff --git a/test/test_env.py b/test/test_env.py index 7aa00e98d2d..8031a66e986 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -78,10 +78,16 @@ check_marl_grouping, make_composite_from_td, MarlGroupMapType, - RandomPolicy, step_mdp, ) -from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator +from torchrl.modules import ( + Actor, + ActorCriticOperator, + MLP, + RandomPolicy, + SafeModule, + ValueOperator, +) from torchrl.modules.tensordict_module import WorldModelWrapper pytestmark = [ diff --git a/test/test_libs.py b/test/test_libs.py index 9157734b376..3973cc2604b 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -123,16 +123,12 @@ from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper from torchrl.envs.transforms import ActionMask, TransformedEnv -from torchrl.envs.utils import ( - check_env_specs, - ExplorationType, - MarlGroupMapType, - RandomPolicy, -) +from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType from torchrl.modules import ( ActorCriticOperator, MaskedCategorical, MLP, + RandomPolicy, SafeModule, ValueOperator, ) diff --git a/test/test_rb.py b/test/test_rb.py index 15b9b9af0e5..85b1fe9eb22 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -35,7 +35,7 @@ from torch.utils._pytree import tree_flatten, tree_map from torchrl._utils import _replace_last, logger as torchrl_logger -from torchrl.collectors import RandomPolicy, SyncDataCollector +from torchrl.collectors import SyncDataCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import ( CompressedListStorage, @@ -107,6 +107,7 @@ UnsqueezeTransform, VecNorm, ) +from torchrl.modules import RandomPolicy if os.getenv("PYTORCH_TEST_FBCODE"): @@ -1398,17 +1399,17 @@ def test_replay_buffer_trajectories(stack, reduction, datatype): if datatype == "tc": rb.update_priority(index, sampled_td) sampled_td, info = rb.sample(return_info=True) - assert (info["_weight"] > 0).all() + assert (info["priority_weight"] > 0).all() assert sampled_td.batch_size == torch.Size([3, 4]) else: rb.update_tensordict_priority(sampled_td) sampled_td = rb.sample(include_info=True) - assert (sampled_td.get("_weight") > 0).all() + assert (sampled_td.get("priority_weight") > 0).all() assert sampled_td.batch_size == torch.Size([3, 4]) # # set back the trajectory length # sampled_td_filtered = sampled_td.to_tensordict().exclude( - # "_weight", "index", "td_error" + # "priority_weight", "index", "td_error" # ) # sampled_td_filtered.batch_size = [3, 4] @@ -1904,12 +1905,12 @@ def test_rb_trajectories(stack, reduction): sampled_td.set("td_error", torch.rand(3, 4)) rb.update_tensordict_priority(sampled_td) sampled_td = rb.sample(include_info=True) - assert (sampled_td.get("_weight") > 0).all() + assert (sampled_td.get("priority_weight") > 0).all() assert sampled_td.batch_size == torch.Size([3, 4]) # set back the trajectory length sampled_td_filtered = sampled_td.to_tensordict().exclude( - "_weight", "index", "td_error" + "priority_weight", "index", "td_error" ) sampled_td_filtered.batch_size = [3, 4] @@ -3379,14 +3380,14 @@ def test_prioritized_slice_sampler_doc_example(): sample, info = rb.sample(return_info=True) # print("episode", sample["episode"].tolist()) # print("steps", sample["steps"].tolist()) - # print("weight", info["_weight"].tolist()) + # print("weight", info["priority_weight"].tolist()) priority = torch.tensor([0, 3, 3, 0, 0, 0, 1, 1, 1]) rb.update_priority(torch.arange(0, 9, 1), priority=priority) sample, info = rb.sample(return_info=True) # print("episode", sample["episode"].tolist()) # print("steps", sample["steps"].tolist()) - # print("weight", info["_weight"].tolist()) + # print("weight", info["priority_weight"].tolist()) @pytest.mark.parametrize("device", get_default_devices()) diff --git a/test/test_transforms.py b/test/test_transforms.py index 567c0995d20..37239021ea8 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -13,6 +13,7 @@ import os import pickle import re + import sys from copy import copy from functools import partial @@ -37,9 +38,10 @@ from tensordict.nn import TensorDictModule, TensorDictSequential, WrapModule from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td from torch import multiprocessing as mp, nn, Tensor +from torchrl import logger as torchrl_logger from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env -from torchrl.collectors import MultiSyncDataCollector +from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.data import ( Bounded, BoundedContinuous, @@ -136,9 +138,18 @@ from torchrl.envs.transforms.vc1 import _has_vc from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform from torchrl.envs.utils import check_env_specs, MarlGroupMapType, step_mdp -from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal +from torchrl.modules import ( + GRUModule, + LSTMModule, + MLP, + ProbabilisticActor, + RandomPolicy, + TanhNormal, +) from torchrl.modules.utils import get_primers_from_module from torchrl.record.recorder import VideoRecorder +from torchrl.testing.modules import BiasModule +from torchrl.weight_update import RayModuleTransformScheme if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import ( # noqa @@ -15015,6 +15026,136 @@ def test_ray_extension(self): ray.stop() +class TestRayModuleTransform: + @pytest.fixture(autouse=True, scope="function") + def start_ray(self): + import ray + from torchrl import merge_ray_runtime_env + from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG + + if ray.is_initialized(): + ray.shutdown() + + # Use merge_ray_runtime_env to exclude large directories from the runtime environment + # This prevents issues with Ray's working_dir size limits and GCS package expiration + ray_init_config = merge_ray_runtime_env(dict(DEFAULT_RAY_INIT_CONFIG)) + ray.init(**ray_init_config) + + yield + ray.shutdown() + + @pytest.fixture(autouse=True, scope="function") + def reset_process_group(self): + import torch.distributed as dist + + try: + dist.destroy_process_group() + except Exception: + pass + yield + + def test_ray_module_transform_scheme_flow(self): + bias_module = BiasModule(2.0) + module_fact = lambda: TensorDictModule( + bias_module, + in_keys=["observation"], + out_keys=["action"], + ) + + # Create scheme and transform + scheme = RayModuleTransformScheme() + transform = ModuleTransform( + module_factory=module_fact, + weight_sync_scheme=scheme, + use_ray_service=True, + actor_name="my_transform", + ) + assert transform.in_keys == ["observation"] + assert transform.out_keys == ["action"] + dummy_data = TensorDict(observation=torch.zeros(2, 3), batch_size=[2]) + + module = module_fact() + assert (module(dummy_data)["action"] == 2).all() + + # test sending weights + weights = TensorDict.from_module(module) + d = weights.data + d *= 0 + d += 1 + scheme.send(weights) + assert (module(dummy_data)["action"] == 1).all() + + def test_ray_module_transform_scheme_collector(self): + # Create a simple module that adds a learnable bias to observations + # We use addition instead of scaling to avoid issues with observation values + + bias_module = BiasModule() + module = TensorDictModule( + bias_module, + in_keys=["observation"], + out_keys=["observation"], # Transform in-place + ) + + # Create scheme and transform + scheme = RayModuleTransformScheme() + transform = RayModuleTransform( + module=module, + weight_sync_scheme=scheme, + ) + + # Create transformed env + base_env = ContinuousActionVecMockEnv + + def make_env(): + return TransformedEnv(base_env(), transform) + + # Create collector with scheme registered + torchrl_logger.debug("Creating collector") + policy = RandomPolicy(base_env().action_spec) + collector = SyncDataCollector( + make_env, + policy, + frames_per_batch=50, + total_frames=200, + weight_sync_schemes={"transform_module": scheme}, + ) + + torchrl_logger.debug("Starting collector") + first_batch_mean = None + second_batch_mean = None + try: + for i, data in enumerate(collector): + obs_mean = data["observation"].mean().item() + + if i == 0: + first_batch_mean = obs_mean + + # Update weights: set bias to 100.0 (large value to be clearly visible) + torchrl_logger.debug("Updating weights") + new_weights = TensorDict.from_module(module) + new_weights["module", "bias"].data.fill_(100.0) + collector.update_policy_weights_( + new_weights, model_id="transform_module" + ) + elif i == 1: + second_batch_mean = obs_mean + break + finally: + collector.shutdown() + + # Verify that weights were updated + # With bias=0.0, first batch should have observations around 0 (env default) + # With bias=100.0, second batch should have observations shifted by 100 + assert first_batch_mean is not None, "First batch not collected" + assert second_batch_mean is not None, "Second batch not collected" + + # The second batch should have significantly higher mean due to bias=100 + assert second_batch_mean > first_batch_mean + 50, ( + f"Weight update did not take effect: first_mean={first_batch_mean:.2f}, " + f"second_mean={second_batch_mean:.2f}. Expected second to be at least 50 higher." + ) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_weightsync.py b/test/test_weightsync.py deleted file mode 100644 index 2ccd4308ccf..00000000000 --- a/test/test_weightsync.py +++ /dev/null @@ -1,863 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import argparse -import importlib.util -import pickle -import time - -import pytest -import torch -import torch.nn as nn -from mocking_classes import ContinuousActionVecMockEnv -from tensordict import TensorDict -from tensordict.nn import TensorDictModule -from torch import multiprocessing as mp -from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector -from torchrl.weight_update.weight_sync_schemes import ( - _resolve_model, - DistributedWeightSyncScheme, - MPTransport, - MultiProcessWeightSyncScheme, - NoWeightSyncScheme, - RayModuleTransformScheme, - RayWeightSyncScheme, - RPCWeightSyncScheme, - SharedMemTransport, - SharedMemWeightSyncScheme, - WeightStrategy, -) - -_has_ray = importlib.util.find_spec("ray") is not None - - -def worker_update_policy(pipe, timeout=5.0): - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(0.0) - policy.bias.fill_(0.0) - - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) - receiver = scheme.get_receiver() - - if receiver._transport.pipe.poll(timeout): - data, msg = receiver._transport.pipe.recv() - if msg == "update_weights": - model_id, weights = data - receiver.apply_weights(weights) - - return policy.weight.sum().item(), policy.bias.sum().item() - - -def worker_update_policy_tensordict(pipe, timeout=5.0): - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(0.0) - policy.bias.fill_(0.0) - - scheme = MultiProcessWeightSyncScheme(strategy="tensordict") - scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) - receiver = scheme.get_receiver() - - if receiver._transport.pipe.poll(timeout): - data, msg = receiver._transport.pipe.recv() - if msg == "update_weights": - model_id, weights = data - receiver.apply_weights(weights) - - return policy.weight.sum().item(), policy.bias.sum().item() - - -def worker_shared_mem(pipe, timeout=10.0): - policy = nn.Linear(4, 2) - - if pipe.poll(timeout): - data, msg = pipe.recv() - if msg == "register_shared_weights": - model_id, shared_weights = data - shared_weights.to_module(policy) - pipe.send((None, "registered")) - - time.sleep(0.5) - - return policy.weight.sum().item(), policy.bias.sum().item() - - -class TestTransportBackends: - def test_mp_transport_basic(self): - parent_pipe, child_pipe = mp.Pipe() - transport = MPTransport(parent_pipe) - - assert transport.check_connection() - - proc = mp.Process(target=worker_update_policy, args=(child_pipe,)) - proc.start() - - test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights("policy", test_weights) - - proc.join(timeout=10.0) - assert not proc.is_alive() - - def test_mp_transport_async(self): - parent_pipe, child_pipe = mp.Pipe() - transport = MPTransport(parent_pipe) - - proc = mp.Process(target=worker_update_policy, args=(child_pipe,)) - proc.start() - - test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights_async("policy", test_weights) - transport.wait_ack() - - proc.join(timeout=10.0) - assert not proc.is_alive() - - def test_shared_mem_transport(self): - shared_buffer = TensorDict( - {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] - ).share_memory_() - - transport = SharedMemTransport({"policy": shared_buffer}) - - new_weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] - ) - - transport.send_weights("policy", new_weights) - - assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) - assert torch.allclose(shared_buffer["bias"], torch.ones(2)) - - -class TestWeightStrategies: - def test_state_dict_strategy(self): - strategy = WeightStrategy(extract_as="state_dict") - - policy = nn.Linear(3, 4) - weights = strategy.extract_weights(policy) - assert isinstance(weights, dict) - assert "weight" in weights - assert "bias" in weights - - target_policy = nn.Linear(3, 4) - with torch.no_grad(): - target_policy.weight.fill_(0.0) - target_policy.bias.fill_(0.0) - - strategy.apply_weights(target_policy, weights) - - assert torch.allclose(policy.weight, target_policy.weight) - assert torch.allclose(policy.bias, target_policy.bias) - - def test_tensordict_strategy(self): - strategy = WeightStrategy(extract_as="tensordict") - - policy = nn.Linear(3, 4) - weights = strategy.extract_weights(policy) - assert isinstance(weights, TensorDict) - - target_policy = nn.Linear(3, 4) - with torch.no_grad(): - target_policy.weight.fill_(0.0) - target_policy.bias.fill_(0.0) - - strategy.apply_weights(target_policy, weights) - - assert torch.allclose(policy.weight, target_policy.weight) - assert torch.allclose(policy.bias, target_policy.bias) - - def test_cross_format_conversion(self): - policy = nn.Linear(3, 4) - - state_dict_strategy = WeightStrategy(extract_as="state_dict") - tensordict_strategy = WeightStrategy(extract_as="tensordict") - - state_dict_weights = state_dict_strategy.extract_weights(policy) - tensordict_weights = tensordict_strategy.extract_weights(policy) - - target_policy_1 = nn.Linear(3, 4) - target_policy_2 = nn.Linear(3, 4) - - with torch.no_grad(): - target_policy_1.weight.fill_(0.0) - target_policy_1.bias.fill_(0.0) - target_policy_2.weight.fill_(0.0) - target_policy_2.bias.fill_(0.0) - - state_dict_strategy.apply_weights(target_policy_1, tensordict_weights) - tensordict_strategy.apply_weights(target_policy_2, state_dict_weights) - - assert torch.allclose(policy.weight, target_policy_1.weight) - assert torch.allclose(policy.weight, target_policy_2.weight) - - -class TestWeightSyncSchemes: - """Tests for weight sync schemes using the new simplified API. - - Lower-level transport and legacy API tests are in TestTransportBackends. - """ - - def test_multiprocess_scheme_state_dict(self): - parent_pipe, child_pipe = mp.Pipe() - - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - sender = scheme.get_sender() - - proc = mp.Process(target=worker_update_policy, args=(child_pipe,)) - try: - proc.start() - - weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - sender.send(weights) - finally: - proc.join(timeout=10.0) - assert not proc.is_alive() - - def test_multiprocess_scheme_tensordict(self): - parent_pipe, child_pipe = mp.Pipe() - - scheme = MultiProcessWeightSyncScheme(strategy="tensordict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - sender = scheme.get_sender() - - proc = mp.Process(target=worker_update_policy_tensordict, args=(child_pipe,)) - try: - proc.start() - - weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] - ) - sender.send(weights) - finally: - proc.join(timeout=10.0) - assert not proc.is_alive() - - def test_shared_mem_scheme(self): - shared_buffer = TensorDict( - {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] - ).share_memory_() - - scheme = SharedMemWeightSyncScheme( - policy_weights={"policy": shared_buffer}, - strategy="tensordict", - auto_register=False, - ) - - transport = scheme.create_transport(None) - - new_weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] - ) - - transport.send_weights("policy", new_weights) - - assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) - assert torch.allclose(shared_buffer["bias"], torch.ones(2)) - - def test_shared_mem_scheme_auto_register(self): - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - transport = scheme.create_transport(None) - - weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] - ) - - transport.send_weights("policy", weights) - - assert "policy" in scheme.policy_weights - assert torch.allclose( - scheme.policy_weights["policy"]["weight"], torch.ones(2, 4) - ) - - def test_no_weight_sync_scheme(self): - scheme = NoWeightSyncScheme() - transport = scheme.create_transport(None) - - weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights("policy", weights) - - @classmethod - def _worker_with_receive(cls, pipe, scheme): - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(0.0) - policy.bias.fill_(0.0) - - scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) - receiver = scheme.get_receiver() - - # Non-blocking receive should return False when no data - result = receiver.receive(timeout=0.001) - assert result is False - - # Now actually receive the weights - result = receiver.receive(timeout=5.0) - assert result is True - - # Check weights were applied - return policy.weight.sum().item(), policy.bias.sum().item() - - def test_receiver_receive_method(self): - """Test the new non-blocking receive() method.""" - - parent_pipe, child_pipe = mp.Pipe() - - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - sender = scheme.get_sender() - - proc = mp.Process(target=self._worker_with_receive, args=(child_pipe, scheme)) - try: - proc.start() - - # Give worker time to call receive with no data - - time.sleep(0.1) - - weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - sender.send(weights) - - finally: - proc.join(timeout=10.0) - assert not proc.is_alive() - - -class TestCollectorIntegration: - @pytest.fixture - def simple_env(self): - return ContinuousActionVecMockEnv() - - @pytest.fixture - def simple_policy(self, simple_env): - return TensorDictModule( - nn.Linear( - simple_env.observation_spec["observation"].shape[-1], - simple_env.action_spec.shape[-1], - ), - in_keys=["observation"], - out_keys=["action"], - ) - - def test_syncdatacollector_multiprocess_scheme(self, simple_policy): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - - collector = SyncDataCollector( - create_env_fn=ContinuousActionVecMockEnv, - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, - ) - - new_weights = simple_policy.state_dict() - with torch.no_grad(): - for key in new_weights: - new_weights[key].fill_(1.0) - - collector.update_policy_weights_(new_weights) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - - def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - - collector = MultiSyncDataCollector( - create_env_fn=[ - ContinuousActionVecMockEnv, - ContinuousActionVecMockEnv, - ], - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, - ) - - new_weights = simple_policy.state_dict() - with torch.no_grad(): - for key in new_weights: - new_weights[key].fill_(1.0) - - collector.update_policy_weights_(new_weights) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - - def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy): - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - - collector = MultiSyncDataCollector( - create_env_fn=[ - ContinuousActionVecMockEnv, - ContinuousActionVecMockEnv, - ], - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, - ) - - new_weights = TensorDict.from_module(simple_policy) - with torch.no_grad(): - new_weights["module"]["weight"].fill_(1.0) - new_weights["module"]["bias"].fill_(1.0) - - collector.update_policy_weights_(new_weights) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - - def test_collector_no_weight_sync(self, simple_policy): - scheme = NoWeightSyncScheme() - - collector = SyncDataCollector( - create_env_fn=ContinuousActionVecMockEnv, - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, - ) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - - -class TestMultiModelUpdates: - def test_multi_model_state_dict_updates(self): - env = ContinuousActionVecMockEnv() - - policy = TensorDictModule( - nn.Linear( - env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] - ), - in_keys=["observation"], - out_keys=["action"], - ) - - value = TensorDictModule( - nn.Linear(env.observation_spec["observation"].shape[-1], 1), - in_keys=["observation"], - out_keys=["value"], - ) - - weight_sync_schemes = { - "policy": MultiProcessWeightSyncScheme(strategy="state_dict"), - "value": MultiProcessWeightSyncScheme(strategy="state_dict"), - } - - collector = SyncDataCollector( - create_env_fn=ContinuousActionVecMockEnv, - policy=policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes=weight_sync_schemes, - ) - - policy_weights = policy.state_dict() - value_weights = value.state_dict() - - with torch.no_grad(): - for key in policy_weights: - policy_weights[key].fill_(1.0) - for key in value_weights: - value_weights[key].fill_(2.0) - - collector.update_policy_weights_( - weights_dict={ - "policy": policy_weights, - "value": value_weights, - } - ) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - env.close() - - def test_multi_model_tensordict_updates(self): - env = ContinuousActionVecMockEnv() - - policy = TensorDictModule( - nn.Linear( - env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] - ), - in_keys=["observation"], - out_keys=["action"], - ) - - value = TensorDictModule( - nn.Linear(env.observation_spec["observation"].shape[-1], 1), - in_keys=["observation"], - out_keys=["value"], - ) - - weight_sync_schemes = { - "policy": MultiProcessWeightSyncScheme(strategy="tensordict"), - "value": MultiProcessWeightSyncScheme(strategy="tensordict"), - } - - collector = SyncDataCollector( - create_env_fn=ContinuousActionVecMockEnv, - policy=policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes=weight_sync_schemes, - ) - - policy_weights = TensorDict.from_module(policy) - value_weights = TensorDict.from_module(value) - - with torch.no_grad(): - policy_weights["module"]["weight"].fill_(1.0) - policy_weights["module"]["bias"].fill_(1.0) - value_weights["module"]["weight"].fill_(2.0) - value_weights["module"]["bias"].fill_(2.0) - - collector.update_policy_weights_( - weights_dict={ - "policy": policy_weights, - "value": value_weights, - } - ) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - env.close() - - -class TestHelpers: - def test_resolve_model_simple(self): - class Context: - def __init__(self): - self.policy = nn.Linear(2, 3) - - context = Context() - resolved = _resolve_model(context, "policy") - assert resolved is context.policy - - def test_resolve_model_nested(self): - class Inner: - def __init__(self): - self.value_net = nn.Linear(2, 3) - - class Context: - def __init__(self): - self.env = Inner() - - context = Context() - resolved = _resolve_model(context, "env.value_net") - assert resolved is context.env.value_net - - def test_resolve_model_with_index(self): - class Context: - def __init__(self): - self.transform = [nn.Linear(2, 3), nn.Linear(3, 4)] - - context = Context() - resolved = _resolve_model(context, "transform[0]") - assert resolved is context.transform[0] - - resolved = _resolve_model(context, "transform[1]") - assert resolved is context.transform[1] - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -class TestDeviceHandling: - def test_weight_update_cpu_to_cpu(self): - policy = nn.Linear(3, 4) - strategy = WeightStrategy(extract_as="state_dict") - - weights = strategy.extract_weights(policy) - target = nn.Linear(3, 4) - strategy.apply_weights(target, weights) - - assert torch.allclose(policy.weight, target.weight) - - def test_weight_update_cuda_to_cuda(self): - policy = nn.Linear(3, 4).cuda() - strategy = WeightStrategy(extract_as="tensordict") - - weights = strategy.extract_weights(policy) - target = nn.Linear(3, 4).cuda() - strategy.apply_weights(target, weights) - - assert torch.allclose(policy.weight, target.weight) - - -@pytest.mark.parametrize("strategy", ["state_dict", "tensordict"]) -def test_weight_strategy_parametrized(strategy): - weight_strategy = WeightStrategy(extract_as=strategy) - - policy = nn.Linear(3, 4) - weights = weight_strategy.extract_weights(policy) - - target = nn.Linear(3, 4) - with torch.no_grad(): - target.weight.fill_(0.0) - target.bias.fill_(0.0) - - weight_strategy.apply_weights(target, weights) - - assert torch.allclose(policy.weight, target.weight) - assert torch.allclose(policy.bias, target.bias) - - -class TestSerializeScheme: - """Test that WeightSyncScheme instances can be serialized after initialization. - - This is critical for multiprocessing and Ray, where schemes may be pickled - and sent across process boundaries. The _sender and _receiver attributes - contain non-serializable objects (pipes, weak references, etc.) and must - be excluded from serialization. - """ - - def test_multiprocess_scheme_serialize_before_init(self): - """Test that uninitialized scheme can be pickled.""" - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.strategy == "state_dict" - assert restored._sender is None - assert restored._receiver is None - assert not restored._initialized_on_sender - assert not restored._initialized_on_worker - - def test_multiprocess_scheme_serialize_after_sender_init(self): - """Test that initialized sender can be pickled (excluding runtime state).""" - parent_pipe, child_pipe = mp.Pipe() - - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - - # Scheme now has _sender with non-serializable pipes - assert scheme._sender is not None - assert scheme._initialized_on_sender - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved but runtime state is cleared - assert restored.strategy == "state_dict" - assert restored._sender is None # Runtime state excluded - assert restored._receiver is None - assert not restored._initialized_on_sender # Reset - assert not restored._initialized_on_worker - - # Clean up - parent_pipe.close() - child_pipe.close() - - def test_shared_mem_scheme_serialize_before_init(self): - """Test that uninitialized SharedMemWeightSyncScheme can be pickled.""" - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.strategy == "tensordict" - assert restored._sender is None - assert restored._receiver is None - - def test_shared_mem_scheme_serialize_after_init(self): - """Test that initialized SharedMemWeightSyncScheme can be pickled.""" - parent_pipe, child_pipe = mp.Pipe() - - # Create shared buffer - shared_buffer = TensorDict( - {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] - ).share_memory_() - - scheme = SharedMemWeightSyncScheme( - policy_weights={"policy": shared_buffer}, - strategy="tensordict", - auto_register=False, - ) - - def init_on_sender(scheme, child_pipe): - (model_id, data), msg = child_pipe.recv() - if msg == "register_shared_weights": - child_pipe.send((None, "registered")) - else: - raise ValueError(f"Expected 'register_shared_weights' but got {msg}") - - # Initialize the scheme with the pipes, in 2 separate threads because init requires acknowledgement from the worker - import threading - - future_sender = threading.Thread( - target=scheme.init_on_sender, - kwargs={"model_id": "policy", "pipes": [parent_pipe]}, - ) - future_receiver = threading.Thread( - target=init_on_sender, - kwargs={"scheme": scheme, "child_pipe": child_pipe}, - ) - future_receiver.start() - future_sender.start() - future_receiver.join() - future_sender.join() - - # Scheme now has _sender with non-serializable state - assert scheme._sender is not None - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved but runtime state is cleared - assert restored.strategy == "tensordict" - assert restored._sender is None - assert not restored._initialized_on_sender - - # Note: policy_weights dict is preserved (but may need re-sharing) - assert "policy" in restored.policy_weights - - # Clean up - parent_pipe.close() - child_pipe.close() - - def test_no_weight_sync_scheme_serialize(self): - """Test that NoWeightSyncScheme can be pickled.""" - scheme = NoWeightSyncScheme() - scheme.init_on_sender(model_id="policy") - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that it's still a no-op scheme - assert restored._sender is None - assert restored._receiver is None - - @pytest.mark.skipif( - not torch.distributed.is_available(), reason="torch.distributed not available" - ) - def test_distributed_scheme_serialize_before_init(self): - """Test that uninitialized DistributedWeightSyncScheme can be pickled.""" - - scheme = DistributedWeightSyncScheme(backend="gloo", sync=True) - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.backend == "gloo" - assert restored.sync is True - assert restored._sender is None - assert restored._receiver is None - - @pytest.mark.skipif(not _has_ray, reason="Ray not available") - def test_ray_weight_sync_scheme_serialize_before_init(self): - """Test that uninitialized RayWeightSyncScheme can be pickled.""" - scheme = RayWeightSyncScheme(strategy="state_dict") - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.strategy == "state_dict" - assert restored._sender is None - assert restored._receiver is None - - @pytest.mark.skipif(not _has_ray, reason="Ray not available") - def test_ray_module_transform_scheme_serialize_before_init(self): - """Test that uninitialized RayModuleTransformScheme can be pickled.""" - - scheme = RayModuleTransformScheme(strategy="tensordict") - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.strategy == "tensordict" - assert restored._sender is None - assert restored._receiver is None - - @pytest.mark.skipif( - not torch.distributed.is_available(), reason="torch.distributed not available" - ) - def test_rpc_weight_sync_scheme_serialize_before_init(self): - """Test that uninitialized RPCWeightSyncScheme can be pickled.""" - - scheme = RPCWeightSyncScheme(strategy="state_dict") - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.strategy == "state_dict" - assert restored._sender is None - assert restored._receiver is None - - def test_scheme_reinitialization_after_unpickle(self): - """Test that a scheme can be re-initialized after unpickling. - - This is the expected workflow: pickle a scheme, unpickle it in a worker, - then call init_on_worker() to establish new runtime resources. - """ - # Initialize and pickle a scheme - parent_pipe, child_pipe = mp.Pipe() - - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - - pickled = pickle.dumps(scheme) - - # Clean up original - parent_pipe.close() - - # Unpickle and re-initialize - restored = pickle.loads(pickled) - - # Should be able to initialize again with new pipes - new_parent, new_child = mp.Pipe() - - # Re-initialize on sender - restored.init_on_sender(model_id="policy", pipes=[new_parent]) - sender = restored.get_sender() - - assert sender is not None - assert restored._initialized_on_sender - - # Clean up - new_parent.close() - new_child.close() - child_pipe.close() - - -if __name__ == "__main__": - args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst", "-v"] + unknown) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index f090831d25c..29ddba9ae35 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -52,7 +52,7 @@ def strtobool(val: Any) -> bool: LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO") logger = logging.getLogger("torchrl") -logger.setLevel(getattr(logging, LOGGING_LEVEL)) +logger.setLevel(LOGGING_LEVEL) logger.propagate = False # Clear existing handlers while logger.hasHandlers(): @@ -85,7 +85,9 @@ def format(self, record): console_handler = logging.StreamHandler(stream=stream_handler) console_handler.setFormatter(_CustomFormatter()) logger.addHandler(console_handler) -console_handler.setLevel(logging.INFO) + +console_handler.setLevel(LOGGING_LEVEL) +logger.debug(f"Logging level: {logger.getEffectiveLevel()}") VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG)))) _os_is_windows = sys.platform == "win32" @@ -1045,9 +1047,13 @@ def merge_ray_runtime_env(ray_init_config: dict[str, Any]) -> dict[str, Any]: """ default_runtime_env = get_ray_default_runtime_env() - runtime_env = ray_init_config.setdefault("runtime_env", {}) + runtime_env = ray_init_config.get("runtime_env") - if not isinstance(runtime_env, dict): + # Handle None or missing runtime_env + if runtime_env is None: + runtime_env = {} + ray_init_config["runtime_env"] = runtime_env + elif not isinstance(runtime_env, dict): runtime_env = dict(runtime_env) ray_init_config["runtime_env"] = runtime_env diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index 7f1c812943d..208bd2cab9c 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -3,15 +3,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from torchrl.envs.utils import RandomPolicy -from .collectors import ( - aSyncDataCollector, - DataCollectorBase, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.modules.tensordict_module.exploration import RandomPolicy + +from ._base import DataCollectorBase + +from ._multi_async import MultiaSyncDataCollector +from ._multi_sync import MultiSyncDataCollector +from ._single import SyncDataCollector + +from ._single_async import aSyncDataCollector from .weight_update import ( MultiProcessedWeightUpdater, RayWeightUpdater, @@ -21,9 +22,9 @@ ) __all__ = [ - "RandomPolicy", "WeightUpdaterBase", "VanillaWeightUpdater", + "RandomPolicy", "RayWeightUpdater", "RemoteModuleWeightUpdater", "MultiProcessedWeightUpdater", diff --git a/torchrl/collectors/_base.py b/torchrl/collectors/_base.py new file mode 100644 index 00000000000..9e2c4044744 --- /dev/null +++ b/torchrl/collectors/_base.py @@ -0,0 +1,810 @@ +from __future__ import annotations + +import abc +import contextlib +import functools +import typing +import warnings +from collections import OrderedDict +from collections.abc import Callable, Iterator +from copy import deepcopy +from typing import Any, overload + +import torch +from tensordict import TensorDict, TensorDictBase +from tensordict.base import NO_DEFAULT +from tensordict.nn import TensorDictModule, TensorDictModuleBase +from torch import nn as nn +from torch.utils.data import IterableDataset +from torchrl._utils import logger as torchrl_logger +from torchrl.collectors.utils import _map_weight + +from torchrl.collectors.weight_update import WeightUpdaterBase +from torchrl.weight_update.utils import _resolve_attr +from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme + + +class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): + """Base class for data collectors.""" + + _task = None + _iterator = None + total_frames: int + requested_frames_per_batch: int + frames_per_batch: int + trust_policy: bool + compiled_policy: bool + cudagraphed_policy: bool + _weight_updater: WeightUpdaterBase | None = None + _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None + verbose: bool = False + + @property + def weight_updater(self) -> WeightUpdaterBase: + return self._weight_updater + + @weight_updater.setter + def weight_updater(self, value: WeightUpdaterBase | None): + if value is not None: + if not isinstance(value, WeightUpdaterBase) and callable( + value + ): # Fall back to default constructor + value = value() + value.register_collector(self) + if value.collector is not self: + raise RuntimeError("Failed to register collector.") + self._weight_updater = value + + @property + def worker_idx(self) -> int | None: + """Get the worker index for this collector. + + Returns: + The worker index (0-indexed). + + Raises: + RuntimeError: If worker_idx has not been set. + """ + if not hasattr(self, "_worker_idx"): + raise RuntimeError( + "worker_idx has not been set. This collector may not have been " + "initialized as a worker in a distributed setup." + ) + return self._worker_idx + + @worker_idx.setter + def worker_idx(self, value: int | None) -> None: + """Set the worker index for this collector. + + Args: + value: The worker index (0-indexed) or None. + """ + self._worker_idx = value + + def cascade_execute(self, attr_path: str, *args, **kwargs) -> Any: + """Execute a method on a nested attribute of this collector. + + This method allows remote callers to invoke methods on nested attributes + of the collector without needing to know the full structure. It's particularly + useful for calling methods on weight sync schemes from the sender side. + + Args: + attr_path: Full path to the callable, e.g., + "_receiver_schemes['model_id']._set_dist_connection_info" + *args: Positional arguments to pass to the method. + **kwargs: Keyword arguments to pass to the method. + + Returns: + The return value of the method call. + + Examples: + >>> collector.cascade_execute( + ... "_receiver_schemes['policy']._set_dist_connection_info", + ... connection_info_ref, + ... worker_idx=0 + ... ) + """ + attr = _resolve_attr(self, attr_path) + if callable(attr): + return attr(*args, **kwargs) + else: + if args or kwargs: + raise ValueError( + f"Arguments and keyword arguments are not supported for non-callable attributes. Got {args} and {kwargs} for {attr_path}" + ) + return attr + + def _get_policy_and_device( + self, + policy: Callable[[Any], Any] | None = None, + policy_device: Any = NO_DEFAULT, + env_maker: Any | None = None, + env_maker_kwargs: dict[str, Any] | None = None, + ) -> tuple[TensorDictModule, None | Callable[[], dict]]: + """Util method to get a policy and its device given the collector __init__ inputs. + + We want to copy the policy and then move the data there, not call policy.to(device). + + Args: + policy (TensorDictModule, optional): a policy to be used + policy_device (torch.device, optional): the device where the policy should be placed. + Defaults to self.policy_device + env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair. + env_maker_kwargs (a dict, optional): the env_maker function kwargs. + + """ + if policy_device is NO_DEFAULT: + policy_device = self.policy_device + + if not policy_device: + return policy, None + + if isinstance(policy, nn.Module): + param_and_buf = TensorDict.from_module(policy, as_module=True) + else: + # Because we want to reach the warning + param_and_buf = TensorDict() + + i = -1 + for p in param_and_buf.values(True, True): + i += 1 + if p.device != policy_device: + # Then we need casting + break + else: + if i == -1 and not self.trust_policy: + # We trust that the policy policy device is adequate + warnings.warn( + "A policy device was provided but no parameter/buffer could be found in " + "the policy. Casting to policy_device is therefore impossible. " + "The collector will trust that the devices match. To suppress this " + "warning, set `trust_policy=True` when building the collector." + ) + return policy, None + + # Create a stateless policy, then populate this copy with params on device + def get_original_weights(policy=policy): + td = TensorDict.from_module(policy) + return td.data + + # We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function + with param_and_buf.data.to("meta").to_module(policy): + policy_new_device = deepcopy(policy) + + param_and_buf_new_device = param_and_buf.apply( + functools.partial(_map_weight, policy_device=policy_device), + filter_empty=False, + ) + param_and_buf_new_device.to_module(policy_new_device) + # Sanity check + if set(TensorDict.from_module(policy_new_device).keys(True, True)) != set( + get_original_weights().keys(True, True) + ): + raise RuntimeError("Failed to map weights. The weight sets mismatch.") + return policy_new_device, get_original_weights + + def start(self): + """Starts the collector for asynchronous data collection. + + This method initiates the background collection of data, allowing for decoupling of data collection and training. + + The collected data is typically stored in a replay buffer passed during the collector's initialization. + + .. note:: After calling this method, it's essential to shut down the collector using :meth:`~.async_shutdown` + when you're done with it to free up resources. + + .. warning:: Asynchronous data collection can significantly impact training performance due to its decoupled nature. + Ensure you understand the implications for your specific algorithm before using this mode. + + Raises: + NotImplementedError: If not implemented by a subclass. + """ + raise NotImplementedError( + f"Collector start() is not implemented for {type(self).__name__}." + ) + + @contextlib.contextmanager + def pause(self): + """Context manager that pauses the collector if it is running free.""" + raise NotImplementedError( + f"Collector pause() is not implemented for {type(self).__name__}." + ) + + def async_shutdown( + self, timeout: float | None = None, close_env: bool = True + ) -> None: + """Shuts down the collector when started asynchronously with the `start` method. + + Args: + timeout (float, optional): The maximum time to wait for the collector to shutdown. + close_env (bool, optional): If True, the collector will close the contained environment. + Defaults to `True`. + + .. seealso:: :meth:`~.start` + + """ + return self.shutdown(timeout=timeout, close_env=close_env) + + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: + """Extract weights from a model if needed. + + For the new weight sync scheme system, weight preparation is handled + by the scheme's prepare_weights() method. This method now only handles + legacy weight updater cases. + + Args: + weights: Either already-extracted weights or a model to extract from. + model_id: The model identifier for resolving string paths. + + Returns: + Extracted weights in the appropriate format. + """ + # New weight sync schemes handle preparation themselves + if self._weight_sync_schemes: + # Just pass through - WeightSender will call scheme.prepare_weights() + return weights + + # Legacy weight updater path + return self._legacy_extract_weights(weights, model_id) + + def _legacy_extract_weights(self, weights: Any, model_id: str) -> Any: + """Legacy weight extraction for old weight updater system. + + Args: + weights: Either already-extracted weights or a model to extract from. + model_id: The model identifier. + + Returns: + Extracted weights. + """ + if weights is None: + if model_id == "policy" and hasattr(self, "policy_weights"): + return self.policy_weights + elif model_id == "policy" and hasattr(self, "_policy_weights_dict"): + policy_device = ( + self.policy_device + if not isinstance(self.policy_device, (list, tuple)) + else self.policy_device[0] + ) + return self._policy_weights_dict.get(policy_device) + return None + + return weights + + @property + def _legacy_weight_updater(self) -> bool: + return self._weight_updater is not None + + # Overloads for update_policy_weights_ to support multiple calling conventions + @overload + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict, + /, + ) -> None: + ... + + @overload + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict, + /, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + ) -> None: + ... + + @overload + def update_policy_weights_( + self, + *, + weights: TensorDictBase | dict, + model_id: str | None = None, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + ) -> None: + ... + + @overload + def update_policy_weights_( + self, + *, + policy: TensorDictModuleBase | nn.Module, + model_id: str | None = None, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + ) -> None: + ... + + @overload + def update_policy_weights_( + self, + *, + weights_dict: dict[ + str, TensorDictBase | TensorDictModuleBase | nn.Module | dict + ], + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + ) -> None: + ... + + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase + | TensorDictModuleBase + | nn.Module + | dict + | None = None, + *, + weights: TensorDictBase | dict | None = None, + policy: TensorDictModuleBase | nn.Module | None = None, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + """Update policy weights for the data collector. + + This method synchronizes the policy weights used by the collector with the latest + trained weights. It supports both local and remote weight updates, depending on + the collector configuration. + + The method accepts weights in multiple forms for convenience: + + Examples: + >>> # Pass policy module as positional argument + >>> collector.update_policy_weights_(policy_module) + >>> + >>> # Pass TensorDict weights as positional argument + >>> collector.update_policy_weights_(weights_tensordict) + >>> + >>> # Use keyword arguments for clarity + >>> collector.update_policy_weights_(weights=weights_td, model_id="actor") + >>> collector.update_policy_weights_(policy=actor_module, model_id="actor") + >>> + >>> # Update multiple models atomically + >>> collector.update_policy_weights_(weights_dict={ + ... "actor": actor_weights, + ... "critic": critic_weights, + ... }) + + Args: + policy_or_weights: The weights to update with. Can be: + + - ``nn.Module``: A policy module whose weights will be extracted + - ``TensorDictModuleBase``: A TensorDict module whose weights will be extracted + - ``TensorDictBase``: A TensorDict containing weights + - ``dict``: A regular dict containing weights + - ``None``: Will try to get weights from server using ``_get_server_weights()`` + + Keyword Args: + weights: Alternative to positional argument. A TensorDict or dict containing + weights to update. Cannot be used together with ``policy_or_weights`` or ``policy``. + policy: Alternative to positional argument. An ``nn.Module`` or ``TensorDictModuleBase`` + whose weights will be extracted. Cannot be used together with ``policy_or_weights`` + or ``weights``. + worker_ids: Identifiers for the workers to update. Relevant when the collector + has multiple workers. Can be int, list of ints, device, or list of devices. + model_id: The model identifier to update (default: ``"policy"``). + Cannot be used together with ``weights_dict``. + weights_dict: Dictionary mapping model_id to weights for updating + multiple models atomically. Keys should match model_ids registered in + ``weight_sync_schemes``. Cannot be used together with ``model_id``, + ``policy_or_weights``, ``weights``, or ``policy``. + + Raises: + TypeError: If ``worker_ids`` is provided but no ``weight_updater`` is configured. + ValueError: If conflicting parameters are provided. + + .. note:: Users should extend the ``WeightUpdaterBase`` classes to customize + the weight update logic for specific use cases. + + .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and + :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`. + + """ + # Handle the different keyword argument forms + if weights is not None: + if policy_or_weights is not None: + raise ValueError( + "Cannot specify both positional 'policy_or_weights' and keyword 'weights'" + ) + if policy is not None: + raise ValueError("Cannot specify both 'weights' and 'policy'") + policy_or_weights = weights + + if policy is not None: + if policy_or_weights is not None: + raise ValueError( + "Cannot specify both positional 'policy_or_weights' and keyword 'policy'" + ) + policy_or_weights = policy + if self._legacy_weight_updater: + return self._legacy_weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + else: + return self._weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + + def _legacy_weight_update_impl( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + if weights_dict is not None: + raise ValueError("weights_dict is not supported with legacy weight updater") + if model_id is not None: + raise ValueError("model_id is not supported with legacy weight updater") + # Fall back to old weight updater system + self.weight_updater( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def _weight_update_impl( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + if weights_dict is not None and model_id is not None: + raise ValueError("Cannot specify both 'weights_dict' and 'model_id'") + + if weights_dict is not None and policy_or_weights is not None: + raise ValueError( + "Cannot specify both 'weights_dict' and 'policy_or_weights'" + ) + + if self._weight_sync_schemes: + if model_id is None: + model_id = "policy" + if policy_or_weights is not None and weights_dict is None: + # Use model_id as the key, not hardcoded "policy" + weights_dict = {model_id: policy_or_weights} + elif weights_dict is None: + weights_dict = {model_id: policy_or_weights} + torchrl_logger.debug( + f"Calling weight update with {model_id=} and {weights_dict.keys()=}" + ) + for target_model_id, weights in weights_dict.items(): + if target_model_id not in self._weight_sync_schemes: + raise KeyError( + f"Model '{target_model_id}' not found in registered weight sync schemes. " + f"Available models: {list(self._weight_sync_schemes.keys())}" + ) + processed_weights = self._extract_weights_if_needed( + weights, target_model_id + ) + # Use new send() API with worker_ids support + torchrl_logger.debug("weight update -- getting scheme") + scheme = self._weight_sync_schemes.get(target_model_id) + if not isinstance(scheme, WeightSyncScheme): + raise TypeError(f"Expected WeightSyncScheme, got {target_model_id}") + torchrl_logger.debug( + f"calling send() on scheme {type(scheme).__name__}" + ) + self._send_weights_scheme( + scheme=scheme, + processed_weights=processed_weights, + worker_ids=worker_ids, + model_id=target_model_id, + ) + elif self._weight_updater is not None: + # unreachable + raise RuntimeError + else: + # No weight updater configured, try fallback + torchrl_logger.debug("No weight update configured, trying fallback.") + self._maybe_fallback_update(policy_or_weights, model_id=model_id) + + def _maybe_fallback_update( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + model_id: str | None = None, + ) -> None: + """Fallback weight update when no scheme is configured. + + Override in subclasses to provide custom fallback behavior. + By default, this is a no-op. + """ + + def _send_weights_scheme(self, *, model_id, scheme, processed_weights, worker_ids): + # method to override if the scheme requires an RPC call to receive the weights + scheme.send(weights=processed_weights, worker_ids=worker_ids) + + def _receive_weights_scheme(self): + """Receive weights for all registered receiver schemes. + + scheme.receive() handles both applying weights locally and cascading + to sub-collectors via context.update_policy_weights_(). + """ + if not hasattr(self, "_receiver_schemes"): + raise RuntimeError("No receiver schemes registered.") + + for model_id, scheme in self._receiver_schemes.items(): + torchrl_logger.debug( + f"Receiving weights for scheme {type(scheme).__name__} for model '{model_id}' on worker {self._worker_idx}" + ) + received_weights = scheme.receive() + torchrl_logger.debug(f"Received weights: {type(received_weights)=}") + + # Overloads for receive_weights to support multiple calling conventions + @overload + def receive_weights(self) -> None: + ... + + @overload + def receive_weights( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict, + /, + ) -> None: + ... + + @overload + def receive_weights( + self, + *, + weights: TensorDictBase | dict, + ) -> None: + ... + + @overload + def receive_weights( + self, + *, + policy: TensorDictModuleBase | nn.Module, + ) -> None: + ... + + def receive_weights( + self, + policy_or_weights: TensorDictBase + | TensorDictModuleBase + | nn.Module + | dict + | None = None, + *, + weights: TensorDictBase | dict | None = None, + policy: TensorDictModuleBase | nn.Module | None = None, + ) -> None: + """Receive and apply weights to the collector's policy. + + This method applies weights to the local policy. When receiver schemes are + registered, it delegates to those schemes. Otherwise, it directly applies + the provided weights. + + The method accepts weights in multiple forms for convenience: + + Examples: + >>> # Receive from registered schemes (distributed collectors) + >>> collector.receive_weights() + >>> + >>> # Apply weights from a policy module (positional) + >>> collector.receive_weights(trained_policy) + >>> + >>> # Apply weights from a TensorDict (positional) + >>> collector.receive_weights(weights_tensordict) + >>> + >>> # Use keyword arguments for clarity + >>> collector.receive_weights(weights=weights_td) + >>> collector.receive_weights(policy=trained_policy) + + Args: + policy_or_weights: The weights to apply. Can be: + + - ``nn.Module``: A policy module whose weights will be extracted and applied + - ``TensorDictModuleBase``: A TensorDict module whose weights will be extracted + - ``TensorDictBase``: A TensorDict containing weights + - ``dict``: A regular dict containing weights + - ``None``: Receive from registered schemes or mirror from original policy + + Keyword Args: + weights: Alternative to positional argument. A TensorDict or dict containing + weights to apply. Cannot be used together with ``policy_or_weights`` or ``policy``. + policy: Alternative to positional argument. An ``nn.Module`` or ``TensorDictModuleBase`` + whose weights will be extracted. Cannot be used together with ``policy_or_weights`` + or ``weights``. + + Raises: + ValueError: If conflicting parameters are provided or if arguments are passed + when receiver schemes are registered. + + """ + # Handle the different keyword argument forms + if weights is not None: + if policy_or_weights is not None: + raise ValueError( + "Cannot specify both positional 'policy_or_weights' and keyword 'weights'" + ) + if policy is not None: + raise ValueError("Cannot specify both 'weights' and 'policy'") + policy_or_weights = weights + + if policy is not None: + if policy_or_weights is not None: + raise ValueError( + "Cannot specify both positional 'policy_or_weights' and keyword 'policy'" + ) + policy_or_weights = policy + + if getattr(self, "_receiver_schemes", None) is not None: + if policy_or_weights is not None: + raise ValueError( + "Cannot specify 'policy_or_weights' when using 'receiver_schemes'. Schemes should know how to get the weights." + ) + self._receive_weights_scheme() + return + + # No weight updater configured + # For single-process collectors, apply weights locally if explicitly provided + if policy_or_weights is not None: + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + # Use WeightStrategy to apply weights properly + strategy = WeightStrategy(extract_as="tensordict") + + # Extract weights if needed + if isinstance(policy_or_weights, nn.Module): + weights = strategy.extract_weights(policy_or_weights) + else: + weights = policy_or_weights + + # Apply to local policy + if hasattr(self, "policy") and isinstance(self.policy, nn.Module): + strategy.apply_weights(self.policy, weights) + # Otherwise, no action needed - policy is local and changes are immediately visible + + def register_scheme_receiver( + self, + weight_recv_schemes: dict[str, WeightSyncScheme], + *, + synchronize_weights: bool = True, + ): # noqa: D417 + """Set up receiver schemes for this collector to receive weights from parent collectors. + + This method initializes receiver schemes and stores them in _receiver_schemes + for later use by _receive_weights_scheme() and receive_weights(). + + Receiver schemes enable cascading weight updates across collector hierarchies: + - Parent collector sends weights via its weight_sync_schemes (senders) + - Child collector receives weights via its weight_recv_schemes (receivers) + - If child is also a parent (intermediate node), it can propagate to its own children + + Args: + weight_recv_schemes (dict[str, WeightSyncScheme]): Dictionary of {model_id: WeightSyncScheme} to set up as receivers. + These schemes will receive weights from parent collectors. + + Keyword Args: + synchronize_weights (bool, optional): If True, synchronize weights immediately after registering the schemes. + Defaults to `True`. + """ + # Initialize _receiver_schemes if not already present + if not hasattr(self, "_receiver_schemes"): + self._receiver_schemes = {} + + # Initialize each scheme on the receiver side + for model_id, scheme in weight_recv_schemes.items(): + if not scheme.initialized_on_receiver: + if scheme.initialized_on_sender: + raise RuntimeError( + "Weight sync scheme cannot be initialized on both sender and receiver." + ) + scheme.init_on_receiver( + model_id=model_id, + context=self, + worker_idx=self.worker_idx, + ) + + # Store the scheme for later use in receive_weights() + self._receiver_schemes[model_id] = scheme + + # Perform initial synchronization + if synchronize_weights: + for model_id, scheme in weight_recv_schemes.items(): + if not scheme.synchronized_on_receiver: + torchrl_logger.debug( + f"Synchronizing weights for scheme {type(scheme).__name__} for model '{model_id}'" + ) + scheme.connect(worker_idx=self.worker_idx) + + def __iter__(self) -> Iterator[TensorDictBase]: + try: + yield from self.iterator() + except Exception: + self.shutdown() + raise + + def next(self): + try: + if self._iterator is None: + self._iterator = iter(self) + out = next(self._iterator) + # if any, we don't want the device ref to be passed in distributed settings + if out is not None and (out.device != "cpu"): + out = out.copy().clear_device_() + return out + except StopIteration: + return None + + @abc.abstractmethod + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + raise NotImplementedError + + @abc.abstractmethod + def iterator(self) -> Iterator[TensorDictBase]: + raise NotImplementedError + + @abc.abstractmethod + def set_seed(self, seed: int, static_seed: bool = False) -> int: + raise NotImplementedError + + @abc.abstractmethod + def state_dict(self) -> OrderedDict: + raise NotImplementedError + + @abc.abstractmethod + def load_state_dict(self, state_dict: OrderedDict) -> None: + raise NotImplementedError + + def _read_compile_kwargs(self, compile_policy, cudagraph_policy): + self.compiled_policy = compile_policy not in (False, None) + self.cudagraphed_policy = cudagraph_policy not in (False, None) + self.compiled_policy_kwargs = ( + {} if not isinstance(compile_policy, typing.Mapping) else compile_policy + ) + self.cudagraphed_policy_kwargs = ( + {} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy + ) + + def __repr__(self) -> str: + string = f"{self.__class__.__name__}()" + return string + + def __class_getitem__(self, index): + raise NotImplementedError + + def __len__(self) -> int: + if self.total_frames > 0: + return -(self.total_frames // -self.requested_frames_per_batch) + raise RuntimeError("Non-terminating collectors do not have a length") + + def init_updater(self, *args, **kwargs): + """Initialize the weight updater with custom arguments. + + This method passes the arguments to the weight updater's init method. + If no weight updater is set, this is a no-op. + + Args: + *args: Positional arguments for weight updater initialization + **kwargs: Keyword arguments for weight updater initialization + """ + if self.weight_updater is not None: + self.weight_updater.init(*args, **kwargs) diff --git a/torchrl/collectors/_constants.py b/torchrl/collectors/_constants.py new file mode 100644 index 00000000000..1587d800166 --- /dev/null +++ b/torchrl/collectors/_constants.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Constants and helper classes for collectors.""" +from __future__ import annotations + +import os +import sys +from multiprocessing.managers import SyncManager + +import torch +from torch import multiprocessing as mp + +from torchrl.envs.utils import ExplorationType + +try: + from torch.compiler import cudagraph_mark_step_begin +except ImportError: + + def cudagraph_mark_step_begin(): + """Placeholder for missing cudagraph_mark_step_begin method.""" + raise NotImplementedError("cudagraph_mark_step_begin not implemented.") + + +__all__ = [ + "_TIMEOUT", + "INSTANTIATE_TIMEOUT", + "_MIN_TIMEOUT", + "_MAX_IDLE_COUNT", + "DEFAULT_EXPLORATION_TYPE", + "_is_osx", + "_Interruptor", + "_InterruptorManager", + "cudagraph_mark_step_begin", +] + +_TIMEOUT = 1.0 +INSTANTIATE_TIMEOUT = 20 +_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory +# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue. +_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max)) + +DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM + +_is_osx = sys.platform.startswith("darwin") + + +class _Interruptor: + """A class for managing the collection state of a process. + + This class provides methods to start and stop collection, and to check + whether collection has been stopped. The collection state is protected + by a lock to ensure thread-safety. + """ + + # interrupter vs interruptor: google trends seems to indicate that "or" is more + # widely used than "er" even if my IDE complains about that... + def __init__(self): + self._collect = True + self._lock = mp.Lock() + + def start_collection(self): + with self._lock: + self._collect = True + + def stop_collection(self): + with self._lock: + self._collect = False + + def collection_stopped(self): + with self._lock: + return self._collect is False + + +class _InterruptorManager(SyncManager): + """A custom SyncManager for managing the collection state of a process. + + This class extends the SyncManager class and allows to share an Interruptor object + between processes. + """ + + +_InterruptorManager.register("_Interruptor", _Interruptor) diff --git a/torchrl/collectors/_multi_async.py b/torchrl/collectors/_multi_async.py new file mode 100644 index 00000000000..a7b468e5dc7 --- /dev/null +++ b/torchrl/collectors/_multi_async.py @@ -0,0 +1,303 @@ +from __future__ import annotations + +import time +import warnings +from collections import defaultdict, OrderedDict +from collections.abc import Iterator, Sequence +from copy import deepcopy +from queue import Empty + +import torch + +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torchrl._utils import _check_for_faulty_process, accept_remote_rref_udf_invocation +from torchrl.collectors._constants import _MAX_IDLE_COUNT, _TIMEOUT +from torchrl.collectors._multi_base import _MultiDataCollector +from torchrl.collectors.utils import split_trajectories + + +@accept_remote_rref_udf_invocation +class MultiaSyncDataCollector(_MultiDataCollector): + """Runs a given number of DataCollectors on separate processes asynchronously. + + .. aafig:: + + + +----------------------------------------------------------------------+ + | "MultiConcurrentCollector" | | + |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | + | "Collector 1" | "Collector 2" | "Collector 3" | "Main" | + |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| + | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | + |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| + |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | + | | | | | | | | + | "actor" | | | "actor" | | + | | | | | | + | "step" | "step" | "actor" | | | + | | | | | | + | | | | "step" | "step" | | + | | | | | | | + | "actor | "step" | "step" | "actor" | | + | | | | | | + | "yield batch 1" | "actor" | |"collect, train"| + | | | | | + | "step" | "step" | | "yield batch 2" |"collect, train"| + | | | | | | + | | | "yield batch 3" | |"collect, train"| + | | | | | | + +----------------------------------------------------------------------+ + + Environment types can be identical or different. + + The collection keeps on occurring on all processes even between the time + the batch of rollouts is collected and the next call to the iterator. + This class can be safely used with offline RL sota-implementations. + + .. note:: Python requires multiprocessed code to be instantiated within a main guard: + + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> if __name__ == "__main__": + ... # Create your collector here + + See https://docs.python.org/3/library/multiprocessing.html for more info. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> if __name__ == "__main__": + ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + ... collector = MultiaSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + ... for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + ... collector.shutdown() + ... del collector + TensorDict( + fields={ + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + + """ + + __doc__ += _MultiDataCollector.__doc__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.out_tensordicts = defaultdict(lambda: None) + self.running = False + + if self.postprocs is not None and self.replay_buffer is None: + postproc = self.postprocs + self.postprocs = {} + for _device in self.storing_device: + if _device not in self.postprocs: + if hasattr(postproc, "to"): + postproc = deepcopy(postproc).to(_device) + self.postprocs[_device] = postproc + + # for RPC + def next(self): + return super().next() + + # for RPC + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + if hasattr(self, "out_tensordicts"): + del self.out_tensordicts + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) + return super().shutdown(timeout=timeout, raise_on_error=raise_on_error) + + # for RPC + def set_seed(self, seed: int, static_seed: bool = False) -> int: + return super().set_seed(seed, static_seed) + + # for RPC + def state_dict(self) -> OrderedDict: + return super().state_dict() + + # for RPC + def load_state_dict(self, state_dict: OrderedDict) -> None: + return super().load_state_dict(state_dict) + + # for RPC + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + super().update_policy_weights_( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int: + return self.requested_frames_per_batch + + def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]: + new_data, j = self.queue_out.get(timeout=timeout) + use_buffers = self._use_buffers + if self.replay_buffer is not None: + idx = new_data + elif j == 0 or not use_buffers: + try: + data, idx = new_data + self.out_tensordicts[idx] = data + if use_buffers is None and j > 0: + use_buffers = self._use_buffers = False + except TypeError: + if use_buffers is None: + use_buffers = self._use_buffers = True + idx = new_data + else: + raise + else: + idx = new_data + out = self.out_tensordicts[idx] + if not self.replay_buffer and (j == 0 or use_buffers): + # we clone the data to make sure that we'll be working with a fixed copy + out = out.clone() + return idx, j, out + + @property + def _queue_len(self) -> int: + return 1 + + def iterator(self) -> Iterator[TensorDictBase]: + if self.update_at_each_batch: + self.update_policy_weights_() + + for i in range(self.num_workers): + if self.init_random_frames is not None and self.init_random_frames > 0: + self.pipes[i].send((None, "continue_random")) + else: + self.pipes[i].send((None, "continue")) + self.running = True + + workers_frames = [0 for _ in range(self.num_workers)] + while self._frames < self.total_frames: + self._iter += 1 + counter = 0 + while True: + try: + idx, j, out = self._get_from_queue(timeout=_TIMEOUT) + break + except (TimeoutError, Empty): + counter += _TIMEOUT + _check_for_faulty_process(self.procs) + if counter > (_TIMEOUT * _MAX_IDLE_COUNT): + raise RuntimeError( + f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " + f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." + ) + if self.replay_buffer is None: + worker_frames = out.numel() + if self.split_trajs: + out = split_trajectories(out, prefix="collector") + else: + worker_frames = self.frames_per_batch_worker() + self._frames += worker_frames + workers_frames[idx] = workers_frames[idx] + worker_frames + if out is not None and self.postprocs: + out = self.postprocs[out.device](out) + + # the function blocks here until the next item is asked, hence we send the message to the + # worker to keep on working in the meantime before the yield statement + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + msg = "continue_random" + else: + msg = "continue" + self.pipes[idx].send((idx, msg)) + if out is not None and self._exclude_private_keys: + excluded_keys = [key for key in out.keys() if key.startswith("_")] + out = out.exclude(*excluded_keys) + yield out + + # We don't want to shutdown yet, the user may want to call state_dict before + # self._shutdown_main() + self.running = False + + def _shutdown_main(self, *args, **kwargs) -> None: + if hasattr(self, "out_tensordicts"): + del self.out_tensordicts + return super()._shutdown_main(*args, **kwargs) + + def reset(self, reset_idx: Sequence[bool] | None = None) -> None: + super().reset(reset_idx) + if self.queue_out.full(): + time.sleep(_TIMEOUT) # wait until queue is empty + if self.queue_out.full(): + raise Exception("self.queue_out is full") + if self.running: + for idx in range(self.num_workers): + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + self.pipes[idx].send((idx, "continue_random")) + else: + self.pipes[idx].send((idx, "continue")) + + # for RPC + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() + + # for RPC + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): + return super().receive_weights(policy_or_weights) diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py new file mode 100644 index 00000000000..d3ce8c38c60 --- /dev/null +++ b/torchrl/collectors/_multi_base.py @@ -0,0 +1,1559 @@ +from __future__ import annotations + +import _pickle + +import contextlib +import warnings +from collections import OrderedDict +from collections.abc import Callable, Mapping, Sequence +from typing import Any + +import numpy as np +import torch +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import CudaGraphModule, TensorDictModule +from tensordict.utils import _zip_strict +from torch import multiprocessing as mp, nn +from torchrl import logger as torchrl_logger +from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, RL_WARNINGS +from torchrl.collectors._base import DataCollectorBase +from torchrl.collectors._constants import ( + _InterruptorManager, + _is_osx, + DEFAULT_EXPLORATION_TYPE, + ExplorationType, + INSTANTIATE_TIMEOUT, +) +from torchrl.collectors._runner import _main_async_collector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.utils import _make_meta_policy, _TrajectoryPool +from torchrl.collectors.weight_update import WeightUpdaterBase +from torchrl.data import ReplayBuffer +from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING +from torchrl.envs import EnvBase, EnvCreator +from torchrl.envs.llm.transforms import PolicyVersion +from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, + WeightSyncScheme, +) +from torchrl.weight_update.utils import _resolve_model + + +class _MultiDataCollector(DataCollectorBase): + """Runs a given number of DataCollectors on separate processes. + + Args: + create_env_fn (List[Callabled]): list of Callables, each returning an + instance of :class:`~torchrl.envs.EnvBase`. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided (default), the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + + - In all other cases an attempt to wrap it will be undergone as such: + ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the ``policy_factory`` should be used instead. + + .. note:: When using ``weight_sync_schemes``, both ``policy`` and ``policy_factory`` can be provided together. + In this case, the ``policy`` is used ONLY for weight extraction (via ``TensorDict.from_module()``) to + set up weight synchronization, but it is NOT sent to workers and its weights are NOT depopulated. + The ``policy_factory`` is what actually gets passed to workers to create their local policy instances. + This is useful when the policy is hard to serialize but you have a copy on the main node for + weight synchronization purposes. + + Keyword Args: + policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable + (or list of callables) that returns a policy instance. + + When not using ``weight_sync_schemes``, this is mutually exclusive with the ``policy`` argument. + + When using ``weight_sync_schemes``, both ``policy`` and ``policy_factory`` can be provided: + the ``policy`` is used for weight extraction only, while ``policy_factory`` creates policies on workers. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + + .. warning:: `policy_factory` is currently not compatible with multiprocessed data + collectors. + + num_workers (int, optional): number of workers to use. If `create_env_fn` is a list, this will be ignored. + Defaults to `None` (workers determined by the `create_env_fn` length). + frames_per_batch (int, Sequence[int]): A keyword-only argument representing the + total number of elements in a batch. If a sequence is provided, represents the number of elements in a + batch per worker. Total number of elements in a batch is then the sum over the sequence. + total_frames (int, optional): A keyword-only argument representing the + total number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (never ending collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + create_env_kwargs (dict, optional): A dictionary with the + keyword arguments used to create an environment. If a list is + provided, each of its elements will be assigned to a sub-collector. + collector_class (Python class or constructor): a collector class to be remotely instantiated. Can be + :class:`~torchrl.collectors.SyncDataCollector`, + :class:`~torchrl.collectors.MultiSyncDataCollector`, + :class:`~torchrl.collectors.MultiaSyncDataCollector` + or a derived class of these. + Defaults to :class:`~torchrl.collectors.SyncDataCollector`. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative + values are allowed, in which case this argument is ignored. + Defaults to ``None`` (i.e. no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. + reset_when_done (bool, optional): if ``True`` (default), an environment + that return a ``True`` value in its ``"done"`` or ``"truncated"`` + entry will be reset at the corresponding indices. + update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` + will be called before (sync) or after (async) each data collection. + Defaults to ``False``. + preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers + that will be allowed to finished collecting their rollout before the rest are forced to end early. + num_threads (int, optional): number of threads for this process. + Defaults to the number of workers. + num_sub_threads (int, optional): number of threads of the subprocesses. + Should be equal to one plus the number of processes launched within + each subprocess (or one if a single process is launched). + Defaults to 1 for safety: if none is indicated, launching multiple + workers may charge the cpu load too much and harm performance. + cat_results (str, int or None): (:class:`~torchrl.collectors.MultiSyncDataCollector` exclusively). + If ``"stack"``, the data collected from the workers will be stacked along the + first dimension. This is the preferred behavior as it is the most compatible + with the rest of the library. + If ``0``, results will be concatenated along the first dimension + of the outputs, which can be the batched dimension if the environments are + batched or the time dimension if not. + A ``cat_results`` value of ``-1`` will always concatenate results along the + time dimension. This should be preferred over the default. Intermediate values + are also accepted. + Defaults to ``"stack"``. + + .. note:: From v0.5, this argument will default to ``"stack"`` for a better + interoperability with the rest of the library. + + set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding + ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of + a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. + Truncated keys can be set through ``env.add_truncated_keys``. + Defaults to ``False``. + use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. + This isn't compatible with environments with dynamic specs. Defaults to ``True`` + for envs without dynamic specs, ``False`` for others. + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts + but populate the buffer instead. Defaults to ``None``. + extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not + with single steps. Defaults to `True` for multiprocessed data collectors. + local_init_rb (bool, optional): if ``False``, the collector will use fake data to initialize + the replay buffer in the main process (legacy behavior). If ``True``, the storage-level + coordination will handle initialization with real data from worker processes. + Defaults to ``None``, which maintains backward compatibility but shows a deprecation warning. + This parameter is deprecated and will be removed in v0.12. + trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be + assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules + and ``False`` otherwise. + compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled + using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it + will be used to compile the policy. + cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped + in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. + If a dictionary of kwargs is passed, it will be used to wrap the policy. + no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. + For environments running directly on CUDA (`IsaacLab `_ + or `ManiSkills `_) cuda synchronization may cause unexpected + crashes. + Defaults to ``False``. + weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` + or its subclass, responsible for updating the policy weights on remote inference workers. + If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, + which handles weight synchronization across multiple processes. + Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + SENDING weights to worker sub-collectors. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to send weights to child processes. + If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default. + This is for propagating weights DOWN the hierarchy (parent -> children). + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from parent collectors. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This enables cascading in hierarchies like: RPCDataCollector -> MultiSyncDataCollector -> SyncDataCollector. + Received weights are automatically propagated to sub-collectors if matching model_ids exist. + Defaults to ``None``. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. + worker_idx (int, optional): the index of the worker. + + """ + + def __init__( + self, + create_env_fn: Sequence[Callable[[], EnvBase]], + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, + *, + num_workers: int | None = None, + policy_factory: Callable[[], Callable] + | list[Callable[[], Callable]] + | None = None, + frames_per_batch: int | Sequence[int], + total_frames: int | None = -1, + device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + create_env_kwargs: Sequence[dict] | None = None, + collector_class: type | Callable[[], DataCollectorBase] | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, + reset_at_each_iter: bool = False, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + reset_when_done: bool = True, + update_at_each_batch: bool = False, + preemptive_threshold: float | None = None, + num_threads: int | None = None, + num_sub_threads: int = 1, + cat_results: str | int | None = None, + set_truncated: bool = False, + use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + extend_buffer: bool = True, + replay_buffer_chunk: bool | None = None, + local_init_rb: bool | None = None, + trust_policy: bool | None = None, + compile_policy: bool | dict[str, Any] | None = None, + cudagraph_policy: bool | dict[str, Any] | None = None, + no_cuda_sync: bool = False, + weight_updater: WeightUpdaterBase + | Callable[[], WeightUpdaterBase] + | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, + track_policy_version: bool = False, + worker_idx: int | None = None, + ): + self.closed = True + self.worker_idx = worker_idx + + # Set up workers and environment functions + create_env_fn, total_frames_per_batch = self._setup_workers_and_env_fns( + create_env_fn, num_workers, frames_per_batch + ) + + # Set up basic configuration + self.set_truncated = set_truncated + self.num_sub_threads = num_sub_threads + self.num_threads = num_threads + self.create_env_fn = create_env_fn + self._read_compile_kwargs(compile_policy, cudagraph_policy) + + # Set up environment kwargs + self.create_env_kwargs = self._setup_env_kwargs(create_env_kwargs) + + # Set up devices + storing_devices, policy_devices, env_devices = self._get_devices( + storing_device=storing_device, + env_device=env_device, + policy_device=policy_device, + device=device, + ) + self.storing_device = storing_devices + self.policy_device = policy_devices + self.env_device = env_devices + self.collector_class = collector_class + del storing_device, env_device, policy_device, device + self.no_cuda_sync = no_cuda_sync + + # Set up replay buffer + self._use_buffers = use_buffers + self.replay_buffer = replay_buffer + self._setup_multi_replay_buffer( + local_init_rb, replay_buffer, replay_buffer_chunk, extend_buffer + ) + + # Set up policy and weights + if trust_policy is None: + trust_policy = policy is not None and isinstance(policy, CudaGraphModule) + self.trust_policy = trust_policy + + policy_factory = self._setup_policy_factory(policy_factory) + + # Set up weight synchronization + if weight_sync_schemes is None and weight_updater is None: + weight_sync_schemes = {} + elif weight_sync_schemes is not None and weight_updater is not None: + raise TypeError( + "Cannot specify both weight_sync_schemes and weight_updater." + ) + if ( + weight_sync_schemes is not None + and not weight_sync_schemes + and weight_updater is None + and (isinstance(policy, nn.Module) or any(policy_factory)) + ): + # Set up a default local shared-memory sync scheme for the policy. + # This is used to propagate weights from the orchestrator policy + # (possibly combined with a policy_factory) down to worker policies. + weight_sync_schemes["policy"] = SharedMemWeightSyncScheme() + + self._setup_multi_weight_sync(weight_updater, weight_sync_schemes) + + # Store policy and policy_factory - temporary set to make them visible to the receiver + self.policy = policy + self.policy_factory = policy_factory + + # Set up weight receivers if provided + if weight_recv_schemes is not None: + self.register_scheme_receiver(weight_recv_schemes) + + self._setup_multi_policy_and_weights( + self.policy, self.policy_factory, weight_updater, weight_sync_schemes + ) + + # Set up policy version tracking + self._setup_multi_policy_version_tracking(track_policy_version) + + # # Set up fallback policy for weight extraction + # self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) + + # Set up total frames and other parameters + self._setup_multi_total_frames( + total_frames, total_frames_per_batch, frames_per_batch + ) + self.reset_at_each_iter = reset_at_each_iter + self.postprocs = postproc + self.max_frames_per_traj = ( + int(max_frames_per_traj) if max_frames_per_traj is not None else 0 + ) + + # Set up split trajectories + self.requested_frames_per_batch = total_frames_per_batch + self.reset_when_done = reset_when_done + self._setup_split_trajs(split_trajs, reset_when_done) + + # Set up other parameters + self.init_random_frames = ( + int(init_random_frames) if init_random_frames is not None else 0 + ) + self.update_at_each_batch = update_at_each_batch + self.exploration_type = exploration_type + self.frames_per_worker = np.inf + + # Set up preemptive threshold + self._setup_preemptive_threshold(preemptive_threshold) + + # Run worker processes + try: + self._run_processes() + except Exception as e: + self.shutdown(raise_on_error=False) + raise e + + # Set up frame tracking and other options + self._exclude_private_keys = True + self._frames = 0 + self._iter = -1 + + # Validate cat_results + self._validate_cat_results(cat_results) + + def _setup_workers_and_env_fns( + self, + create_env_fn: Sequence[Callable] | Callable, + num_workers: int | None, + frames_per_batch: int | Sequence[int], + ) -> tuple[list[Callable], int]: + """Set up workers and environment functions.""" + if isinstance(create_env_fn, Sequence): + self.num_workers = len(create_env_fn) + else: + self.num_workers = num_workers + create_env_fn = [create_env_fn] * self.num_workers + + if ( + isinstance(frames_per_batch, Sequence) + and len(frames_per_batch) != self.num_workers + ): + raise ValueError( + "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker." + f"Got {len(frames_per_batch)} values for {self.num_workers} workers." + ) + + self._frames_per_batch = frames_per_batch + total_frames_per_batch = ( + sum(frames_per_batch) + if isinstance(frames_per_batch, Sequence) + else frames_per_batch + ) + + return create_env_fn, total_frames_per_batch + + def _setup_env_kwargs( + self, create_env_kwargs: Sequence[dict] | dict | None + ) -> list[dict]: + """Set up environment kwargs for each worker.""" + if isinstance(create_env_kwargs, Mapping): + create_env_kwargs = [create_env_kwargs] * self.num_workers + elif create_env_kwargs is None: + create_env_kwargs = [{}] * self.num_workers + elif isinstance(create_env_kwargs, (tuple, list)): + create_env_kwargs = list(create_env_kwargs) + if len(create_env_kwargs) != self.num_workers: + raise ValueError( + f"len(create_env_kwargs) must be equal to num_workers, got {len(create_env_kwargs)=} and {self.num_workers=}" + ) + return create_env_kwargs + + def _setup_multi_replay_buffer( + self, + local_init_rb: bool | None, + replay_buffer: ReplayBuffer | None, + replay_buffer_chunk: bool | None, + extend_buffer: bool, + ) -> None: + """Set up replay buffer for multi-process collector.""" + # Handle local_init_rb deprecation + if local_init_rb is None: + local_init_rb = False + if replay_buffer is not None and not local_init_rb: + warnings.warn( + "local_init_rb=False is deprecated and will be removed in v0.12. " + "The new storage-level initialization provides better performance.", + FutureWarning, + ) + self.local_init_rb = local_init_rb + + self._check_replay_buffer_init() + + if replay_buffer_chunk is not None: + if extend_buffer is None: + replay_buffer_chunk = extend_buffer + warnings.warn( + "The replay_buffer_chunk is deprecated and replaced by extend_buffer. This argument will disappear in v0.10.", + DeprecationWarning, + ) + elif extend_buffer != replay_buffer_chunk: + raise ValueError( + "conflicting values for replay_buffer_chunk and extend_buffer." + ) + self.extend_buffer = extend_buffer + + if ( + replay_buffer is not None + and hasattr(replay_buffer, "shared") + and not replay_buffer.shared + ): + torchrl_logger.warning("Replay buffer is not shared. Sharing it.") + replay_buffer.share() + + def _setup_policy_factory( + self, policy_factory: Callable | list[Callable] | None + ) -> list[Callable | None]: + """Set up policy factory for each worker.""" + if not isinstance(policy_factory, Sequence): + policy_factory = [policy_factory] * self.num_workers + return policy_factory + + def _setup_multi_policy_and_weights( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up policy for multi-process collector. + + With weight sync schemes: validates and stores policy without weight extraction. + With weight updater: extracts weights and creates stateful policies. + + When both policy and policy_factory are provided (with weight_sync_schemes): + - The policy is used ONLY for weight extraction via get_model() + - The policy is NOT depopulated of its weights + - The policy is NOT sent to workers + - The policy_factory is used to create policies on workers + """ + if any(policy_factory) and policy is not None: + if weight_sync_schemes is None: + raise TypeError( + "policy_factory and policy are mutually exclusive when not using weight_sync_schemes. " + "When using weight_sync_schemes, policy can be provided alongside policy_factory " + "for weight extraction purposes only (the policy will not be sent to workers)." + ) + # Store policy as fallback for weight extraction only + # The policy keeps its weights and is NOT sent to workers + self._fallback_policy = policy + + if weight_sync_schemes is not None: + weight_sync_policy = weight_sync_schemes.get("policy") + if weight_sync_policy is None: + return + # # If we only have a policy_factory (no policy instance), the scheme must + # # be pre-initialized on the sender, since there is no policy on the + # # collector to extract weights from. + # if any(p is not None for p in policy_factory) and policy is None: + # if not weight_sync_policy.initialized_on_sender: + # raise RuntimeError( + # "the weight sync scheme must be initialized on sender ahead of time " + # "when passing a policy_factory without a policy instance on the collector. " + # f"Got {policy_factory=}" + # ) + # # When a policy instance is provided alongside a policy_factory, the scheme + # # can rely on the collector context (and its policy) to extract weights. + # # Weight sync scheme initialization then happens in _run_processes where + # # pipes and workers are available. + else: + # Using legacy weight updater - extract weights and create stateful policies + self._setup_multi_policy_and_weights_legacy( + policy, policy_factory, weight_updater, weight_sync_schemes + ) + + def _setup_multi_policy_and_weights_legacy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up policy and extract weights for each device. + + Creates stateful policies with weights extracted and placed in shared memory. + Used with weight updater for in-place weight replacement. + """ + self._policy_weights_dict = {} + self._fallback_policy = None # Policy to use for weight extraction fallback + + if not any(policy_factory): + for policy_device, env_maker, env_maker_kwargs in _zip_strict( + self.policy_device, self.create_env_fn, self.create_env_kwargs + ): + policy_new_device, get_weights_fn = self._get_policy_and_device( + policy=policy, + policy_device=policy_device, + env_maker=env_maker, + env_maker_kwargs=env_maker_kwargs, + ) + if type(policy_new_device) is not type(policy): + policy = policy_new_device + weights = ( + TensorDict.from_module(policy_new_device) + if isinstance(policy_new_device, nn.Module) + else TensorDict() + ) + # For multi-process collectors, ensure weights are in shared memory + if policy_device and policy_device.type == "cpu": + weights = weights.share_memory_() + self._policy_weights_dict[policy_device] = weights + # Store the first policy instance for fallback weight extraction + if self._fallback_policy is None: + self._fallback_policy = policy_new_device + self._get_weights_fn = get_weights_fn + if weight_updater is None: + # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default + if weight_sync_schemes is None: + weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} + self._weight_sync_schemes = weight_sync_schemes + elif weight_updater is None: + warnings.warn( + "weight_updater is None, but policy_factory is provided. This means that the server will " + "not know how to send the weights to the workers. If the workers can handle their weight synchronization " + "on their own (via some specialized worker type / constructor) this may well work, but make sure " + "your weight synchronization strategy is properly set. To suppress this warning, you can use " + "RemoteModuleWeightUpdater() which enforces explicit weight passing when calling update_policy_weights_(weights). " + "This will work whenever your inference and training policies are nn.Module instances with similar structures." + ) + + def _setup_multi_weight_sync( + self, + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up weight synchronization for multi-process collector.""" + if weight_sync_schemes is not None: + # Use weight sync schemes for weight distribution + self._weight_sync_schemes = weight_sync_schemes + # Senders will be created in _run_processes + self.weight_updater = None + else: + # Use weight updater for weight distribution + self.weight_updater = weight_updater + self._weight_sync_schemes = None + + def _setup_multi_policy_version_tracking( + self, track_policy_version: bool | PolicyVersion + ) -> None: + """Set up policy version tracking for multi-process collector.""" + self.policy_version_tracker = track_policy_version + if PolicyVersion is not None: + if isinstance(track_policy_version, bool) and track_policy_version: + self.policy_version_tracker = PolicyVersion() + elif hasattr(track_policy_version, "increment_version"): + self.policy_version_tracker = track_policy_version + else: + self.policy_version_tracker = None + else: + if track_policy_version: + raise ImportError( + "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." + ) + self.policy_version_tracker = None + + # TODO: Remove this + def _setup_fallback_policy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up fallback policy for weight extraction when using policy_factory.""" + # _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided + # If policy_factory was used, create a policy instance to use as fallback + if policy is None and any(policy_factory) and weight_sync_schemes is not None: + if not hasattr(self, "_fallback_policy") or self._fallback_policy is None: + first_factory = ( + policy_factory[0] + if isinstance(policy_factory, list) + else policy_factory + ) + if first_factory is not None: + # Create a policy instance for weight extraction + # This will be a reference to a policy with the same structure + # For shared memory, modifications to any policy will be visible here + self._fallback_policy = first_factory() + + def _setup_multi_total_frames( + self, + total_frames: int, + total_frames_per_batch: int, + frames_per_batch: int | Sequence[int], + ) -> None: + """Validate and set total frames for multi-process collector.""" + if total_frames is None or total_frames < 0: + total_frames = float("inf") + else: + remainder = total_frames % total_frames_per_batch + if remainder != 0 and RL_WARNINGS: + warnings.warn( + f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). " + f"This means {total_frames_per_batch - remainder} additional frames will be collected. " + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.total_frames = ( + int(total_frames) if total_frames != float("inf") else total_frames + ) + + def _setup_split_trajs( + self, split_trajs: bool | None, reset_when_done: bool + ) -> None: + """Set up split trajectories option.""" + if split_trajs is None: + split_trajs = False + elif not reset_when_done and split_trajs: + raise RuntimeError( + "Cannot split trajectories when reset_when_done is False." + ) + self.split_trajs = split_trajs + + def _setup_preemptive_threshold(self, preemptive_threshold: float | None) -> None: + """Set up preemptive threshold for early stopping.""" + if preemptive_threshold is not None: + if _is_osx: + raise NotImplementedError( + "Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform." + ) + self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0) + manager = _InterruptorManager() + manager.start() + self.interruptor = manager._Interruptor() + else: + self.preemptive_threshold = 1.0 + self.interruptor = None + + def _validate_cat_results(self, cat_results: str | int | None) -> None: + """Validate cat_results parameter.""" + if cat_results is not None and ( + not isinstance(cat_results, (int, str)) + or (isinstance(cat_results, str) and cat_results != "stack") + ): + raise ValueError( + "cat_results must be a string ('stack') " + f"or an integer representing the cat dimension. Got {cat_results}." + ) + # Lazy import to avoid circular dependency + from torchrl.collectors._multi_sync import MultiSyncDataCollector + + if not isinstance(self, MultiSyncDataCollector) and cat_results not in ( + "stack", + None, + ): + raise ValueError( + "cat_results can only be used with ``MultiSyncDataCollector``." + ) + self.cat_results = cat_results + + def _check_replay_buffer_init(self): + if self.replay_buffer is None: + return + is_init = hasattr(self.replay_buffer, "_storage") and getattr( + self.replay_buffer._storage, "initialized", True + ) + if not is_init: + if self.local_init_rb: + # New behavior: storage handles all coordination itself + # Nothing to do here - the storage will coordinate during first write + self.replay_buffer.share() + return + + # Legacy behavior: fake tensordict initialization + if isinstance(self.create_env_fn[0], EnvCreator): + fake_td = self.create_env_fn[0].meta_data.tensordict + elif isinstance(self.create_env_fn[0], EnvBase): + fake_td = self.create_env_fn[0].fake_tensordict() + else: + fake_td = self.create_env_fn[0]( + **self.create_env_kwargs[0] + ).fake_tensordict() + fake_td["collector", "traj_ids"] = torch.zeros( + fake_td.shape, dtype=torch.long + ) + # Use extend to avoid time-related transforms to fail + self.replay_buffer.extend(fake_td.unsqueeze(-1)) + self.replay_buffer.empty() + + @classmethod + def _total_workers_from_env(cls, env_creators): + if isinstance(env_creators, (tuple, list)): + return sum( + cls._total_workers_from_env(env_creator) for env_creator in env_creators + ) + from torchrl.envs import ParallelEnv + + if isinstance(env_creators, ParallelEnv): + return env_creators.num_workers + return 1 + + def _get_devices( + self, + *, + storing_device: torch.device, + policy_device: torch.device, + env_device: torch.device, + device: torch.device, + ): + # convert all devices to lists + if not isinstance(storing_device, (list, tuple)): + storing_device = [ + storing_device, + ] * self.num_workers + if not isinstance(policy_device, (list, tuple)): + policy_device = [ + policy_device, + ] * self.num_workers + if not isinstance(env_device, (list, tuple)): + env_device = [ + env_device, + ] * self.num_workers + if not isinstance(device, (list, tuple)): + device = [ + device, + ] * self.num_workers + if not ( + len(device) + == len(storing_device) + == len(policy_device) + == len(env_device) + == self.num_workers + ): + raise RuntimeError( + f"THe length of the devices does not match the number of workers: {self.num_workers}." + ) + storing_device, policy_device, env_device = zip( + *[ + SyncDataCollector._get_devices( + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + device=device, + ) + for (storing_device, policy_device, env_device, device) in zip( + storing_device, policy_device, env_device, device + ) + ] + ) + return storing_device, policy_device, env_device + + def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int: + raise NotImplementedError + + @property + def _queue_len(self) -> int: + raise NotImplementedError + + def _run_processes(self) -> None: + if self.num_threads is None: + total_workers = self._total_workers_from_env(self.create_env_fn) + self.num_threads = max( + 1, torch.get_num_threads() - total_workers + ) # 1 more thread for this proc + + # Set up for worker processes + torch.set_num_threads(self.num_threads) + queue_out = mp.Queue(self._queue_len) # sends data from proc to main + self.procs = [] + self._traj_pool = _TrajectoryPool(lock=True) + + # Create all pipes upfront (needed for weight sync scheme initialization) + # Store as list of (parent, child) tuples for use in worker creation + pipe_pairs = [mp.Pipe() for _ in range(self.num_workers)] + # Extract parent pipes for external use (e.g., polling, receiving messages) + self.pipes = [pipe_parent for pipe_parent, _ in pipe_pairs] + + # Initialize all weight sync schemes now that pipes are available + # Both SharedMemWeightSyncScheme (uses queues) and MultiProcessWeightSyncScheme (uses pipes) + # can be initialized here since all required resources exist + if self._weight_sync_schemes: + for model_id, scheme in self._weight_sync_schemes.items(): + if not scheme.initialized_on_sender: + torchrl_logger.debug( + f"Init scheme {type(scheme)} on sender side of {type(self)} with {model_id=} and model {_resolve_model(self, model_id)}." + ) + scheme.init_on_sender(model_id=model_id, context=self) + + # Create a policy on the right device + policy_factory = self.policy_factory + has_policy_factory = any(policy_factory) + if has_policy_factory: + policy_factory = [ + CloudpickleWrapper(_policy_factory) + for _policy_factory in policy_factory + ] + + for i, (env_fun, env_fun_kwargs) in enumerate( + zip(self.create_env_fn, self.create_env_kwargs) + ): + pipe_parent, pipe_child = pipe_pairs[i] # use pre-created pipes + if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( + env_fun, EnvBase + ): # to avoid circular imports + env_fun = CloudpickleWrapper(env_fun) + + policy_device = self.policy_device[i] + storing_device = self.storing_device[i] + env_device = self.env_device[i] + + # Prepare policy for worker based on weight synchronization method. + # IMPORTANT: when a policy_factory is provided, the policy instance + # is used ONLY on the main process (for weight extraction etc.) and + # is NOT sent to workers. + policy = self.policy + + if self._weight_sync_schemes: + # With weight sync schemes, send stateless policies. + # Schemes handle weight distribution on worker side. + if has_policy_factory: + # Factory will create policy in worker; don't send policy. + policy_to_send = None + cm = contextlib.nullcontext() + elif policy is not None: + # Send policy with meta-device parameters (empty structure) - schemes apply weights + policy_to_send = policy + cm = _make_meta_policy(policy) + else: + policy_to_send = None + cm = contextlib.nullcontext() + elif hasattr(self, "_policy_weights_dict"): + # LEGACY: + # With weight updater, use in-place weight replacement. + # Take the weights and locally dispatch them to the policy before sending. + # This ensures a given set of shared weights for a device are shared + # for all policies that rely on that device. + policy_weights = self._policy_weights_dict.get(policy_device) + if has_policy_factory: + # Even in legacy mode, when a policy_factory is present, do not + # send the stateful policy down to workers. + policy_to_send = None + cm = contextlib.nullcontext() + else: + policy_to_send = policy + if policy is not None and policy_weights is not None: + cm = policy_weights.to_module(policy) + else: + cm = contextlib.nullcontext() + else: + # Parameter-less policy. + cm = contextlib.nullcontext() + # When a policy_factory exists, never send the policy instance. + policy_to_send = None if has_policy_factory else policy + + with cm: + kwargs = { + "policy_factory": policy_factory[i], + "pipe_parent": pipe_parent, + "pipe_child": pipe_child, + "queue_out": queue_out, + "create_env_fn": env_fun, + "create_env_kwargs": env_fun_kwargs, + "policy": policy_to_send, + "max_frames_per_traj": self.max_frames_per_traj, + "frames_per_batch": self.frames_per_batch_worker(worker_idx=i), + "reset_at_each_iter": self.reset_at_each_iter, + "policy_device": policy_device, + "storing_device": storing_device, + "env_device": env_device, + "exploration_type": self.exploration_type, + "reset_when_done": self.reset_when_done, + "idx": i, + "interruptor": self.interruptor, + "set_truncated": self.set_truncated, + "use_buffers": self._use_buffers, + "replay_buffer": self.replay_buffer, + "extend_buffer": self.extend_buffer, + "traj_pool": self._traj_pool, + "trust_policy": self.trust_policy, + "compile_policy": self.compiled_policy_kwargs + if self.compiled_policy + else False, + "cudagraph_policy": self.cudagraphed_policy_kwargs + if self.cudagraphed_policy + else False, + "no_cuda_sync": self.no_cuda_sync, + "collector_class": self.collector_class, + "postproc": self.postprocs + if self.replay_buffer is not None + else None, + "weight_sync_schemes": self._weight_sync_schemes, + "worker_idx": i, # Worker index for queue-based weight distribution + } + proc = _ProcessNoWarn( + target=_main_async_collector, + num_threads=self.num_sub_threads, + kwargs=kwargs, + ) + # proc.daemon can't be set as daemonic processes may be launched by the process itself + try: + proc.start() + except TypeError as err: + if "cannot pickle" in str(err): + raise RuntimeError( + "A non-serializable object was passed to the collector workers." + ) from err + except RuntimeError as err: + if "Cowardly refusing to serialize non-leaf tensor" in str(err): + raise RuntimeError( + "At least one of the tensors in the policy, replay buffer, environment constructor or postprocessor requires gradients. " + "This is not supported in multiprocessed data collectors.\n- For ReplayBuffer transforms, use a `transform_factory` instead with `delayed_init=True`.\n" + "- Make sure your environment constructor does not reference tensors already instantiated on the main process.\n" + "- Since no gradient can be propagated through the Collector pipes, the backward graph is never needed. Consider using detached tensors instead." + ) from err + else: + raise err + except _pickle.PicklingError as err: + if "" in str(err): + raise RuntimeError( + """Can't open a process with doubly cloud-pickled lambda function. +This error is likely due to an attempt to use a ParallelEnv in a +multiprocessed data collector. To do this, consider wrapping your +lambda function in an `torchrl.envs.EnvCreator` wrapper as follows: +`env = ParallelEnv(N, EnvCreator(my_lambda_function))`. +This will not only ensure that your lambda function is cloud-pickled once, but +also that the state dict is synchronised across processes if needed.""" + ) from err + pipe_child.close() + self.procs.append(proc) + + # Synchronize initial weights with workers AFTER starting processes but BEFORE waiting for "instantiated" + # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock: + # Workers will call receiver.collect() during init and may block waiting for data + if self._weight_sync_schemes: + # start with policy + policy_scheme = self._weight_sync_schemes.get("policy") + if policy_scheme is not None: + policy_scheme.connect() + for key, scheme in self._weight_sync_schemes.items(): + if key == "policy": + continue + scheme.connect() + + # Wait for workers to be ready + for i, pipe_parent in enumerate(self.pipes): + pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) + try: + msg = pipe_parent.recv() + except EOFError as e: + raise RuntimeError( + f"Worker {i} failed to initialize and closed the connection before sending status. " + f"This typically indicates that the worker process crashed during initialization. " + f"Check the worker process logs for the actual error." + ) from e + if msg != "instantiated": + # Check if it's an error dict from worker + if isinstance(msg, dict) and msg.get("error"): + # Reconstruct the exception from the worker + exc_type_name = msg["exception_type"] + exc_msg = msg["exception_msg"] + traceback_str = msg["traceback"] + + # Try to get the actual exception class + exc_class = None + exc_module = msg["exception_module"] + + if exc_module == "builtins": + # Get from builtins + import builtins + + exc_class = getattr(builtins, exc_type_name, None) + else: + # Try to import from the module + try: + import importlib + + mod = importlib.import_module(exc_module) + exc_class = getattr(mod, exc_type_name, None) + except Exception: + pass + + # Re-raise with original exception type if possible + if exc_class is not None: + raise exc_class( + f"{exc_msg}\n\nWorker traceback:\n{traceback_str}" + ) + else: + # Fall back to RuntimeError if we can't get the original type + raise RuntimeError( + f"Worker {i} raised {exc_type_name}: {exc_msg}\n\nWorker traceback:\n{traceback_str}" + ) + else: + # Legacy string error message + raise RuntimeError(msg) + + self.queue_out = queue_out + self.closed = False + + _running_free = False + + def start(self): + """Starts the collector(s) for asynchronous data collection. + + The collected data is stored in the provided replay buffer. This method initiates the background collection of + data across multiple processes, allowing for decoupling of data collection and training. + + Raises: + RuntimeError: If no replay buffer is defined during the collector's initialization. + + Example: + >>> from torchrl.modules import RandomPolicy >>> >>> import time + >>> from functools import partial + >>> + >>> import tqdm + >>> + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> from torchrl.data import LazyTensorStorage, ReplayBuffer + >>> from torchrl.envs import GymEnv, set_gym_backend + >>> import ale_py + >>> + >>> # Set the gym backend to gymnasium + >>> set_gym_backend("gymnasium").set() + >>> + >>> if __name__ == "__main__": + ... # Create a random policy for the Pong environment + ... env_fn = partial(GymEnv, "ALE/Pong-v5") + ... policy = RandomPolicy(env_fn().action_spec) + ... + ... # Initialize a shared replay buffer + ... rb = ReplayBuffer(storage=LazyTensorStorage(10000), shared=True) + ... + ... # Create a multi-async data collector with 16 environments + ... num_envs = 16 + ... collector = MultiaSyncDataCollector( + ... [env_fn] * num_envs, + ... policy=policy, + ... replay_buffer=rb, + ... frames_per_batch=num_envs * 16, + ... total_frames=-1, + ... ) + ... + ... # Progress bar to track the number of collected frames + ... pbar = tqdm.tqdm(total=100_000) + ... + ... # Start the collector asynchronously + ... collector.start() + ... + ... # Track the write count of the replay buffer + ... prec_wc = 0 + ... while True: + ... wc = rb.write_count + ... c = wc - prec_wc + ... prec_wc = wc + ... + ... # Update the progress bar + ... pbar.update(c) + ... pbar.set_description(f"Write Count: {rb.write_count}") + ... + ... # Check the write count every 0.5 seconds + ... time.sleep(0.5) + ... + ... # Stop when the desired number of frames is reached + ... if rb.write_count . 100_000: + ... break + ... + ... # Shut down the collector + ... collector.async_shutdown() + """ + if self.replay_buffer is None: + raise RuntimeError("Replay buffer must be defined for execution.") + if self.init_random_frames is not None and self.init_random_frames > 0: + raise RuntimeError( + "Cannot currently start() a collector that requires random frames. Please submit a feature request on github." + ) + self._running_free = True + for pipe in self.pipes: + pipe.send((None, "run_free")) + + @contextlib.contextmanager + def pause(self): + """Context manager that pauses the collector if it is running free.""" + if self._running_free: + for pipe in self.pipes: + pipe.send((None, "pause")) + # Make sure all workers are paused + for _ in self.pipes: + idx, msg = self.queue_out.get() + if msg != "paused": + raise ValueError(f"Expected paused, but got {msg=}.") + torchrl_logger.debug(f"Worker {idx} is paused.") + self._running_free = False + yield None + for pipe in self.pipes: + pipe.send((None, "restart")) + self._running_free = True + else: + raise RuntimeError("Collector cannot be paused.") + + def __del__(self): + try: + self.shutdown() + except Exception: + # an AttributeError will typically be raised if the collector is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass + + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + """Shuts down all processes. This operation is irreversible. + + Args: + timeout (float, optional): The timeout for closing pipes between workers. + close_env (bool, optional): Whether to close the environment. Defaults to `True`. + raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. + """ + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) + try: + self._shutdown_main(timeout) + except Exception as e: + if raise_on_error: + raise e + else: + pass + + def _shutdown_main(self, timeout: float | None = None) -> None: + if timeout is None: + timeout = 10 + try: + if self.closed: + return + _check_for_faulty_process(self.procs) + all_closed = [False] * self.num_workers + rep = 0 + for idx in range(self.num_workers): + if all_closed[idx]: + continue + if not self.procs[idx].is_alive(): + continue + self.pipes[idx].send((None, "close")) + + while not all(all_closed) and rep < 1000: + rep += 1 + for idx in range(self.num_workers): + if all_closed[idx]: + continue + if not self.procs[idx].is_alive(): + all_closed[idx] = True + continue + try: + if self.pipes[idx].poll(timeout / 1000 / self.num_workers): + msg = self.pipes[idx].recv() + if msg != "closed": + raise RuntimeError(f"got {msg} but expected 'close'") + all_closed[idx] = True + else: + continue + except BrokenPipeError: + all_closed[idx] = True + continue + self.closed = True + + self.queue_out.close() + for pipe in self.pipes: + pipe.close() + for proc in self.procs: + proc.join(1.0) + finally: + import torchrl + + num_threads = min( + torchrl._THREAD_POOL_INIT, + torch.get_num_threads() + + self._total_workers_from_env(self.create_env_fn), + ) + torch.set_num_threads(num_threads) + + for proc in self.procs: + if proc.is_alive(): + proc.terminate() + + def async_shutdown(self, timeout: float | None = None): + return self.shutdown(timeout=timeout) + + def set_seed(self, seed: int, static_seed: bool = False) -> int: + """Sets the seeds of the environments stored in the DataCollector. + + Args: + seed: integer representing the seed to be used for the environment. + static_seed (bool, optional): if ``True``, the seed is not incremented. + Defaults to False + + Returns: + Output seed. This is useful when more than one environment is + contained in the DataCollector, as the seed will be incremented for + each of these. The resulting seed is the seed of the last + environment. + + Examples: + >>> from torchrl.envs import ParallelEnv + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> env_fn = lambda: GymEnv("Pendulum-v1") + >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300) + >>> out_seed = collector.set_seed(1) # out_seed = 6 + + """ + _check_for_faulty_process(self.procs) + for idx in range(self.num_workers): + self.pipes[idx].send(((seed, static_seed), "seed")) + new_seed, msg = self.pipes[idx].recv() + if msg != "seeded": + raise RuntimeError(f"Expected msg='seeded', got {msg}") + seed = new_seed + self.reset() + return seed + + def reset(self, reset_idx: Sequence[bool] | None = None) -> None: + """Resets the environments to a new initial state. + + Args: + reset_idx: Optional. Sequence indicating which environments have + to be reset. If None, all environments are reset. + + """ + _check_for_faulty_process(self.procs) + + if reset_idx is None: + reset_idx = [True for _ in range(self.num_workers)] + for idx in range(self.num_workers): + if reset_idx[idx]: + self.pipes[idx].send((None, "reset")) + for idx in range(self.num_workers): + if reset_idx[idx]: + j, msg = self.pipes[idx].recv() + if msg != "reset": + raise RuntimeError(f"Expected msg='reset', got {msg}") + + def state_dict(self) -> OrderedDict: + """Returns the state_dict of the data collector. + + Each field represents a worker containing its own state_dict. + + """ + for idx in range(self.num_workers): + self.pipes[idx].send((None, "state_dict")) + state_dict = OrderedDict() + for idx in range(self.num_workers): + _state_dict, msg = self.pipes[idx].recv() + if msg != "state_dict": + raise RuntimeError(f"Expected msg='state_dict', got {msg}") + state_dict[f"worker{idx}"] = _state_dict + state_dict.update({"frames": self._frames, "iter": self._iter}) + + return state_dict + + def load_state_dict(self, state_dict: OrderedDict) -> None: + """Loads the state_dict on the workers. + + Args: + state_dict (OrderedDict): state_dict of the form + ``{"worker0": state_dict0, "worker1": state_dict1}``. + + """ + for idx in range(self.num_workers): + self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict")) + for idx in range(self.num_workers): + _, msg = self.pipes[idx].recv() + if msg != "loaded": + raise RuntimeError(f"Expected msg='loaded', got {msg}") + self._frames = state_dict["frames"] + self._iter = state_dict["iter"] + + def increment_version(self): + """Increment the policy version.""" + if self.policy_version_tracker is not None: + if not hasattr(self.policy_version_tracker, "increment_version"): + raise RuntimeError( + "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." + ) + self.policy_version_tracker.increment_version() + + @property + def policy_version(self) -> str | int | None: + """The current policy version.""" + if not hasattr(self.policy_version_tracker, "version"): + return None + return self.policy_version_tracker.version + + def get_policy_version(self) -> str | int | None: + """Get the current policy version. + + This method exists to support remote calls in Ray actors, since properties + cannot be accessed directly through Ray's RPC mechanism. + + Returns: + The current version number (int) or UUID (str), or None if version tracking is disabled. + """ + return self.policy_version + + def getattr_policy(self, attr): + """Get an attribute from the policy of the first worker. + + Args: + attr (str): The attribute name to retrieve from the policy. + + Returns: + The attribute value from the policy of the first worker. + + Raises: + AttributeError: If the attribute doesn't exist on the policy. + """ + _check_for_faulty_process(self.procs) + + # Send command to first worker (index 0) + self.pipes[0].send((attr, "getattr_policy")) + result, msg = self.pipes[0].recv() + if msg != "getattr_policy": + raise RuntimeError(f"Expected msg='getattr_policy', got {msg}") + + # If the worker returned an AttributeError, re-raise it + if isinstance(result, AttributeError): + raise result + + return result + + def getattr_env(self, attr): + """Get an attribute from the environment of the first worker. + + Args: + attr (str): The attribute name to retrieve from the environment. + + Returns: + The attribute value from the environment of the first worker. + + Raises: + AttributeError: If the attribute doesn't exist on the environment. + """ + _check_for_faulty_process(self.procs) + + # Send command to first worker (index 0) + self.pipes[0].send((attr, "getattr_env")) + result, msg = self.pipes[0].recv() + if msg != "getattr_env": + raise RuntimeError(f"Expected msg='getattr_env', got {msg}") + + # If the worker returned an AttributeError, re-raise it + if isinstance(result, AttributeError): + raise result + + return result + + def getattr_rb(self, attr): + """Get an attribute from the replay buffer.""" + return getattr(self.replay_buffer, attr) + + def get_model(self, model_id: str): + """Get model instance by ID (for weight sync schemes). + + Args: + model_id: Model identifier (e.g., "policy", "value_net") + + Returns: + The model instance + + Raises: + ValueError: If model_id is not recognized + """ + if model_id == "policy": + # Return the fallback policy instance + if (fallback_policy := getattr(self, "_fallback_policy", None)) is not None: + return fallback_policy + elif hasattr(self, "policy") and self.policy is not None: + return self.policy + else: + raise ValueError(f"No policy found for model_id '{model_id}'") + else: + # Try to resolve via attribute access + return _resolve_model(self, model_id) + + def get_cached_weights(self, model_id: str): + """Get cached shared memory weights if available (for weight sync schemes). + + Args: + model_id: Model identifier + + Returns: + Cached TensorDict weights or None if not available + """ + if model_id == "policy" and hasattr(self, "_policy_weights_dict"): + # Get the policy device (first device if list) + policy_device = self.policy_device + if isinstance(policy_device, (list, tuple)): + policy_device = policy_device[0] if len(policy_device) > 0 else None + + # Return cached weights for this device + return self._policy_weights_dict.get(policy_device) + return None + + def _weight_update_impl( + self, + policy_or_weights: TensorDictBase | nn.Module | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + """Update weights on workers. + + Weight sync schemes now use background threads on the receiver side. + The scheme's send() method: + 1. Puts weights in the queue (or updates shared memory) + 2. Sends a "receive" instruction to the worker's background thread + 3. Waits for acknowledgment (if sync=True) + + No pipe signaling is needed - the scheme handles everything internally. + """ + # Call parent implementation which calls scheme.send() + # The scheme handles instruction delivery and acknowledgments + super()._weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + + # for RPC + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): + return super().receive_weights(policy_or_weights) + + # for RPC + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_multi_sync.py b/torchrl/collectors/_multi_sync.py new file mode 100644 index 00000000000..1f756a8b26d --- /dev/null +++ b/torchrl/collectors/_multi_sync.py @@ -0,0 +1,438 @@ +from __future__ import annotations + +import collections +import time +import warnings +from collections import OrderedDict +from collections.abc import Iterator, Sequence +from queue import Empty + +import torch + +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torchrl import logger as torchrl_logger +from torchrl._utils import ( + _check_for_faulty_process, + accept_remote_rref_udf_invocation, + RL_WARNINGS, +) +from torchrl.collectors._constants import _MAX_IDLE_COUNT, _TIMEOUT +from torchrl.collectors._multi_base import _MultiDataCollector +from torchrl.collectors.utils import split_trajectories + + +@accept_remote_rref_udf_invocation +class MultiSyncDataCollector(_MultiDataCollector): + """Runs a given number of DataCollectors on separate processes synchronously. + + .. aafig:: + + +----------------------------------------------------------------------+ + | "MultiSyncDataCollector" | | + |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | + | "Collector 1" | "Collector 2" | "Collector 3" | Main | + |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| + | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | + |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| + |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | + | | | | | | | | + | "actor" | | | "actor" | | + | | | | | | + | "step" | "step" | "actor" | | | + | | | | | | + | | | | "step" | "step" | | + | | | | | | | + | "actor" | "step" | "step" | "actor" | | + | | | | | | + | | "actor" | | | + | | | | | + | "yield batch of traj 1"------->"collect, train"| + | | | + | "step" | "step" | "step" | "step" | "step" | "step" | | + | | | | | | | | + | "actor" | "actor" | | | | + | | "step" | "step" | "actor" | | + | | | | | | + | "step" | "step" | "actor" | "step" | "step" | | + | | | | | | | + | "actor" | | "actor" | | + | "yield batch of traj 2"------->"collect, train"| + | | | + +----------------------------------------------------------------------+ + + Envs can be identical or different. + + The collection starts when the next item of the collector is queried, + and no environment step is computed in between the reception of a batch of + trajectory and the start of the next collection. + This class can be safely used with online RL sota-implementations. + + .. note:: + Python requires multiprocessed code to be instantiated within a main guard: + + >>> from torchrl.collectors import MultiSyncDataCollector + >>> if __name__ == "__main__": + ... # Create your collector here + ... collector = MultiSyncDataCollector(...) + + See https://docs.python.org/3/library/multiprocessing.html for more info. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> from torchrl.collectors import MultiSyncDataCollector + >>> if __name__ == "__main__": + ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + ... collector = MultiSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + ... for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + ... collector.shutdown() + ... del collector + TensorDict( + fields={ + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + + """ + + __doc__ += _MultiDataCollector.__doc__ + + # for RPC + def next(self): + return super().next() + + # for RPC + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) + if hasattr(self, "out_buffer"): + del self.out_buffer + if hasattr(self, "buffers"): + del self.buffers + try: + return super().shutdown(timeout=timeout) + except Exception as e: + if raise_on_error: + raise e + else: + pass + + # for RPC + def set_seed(self, seed: int, static_seed: bool = False) -> int: + return super().set_seed(seed, static_seed) + + # for RPC + def state_dict(self) -> OrderedDict: + return super().state_dict() + + # for RPC + def load_state_dict(self, state_dict: OrderedDict) -> None: + return super().load_state_dict(state_dict) + + # for RPC + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + super().update_policy_weights_( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int: + if worker_idx is not None and isinstance(self._frames_per_batch, Sequence): + return self._frames_per_batch[worker_idx] + if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS: + warnings.warn( + f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}," + f" this results in more frames_per_batch per iteration that requested." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + frames_per_batch_worker = -( + -self.requested_frames_per_batch // self.num_workers + ) + return frames_per_batch_worker + + @property + def _queue_len(self) -> int: + return self.num_workers + + def iterator(self) -> Iterator[TensorDictBase]: + cat_results = self.cat_results + if cat_results is None: + cat_results = "stack" + + self.buffers = {} + dones = [False for _ in range(self.num_workers)] + workers_frames = [0 for _ in range(self.num_workers)] + same_device = None + self.out_buffer = None + preempt = self.interruptor is not None and self.preemptive_threshold < 1.0 + + while not all(dones) and self._frames < self.total_frames: + _check_for_faulty_process(self.procs) + if self.update_at_each_batch: + self.update_policy_weights_() + + for idx in range(self.num_workers): + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + msg = "continue_random" + else: + msg = "continue" + # Debug: sending 'continue' + self.pipes[idx].send((None, msg)) + + self._iter += 1 + + if preempt: + self.interruptor.start_collection() + while self.queue_out.qsize() < int( + self.num_workers * self.preemptive_threshold + ): + continue + self.interruptor.stop_collection() + # Now wait for stragglers to return + while self.queue_out.qsize() < int(self.num_workers): + continue + + recv = collections.deque() + t0 = time.time() + while len(recv) < self.num_workers and ( + (time.time() - t0) < (_TIMEOUT * _MAX_IDLE_COUNT) + ): + for _ in range(self.num_workers): + try: + new_data, j = self.queue_out.get(timeout=_TIMEOUT) + recv.append((new_data, j)) + except (TimeoutError, Empty): + _check_for_faulty_process(self.procs) + if (time.time() - t0) > (_TIMEOUT * _MAX_IDLE_COUNT): + try: + self.shutdown() + finally: + raise RuntimeError( + f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " + f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." + ) + + for _ in range(self.num_workers): + new_data, j = recv.popleft() + use_buffers = self._use_buffers + if self.replay_buffer is not None: + idx = new_data + workers_frames[idx] = workers_frames[ + idx + ] + self.frames_per_batch_worker(worker_idx=idx) + continue + elif j == 0 or not use_buffers: + try: + data, idx = new_data + self.buffers[idx] = data + if use_buffers is None and j > 0: + self._use_buffers = False + except TypeError: + if use_buffers is None: + self._use_buffers = True + idx = new_data + else: + raise + else: + idx = new_data + + if preempt: + # mask buffers if cat, and create a mask if stack + if cat_results != "stack": + buffers = {} + for worker_idx, buffer in self.buffers.items(): + valid = buffer.get(("collector", "traj_ids")) != -1 + if valid.ndim > 2: + valid = valid.flatten(0, -2) + if valid.ndim == 2: + valid = valid.any(0) + buffers[worker_idx] = buffer[..., valid] + else: + for buffer in self.buffers.values(): + with buffer.unlock_(): + buffer.set( + ("collector", "mask"), + buffer.get(("collector", "traj_ids")) != -1, + ) + buffers = self.buffers + else: + buffers = self.buffers + + # Skip frame counting if this worker didn't send data this iteration + # (happens when reusing buffers or on first iteration with some workers) + if idx not in buffers: + continue + + workers_frames[idx] = workers_frames[idx] + buffers[idx].numel() + + if workers_frames[idx] >= self.total_frames: + dones[idx] = True + + if self.replay_buffer is not None: + yield + self._frames += sum( + [ + self.frames_per_batch_worker(worker_idx=worker_idx) + for worker_idx in range(self.num_workers) + ] + ) + continue + + # we have to correct the traj_ids to make sure that they don't overlap + # We can count the number of frames collected for free in this loop + n_collected = 0 + for idx in buffers.keys(): + buffer = buffers[idx] + traj_ids = buffer.get(("collector", "traj_ids")) + if preempt: + if cat_results == "stack": + mask_frames = buffer.get(("collector", "traj_ids")) != -1 + n_collected += mask_frames.sum().cpu() + else: + n_collected += traj_ids.numel() + else: + n_collected += traj_ids.numel() + + if same_device is None: + prev_device = None + same_device = True + for item in self.buffers.values(): + if prev_device is None: + prev_device = item.device + else: + same_device = same_device and (item.device == prev_device) + + if cat_results == "stack": + stack = ( + torch.stack if self._use_buffers else TensorDict.maybe_dense_stack + ) + if same_device: + self.out_buffer = stack(list(buffers.values()), 0) + else: + self.out_buffer = stack( + [item.cpu() for item in buffers.values()], 0 + ) + else: + if self._use_buffers is None: + torchrl_logger.warning( + "use_buffer not specified and not yet inferred from data, assuming `True`." + ) + elif not self._use_buffers: + raise RuntimeError( + "Cannot concatenate results with use_buffers=False" + ) + try: + if same_device: + self.out_buffer = torch.cat(list(buffers.values()), cat_results) + else: + self.out_buffer = torch.cat( + [item.cpu() for item in buffers.values()], cat_results + ) + except RuntimeError as err: + if ( + preempt + and cat_results != -1 + and "Sizes of tensors must match" in str(err) + ): + raise RuntimeError( + "The value provided to cat_results isn't compatible with the collectors outputs. " + "Consider using `cat_results=-1`." + ) + raise + + # TODO: why do we need to do cat inplace and clone? + if self.split_trajs: + out = split_trajectories(self.out_buffer, prefix="collector") + else: + out = self.out_buffer + if cat_results in (-1, "stack"): + out.refine_names(*[None] * (out.ndim - 1) + ["time"]) + + self._frames += n_collected + + if self.postprocs: + self.postprocs = ( + self.postprocs.to(out.device) + if hasattr(self.postprocs, "to") + else self.postprocs + ) + out = self.postprocs(out) + if self._exclude_private_keys: + excluded_keys = [key for key in out.keys() if key.startswith("_")] + if excluded_keys: + out = out.exclude(*excluded_keys) + yield out + del out + + del self.buffers + self.out_buffer = None + # We shall not call shutdown just yet as user may want to retrieve state_dict + # self._shutdown_main() + + # for RPC + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): + return super().receive_weights(policy_or_weights) + + # for RPC + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py new file mode 100644 index 00000000000..54570501cc2 --- /dev/null +++ b/torchrl/collectors/_runner.py @@ -0,0 +1,402 @@ +from __future__ import annotations + +import queue +from collections.abc import Callable +from functools import partial +from multiprocessing import connection, queues +from typing import Any + +import numpy as np +import torch +from tensordict import TensorDict, TensorDictBase + +from torchrl import logger as torchrl_logger +from torchrl._utils import VERBOSE +from torchrl.collectors._base import DataCollectorBase +from torchrl.collectors._constants import ( + _MAX_IDLE_COUNT, + _MIN_TIMEOUT, + _TIMEOUT, + DEFAULT_EXPLORATION_TYPE, +) +from torchrl.collectors._single import SyncDataCollector + +from torchrl.collectors.utils import ( + _cast, + _make_policy_factory, + _map_to_cpu_if_needed, + _TrajectoryPool, +) +from torchrl.data import ReplayBuffer +from torchrl.envs import EnvBase, EnvCreator +from torchrl.envs.utils import ExplorationType +from torchrl.weight_update import WeightSyncScheme + + +def _main_async_collector( + pipe_parent: connection.Connection, + pipe_child: connection.Connection, + queue_out: queues.Queue, + create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821 + create_env_kwargs: dict[str, Any], + policy: Callable[[TensorDictBase], TensorDictBase], + max_frames_per_traj: int, + frames_per_batch: int, + reset_at_each_iter: bool, + storing_device: torch.device | str | int | None, + env_device: torch.device | str | int | None, + policy_device: torch.device | str | int | None, + idx: int = 0, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + reset_when_done: bool = True, + verbose: bool = VERBOSE, + interruptor=None, + set_truncated: bool = False, + use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + extend_buffer: bool = True, + traj_pool: _TrajectoryPool = None, + trust_policy: bool = False, + compile_policy: bool = False, + cudagraph_policy: bool = False, + no_cuda_sync: bool = False, + policy_factory: Callable | None = None, + collector_class: type | Callable[[], DataCollectorBase] | None = None, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + worker_idx: int | None = None, +) -> None: + if collector_class is None: + collector_class = SyncDataCollector + pipe_parent.close() + # init variables that will be cleared when closing + collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None + + # Make a policy-factory out of the policy + policy_factory = partial( + _make_policy_factory, + policy=policy, + policy_factory=policy_factory, + weight_sync_scheme=weight_sync_schemes.get("policy") + if weight_sync_schemes + else None, + worker_idx=worker_idx, + pipe=pipe_child, + ) + policy = None + try: + collector_class._ignore_rb = extend_buffer + inner_collector = collector_class( + create_env_fn, + create_env_kwargs=create_env_kwargs, + policy=policy, + policy_factory=policy_factory, + total_frames=-1, + max_frames_per_traj=max_frames_per_traj, + frames_per_batch=frames_per_batch, + reset_at_each_iter=reset_at_each_iter, + postproc=postproc, + split_trajs=False, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + exploration_type=exploration_type, + reset_when_done=reset_when_done, + return_same_td=replay_buffer is None, + interruptor=interruptor, + set_truncated=set_truncated, + use_buffers=use_buffers, + replay_buffer=replay_buffer, + extend_buffer=False, + traj_pool=traj_pool, + trust_policy=trust_policy, + compile_policy=compile_policy, + cudagraph_policy=cudagraph_policy, + no_cuda_sync=no_cuda_sync, + # We don't pass the weight sync scheme as only the sender has the weight sync scheme within. + # weight_sync_schemes=weight_sync_schemes, + worker_idx=worker_idx, + ) + # Set up weight receivers for worker process using the standard register_scheme_receiver API. + # This properly initializes the schemes on the receiver side and stores them in _receiver_schemes. + if weight_sync_schemes: + inner_collector.register_scheme_receiver(weight_sync_schemes) + + use_buffers = inner_collector._use_buffers + if verbose: + torchrl_logger.debug("Sync data collector created") + dc_iter = iter(inner_collector) + j = 0 + pipe_child.send("instantiated") + except Exception as e: + # Send error information to main process + # We send a dict with the exception info so we can recreate it in the main process + import traceback + + error_info = { + "error": True, + "exception_type": type(e).__name__, + "exception_module": type(e).__module__, + "exception_msg": str(e), + "traceback": traceback.format_exc(), + } + try: + pipe_child.send(error_info) + except Exception: + # If pipe is broken, nothing we can do + pass + return + + has_timed_out = False + counter = 0 + run_free = False + while True: + _timeout = _TIMEOUT if not has_timed_out else 1e-3 + if not run_free and pipe_child.poll(_timeout): + counter = 0 + try: + data_in, msg = pipe_child.recv() + if verbose: + torchrl_logger.debug(f"mp worker {idx} received {msg}") + except EOFError: + torchrl_logger.debug( + f"Failed to receive data. Last message received: {msg}" + ) + raise + elif not run_free: + if verbose: + torchrl_logger.debug(f"poll failed, j={j}, worker={idx}") + # default is "continue" (after first iteration) + # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe + # in that case, the main process probably expects the worker to continue collect data + if has_timed_out: + counter = 0 + # has_timed_out is True if the process failed to send data, which will + # typically occur if main has taken another batch (i.e. the queue is Full). + # In this case, msg is the previous msg sent by main, which will typically be "continue" + # If it's not the case, it is not expected that has_timed_out is True. + if msg not in ("continue", "continue_random"): + raise RuntimeError(f"Unexpected message after time out: msg={msg}") + else: + # if has_timed_out is False, then the time out does not come from the fact that the queue is Full. + # this means that our process has been waiting for a command from main in vain, while main was not + # receiving data. + # This will occur if main is busy doing something else (e.g. computing loss etc). + + counter += _timeout + if verbose: + torchrl_logger.debug(f"mp worker {idx} has counter {counter}") + if counter >= (_MAX_IDLE_COUNT * _TIMEOUT): + raise RuntimeError( + f"This process waited for {counter} seconds " + f"without receiving a command from main. Consider increasing the maximum idle count " + f"if this is expected via the environment variable MAX_IDLE_COUNT " + f"(current value is {_MAX_IDLE_COUNT})." + f"\nIf this occurs at the end of a function or program, it means that your collector has not been " + f"collected, consider calling `collector.shutdown()` before ending the program." + ) + continue + else: + # placeholder, will be checked after + if msg != "continue": + torchrl_logger.debug(f"mp worker {idx} will reset {msg} to 'continue'") + msg = "continue" + if msg == "run_free": + run_free = True + msg = "continue" + if run_free: + # Capture shutdown / update / seed signal, but continue should not be expected + if pipe_child.poll(1e-4): + data_in, msg = pipe_child.recv() + torchrl_logger.debug( + f"mp worker {idx} received {msg} while running free" + ) + if msg == "continue": + # Switch back to run_free = False + run_free = False + if msg == "pause": + queue_out.put((idx, "paused"), timeout=_TIMEOUT) + while not pipe_child.poll(1e-2): + continue + data_in, msg = pipe_child.recv() + if msg != "restart": + raise RuntimeError(f"Expected msg='restart', got {msg=}") + msg = "continue" + else: + data_in = None + # TODO: this does not work with random frames + msg = "continue" + # Note: Weight updates are handled by background threads in weight sync schemes. + # The scheme's background receiver thread listens for "receive" instructions. + + if msg == "update": + # Legacy - weight updater + torchrl_logger.debug(f"mp worker {idx} updating the params...") + inner_collector.update_policy_weights_(policy_weights=data_in) + pipe_child.send((j, "updated")) + has_timed_out = False + continue + + # Note: Weight updates are now handled by background threads in the weight sync schemes. + # The scheme's background receiver thread listens for "receive" instructions and + # applies weights automatically. No explicit message handling needed here. + + if msg in ("continue", "continue_random"): + if msg == "continue_random": + inner_collector.init_random_frames = float("inf") + else: + inner_collector.init_random_frames = -1 + + next_data = next(dc_iter) + if pipe_child.poll(_MIN_TIMEOUT): + # in this case, main send a message to the worker while it was busy collecting trajectories. + # In that case, we skip the collected trajectory and get the message from main. This is faster than + # sending the trajectory in the queue until timeout when it's never going to be received. + continue + + if replay_buffer is not None: + if extend_buffer: + next_data.names = None + replay_buffer.extend(next_data) + + if run_free: + continue + + try: + queue_out.put((idx, j), timeout=_TIMEOUT) + if verbose: + torchrl_logger.debug(f"mp worker {idx} successfully sent data") + j += 1 + has_timed_out = False + continue + except queue.Full: + if verbose: + torchrl_logger.debug(f"mp worker {idx} has timed out") + has_timed_out = True + continue + + if j == 0 or not use_buffers: + collected_tensordict = next_data + if ( + storing_device is not None + and collected_tensordict.device != storing_device + ): + raise RuntimeError( + f"expected device to be {storing_device} but got {collected_tensordict.device}" + ) + if use_buffers: + # If policy and env are on cpu, we put in shared mem, + # if policy is on cuda and env on cuda, we are fine with this + # If policy is on cuda and env on cpu (or opposite) we put tensors that + # are on cpu in shared mem. + MPS_ERROR = ( + "tensors on mps device cannot be put in shared memory. Make sure " + "the shared device (aka storing_device) is set to CPU." + ) + if collected_tensordict.device is not None: + # placeholder in case we need different behaviors + if collected_tensordict.device.type in ("cpu",): + collected_tensordict.share_memory_() + elif collected_tensordict.device.type in ("mps",): + raise RuntimeError(MPS_ERROR) + elif collected_tensordict.device.type == "cuda": + collected_tensordict.share_memory_() + else: + raise NotImplementedError( + f"Device {collected_tensordict.device} is not supported in multi-collectors yet." + ) + else: + # make sure each cpu tensor is shared - assuming non-cpu devices are shared + def cast_tensor(x, MPS_ERROR=MPS_ERROR): + if x.device.type in ("cpu",): + x.share_memory_() + if x.device.type in ("mps",): + RuntimeError(MPS_ERROR) + + collected_tensordict.apply(cast_tensor, filter_empty=True) + data = (collected_tensordict, idx) + else: + if next_data is not collected_tensordict: + raise RuntimeError( + "SyncDataCollector should return the same tensordict modified in-place." + ) + data = idx # flag the worker that has sent its data + try: + queue_out.put((data, j), timeout=_TIMEOUT) + if verbose: + torchrl_logger.debug(f"mp worker {idx} successfully sent data") + j += 1 + has_timed_out = False + continue + except queue.Full: + if verbose: + torchrl_logger.debug(f"mp worker {idx} has timed out") + has_timed_out = True + continue + + if msg == "seed": + data_in, static_seed = data_in + new_seed = inner_collector.set_seed(data_in, static_seed=static_seed) + torch.manual_seed(data_in) + np.random.seed(data_in) + pipe_child.send((new_seed, "seeded")) + has_timed_out = False + continue + + elif msg == "reset": + inner_collector.reset() + pipe_child.send((j, "reset")) + continue + + elif msg == "state_dict": + from torch.utils._pytree import tree_map + + state_dict = inner_collector.state_dict() + # Map exotic devices (MPS, NPU, etc.) to CPU for multiprocessing compatibility + # CPU and CUDA tensors are already shareable and don't need conversion BUT we need to clone the CUDA tensors in case they were sent from main (cannot send cuda tensors back and forth) + state_dict = tree_map(_map_to_cpu_if_needed, state_dict) + state_dict = TensorDict(state_dict) + state_dict = state_dict.clone().apply(_cast, state_dict).to_dict() + pipe_child.send((state_dict, "state_dict")) + has_timed_out = False + continue + + elif msg == "load_state_dict": + state_dict = data_in + inner_collector.load_state_dict(state_dict) + del state_dict + pipe_child.send((j, "loaded")) + has_timed_out = False + continue + + elif msg == "getattr_policy": + attr_name = data_in + try: + result = getattr(inner_collector.policy, attr_name) + pipe_child.send((result, "getattr_policy")) + except AttributeError as e: + pipe_child.send((e, "getattr_policy")) + has_timed_out = False + continue + + elif msg == "getattr_env": + attr_name = data_in + try: + result = getattr(inner_collector.env, attr_name) + pipe_child.send((result, "getattr_env")) + except AttributeError as e: + pipe_child.send((e, "getattr_env")) + has_timed_out = False + continue + + elif msg == "close": + del collected_tensordict, data, next_data, data_in + inner_collector.shutdown() + del inner_collector, dc_iter + pipe_child.send("closed") + if verbose: + torchrl_logger.debug(f"collector {idx} closed") + break + + else: + raise Exception(f"Unrecognized message {msg}") diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py new file mode 100644 index 00000000000..cac92dd0b72 --- /dev/null +++ b/torchrl/collectors/_single.py @@ -0,0 +1,1850 @@ +from __future__ import annotations + +import contextlib +import threading +import warnings +import weakref +from collections import OrderedDict +from collections.abc import Callable, Iterator, Sequence +from textwrap import indent +from typing import Any + +import torch + +from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase +from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase +from torch import nn +from torchrl import compile_with_warmup, logger as torchrl_logger +from torchrl._utils import ( + _ends_with, + _make_ordinal_device, + _replace_last, + accept_remote_rref_udf_invocation, + prod, + RL_WARNINGS, +) +from torchrl.collectors._base import DataCollectorBase +from torchrl.collectors._constants import ( + cudagraph_mark_step_begin, + DEFAULT_EXPLORATION_TYPE, + ExplorationType, +) +from torchrl.collectors.utils import _TrajectoryPool, split_trajectories +from torchrl.collectors.weight_update import WeightUpdaterBase +from torchrl.data import ReplayBuffer +from torchrl.data.utils import DEVICE_TYPING +from torchrl.envs import EnvBase, EnvCreator, StepCounter, TransformedEnv +from torchrl.envs.common import _do_nothing +from torchrl.envs.llm.transforms import PolicyVersion +from torchrl.envs.utils import ( + _aggregate_end_of_traj, + _make_compatible_policy, + set_exploration_type, +) +from torchrl.modules import RandomPolicy +from torchrl.weight_update import WeightSyncScheme +from torchrl.weight_update.utils import _resolve_model + + +@accept_remote_rref_udf_invocation +class SyncDataCollector(DataCollectorBase): + """Generic data collector for RL problems. Requires an environment constructor and a policy. + + Args: + create_env_fn (Callable or EnvBase): a callable that returns an instance of + :class:`~torchrl.envs.EnvBase` class, or the env itself. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided, the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the ``policy_factory`` should be used instead. + + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + + frames_per_batch (int): A keyword-only argument representing the total + number of elements in a batch. + total_frames (int): A keyword-only argument representing the total + number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (endless collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + create_env_kwargs (dict, optional): Dictionary of kwargs for + ``create_env_fn``. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative + values are allowed, in which case this argument is ignored. + Defaults to ``None`` (i.e., no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + + .. warning:: Postproc is not applied when a replay buffer is used and items are added to the buffer + as they are produced (`extend_buffer=False`). The recommended usage is to use `extend_buffer=True`. + + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. + return_same_td (bool, optional): if ``True``, the same TensorDict + will be returned at each iteration, with its values + updated. This feature should be used cautiously: if the same + tensordict is added to a replay buffer for instance, + the whole content of the buffer will be identical. + Default is ``False``. + interruptor (_Interruptor, optional): + An _Interruptor object that can be used from outside the class to control rollout collection. + The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement + strategies such as preeptively stopping rollout collection. + Default is ``False``. + set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding + ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of + a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. + Truncated keys can be set through ``env.add_truncated_keys``. + Defaults to ``False``. + use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. + This isn't compatible with environments with dynamic specs. Defaults to ``True`` + for envs without dynamic specs, ``False`` for others. + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts + but populate the buffer instead. + Defaults to ``None``. + + .. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts. + If the buffer needs to be populated with individual frames as they are collected, + set ``extend_buffer=False`` (deprecated). + + .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires + `extend_buffer=True`, as the whole batch needs to be observed to apply these transforms. + + extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not + with single steps. Defaults to `True`. + + .. note:: Setting this to `False` is deprecated and will be removed in a future version. + Extending the buffer with entire rollouts is the recommended approach for better + compatibility with postprocessing and trajectory splitting. + trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be + assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules + and ``False`` otherwise. + compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled + using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it + will be used to compile the policy. + cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped + in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. + If a dictionary of kwargs is passed, it will be used to wrap the policy. + no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. + For environments running directly on CUDA (`IsaacLab `_ + or `ManiSkills `_) cuda synchronization may cause unexpected + crashes. + Defaults to ``False``. + weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` + or its subclass, responsible for updating the policy weights on remote inference workers. + This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment. + Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): **Not supported for SyncDataCollector**. + SyncDataCollector is a leaf collector and cannot send weights to sub-collectors. + Providing this parameter will raise a ValueError. + Use ``weight_recv_schemes`` if you need to receive weights from a parent collector. + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from parent collectors. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This enables cascading weight updates in hierarchies like: + RPCDataCollector -> MultiSyncDataCollector -> SyncDataCollector. + Defaults to ``None``. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector( + ... create_env_fn=env_maker, + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... ) + >>> for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + TensorDict( + fields={ + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + >>> del collector + + The collector delivers batches of data that are marked with a ``"time"`` + dimension. + + Examples: + >>> assert data.names[-1] == "time" + + """ + + _ignore_rb: bool = False + + def __init__( + self, + create_env_fn: ( + EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]] # noqa: F821 + ), # noqa: F821 + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, + *, + policy_factory: Callable[[], Callable] | None = None, + frames_per_batch: int, + total_frames: int = -1, + device: DEVICE_TYPING | None = None, + storing_device: DEVICE_TYPING | None = None, + policy_device: DEVICE_TYPING | None = None, + env_device: DEVICE_TYPING | None = None, + create_env_kwargs: dict[str, Any] | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, + reset_at_each_iter: bool = False, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + return_same_td: bool = False, + reset_when_done: bool = True, + interruptor=None, + set_truncated: bool = False, + use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + extend_buffer: bool = True, + local_init_rb: bool | None = None, + trust_policy: bool | None = None, + compile_policy: bool | dict[str, Any] | None = None, + cudagraph_policy: bool | dict[str, Any] | None = None, + no_cuda_sync: bool = False, + weight_updater: WeightUpdaterBase + | Callable[[], WeightUpdaterBase] + | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, + track_policy_version: bool = False, + worker_idx: int | None = None, + **kwargs, + ): + self.closed = True + self.worker_idx = worker_idx + + # Note: weight_sync_schemes can be used to send weights to components + # within the environment (e.g., RayModuleTransform), not just sub-collectors + + # Initialize environment + env = self._init_env(create_env_fn, create_env_kwargs) + + # Initialize policy + policy = self._init_policy(policy, policy_factory, env, trust_policy) + self._read_compile_kwargs(compile_policy, cudagraph_policy) + + # Handle trajectory pool and validate kwargs + self._traj_pool_val = kwargs.pop("traj_pool", None) + if kwargs: + raise TypeError( + f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}." + ) + + # Set up devices and synchronization + self._setup_devices( + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + no_cuda_sync=no_cuda_sync, + ) + + self.env: EnvBase = env + del env + + # Set up policy version tracking + self._setup_policy_version_tracking(track_policy_version) + + # Set up replay buffer + self._setup_replay_buffer( + replay_buffer=replay_buffer, + extend_buffer=extend_buffer, + local_init_rb=local_init_rb, + postproc=postproc, + split_trajs=split_trajs, + return_same_td=return_same_td, + use_buffers=use_buffers, + ) + + self.closed = False + + # Validate reset_when_done + if not reset_when_done: + raise ValueError("reset_when_done is deprecated.") + self.reset_when_done = reset_when_done + self.n_env = self.env.batch_size.numel() + + # Register collector with policy and env + if hasattr(policy, "register_collector"): + policy.register_collector(self) + if hasattr(self.env, "register_collector"): + self.env.register_collector(self) + + # Set up policy and weights + self._setup_policy_and_weights(policy) + + # Apply environment device + self._apply_env_device() + + # Set up max frames per trajectory + self._setup_max_frames_per_traj(max_frames_per_traj) + + # Validate and set total frames + self.reset_at_each_iter = reset_at_each_iter + self._setup_total_frames(total_frames, frames_per_batch) + + # Set up init random frames + self._setup_init_random_frames(init_random_frames, frames_per_batch) + + # Set up postproc + self._setup_postproc(postproc) + + # Calculate frames per batch + self._setup_frames_per_batch(frames_per_batch) + + # Set exploration and other options + self.exploration_type = ( + exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE + ) + self.return_same_td = return_same_td + self.set_truncated = set_truncated + + # Create shuttle and rollout buffers + self._make_shuttle() + self._maybe_make_final_rollout(make_rollout=self._use_buffers) + self._set_truncated_keys() + + # Set split trajectories option + if split_trajs is None: + split_trajs = False + self.split_trajs = split_trajs + self._exclude_private_keys = True + + # Set up interruptor and frame tracking + self.interruptor = interruptor + self._frames = 0 + self._iter = -1 + + # Set up weight synchronization + self._setup_weight_sync(weight_updater, weight_sync_schemes) + + # Set up weight receivers if provided + if weight_recv_schemes is not None: + self.register_scheme_receiver(weight_recv_schemes) + + def _init_env( + self, + create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], + create_env_kwargs: dict[str, Any] | None, + ) -> EnvBase: + """Initialize and configure the environment.""" + from torchrl.envs.batched_envs import BatchedEnvBase + + if create_env_kwargs is None: + create_env_kwargs = {} + + if not isinstance(create_env_fn, EnvBase): + env = create_env_fn(**create_env_kwargs) + else: + env = create_env_fn + if create_env_kwargs: + if not isinstance(env, BatchedEnvBase): + raise RuntimeError( + "kwargs were passed to SyncDataCollector but they can't be set " + f"on environment of type {type(create_env_fn)}." + ) + env.update_kwargs(create_env_kwargs) + return env + + def _init_policy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: Callable[[], Callable] | None, + env: EnvBase, + trust_policy: bool | None, + ) -> TensorDictModule | Callable: + """Initialize and configure the policy before device placement / wrapping.""" + if policy is None: + if policy_factory is not None: + policy = policy_factory() + else: + policy = RandomPolicy(env.full_action_spec) + elif policy_factory is not None: + raise TypeError("policy_factory cannot be used with policy argument.") + + if trust_policy is None: + trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) + self.trust_policy = trust_policy + + return policy + + def _setup_devices( + self, + device: DEVICE_TYPING | None, + storing_device: DEVICE_TYPING | None, + policy_device: DEVICE_TYPING | None, + env_device: DEVICE_TYPING | None, + no_cuda_sync: bool, + ) -> None: + """Set up devices and synchronization functions.""" + storing_device, policy_device, env_device = self._get_devices( + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + device=device, + ) + + self.storing_device = storing_device + self._sync_storage = self._get_sync_fn(storing_device) + + self.env_device = env_device + self._sync_env = self._get_sync_fn(env_device) + + self.policy_device = policy_device + self._sync_policy = self._get_sync_fn(policy_device) + + self.device = device + self.no_cuda_sync = no_cuda_sync + self._cast_to_policy_device = self.policy_device != self.env_device + + def _get_sync_fn(self, device: torch.device | None) -> Callable: + """Get the appropriate synchronization function for a device.""" + if device is not None and device.type != "cuda": + # Cuda handles sync + if torch.cuda.is_available(): + return torch.cuda.synchronize + elif torch.backends.mps.is_available() and hasattr(torch, "mps"): + return torch.mps.synchronize + elif hasattr(torch, "npu") and torch.npu.is_available(): + return torch.npu.synchronize + elif device.type == "cpu": + return _do_nothing + else: + raise RuntimeError("Non supported device") + else: + return _do_nothing + + def _setup_policy_version_tracking( + self, track_policy_version: bool | PolicyVersion + ) -> None: + """Set up policy version tracking if requested.""" + self.policy_version_tracker = track_policy_version + if isinstance(track_policy_version, bool) and track_policy_version: + from torchrl.envs.batched_envs import BatchedEnvBase + + if isinstance(self.env, BatchedEnvBase): + raise RuntimeError( + "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " + "and pass that transform to the collector." + ) + self.policy_version_tracker = PolicyVersion() + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + elif hasattr(track_policy_version, "increment_version"): + self.policy_version_tracker = track_policy_version + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + else: + self.policy_version_tracker = None + + def _setup_replay_buffer( + self, + replay_buffer: ReplayBuffer | None, + extend_buffer: bool, + local_init_rb: bool | None, + postproc: Callable | None, + split_trajs: bool | None, + return_same_td: bool, + use_buffers: bool | None, + ) -> None: + """Set up replay buffer configuration and validate compatibility.""" + self.replay_buffer = replay_buffer + self.extend_buffer = extend_buffer + + # Handle local_init_rb deprecation + if local_init_rb is None: + local_init_rb = False + if replay_buffer is not None and not local_init_rb: + warnings.warn( + "local_init_rb=False is deprecated and will be removed in v0.12. " + "The new storage-level initialization provides better performance.", + FutureWarning, + ) + self.local_init_rb = local_init_rb + + # Validate replay buffer compatibility + if self.replay_buffer is not None and not self._ignore_rb: + if postproc is not None and not self.extend_buffer: + raise TypeError( + "postproc must be None when a replay buffer is passed, or extend_buffer must be set to True." + ) + if split_trajs not in (None, False) and not self.extend_buffer: + raise TypeError( + "split_trajs must be None/False when a replay buffer is passed, or extend_buffer must be set to True." + ) + if return_same_td: + raise TypeError( + "return_same_td must be False when a replay buffer is passed, or extend_buffer must be set to True." + ) + if use_buffers: + raise TypeError("replay_buffer is exclusive with use_buffers.") + + if use_buffers is None: + use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None + self._use_buffers = use_buffers + + def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None: + """Set up policy, wrapped policy, and extract weights.""" + # Store weak reference to original policy before any transformations + # This allows update_policy_weights_ to sync from the original when no scheme is configured + if isinstance(policy, nn.Module): + self._orig_policy_ref = weakref.ref(policy) + else: + self._orig_policy_ref = None + + # Check if policy has meta-device parameters (sent from weight sync schemes) + # In that case, skip device placement - weights will come from the receiver + has_meta_params = False + if isinstance(policy, nn.Module): + for p in policy.parameters(): + if p.device.type == "meta": + has_meta_params = True + break + + if has_meta_params: + # Policy has meta params - sent from weight sync schemes + # Skip device placement, weights will come from receiver + # Keep policy on meta device until weights are loaded + if not self.trust_policy: + self.policy = policy + env = getattr(self, "env", None) + try: + wrapped_policy = _make_compatible_policy( + policy=policy, + observation_spec=getattr(env, "observation_spec", None), + env=self.env, + ) + except (TypeError, AttributeError, ValueError) as err: + raise TypeError( + "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." + ) from err + self._wrapped_policy = wrapped_policy + else: + self.policy = self._wrapped_policy = policy + + # For meta-parameter policies, keep the internal (worker-side) policy + # as the reference for collector state_dict / load_state_dict. + if isinstance(self.policy, nn.Module): + self._policy_w_state_dict = self.policy + + # Don't extract weights yet - they're on meta device (empty) + self.policy_weights = TensorDict() + self.get_weights_fn = None + else: + # Normal path: move policy to correct device + policy, self.get_weights_fn = self._get_policy_and_device(policy=policy) + + if not self.trust_policy: + self.policy = policy + env = getattr(self, "env", None) + try: + wrapped_policy = _make_compatible_policy( + policy=policy, + observation_spec=getattr(env, "observation_spec", None), + env=self.env, + ) + except (TypeError, AttributeError, ValueError) as err: + raise TypeError( + "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." + ) from err + self._wrapped_policy = wrapped_policy + else: + self.policy = self._wrapped_policy = policy + + # Use the internal, unwrapped policy (cast to the correct device) as the + # reference for state_dict / load_state_dict and legacy weight extractors. + if isinstance(self.policy, nn.Module): + self._policy_w_state_dict = self.policy + + # Extract policy weights from the uncompiled wrapped policy + # Access _wrapped_policy_uncompiled directly to avoid triggering compilation. + if isinstance(self._wrapped_policy_uncompiled, nn.Module): + self.policy_weights = TensorDict.from_module( + self._wrapped_policy_uncompiled, as_module=True + ).data + else: + self.policy_weights = TensorDict() + + # If policy doesn't have meta params, compile immediately + # Otherwise, defer until first use (after weights are loaded) + if not has_meta_params and (self.compiled_policy or self.cudagraphed_policy): + self._wrapped_policy_maybe_compiled = self._compile_wrapped_policy( + self._wrapped_policy_uncompiled + ) + + def _compile_wrapped_policy(self, policy): + """Apply compilation and/or cudagraph to a policy.""" + if self.compiled_policy: + policy = compile_with_warmup(policy, **self.compiled_policy_kwargs) + if self.cudagraphed_policy: + policy = CudaGraphModule( + policy, + in_keys=[], + out_keys=[], + device=self.policy_device, + **self.cudagraphed_policy_kwargs, + ) + return policy + + @property + def _wrapped_policy(self): + """Returns the compiled policy, compiling it lazily if needed.""" + if (policy := self._wrapped_policy_maybe_compiled) is None: + if self.compiled_policy or self.cudagraphed_policy: + policy = ( + self._wrapped_policy_maybe_compiled + ) = self._compile_wrapped_policy(self._wrapped_policy_uncompiled) + else: + policy = ( + self._wrapped_policy_maybe_compiled + ) = self._wrapped_policy_uncompiled + return policy + + @property + def _orig_policy(self): + """Returns the original policy passed to the collector, if still alive.""" + if self._orig_policy_ref is not None: + return self._orig_policy_ref() + return None + + @_wrapped_policy.setter + def _wrapped_policy(self, value): + """Allow setting the wrapped policy during initialization.""" + self._wrapped_policy_uncompiled = value + self._wrapped_policy_maybe_compiled = None + + def _apply_env_device(self) -> None: + """Apply device to environment if specified.""" + if self.env_device: + self.env: EnvBase = self.env.to(self.env_device) + elif self.env.device is not None: + # Use the device of the env if none was provided + self.env_device = self.env.device + + # Check if we need to cast to env device + self._cast_to_env_device = self._cast_to_policy_device or ( + self.env.device != self.storing_device + ) + + def _setup_max_frames_per_traj(self, max_frames_per_traj: int | None) -> None: + """Set up maximum frames per trajectory and add StepCounter if needed.""" + self.max_frames_per_traj = ( + int(max_frames_per_traj) if max_frames_per_traj is not None else 0 + ) + if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0: + # Check that there is no StepCounter yet + for key in self.env.output_spec.keys(True, True): + if isinstance(key, str): + key = (key,) + if "step_count" in key: + raise ValueError( + "A 'step_count' key is already present in the environment " + "and the 'max_frames_per_traj' argument may conflict with " + "a 'StepCounter' that has already been set. " + "Possible solutions: Set max_frames_per_traj to 0 or " + "remove the StepCounter limit from the environment transforms." + ) + self.env = TransformedEnv( + self.env, StepCounter(max_steps=self.max_frames_per_traj) + ) + + def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None: + """Validate and set total frames.""" + if total_frames is None or total_frames < 0: + total_frames = float("inf") + else: + remainder = total_frames % frames_per_batch + if remainder != 0 and RL_WARNINGS: + warnings.warn( + f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " + f"This means {frames_per_batch - remainder} additional frames will be collected." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.total_frames = ( + int(total_frames) if total_frames != float("inf") else total_frames + ) + + def _setup_init_random_frames( + self, init_random_frames: int | None, frames_per_batch: int + ) -> None: + """Set up initial random frames.""" + self.init_random_frames = ( + int(init_random_frames) if init_random_frames not in (None, -1) else 0 + ) + if ( + init_random_frames not in (-1, None, 0) + and init_random_frames % frames_per_batch != 0 + and RL_WARNINGS + ): + warnings.warn( + f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), " + f" this results in more init_random_frames than requested" + f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + + def _setup_postproc(self, postproc: Callable | None) -> None: + """Set up post-processing transform.""" + self.postproc = postproc + if ( + self.postproc is not None + and hasattr(self.postproc, "to") + and self.storing_device + ): + postproc = self.postproc.to(self.storing_device) + if postproc is not self.postproc and postproc is not None: + self.postproc = postproc + + def _setup_frames_per_batch(self, frames_per_batch: int) -> None: + """Calculate and validate frames per batch.""" + if frames_per_batch % self.n_env != 0 and RL_WARNINGS: + warnings.warn( + f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " + f" this results in more frames_per_batch per iteration that requested" + f" ({-(-frames_per_batch // self.n_env) * self.n_env}). " + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.frames_per_batch = -(-frames_per_batch // self.n_env) + self.requested_frames_per_batch = self.frames_per_batch * self.n_env + + def _setup_weight_sync( + self, + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up weight synchronization system.""" + if weight_sync_schemes is not None: + # Use new simplified weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + # Initialize and synchronize schemes that need sender-side setup + # (e.g., RayModuleTransformScheme for updating transforms in the env) + for model_id, scheme in weight_sync_schemes.items(): + if not scheme.initialized_on_sender: + scheme.init_on_sender(model_id=model_id, context=self) + if not scheme.synchronized_on_sender: + scheme.connect() + self.weight_updater = None # Don't use legacy system + elif weight_updater is not None: + # Use legacy weight updater system if explicitly provided + if not isinstance(weight_updater, WeightUpdaterBase): + if callable(weight_updater): + weight_updater = weight_updater() + else: + raise TypeError( + f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." + ) + warnings.warn( + "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. " + "This will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + self.weight_updater = weight_updater + self._weight_sync_schemes = None + else: + # No weight sync needed for single-process collectors + self.weight_updater = None + self._weight_sync_schemes = None + + @property + def _traj_pool(self): + pool = getattr(self, "_traj_pool_val", None) + if pool is None: + pool = self._traj_pool_val = _TrajectoryPool() + return pool + + def _make_shuttle(self): + # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env + with torch.no_grad(): + self._carrier = self.env.reset() + if self.policy_device != self.env_device or self.env_device is None: + self._shuttle_has_no_device = True + self._carrier.clear_device_() + else: + self._shuttle_has_no_device = False + + traj_ids = self._traj_pool.get_traj_and_increment( + self.n_env, device=self.storing_device + ).view(self.env.batch_size) + self._carrier.set( + ("collector", "traj_ids"), + traj_ids, + ) + + def _maybe_make_final_rollout(self, make_rollout: bool): + if make_rollout: + with torch.no_grad(): + self._final_rollout = self.env.fake_tensordict() + + # If storing device is not None, we use this to cast the storage. + # If it is None and the env and policy are on the same device, + # the storing device is already the same as those, so we don't need + # to consider this use case. + # In all other cases, we can't really put a device on the storage, + # since at least one data source has a device that is not clear. + if self.storing_device: + self._final_rollout = self._final_rollout.to( + self.storing_device, non_blocking=True + ) + else: + # erase all devices + self._final_rollout.clear_device_() + + # Check if policy has meta-device parameters (not yet initialized) + has_meta_params = False + if hasattr(self, "_wrapped_policy_uncompiled") and isinstance( + self._wrapped_policy_uncompiled, nn.Module + ): + for p in self._wrapped_policy_uncompiled.parameters(): + if p.device.type == "meta": + has_meta_params = True + break + + # If the policy has a valid spec, we use it + self._policy_output_keys = set() + if ( + make_rollout + and hasattr( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy, + "spec", + ) + and ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec + is not None + and all( + v is not None + for v in ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec.values(True, True) + ) + ): + if any( + key not in self._final_rollout.keys(isinstance(key, tuple)) + for key in ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec.keys(True, True) + ): + # if policy spec is non-empty, all the values are not None and the keys + # match the out_keys we assume the user has given all relevant information + # the policy could have more keys than the env: + policy_spec = ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec + if policy_spec.ndim < self._final_rollout.ndim: + policy_spec = policy_spec.expand(self._final_rollout.shape) + for key, spec in policy_spec.items(True, True): + self._policy_output_keys.add(key) + if key in self._final_rollout.keys(True): + continue + self._final_rollout.set(key, spec.zero()) + elif ( + not make_rollout + and hasattr( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy, + "out_keys", + ) + and ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).out_keys + ): + self._policy_output_keys = list( + ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).out_keys + ) + elif has_meta_params: + # Policy has meta params and no spec/out_keys - defer initialization + # Mark that we need to initialize later when weights are loaded + self._policy_output_keys = set() + if make_rollout: + # We'll populate keys on first actual rollout after weights are loaded + self._final_rollout_needs_init = True + else: + if make_rollout: + # otherwise, we perform a small number of steps with the policy to + # determine the relevant keys with which to pre-populate _final_rollout. + # This is the safest thing to do if the spec has None fields or if there is + # no spec at all. + # See #505 for additional context. + self._final_rollout.update(self._carrier.copy()) + with torch.no_grad(): + policy_input = self._carrier.copy() + if self.policy_device: + policy_input = policy_input.to(self.policy_device) + # we cast to policy device, we'll deal with the device later + policy_input_copy = policy_input.copy() + policy_input_clone = ( + policy_input.clone() + ) # to test if values have changed in-place + if self.compiled_policy: + cudagraph_mark_step_begin() + policy_output = self._wrapped_policy(policy_input) + + # check that we don't have exclusive keys, because they don't appear in keys + def check_exclusive(val): + if ( + isinstance(val, LazyStackedTensorDict) + and val._has_exclusive_keys + ): + raise RuntimeError( + "LazyStackedTensorDict with exclusive keys are not permitted in collectors. " + "Consider using a placeholder for missing keys." + ) + + policy_output._fast_apply( + check_exclusive, call_on_nested=True, filter_empty=True + ) + + # Use apply, because it works well with lazy stacks + # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit + # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has + # changed them here). + # This will cause a failure to update entries when policy and env device mismatch and + # casting is necessary. + def filter_policy(name, value_output, value_input, value_input_clone): + if (value_input is None) or ( + (value_output is not value_input) + and ( + value_output.device != value_input_clone.device + or ~torch.isclose(value_output, value_input_clone).any() + ) + ): + return value_output + + filtered_policy_output = policy_output.apply( + filter_policy, + policy_input_copy, + policy_input_clone, + default=None, + filter_empty=True, + named=True, + ) + self._policy_output_keys = list( + self._policy_output_keys.union( + set(filtered_policy_output.keys(True, True)) + ) + ) + if make_rollout: + self._final_rollout.update( + policy_output.select(*self._policy_output_keys) + ) + del filtered_policy_output, policy_output, policy_input + + _env_output_keys = [] + for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]: + _env_output_keys += list(self.env.output_spec[spec].keys(True, True)) + self._env_output_keys = _env_output_keys + if make_rollout: + self._final_rollout = ( + self._final_rollout.unsqueeze(-1) + .expand(*self.env.batch_size, self.frames_per_batch) + .clone() + .zero_() + ) + + # in addition to outputs of the policy, we add traj_ids to + # _final_rollout which will be collected during rollout + self._final_rollout.set( + ("collector", "traj_ids"), + torch.zeros( + *self._final_rollout.batch_size, + dtype=torch.int64, + device=self.storing_device, + ), + ) + self._final_rollout.refine_names(..., "time") + + def _set_truncated_keys(self): + self._truncated_keys = [] + if self.set_truncated: + if not any(_ends_with(key, "truncated") for key in self.env.done_keys): + raise RuntimeError( + "set_truncated was set to True but no truncated key could be found " + "in the environment. Make sure the truncated keys are properly set using " + "`env.add_truncated_keys()` before passing the env to the collector." + ) + self._truncated_keys = [ + key for key in self.env.done_keys if _ends_with(key, "truncated") + ] + + @classmethod + def _get_devices( + cls, + *, + storing_device: torch.device, + policy_device: torch.device, + env_device: torch.device, + device: torch.device, + ): + device = _make_ordinal_device(torch.device(device) if device else device) + storing_device = _make_ordinal_device( + torch.device(storing_device) if storing_device else device + ) + policy_device = _make_ordinal_device( + torch.device(policy_device) if policy_device else device + ) + env_device = _make_ordinal_device( + torch.device(env_device) if env_device else device + ) + if storing_device is None and (env_device == policy_device): + storing_device = env_device + return storing_device, policy_device, env_device + + # for RPC + def next(self): + return super().next() + + # for RPC + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + super().update_policy_weights_( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def _maybe_fallback_update( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + model_id: str | None = None, + ) -> None: + """Copy weights from original policy to internal policy when no scheme configured.""" + if model_id is not None and model_id != "policy": + return + + # Get source weights - either from argument or from original policy + if policy_or_weights is not None: + weights = self._extract_weights_if_needed(policy_or_weights, "policy") + elif self._orig_policy is not None: + weights = TensorDict.from_module(self._orig_policy) + else: + return + + # Apply to internal policy + if ( + hasattr(self, "_policy_w_state_dict") + and self._policy_w_state_dict is not None + ): + TensorDict.from_module(self._policy_w_state_dict).data.update_(weights.data) + + def set_seed(self, seed: int, static_seed: bool = False) -> int: + """Sets the seeds of the environments stored in the DataCollector. + + Args: + seed (int): integer representing the seed to be used for the environment. + static_seed(bool, optional): if ``True``, the seed is not incremented. + Defaults to False + + Returns: + Output seed. This is useful when more than one environment is contained in the DataCollector, as the + seed will be incremented for each of these. The resulting seed is the seed of the last environment. + + Examples: + >>> from torchrl.envs import ParallelEnv + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> env_fn = lambda: GymEnv("Pendulum-v1") + >>> env_fn_parallel = ParallelEnv(6, env_fn) + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) + >>> out_seed = collector.set_seed(1) # out_seed = 6 + + """ + out = self.env.set_seed(seed, static_seed=static_seed) + return out + + def _increment_frames(self, numel): + self._frames += numel + completed = self._frames >= self.total_frames + if completed: + self.env.close() + return completed + + def iterator(self) -> Iterator[TensorDictBase]: + """Iterates through the DataCollector. + + Yields: TensorDictBase objects containing (chunks of) trajectories + + """ + if ( + not self.no_cuda_sync + and self.storing_device + and self.storing_device.type == "cuda" + ): + stream = torch.cuda.Stream(self.storing_device, priority=-1) + event = stream.record_event() + streams = [stream] + events = [event] + elif not self.no_cuda_sync and self.storing_device is None: + streams = [] + events = [] + # this way of checking cuda is robust to lazy stacks with mismatching shapes + cuda_devices = set() + + def cuda_check(tensor: torch.Tensor): + if tensor.is_cuda: + cuda_devices.add(tensor.device) + + if not self._use_buffers: + # This may be a bit dangerous as `torch.device("cuda")` may not have a precise + # device associated, whereas `tensor.device` always has + for spec in self.env.specs.values(True, True): + if spec.device is not None and spec.device.type == "cuda": + if ":" not in str(spec.device): + raise RuntimeError( + "A cuda spec did not have a device associated. Make sure to " + "pass `'cuda:device_num'` to each spec device." + ) + cuda_devices.add(spec.device) + else: + self._final_rollout.apply(cuda_check, filter_empty=True) + for device in cuda_devices: + streams.append(torch.cuda.Stream(device, priority=-1)) + events.append(streams[-1].record_event()) + else: + streams = [] + events = [] + with contextlib.ExitStack() as stack: + for stream in streams: + stack.enter_context(torch.cuda.stream(stream)) + + while self._frames < self.total_frames: + self._iter += 1 + torchrl_logger.debug("Collector: rollout.") + tensordict_out = self.rollout() + if tensordict_out is None: + # if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out + # frames are updated within the rollout function + torchrl_logger.debug("Collector: No tensordict_out. Yielding.") + yield + continue + self._increment_frames(tensordict_out.numel()) + tensordict_out = self._postproc(tensordict_out) + torchrl_logger.debug("Collector: postproc done.") + if self.return_same_td: + # This is used with multiprocessed collectors to use the buffers + # stored in the tensordict. + if events: + for event in events: + event.record() + event.synchronize() + yield tensordict_out + elif self.replay_buffer is not None and not self._ignore_rb: + self.replay_buffer.extend(tensordict_out) + torchrl_logger.debug( + f"Collector: Added {tensordict_out.numel()} frames to replay buffer. " + "Buffer write count: {self.replay_buffer.write_count}. Yielding." + ) + yield + else: + # we must clone the values, as the tensordict is updated in-place. + # otherwise the following code may break: + # >>> for i, data in enumerate(collector): + # >>> if i == 0: + # >>> data0 = data + # >>> elif i == 1: + # >>> data1 = data + # >>> else: + # >>> break + # >>> assert data0["done"] is not data1["done"] + yield tensordict_out.clone() + + def start(self): + """Starts the collector in a separate thread for asynchronous data collection. + + The collected data is stored in the provided replay buffer. This method is useful when you want to decouple data + collection from training, allowing your training loop to run independently of the data collection process. + + Raises: + RuntimeError: If no replay buffer is defined during the collector's initialization. + + Example: + >>> from torchrl.modules import RandomPolicy >>> >>> import time + >>> from functools import partial + >>> + >>> import tqdm + >>> + >>> from torchrl.collectors import SyncDataCollector + >>> from torchrl.data import LazyTensorStorage, ReplayBuffer + >>> from torchrl.envs import GymEnv, set_gym_backend + >>> import ale_py + >>> + >>> # Set the gym backend to gymnasium + >>> set_gym_backend("gymnasium").set() + >>> + >>> if __name__ == "__main__": + ... # Create a random policy for the Pong environment + ... env = GymEnv("ALE/Pong-v5") + ... policy = RandomPolicy(env.action_spec) + ... + ... # Initialize a shared replay buffer + ... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True) + ... + ... # Create a synchronous data collector + ... collector = SyncDataCollector( + ... env, + ... policy=policy, + ... replay_buffer=rb, + ... frames_per_batch=256, + ... total_frames=-1, + ... ) + ... + ... # Progress bar to track the number of collected frames + ... pbar = tqdm.tqdm(total=100_000) + ... + ... # Start the collector asynchronously + ... collector.start() + ... + ... # Track the write count of the replay buffer + ... prec_wc = 0 + ... while True: + ... wc = rb.write_count + ... c = wc - prec_wc + ... prec_wc = wc + ... + ... # Update the progress bar + ... pbar.update(c) + ... pbar.set_description(f"Write Count: {rb.write_count}") + ... + ... # Check the write count every 0.5 seconds + ... time.sleep(0.5) + ... + ... # Stop when the desired number of frames is reached + ... if rb.write_count . 100_000: + ... break + ... + ... # Shut down the collector + ... collector.async_shutdown() + """ + if self.replay_buffer is None: + raise RuntimeError("Replay buffer must be defined for execution.") + if not self.is_running(): + self._stop = False + self._thread = threading.Thread(target=self._run_iterator) + self._thread.daemon = ( + True # So that the thread dies when the main program exits + ) + self._thread.start() + + def _run_iterator(self): + for _ in self: + if self._stop: + return + + def is_running(self): + return hasattr(self, "_thread") and self._thread.is_alive() + + def async_shutdown( + self, timeout: float | None = None, close_env: bool = True + ) -> None: + """Finishes processes started by ray.init() during async execution.""" + self._stop = True + if hasattr(self, "_thread") and self._thread.is_alive(): + self._thread.join(timeout=timeout) + self.shutdown(close_env=close_env) + + def _postproc(self, tensordict_out): + if self.split_trajs: + tensordict_out = split_trajectories(tensordict_out, prefix="collector") + if self.postproc is not None: + tensordict_out = self.postproc(tensordict_out) + if self._exclude_private_keys: + + def is_private(key): + if isinstance(key, str) and key.startswith("_"): + return True + if isinstance(key, tuple) and any(_key.startswith("_") for _key in key): + return True + return False + + excluded_keys = [ + key for key in tensordict_out.keys(True) if is_private(key) + ] + tensordict_out = tensordict_out.exclude(*excluded_keys, inplace=True) + return tensordict_out + + def _update_traj_ids(self, env_output) -> None: + # we can't use the reset keys because they're gone + traj_sop = _aggregate_end_of_traj( + env_output.get("next"), done_keys=self.env.done_keys + ) + if traj_sop.any(): + device = self.storing_device + + traj_ids = self._carrier.get(("collector", "traj_ids")) + if device is not None: + traj_ids = traj_ids.to(device) + traj_sop = traj_sop.to(device) + elif traj_sop.device != traj_ids.device: + traj_sop = traj_sop.to(traj_ids.device) + + pool = self._traj_pool + new_traj = pool.get_traj_and_increment( + traj_sop.sum(), device=traj_sop.device + ) + traj_ids = traj_ids.masked_scatter(traj_sop, new_traj) + self._carrier.set(("collector", "traj_ids"), traj_ids) + + @torch.no_grad() + def rollout(self) -> TensorDictBase: + """Computes a rollout in the environment using the provided policy. + + Returns: + TensorDictBase containing the computed rollout. + + """ + if self.reset_at_each_iter: + self._carrier.update(self.env.reset()) + + # self._shuttle.fill_(("collector", "step_count"), 0) + if self._use_buffers: + self._final_rollout.fill_(("collector", "traj_ids"), -1) + else: + pass + tensordicts = [] + with set_exploration_type(self.exploration_type): + for t in range(self.frames_per_batch): + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + self.env.rand_action(self._carrier) + if ( + self.policy_device is not None + and self.policy_device != self.env_device + ): + # TODO: This may break with exclusive / ragged lazy stacks + self._carrier.apply( + lambda name, val: val.to( + device=self.policy_device, non_blocking=True + ) + if name in self._policy_output_keys + else val, + out=self._carrier, + named=True, + nested_keys=True, + ) + else: + if self._cast_to_policy_device: + if self.policy_device is not None: + # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking + non_blocking = ( + not self.no_cuda_sync + or self.policy_device.type == "cuda" + ) + policy_input = self._carrier.to( + self.policy_device, + non_blocking=non_blocking, + ) + if not self.no_cuda_sync: + self._sync_policy() + elif self.policy_device is None: + # we know the tensordict has a device otherwise we would not be here + # we can pass this, clear_device_ must have been called earlier + # policy_input = self._shuttle.clear_device_() + policy_input = self._carrier + else: + policy_input = self._carrier + # we still do the assignment for security + if self.compiled_policy: + cudagraph_mark_step_begin() + policy_output = self._wrapped_policy(policy_input) + if self.compiled_policy: + policy_output = policy_output.clone() + if self._carrier is not policy_output: + # ad-hoc update shuttle + self._carrier.update( + policy_output, keys_to_update=self._policy_output_keys + ) + + if self._cast_to_env_device: + if self.env_device is not None: + non_blocking = ( + not self.no_cuda_sync or self.env_device.type == "cuda" + ) + env_input = self._carrier.to( + self.env_device, non_blocking=non_blocking + ) + if not self.no_cuda_sync: + self._sync_env() + elif self.env_device is None: + # we know the tensordict has a device otherwise we would not be here + # we can pass this, clear_device_ must have been called earlier + # env_input = self._shuttle.clear_device_() + env_input = self._carrier + else: + env_input = self._carrier + env_output, env_next_output = self.env.step_and_maybe_reset(env_input) + + if self._carrier is not env_output: + # ad-hoc update shuttle + next_data = env_output.get("next") + if self._shuttle_has_no_device: + # Make sure + next_data.clear_device_() + self._carrier.set("next", next_data) + + if ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): + torchrl_logger.debug( + f"Collector: Adding {env_output.numel()} frames to replay buffer using add()." + ) + self.replay_buffer.add(self._carrier) + if self._increment_frames(self._carrier.numel()): + return + else: + if self.storing_device is not None: + torchrl_logger.debug( + f"Collector: Moving to {self.storing_device} and adding to queue." + ) + non_blocking = ( + not self.no_cuda_sync or self.storing_device.type == "cuda" + ) + tensordicts.append( + self._carrier.to( + self.storing_device, non_blocking=non_blocking + ) + ) + if not self.no_cuda_sync: + self._sync_storage() + else: + tensordicts.append(self._carrier) + + # carry over collector data without messing up devices + collector_data = self._carrier.get("collector").copy() + self._carrier = env_next_output + if self._shuttle_has_no_device: + self._carrier.clear_device_() + self._carrier.set("collector", collector_data) + self._update_traj_ids(env_output) + + if ( + self.interruptor is not None + and self.interruptor.collection_stopped() + ): + torchrl_logger.debug("Collector: Interruptor stopped.") + if ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): + return + result = self._final_rollout + if self._use_buffers: + try: + torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout[..., : t + 1], + ) + except RuntimeError: + with self._final_rollout.unlock_(): + torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout[..., : t + 1], + ) + else: + result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) + break + else: + if self._use_buffers: + torchrl_logger.debug("Returning final rollout within buffer.") + result = self._final_rollout + try: + result = torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout, + ) + + except RuntimeError: + with self._final_rollout.unlock_(): + result = torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout, + ) + elif ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): + return + else: + torchrl_logger.debug( + "Returning final rollout with NO buffer (maybe_dense_stack)." + ) + result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) + result.refine_names(..., "time") + + return self._maybe_set_truncated(result) + + def _maybe_set_truncated(self, final_rollout): + last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,) + for truncated_key in self._truncated_keys: + truncated = final_rollout["next", truncated_key] + truncated[last_step] = True + final_rollout["next", truncated_key] = truncated + done = final_rollout["next", _replace_last(truncated_key, "done")] + final_rollout["next", _replace_last(truncated_key, "done")] = ( + done | truncated + ) + return final_rollout + + @torch.no_grad() + def reset(self, index=None, **kwargs) -> None: + """Resets the environments to a new initial state.""" + # metadata + collector_metadata = self._carrier.get("collector").clone() + if index is not None: + # check that the env supports partial reset + if prod(self.env.batch_size) == 0: + raise RuntimeError("resetting unique env with index is not permitted.") + for reset_key, done_keys in zip( + self.env.reset_keys, self.env.done_keys_groups + ): + _reset = torch.zeros( + self.env.full_done_spec[done_keys[0]].shape, + dtype=torch.bool, + device=self.env.device, + ) + _reset[index] = 1 + self._carrier.set(reset_key, _reset) + else: + _reset = None + self._carrier.zero_() + + self._carrier.update(self.env.reset(**kwargs), inplace=True) + collector_metadata["traj_ids"] = ( + collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min() + ) + self._carrier["collector"] = collector_metadata + + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + """Shuts down all workers and/or closes the local environment. + + Args: + timeout (float, optional): The timeout for closing pipes between workers. + No effect for this class. + close_env (bool, optional): Whether to close the environment. Defaults to `True`. + raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. + """ + try: + if not self.closed: + self.closed = True + del self._carrier + if self._use_buffers: + del self._final_rollout + if close_env and not self.env.is_closed: + self.env.close(raise_if_closed=raise_on_error) + del self.env + return + except Exception as e: + if raise_on_error: + raise e + else: + pass + + def __del__(self): + try: + self.shutdown() + except Exception: + # an AttributeError will typically be raised if the collector is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass + + def state_dict(self) -> OrderedDict: + """Returns the local state_dict of the data collector (environment and policy). + + Returns: + an ordered dictionary with fields :obj:`"policy_state_dict"` and + `"env_state_dict"`. + + """ + from torchrl.envs.batched_envs import BatchedEnvBase + + if isinstance(self.env, TransformedEnv): + env_state_dict = self.env.transform.state_dict() + elif isinstance(self.env, BatchedEnvBase): + env_state_dict = self.env.state_dict() + else: + env_state_dict = OrderedDict() + + if hasattr(self, "_policy_w_state_dict"): + policy_state_dict = self._policy_w_state_dict.state_dict() + state_dict = OrderedDict( + policy_state_dict=policy_state_dict, + env_state_dict=env_state_dict, + ) + else: + state_dict = OrderedDict(env_state_dict=env_state_dict) + + state_dict.update({"frames": self._frames, "iter": self._iter}) + + return state_dict + + def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: + """Loads a state_dict on the environment and policy. + + Args: + state_dict (OrderedDict): ordered dictionary containing the fields + `"policy_state_dict"` and :obj:`"env_state_dict"`. + + """ + strict = kwargs.get("strict", True) + if strict or "env_state_dict" in state_dict: + self.env.load_state_dict(state_dict["env_state_dict"], **kwargs) + if strict or "policy_state_dict" in state_dict: + if not hasattr(self, "_policy_w_state_dict"): + raise ValueError( + "Underlying policy does not have state_dict to load policy_state_dict into." + ) + self._policy_w_state_dict.load_state_dict( + state_dict["policy_state_dict"], **kwargs + ) + self._frames = state_dict["frames"] + self._iter = state_dict["iter"] + + def __repr__(self) -> str: + try: + env_str = indent(f"env={self.env}", 4 * " ") + policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ") + td_out_str = repr(getattr(self, "_final_rollout", None)) + if len(td_out_str) > 50: + td_out_str = td_out_str[:50] + "..." + td_out_str = indent(f"td_out={td_out_str}", 4 * " ") + string = ( + f"{self.__class__.__name__}(" + f"\n{env_str}," + f"\n{policy_str}," + f"\n{td_out_str}," + f"\nexploration={self.exploration_type})" + ) + return string + except Exception: + return f"{type(self).__name__}(not_init)" + + def increment_version(self): + """Increment the policy version.""" + if self.policy_version_tracker is not None: + if not hasattr(self.policy_version_tracker, "increment_version"): + raise RuntimeError( + "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." + ) + self.policy_version_tracker.increment_version() + + @property + def policy_version(self) -> str | int | None: + """The current policy version.""" + if not hasattr(self.policy_version_tracker, "version"): + return None + return self.policy_version_tracker.version + + def get_policy_version(self) -> str | int | None: + """Get the current policy version. + + This method exists to support remote calls in Ray actors, since properties + cannot be accessed directly through Ray's RPC mechanism. + + Returns: + The current version number (int) or UUID (str), or None if version tracking is disabled. + """ + return self.policy_version + + def getattr_policy(self, attr): + """Get an attribute from the policy.""" + # send command to policy to return the attr + return getattr(self._wrapped_policy, attr) + + def getattr_env(self, attr): + """Get an attribute from the environment.""" + # send command to env to return the attr + return getattr(self.env, attr) + + def getattr_rb(self, attr): + """Get an attribute from the replay buffer.""" + # send command to rb to return the attr + return getattr(self.replay_buffer, attr) + + def get_model(self, model_id: str): + """Get model instance by ID (for weight sync schemes). + + Args: + model_id: Model identifier (e.g., "policy", "value_net") + + Returns: + The model instance + + Raises: + ValueError: If model_id is not recognized + """ + if model_id == "policy": + # Return the unwrapped policy instance for weight synchronization + # The unwrapped policy has the same parameter structure as what's + # extracted in the main process, avoiding key mismatches when + # the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule) + if hasattr(self, "policy") and self.policy is not None: + return self.policy + else: + raise ValueError(f"No policy found for model_id '{model_id}'") + else: + return _resolve_model(self, model_id) + + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_single_async.py b/torchrl/collectors/_single_async.py new file mode 100644 index 00000000000..131c913b184 --- /dev/null +++ b/torchrl/collectors/_single_async.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +from collections import OrderedDict +from collections.abc import Callable, Sequence +from typing import Any + +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModule + +from torchrl._utils import accept_remote_rref_udf_invocation +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE, ExplorationType +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.data.utils import DEVICE_TYPING +from torchrl.envs import EnvBase + + +@accept_remote_rref_udf_invocation +class aSyncDataCollector(MultiaSyncDataCollector): + """Runs a single DataCollector on a separate process. + + This is mostly useful for offline RL paradigms where the policy being + trained can differ from the policy used to collect data. In online + settings, a regular DataCollector should be preferred. This class is + merely a wrapper around a MultiaSyncDataCollector where a single process + is being created. + + Args: + create_env_fn (Callabled): Callable returning an instance of EnvBase + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided, the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the ``policy_factory`` should be used instead. + + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + + frames_per_batch (int): A keyword-only argument representing the + total number of elements in a batch. + total_frames (int, optional): A keyword-only argument representing the + total number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (never ending collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + create_env_kwargs (dict, optional): A dictionary with the + keyword arguments used to create an environment. If a list is + provided, each of its elements will be assigned to a sub-collector. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative + values are allowed, in which case this argument is ignored. + Defaults to ``None`` (i.e. no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. + reset_when_done (bool, optional): if ``True`` (default), an environment + that return a ``True`` value in its ``"done"`` or ``"truncated"`` + entry will be reset at the corresponding indices. + update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` + will be called before (sync) or after (async) each data collection. + Defaults to ``False``. + preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers + that will be allowed to finished collecting their rollout before the rest are forced to end early. + num_threads (int, optional): number of threads for this process. + Defaults to the number of workers. + num_sub_threads (int, optional): number of threads of the subprocesses. + Should be equal to one plus the number of processes launched within + each subprocess (or one if a single process is launched). + Defaults to 1 for safety: if none is indicated, launching multiple + workers may charge the cpu load too much and harm performance. + set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding + ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of + a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. + Truncated keys can be set through ``env.add_truncated_keys``. + Defaults to ``False``. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. + + """ + + def __init__( + self, + create_env_fn: Callable[[], EnvBase], + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, + *, + policy_factory: Callable[[], Callable] | None = None, + frames_per_batch: int, + total_frames: int | None = -1, + device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + create_env_kwargs: Sequence[dict[str, Any]] | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, + reset_at_each_iter: bool = False, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + reset_when_done: bool = True, + update_at_each_batch: bool = False, + preemptive_threshold: float | None = None, + num_threads: int | None = None, + num_sub_threads: int = 1, + set_truncated: bool = False, + track_policy_version: bool = False, + **kwargs, + ): + super().__init__( + create_env_fn=[create_env_fn], + policy=policy, + policy_factory=policy_factory, + total_frames=total_frames, + create_env_kwargs=[create_env_kwargs] + if create_env_kwargs + else create_env_kwargs, + max_frames_per_traj=max_frames_per_traj, + frames_per_batch=frames_per_batch, + reset_at_each_iter=reset_at_each_iter, + init_random_frames=init_random_frames, + postproc=postproc, + split_trajs=split_trajs, + device=device, + policy_device=policy_device, + env_device=env_device, + storing_device=storing_device, + exploration_type=exploration_type, + reset_when_done=reset_when_done, + update_at_each_batch=update_at_each_batch, + preemptive_threshold=preemptive_threshold, + num_threads=num_threads, + num_sub_threads=num_sub_threads, + set_truncated=set_truncated, + track_policy_version=track_policy_version, + **kwargs, + ) + + # for RPC + def next(self): + return super().next() + + # for RPC + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + return super().shutdown( + timeout=timeout, close_env=close_env, raise_on_error=raise_on_error + ) + + # for RPC + def set_seed(self, seed: int, static_seed: bool = False) -> int: + return super().set_seed(seed, static_seed) + + # for RPC + def state_dict(self) -> OrderedDict: + return super().state_dict() + + # for RPC + def load_state_dict(self, state_dict: OrderedDict) -> None: + return super().load_state_dict(state_dict) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b7be73d243f..5af173a40c4 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -2,4973 +2,47 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""Re-exports of collector classes for backward compatibility.""" from __future__ import annotations -import _pickle -import abc -import collections -import contextlib -import functools -import os -import queue -import sys -import threading -import time -import typing -import warnings -from collections import defaultdict, OrderedDict -from collections.abc import Callable, Iterator, Mapping, Sequence -from copy import deepcopy -from multiprocessing import connection, queues -from multiprocessing.managers import SyncManager -from queue import Empty -from textwrap import indent -from typing import Any, TypeVar - -import numpy as np -import torch -import torch.nn as nn - -from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase -from tensordict.base import NO_DEFAULT -from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase -from tensordict.utils import _zip_strict, Buffer -from torch import multiprocessing as mp -from torch.nn import Parameter -from torch.utils.data import IterableDataset - -from torchrl._utils import ( - _check_for_faulty_process, - _ends_with, - _make_ordinal_device, - _ProcessNoWarn, - _replace_last, - accept_remote_rref_udf_invocation, - compile_with_warmup, - logger as torchrl_logger, - prod, - rl_warnings, - VERBOSE, -) -from torchrl.collectors.utils import split_trajectories -from torchrl.collectors.weight_update import WeightUpdaterBase -from torchrl.data import ReplayBuffer -from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING -from torchrl.envs.common import _do_nothing, EnvBase -from torchrl.envs.env_creator import EnvCreator - -from torchrl.envs.llm.transforms.policy_version import PolicyVersion -from torchrl.envs.transforms import StepCounter, TransformedEnv -from torchrl.envs.utils import ( - _aggregate_end_of_traj, - _make_compatible_policy, - ExplorationType, - RandomPolicy, - set_exploration_type, -) -from torchrl.weight_update import SharedMemWeightSyncScheme -from torchrl.weight_update.weight_sync_schemes import ( - _resolve_model, - MultiProcessWeightSyncScheme, - WeightReceiver, - WeightSender, - WeightSyncScheme, +from torchrl.collectors._base import DataCollectorBase + +# Re-export constants for backward compatibility +from torchrl.collectors._constants import ( + _Interruptor, + _InterruptorManager, + _is_osx, + _MAX_IDLE_COUNT, + _MIN_TIMEOUT, + _TIMEOUT, + cudagraph_mark_step_begin, + DEFAULT_EXPLORATION_TYPE, + INSTANTIATE_TIMEOUT, ) -try: - from torch.compiler import cudagraph_mark_step_begin -except ImportError: - - def cudagraph_mark_step_begin(): - """Placeholder for missing cudagraph_mark_step_begin method.""" - raise NotImplementedError("cudagraph_mark_step_begin not implemented.") - - -_TIMEOUT = 1.0 -INSTANTIATE_TIMEOUT = 20 -_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory -# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue. -_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max)) - -DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM - -_is_osx = sys.platform.startswith("darwin") - -T = TypeVar("T") - - -class _Interruptor: - """A class for managing the collection state of a process. - - This class provides methods to start and stop collection, and to check - whether collection has been stopped. The collection state is protected - by a lock to ensure thread-safety. - """ - - # interrupter vs interruptor: google trends seems to indicate that "or" is more - # widely used than "er" even if my IDE complains about that... - def __init__(self): - self._collect = True - self._lock = mp.Lock() - - def start_collection(self): - with self._lock: - self._collect = True - - def stop_collection(self): - with self._lock: - self._collect = False - - def collection_stopped(self): - with self._lock: - return self._collect is False - - -class _InterruptorManager(SyncManager): - """A custom SyncManager for managing the collection state of a process. - - This class extends the SyncManager class and allows to share an Interruptor object - between processes. - """ - - -_InterruptorManager.register("_Interruptor", _Interruptor) - - -def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: - """Maps the tensors to CPU through a nested dictionary.""" - return OrderedDict( - **{ - k: recursive_map_to_cpu(item) - if isinstance(item, OrderedDict) - else item.cpu() - if isinstance(item, torch.Tensor) - else item - for k, item in dictionary.items() - } - ) - - -class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): - """Base class for data collectors.""" - - _task = None - _iterator = None - total_frames: int - requested_frames_per_batch: int - frames_per_batch: int - trust_policy: bool - compiled_policy: bool - cudagraphed_policy: bool - _weight_updater: WeightUpdaterBase | None = None - _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None - _weight_senders: dict[str, WeightSender] | None = None - _weight_receivers: dict[str, WeightReceiver] | None = None - verbose: bool = False - - @property - def weight_updater(self) -> WeightUpdaterBase: - return self._weight_updater - - @weight_updater.setter - def weight_updater(self, value: WeightUpdaterBase | None): - if value is not None: - if not isinstance(value, WeightUpdaterBase) and callable( - value - ): # Fall back to default constructor - value = value() - value.register_collector(self) - if value.collector is not self: - raise RuntimeError("Failed to register collector.") - self._weight_updater = value - - def _get_policy_and_device( - self, - policy: Callable[[Any], Any] | None = None, - policy_device: Any = NO_DEFAULT, - env_maker: Any | None = None, - env_maker_kwargs: dict[str, Any] | None = None, - ) -> tuple[TensorDictModule, None | Callable[[], dict]]: - """Util method to get a policy and its device given the collector __init__ inputs. - - We want to copy the policy and then move the data there, not call policy.to(device). - - Args: - policy (TensorDictModule, optional): a policy to be used - policy_device (torch.device, optional): the device where the policy should be placed. - Defaults to self.policy_device - env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair. - env_maker_kwargs (a dict, optional): the env_maker function kwargs. - - """ - if policy_device is NO_DEFAULT: - policy_device = self.policy_device - - if not policy_device: - return policy, None - - if isinstance(policy, nn.Module): - param_and_buf = TensorDict.from_module(policy, as_module=True) - else: - # Because we want to reach the warning - param_and_buf = TensorDict() - - i = -1 - for p in param_and_buf.values(True, True): - i += 1 - if p.device != policy_device: - # Then we need casting - break - else: - if i == -1 and not self.trust_policy: - # We trust that the policy policy device is adequate - warnings.warn( - "A policy device was provided but no parameter/buffer could be found in " - "the policy. Casting to policy_device is therefore impossible. " - "The collector will trust that the devices match. To suppress this " - "warning, set `trust_policy=True` when building the collector." - ) - return policy, None - - # Create a stateless policy, then populate this copy with params on device - def get_original_weights(policy=policy): - td = TensorDict.from_module(policy) - return td.data - - # We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function - with param_and_buf.data.to("meta").to_module(policy): - policy_new_device = deepcopy(policy) - - param_and_buf_new_device = param_and_buf.apply( - functools.partial(_map_weight, policy_device=policy_device), - filter_empty=False, - ) - param_and_buf_new_device.to_module(policy_new_device) - # Sanity check - if set(TensorDict.from_module(policy_new_device).keys(True, True)) != set( - get_original_weights().keys(True, True) - ): - raise RuntimeError("Failed to map weights. The weight sets mismatch.") - return policy_new_device, get_original_weights - - def start(self): - """Starts the collector for asynchronous data collection. - - This method initiates the background collection of data, allowing for decoupling of data collection and training. - - The collected data is typically stored in a replay buffer passed during the collector's initialization. - - .. note:: After calling this method, it's essential to shut down the collector using :meth:`~.async_shutdown` - when you're done with it to free up resources. - - .. warning:: Asynchronous data collection can significantly impact training performance due to its decoupled nature. - Ensure you understand the implications for your specific algorithm before using this mode. - - Raises: - NotImplementedError: If not implemented by a subclass. - """ - raise NotImplementedError( - f"Collector start() is not implemented for {type(self).__name__}." - ) - - @contextlib.contextmanager - def pause(self): - """Context manager that pauses the collector if it is running free.""" - raise NotImplementedError( - f"Collector pause() is not implemented for {type(self).__name__}." - ) - - def async_shutdown( - self, timeout: float | None = None, close_env: bool = True - ) -> None: - """Shuts down the collector when started asynchronously with the `start` method. - - Args: - timeout (float, optional): The maximum time to wait for the collector to shutdown. - close_env (bool, optional): If True, the collector will close the contained environment. - Defaults to `True`. - - .. seealso:: :meth:`~.start` - - """ - return self.shutdown(timeout=timeout, close_env=close_env) - - def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: - """Extract weights from a model if needed. - - For the new weight sync scheme system, weight preparation is handled - by the scheme's prepare_weights() method. This method now only handles - legacy weight updater cases. - - Args: - weights: Either already-extracted weights or a model to extract from. - model_id: The model identifier for resolving string paths. - - Returns: - Extracted weights in the appropriate format. - """ - # New weight sync schemes handle preparation themselves - if self._weight_sync_schemes: - # Just pass through - WeightSender will call scheme.prepare_weights() - return weights - - # Legacy weight updater path - return self._legacy_extract_weights(weights, model_id) - - def _legacy_extract_weights(self, weights: Any, model_id: str) -> Any: - """Legacy weight extraction for old weight updater system. - - Args: - weights: Either already-extracted weights or a model to extract from. - model_id: The model identifier. - - Returns: - Extracted weights. - """ - if weights is None: - if model_id == "policy" and hasattr(self, "policy_weights"): - return self.policy_weights - elif model_id == "policy" and hasattr(self, "_policy_weights_dict"): - policy_device = ( - self.policy_device - if not isinstance(self.policy_device, (list, tuple)) - else self.policy_device[0] - ) - return self._policy_weights_dict.get(policy_device) - return None - - return weights - - @property - def _legacy_weight_updater(self) -> bool: - return self._weight_updater is not None - - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - model_id: str | None = None, - weights_dict: dict[str, Any] | None = None, - **kwargs, - ) -> None: - """Updates the policy weights for the data collector, accommodating both local and remote execution contexts. - - This method ensures that the policy weights used by the data collector are synchronized with the latest - trained weights. It supports both local and remote weight updates, depending on the configuration of the - data collector. The local (download) update is performed before the remote (upload) update, such that weights - can be transferred to the children workers from a server. - - Args: - policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None): The weights to update with. Can be: - - TensorDictModuleBase: A policy module whose weights will be extracted - - TensorDictBase: A TensorDict containing weights - - dict: A regular dict containing weights - - None: Will try to get weights from server using _get_server_weights() - worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the - workers that need to be updated. This is relevant when the collector has more than one worker associated - with it. - model_id (str | None, optional): The model identifier to update. If provided, only updates this specific - model. Cannot be used together with weights_dict. - weights_dict (dict[str, Any] | None, optional): Dictionary mapping model_id to weights for updating - multiple models atomically. Keys should match the model_ids registered in weight_sync_schemes. - Cannot be used together with model_id or policy_or_weights. - - Raises: - TypeError: If `worker_ids` is provided but no `weight_updater` is configured. - ValueError: If conflicting parameters are provided (e.g., both model_id and weights_dict). - - .. note:: Users should extend the `WeightUpdaterBase` classes to customize - the weight update logic for specific use cases. This method should not be overwritten. - - .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and - :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`. - - """ - if self._legacy_weight_updater: - return self._legacy_weight_update_impl( - policy_or_weights=policy_or_weights, - worker_ids=worker_ids, - model_id=model_id, - weights_dict=weights_dict, - **kwargs, - ) - else: - return self._weight_update_impl( - policy_or_weights=policy_or_weights, - worker_ids=worker_ids, - model_id=model_id, - weights_dict=weights_dict, - **kwargs, - ) - - def _legacy_weight_update_impl( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - model_id: str | None = None, - weights_dict: dict[str, Any] | None = None, - **kwargs, - ) -> None: - if weights_dict is not None: - raise ValueError("weights_dict is not supported with legacy weight updater") - if model_id is not None: - raise ValueError("model_id is not supported with legacy weight updater") - # Fall back to old weight updater system - self.weight_updater( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def _weight_update_impl( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - model_id: str | None = None, - weights_dict: dict[str, Any] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - if weights_dict is not None and model_id is not None: - raise ValueError("Cannot specify both 'weights_dict' and 'model_id'") - - if weights_dict is not None and policy_or_weights is not None: - raise ValueError( - "Cannot specify both 'weights_dict' and 'policy_or_weights'" - ) - - if policy_or_weights is not None: - weights_dict = {"policy": policy_or_weights} - - # Priority: new weight sync schemes > old weight updater system - if self._weight_senders: - if model_id is not None: - # Compose weight_dict - weights_dict = {model_id: policy_or_weights} - if weights_dict is None: - if "policy" in self._weight_senders: - weights_dict = {"policy": policy_or_weights} - elif len(self._weight_senders) == 1: - single_model_id = next(iter(self._weight_senders.keys())) - weights_dict = {single_model_id: policy_or_weights} - else: - raise ValueError( - "Cannot determine the model to update. Please provide a weights_dict." - ) - for target_model_id, weights in weights_dict.items(): - if target_model_id not in self._weight_senders: - raise KeyError( - f"Model '{target_model_id}' not found in registered weight senders. " - f"Available models: {list(self._weight_senders.keys())}" - ) - processed_weights = self._extract_weights_if_needed( - weights, target_model_id - ) - # Use new send() API with worker_ids support - self._weight_senders[target_model_id].send( - weights=processed_weights, worker_ids=worker_ids - ) - elif self._weight_updater is not None: - # unreachable - raise RuntimeError - else: - return self.receive_weights(policy_or_weights) - - def receive_weights(self, policy_or_weights: TensorDictBase | None = None): - # No weight updater configured - # For single-process collectors, apply weights locally if explicitly provided - if policy_or_weights is not None: - from torchrl.weight_update.weight_sync_schemes import WeightStrategy - - # Use WeightStrategy to apply weights properly - strategy = WeightStrategy(extract_as="tensordict") - - # Extract weights if needed - if isinstance(policy_or_weights, nn.Module): - weights = strategy.extract_weights(policy_or_weights) - else: - weights = policy_or_weights - - # Apply to local policy - if hasattr(self, "policy") and isinstance(self.policy, nn.Module): - strategy.apply_weights(self.policy, weights) - elif ( - hasattr(self, "_original_policy") - and isinstance(self._original_policy, nn.Module) - and hasattr(self, "policy") - and isinstance(self.policy, nn.Module) - ): - # If no weights were provided, mirror weights from the original (trainer) policy - from torchrl.weight_update.weight_sync_schemes import WeightStrategy - - strategy = WeightStrategy(extract_as="tensordict") - weights = strategy.extract_weights(self._original_policy) - # Cast weights to the policy device before applying - if self.policy_device is not None: - weights = weights.to(self.policy_device) - strategy.apply_weights(self.policy, weights) - # Otherwise, no action needed - policy is local and changes are immediately visible - - def __iter__(self) -> Iterator[TensorDictBase]: - try: - yield from self.iterator() - except Exception: - self.shutdown() - raise - - def next(self): - try: - if self._iterator is None: - self._iterator = iter(self) - out = next(self._iterator) - # if any, we don't want the device ref to be passed in distributed settings - if out is not None and (out.device != "cpu"): - out = out.copy().clear_device_() - return out - except StopIteration: - return None - - @abc.abstractmethod - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - raise NotImplementedError - - @abc.abstractmethod - def iterator(self) -> Iterator[TensorDictBase]: - raise NotImplementedError - - @abc.abstractmethod - def set_seed(self, seed: int, static_seed: bool = False) -> int: - raise NotImplementedError - - @abc.abstractmethod - def state_dict(self) -> OrderedDict: - raise NotImplementedError - - @abc.abstractmethod - def load_state_dict(self, state_dict: OrderedDict) -> None: - raise NotImplementedError - - def _read_compile_kwargs(self, compile_policy, cudagraph_policy): - self.compiled_policy = compile_policy not in (False, None) - self.cudagraphed_policy = cudagraph_policy not in (False, None) - self.compiled_policy_kwargs = ( - {} if not isinstance(compile_policy, typing.Mapping) else compile_policy - ) - self.cudagraphed_policy_kwargs = ( - {} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy - ) - - def __repr__(self) -> str: - string = f"{self.__class__.__name__}()" - return string - - def __class_getitem__(self, index): - raise NotImplementedError - - def __len__(self) -> int: - if self.total_frames > 0: - return -(self.total_frames // -self.requested_frames_per_batch) - raise RuntimeError("Non-terminating collectors do not have a length") - - def init_updater(self, *args, **kwargs): - """Initialize the weight updater with custom arguments. - - This method passes the arguments to the weight updater's init method. - If no weight updater is set, this is a no-op. - - Args: - *args: Positional arguments for weight updater initialization - **kwargs: Keyword arguments for weight updater initialization - """ - if self.weight_updater is not None: - self.weight_updater.init(*args, **kwargs) - - -@accept_remote_rref_udf_invocation -class SyncDataCollector(DataCollectorBase): - """Generic data collector for RL problems. Requires an environment constructor and a policy. - - Args: - create_env_fn (Callable or EnvBase): a callable that returns an instance of - :class:`~torchrl.envs.EnvBase` class, or the env itself. - policy (Callable): Policy to be executed in the environment. - Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. - If ``None`` is provided, the policy used will be a - :class:`~torchrl.collectors.RandomPolicy` instance with the environment - ``action_spec``. - Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. - This is the recommended usage of the collector. - Other callables are accepted too: - If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` - instances) it will be wrapped in a `nn.Module` first. - Then, the collector will try to assess if these - modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - - - If the policy forward signature matches any of ``forward(self, tensordict)``, - ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or - any typing with a single argument typed as a subclass of ``TensorDictBase``) - then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - - - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. - - .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / - pickled directly), the ``policy_factory`` should be used instead. - - Keyword Args: - policy_factory (Callable[[], Callable], optional): a callable that returns - a policy instance. This is exclusive with the `policy` argument. - - .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. - - frames_per_batch (int): A keyword-only argument representing the total - number of elements in a batch. - total_frames (int): A keyword-only argument representing the total - number of frames returned by the collector - during its lifespan. If the ``total_frames`` is not divisible by - ``frames_per_batch``, an exception is raised. - Endless collectors can be created by passing ``total_frames=-1``. - Defaults to ``-1`` (endless collector). - device (int, str or torch.device, optional): The generic device of the - collector. The ``device`` args fills any non-specified device: if - ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or - ``env_device`` is not specified, its value will be set to ``device``. - Defaults to ``None`` (No default device). - storing_device (int, str or torch.device, optional): The device on which - the output :class:`~tensordict.TensorDict` will be stored. - If ``device`` is passed and ``storing_device`` is ``None``, it will - default to the value indicated by ``device``. - For long trajectories, it may be necessary to store the data on a different - device than the one where the policy and env are executed. - Defaults to ``None`` (the output tensordict isn't on a specific device, - leaf tensors sit on the device where they were created). - env_device (int, str or torch.device, optional): The device on which - the environment should be cast (or executed if that functionality is - supported). If not specified and the env has a non-``None`` device, - ``env_device`` will default to that value. If ``device`` is passed - and ``env_device=None``, it will default to ``device``. If the value - as such specified of ``env_device`` differs from ``policy_device`` - and one of them is not ``None``, the data will be cast to ``env_device`` - before being passed to the env (i.e., passing different devices to - policy and env is supported). Defaults to ``None``. - policy_device (int, str or torch.device, optional): The device on which - the policy should be cast. - If ``device`` is passed and ``policy_device=None``, it will default - to ``device``. If the value as such specified of ``policy_device`` - differs from ``env_device`` and one of them is not ``None``, - the data will be cast to ``policy_device`` before being passed to - the policy (i.e., passing different devices to policy and env is - supported). Defaults to ``None``. - create_env_kwargs (dict, optional): Dictionary of kwargs for - ``create_env_fn``. - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span across multiple batches (unless - ``reset_at_each_iter`` is set to ``True``, see below). - Once a trajectory reaches ``n_steps``, the environment is reset. - If the environment wraps multiple environments together, the number - of steps is tracked for each environment independently. Negative - values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e., no maximum number of steps). - init_random_frames (int, optional): Number of frames for which the - policy is ignored before it is called. This feature is mainly - intended to be used in offline/model-based settings, where a - batch of random trajectories can be used to initialize training. - If provided, it will be rounded up to the closest multiple of frames_per_batch. - Defaults to ``None`` (i.e. no random frames). - reset_at_each_iter (bool, optional): Whether environments should be reset - at the beginning of a batch collection. - Defaults to ``False``. - postproc (Callable, optional): A post-processing transform, such as - a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` - instance. - - .. warning:: Postproc is not applied when a replay buffer is used and items are added to the buffer - as they are produced (`extend_buffer=False`). The recommended usage is to use `extend_buffer=True`. - - Defaults to ``None``. - split_trajs (bool, optional): Boolean indicating whether the resulting - TensorDict should be split according to the trajectories. - See :func:`~torchrl.collectors.utils.split_trajectories` for more - information. - Defaults to ``False``. - exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, - ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` - or ``torchrl.envs.utils.ExplorationType.MEAN``. - return_same_td (bool, optional): if ``True``, the same TensorDict - will be returned at each iteration, with its values - updated. This feature should be used cautiously: if the same - tensordict is added to a replay buffer for instance, - the whole content of the buffer will be identical. - Default is ``False``. - interruptor (_Interruptor, optional): - An _Interruptor object that can be used from outside the class to control rollout collection. - The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement - strategies such as preeptively stopping rollout collection. - Default is ``False``. - set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding - ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of - a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. - Truncated keys can be set through ``env.add_truncated_keys``. - Defaults to ``False``. - use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. - This isn't compatible with environments with dynamic specs. Defaults to ``True`` - for envs without dynamic specs, ``False`` for others. - replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts - but populate the buffer instead. - Defaults to ``None``. - - .. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts. - If the buffer needs to be populated with individual frames as they are collected, - set ``extend_buffer=False`` (deprecated). - - .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires - `extend_buffer=True`, as the whole batch needs to be observed to apply these transforms. - - extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not - with single steps. Defaults to `True`. - - .. note:: Setting this to `False` is deprecated and will be removed in a future version. - Extending the buffer with entire rollouts is the recommended approach for better - compatibility with postprocessing and trajectory splitting. - trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be - assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules - and ``False`` otherwise. - compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled - using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it - will be used to compile the policy. - cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped - in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. - If a dictionary of kwargs is passed, it will be used to wrap the policy. - no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. - For environments running directly on CUDA (`IsaacLab `_ - or `ManiSkills `_) cuda synchronization may cause unexpected - crashes. - Defaults to ``False``. - weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` - or its subclass, responsible for updating the policy weights on remote inference workers. - This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment. - Consider using a constructor if the updater needs to be serialized. - track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. - This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. - Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track - the policy version. - Defaults to `False`. - - Examples: - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = SyncDataCollector( - ... create_env_fn=env_maker, - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... ) - >>> for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break - TensorDict( - fields={ - action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - collector: TensorDict( - fields={ - traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False) - >>> del collector - - The collector delivers batches of data that are marked with a ``"time"`` - dimension. - - Examples: - >>> assert data.names[-1] == "time" - - """ - - _ignore_rb: bool = False - - def __init__( - self, - create_env_fn: ( - EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]] # noqa: F821 - ), # noqa: F821 - policy: None - | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, - *, - policy_factory: Callable[[], Callable] | None = None, - frames_per_batch: int, - total_frames: int = -1, - device: DEVICE_TYPING | None = None, - storing_device: DEVICE_TYPING | None = None, - policy_device: DEVICE_TYPING | None = None, - env_device: DEVICE_TYPING | None = None, - create_env_kwargs: dict[str, Any] | None = None, - max_frames_per_traj: int | None = None, - init_random_frames: int | None = None, - reset_at_each_iter: bool = False, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - split_trajs: bool | None = None, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - return_same_td: bool = False, - reset_when_done: bool = True, - interruptor=None, - set_truncated: bool = False, - use_buffers: bool | None = None, - replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = True, - local_init_rb: bool | None = None, - trust_policy: bool | None = None, - compile_policy: bool | dict[str, Any] | None = None, - cudagraph_policy: bool | dict[str, Any] | None = None, - no_cuda_sync: bool = False, - weight_updater: WeightUpdaterBase - | Callable[[], WeightUpdaterBase] - | None = None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, - track_policy_version: bool = False, - **kwargs, - ): - self.closed = True - - # Initialize environment - env = self._init_env(create_env_fn, create_env_kwargs) - - # Initialize policy - policy = self._init_policy(policy, policy_factory, env, trust_policy) - self._read_compile_kwargs(compile_policy, cudagraph_policy) - - # Handle trajectory pool and validate kwargs - self._traj_pool_val = kwargs.pop("traj_pool", None) - if kwargs: - raise TypeError( - f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}." - ) - - # Set up devices and synchronization - self._setup_devices( - device=device, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - no_cuda_sync=no_cuda_sync, - ) - - self.env: EnvBase = env - del env - - # Set up policy version tracking - self._setup_policy_version_tracking(track_policy_version) - - # Set up replay buffer - self._setup_replay_buffer( - replay_buffer=replay_buffer, - extend_buffer=extend_buffer, - local_init_rb=local_init_rb, - postproc=postproc, - split_trajs=split_trajs, - return_same_td=return_same_td, - use_buffers=use_buffers, - ) - - self.closed = False - - # Validate reset_when_done - if not reset_when_done: - raise ValueError("reset_when_done is deprecated.") - self.reset_when_done = reset_when_done - self.n_env = self.env.batch_size.numel() - - # Register collector with policy and env - if hasattr(policy, "register_collector"): - policy.register_collector(self) - if hasattr(self.env, "register_collector"): - self.env.register_collector(self) - - # Set up policy and weights - self._setup_policy_and_weights(policy) - - # Apply environment device - self._apply_env_device() - - # Set up max frames per trajectory - self._setup_max_frames_per_traj(max_frames_per_traj) - - # Validate and set total frames - self.reset_at_each_iter = reset_at_each_iter - self._setup_total_frames(total_frames, frames_per_batch) - - # Set up init random frames - self._setup_init_random_frames(init_random_frames, frames_per_batch) - - # Set up postproc - self._setup_postproc(postproc) - - # Calculate frames per batch - self._setup_frames_per_batch(frames_per_batch) - - # Set exploration and other options - self.exploration_type = ( - exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE - ) - self.return_same_td = return_same_td - self.set_truncated = set_truncated - - # Create shuttle and rollout buffers - self._make_shuttle() - self._maybe_make_final_rollout(make_rollout=self._use_buffers) - self._set_truncated_keys() - - # Set split trajectories option - if split_trajs is None: - split_trajs = False - self.split_trajs = split_trajs - self._exclude_private_keys = True - - # Set up interruptor and frame tracking - self.interruptor = interruptor - self._frames = 0 - self._iter = -1 - - # Set up weight synchronization - self._setup_weight_sync(weight_updater, weight_sync_schemes) - - def _init_env( - self, - create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], - create_env_kwargs: dict[str, Any] | None, - ) -> EnvBase: - """Initialize and configure the environment.""" - from torchrl.envs.batched_envs import BatchedEnvBase - - if create_env_kwargs is None: - create_env_kwargs = {} - - if not isinstance(create_env_fn, EnvBase): - env = create_env_fn(**create_env_kwargs) - else: - env = create_env_fn - if create_env_kwargs: - if not isinstance(env, BatchedEnvBase): - raise RuntimeError( - "kwargs were passed to SyncDataCollector but they can't be set " - f"on environment of type {type(create_env_fn)}." - ) - env.update_kwargs(create_env_kwargs) - return env - - def _init_policy( - self, - policy: TensorDictModule | Callable | None, - policy_factory: Callable[[], Callable] | None, - env: EnvBase, - trust_policy: bool | None, - ) -> TensorDictModule | Callable: - """Initialize and configure the policy.""" - if policy is None: - if policy_factory is not None: - policy = policy_factory() - else: - policy = RandomPolicy(env.full_action_spec) - elif policy_factory is not None: - raise TypeError("policy_factory cannot be used with policy argument.") - - # If the underlying policy has a state_dict, keep a reference to it - if hasattr(policy, "state_dict"): - self._policy_w_state_dict = policy - - if trust_policy is None: - trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) - self.trust_policy = trust_policy - - return policy - - def _setup_devices( - self, - device: DEVICE_TYPING | None, - storing_device: DEVICE_TYPING | None, - policy_device: DEVICE_TYPING | None, - env_device: DEVICE_TYPING | None, - no_cuda_sync: bool, - ) -> None: - """Set up devices and synchronization functions.""" - storing_device, policy_device, env_device = self._get_devices( - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - device=device, - ) - - self.storing_device = storing_device - self._sync_storage = self._get_sync_fn(storing_device) - - self.env_device = env_device - self._sync_env = self._get_sync_fn(env_device) - - self.policy_device = policy_device - self._sync_policy = self._get_sync_fn(policy_device) - - self.device = device - self.no_cuda_sync = no_cuda_sync - self._cast_to_policy_device = self.policy_device != self.env_device - - def _get_sync_fn(self, device: torch.device | None) -> Callable: - """Get the appropriate synchronization function for a device.""" - if device is not None and device.type != "cuda": - # Cuda handles sync - if torch.cuda.is_available(): - return torch.cuda.synchronize - elif torch.backends.mps.is_available() and hasattr(torch, "mps"): - return torch.mps.synchronize - elif hasattr(torch, "npu") and torch.npu.is_available(): - return torch.npu.synchronize - elif device.type == "cpu": - return _do_nothing - else: - raise RuntimeError("Non supported device") - else: - return _do_nothing - - def _setup_policy_version_tracking( - self, track_policy_version: bool | PolicyVersion - ) -> None: - """Set up policy version tracking if requested.""" - self.policy_version_tracker = track_policy_version - if isinstance(track_policy_version, bool) and track_policy_version: - from torchrl.envs.batched_envs import BatchedEnvBase - - if isinstance(self.env, BatchedEnvBase): - raise RuntimeError( - "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " - "and pass that transform to the collector." - ) - self.policy_version_tracker = PolicyVersion() - self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore - elif hasattr(track_policy_version, "increment_version"): - self.policy_version_tracker = track_policy_version - self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore - else: - self.policy_version_tracker = None - - def _setup_replay_buffer( - self, - replay_buffer: ReplayBuffer | None, - extend_buffer: bool, - local_init_rb: bool | None, - postproc: Callable | None, - split_trajs: bool | None, - return_same_td: bool, - use_buffers: bool | None, - ) -> None: - """Set up replay buffer configuration and validate compatibility.""" - self.replay_buffer = replay_buffer - self.extend_buffer = extend_buffer - - # Handle local_init_rb deprecation - if local_init_rb is None: - local_init_rb = False - if replay_buffer is not None and not local_init_rb: - warnings.warn( - "local_init_rb=False is deprecated and will be removed in v0.12. " - "The new storage-level initialization provides better performance.", - FutureWarning, - ) - self.local_init_rb = local_init_rb - - # Validate replay buffer compatibility - if self.replay_buffer is not None and not self._ignore_rb: - if postproc is not None and not self.extend_buffer: - raise TypeError( - "postproc must be None when a replay buffer is passed, or extend_buffer must be set to True." - ) - if split_trajs not in (None, False) and not self.extend_buffer: - raise TypeError( - "split_trajs must be None/False when a replay buffer is passed, or extend_buffer must be set to True." - ) - if return_same_td: - raise TypeError( - "return_same_td must be False when a replay buffer is passed, or extend_buffer must be set to True." - ) - if use_buffers: - raise TypeError("replay_buffer is exclusive with use_buffers.") - - if use_buffers is None: - use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None - self._use_buffers = use_buffers - - def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None: - """Set up policy, wrapped policy, and extract weights.""" - self._original_policy = policy - policy, self.get_weights_fn = self._get_policy_and_device(policy=policy) - - if not self.trust_policy: - self.policy = policy - env = getattr(self, "env", None) - try: - wrapped_policy = _make_compatible_policy( - policy=policy, - observation_spec=getattr(env, "observation_spec", None), - env=self.env, - ) - except (TypeError, AttributeError, ValueError) as err: - raise TypeError( - "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True." - ) from err - self._wrapped_policy = wrapped_policy - else: - self.policy = self._wrapped_policy = policy - - # Extract policy weights - if isinstance(self._wrapped_policy, nn.Module): - self.policy_weights = TensorDict.from_module( - self._wrapped_policy, as_module=True - ).data - else: - self.policy_weights = TensorDict() - - # Apply compilation/cudagraph - if self.compiled_policy: - self._wrapped_policy = compile_with_warmup( - self._wrapped_policy, **self.compiled_policy_kwargs - ) - if self.cudagraphed_policy: - self._wrapped_policy = CudaGraphModule( - self._wrapped_policy, - in_keys=[], - out_keys=[], - device=self.policy_device, - **self.cudagraphed_policy_kwargs, - ) - - def _apply_env_device(self) -> None: - """Apply device to environment if specified.""" - if self.env_device: - self.env: EnvBase = self.env.to(self.env_device) - elif self.env.device is not None: - # Use the device of the env if none was provided - self.env_device = self.env.device - - # Check if we need to cast to env device - self._cast_to_env_device = self._cast_to_policy_device or ( - self.env.device != self.storing_device - ) - - def _setup_max_frames_per_traj(self, max_frames_per_traj: int | None) -> None: - """Set up maximum frames per trajectory and add StepCounter if needed.""" - self.max_frames_per_traj = ( - int(max_frames_per_traj) if max_frames_per_traj is not None else 0 - ) - if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0: - # Check that there is no StepCounter yet - for key in self.env.output_spec.keys(True, True): - if isinstance(key, str): - key = (key,) - if "step_count" in key: - raise ValueError( - "A 'step_count' key is already present in the environment " - "and the 'max_frames_per_traj' argument may conflict with " - "a 'StepCounter' that has already been set. " - "Possible solutions: Set max_frames_per_traj to 0 or " - "remove the StepCounter limit from the environment transforms." - ) - self.env = TransformedEnv( - self.env, StepCounter(max_steps=self.max_frames_per_traj) - ) - - def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None: - """Validate and set total frames.""" - if total_frames is None or total_frames < 0: - total_frames = float("inf") - else: - remainder = total_frames % frames_per_batch - if remainder != 0 and rl_warnings(): - warnings.warn( - f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " - f"This means {frames_per_batch - remainder} additional frames will be collected." - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - self.total_frames = ( - int(total_frames) if total_frames != float("inf") else total_frames - ) - - def _setup_init_random_frames( - self, init_random_frames: int | None, frames_per_batch: int - ) -> None: - """Set up initial random frames.""" - self.init_random_frames = ( - int(init_random_frames) if init_random_frames not in (None, -1) else 0 - ) - if ( - init_random_frames not in (-1, None, 0) - and init_random_frames % frames_per_batch != 0 - and rl_warnings() - ): - warnings.warn( - f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), " - f" this results in more init_random_frames than requested" - f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})." - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - - def _setup_postproc(self, postproc: Callable | None) -> None: - """Set up post-processing transform.""" - self.postproc = postproc - if ( - self.postproc is not None - and hasattr(self.postproc, "to") - and self.storing_device - ): - postproc = self.postproc.to(self.storing_device) - if postproc is not self.postproc and postproc is not None: - self.postproc = postproc - - def _setup_frames_per_batch(self, frames_per_batch: int) -> None: - """Calculate and validate frames per batch.""" - if frames_per_batch % self.n_env != 0 and rl_warnings(): - warnings.warn( - f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " - f" this results in more frames_per_batch per iteration that requested" - f" ({-(-frames_per_batch // self.n_env) * self.n_env}). " - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - self.frames_per_batch = -(-frames_per_batch // self.n_env) - self.requested_frames_per_batch = self.frames_per_batch * self.n_env - - def _setup_weight_sync( - self, - weight_updater: WeightUpdaterBase | Callable | None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up weight synchronization system.""" - if weight_sync_schemes is not None: - # Use new simplified weight synchronization system - self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - # For single-process collectors, we don't need senders/receivers - # The policy is local and changes are immediately visible - # Senders will be set up in multiprocess collectors during _run_processes - self.weight_updater = None # Don't use legacy system - elif weight_updater is not None: - # Use legacy weight updater system if explicitly provided - if not isinstance(weight_updater, WeightUpdaterBase): - if callable(weight_updater): - weight_updater = weight_updater() - else: - raise TypeError( - f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." - ) - warnings.warn( - "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. " - "This will be removed in a future version.", - DeprecationWarning, - stacklevel=2, - ) - self.weight_updater = weight_updater - self._weight_sync_schemes = None - self._weight_senders = {} - else: - # No weight sync needed for single-process collectors - self.weight_updater = None - self._weight_sync_schemes = None - self._weight_senders = {} - - @property - def _traj_pool(self): - pool = getattr(self, "_traj_pool_val", None) - if pool is None: - pool = self._traj_pool_val = _TrajectoryPool() - return pool - - def _make_shuttle(self): - # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env - with torch.no_grad(): - self._shuttle = self.env.reset() - if self.policy_device != self.env_device or self.env_device is None: - self._shuttle_has_no_device = True - self._shuttle.clear_device_() - else: - self._shuttle_has_no_device = False - - traj_ids = self._traj_pool.get_traj_and_increment( - self.n_env, device=self.storing_device - ).view(self.env.batch_size) - self._shuttle.set( - ("collector", "traj_ids"), - traj_ids, - ) - - def _maybe_make_final_rollout(self, make_rollout: bool): - if make_rollout: - with torch.no_grad(): - self._final_rollout = self.env.fake_tensordict() - - # If storing device is not None, we use this to cast the storage. - # If it is None and the env and policy are on the same device, - # the storing device is already the same as those, so we don't need - # to consider this use case. - # In all other cases, we can't really put a device on the storage, - # since at least one data source has a device that is not clear. - if self.storing_device: - self._final_rollout = self._final_rollout.to( - self.storing_device, non_blocking=True - ) - else: - # erase all devices - self._final_rollout.clear_device_() - - # If the policy has a valid spec, we use it - self._policy_output_keys = set() - if ( - make_rollout - and hasattr(self._wrapped_policy, "spec") - and self._wrapped_policy.spec is not None - and all(v is not None for v in self._wrapped_policy.spec.values(True, True)) - ): - if any( - key not in self._final_rollout.keys(isinstance(key, tuple)) - for key in self._wrapped_policy.spec.keys(True, True) - ): - # if policy spec is non-empty, all the values are not None and the keys - # match the out_keys we assume the user has given all relevant information - # the policy could have more keys than the env: - policy_spec = self._wrapped_policy.spec - if policy_spec.ndim < self._final_rollout.ndim: - policy_spec = policy_spec.expand(self._final_rollout.shape) - for key, spec in policy_spec.items(True, True): - self._policy_output_keys.add(key) - if key in self._final_rollout.keys(True): - continue - self._final_rollout.set(key, spec.zero()) - elif ( - not make_rollout - and hasattr(self._wrapped_policy, "out_keys") - and self._wrapped_policy.out_keys - ): - self._policy_output_keys = list(self._wrapped_policy.out_keys) - else: - if make_rollout: - # otherwise, we perform a small number of steps with the policy to - # determine the relevant keys with which to pre-populate _final_rollout. - # This is the safest thing to do if the spec has None fields or if there is - # no spec at all. - # See #505 for additional context. - self._final_rollout.update(self._shuttle.copy()) - with torch.no_grad(): - policy_input = self._shuttle.copy() - if self.policy_device: - policy_input = policy_input.to(self.policy_device) - # we cast to policy device, we'll deal with the device later - policy_input_copy = policy_input.copy() - policy_input_clone = ( - policy_input.clone() - ) # to test if values have changed in-place - if self.compiled_policy: - cudagraph_mark_step_begin() - policy_output = self._wrapped_policy(policy_input) - - # check that we don't have exclusive keys, because they don't appear in keys - def check_exclusive(val): - if ( - isinstance(val, LazyStackedTensorDict) - and val._has_exclusive_keys - ): - raise RuntimeError( - "LazyStackedTensorDict with exclusive keys are not permitted in collectors. " - "Consider using a placeholder for missing keys." - ) - - policy_output._fast_apply( - check_exclusive, call_on_nested=True, filter_empty=True - ) - - # Use apply, because it works well with lazy stacks - # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit - # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has - # changed them here). - # This will cause a failure to update entries when policy and env device mismatch and - # casting is necessary. - def filter_policy(name, value_output, value_input, value_input_clone): - if (value_input is None) or ( - (value_output is not value_input) - and ( - value_output.device != value_input_clone.device - or ~torch.isclose(value_output, value_input_clone).any() - ) - ): - return value_output - - filtered_policy_output = policy_output.apply( - filter_policy, - policy_input_copy, - policy_input_clone, - default=None, - filter_empty=True, - named=True, - ) - self._policy_output_keys = list( - self._policy_output_keys.union( - set(filtered_policy_output.keys(True, True)) - ) - ) - if make_rollout: - self._final_rollout.update( - policy_output.select(*self._policy_output_keys) - ) - del filtered_policy_output, policy_output, policy_input - - _env_output_keys = [] - for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]: - _env_output_keys += list(self.env.output_spec[spec].keys(True, True)) - self._env_output_keys = _env_output_keys - if make_rollout: - self._final_rollout = ( - self._final_rollout.unsqueeze(-1) - .expand(*self.env.batch_size, self.frames_per_batch) - .clone() - .zero_() - ) - - # in addition to outputs of the policy, we add traj_ids to - # _final_rollout which will be collected during rollout - self._final_rollout.set( - ("collector", "traj_ids"), - torch.zeros( - *self._final_rollout.batch_size, - dtype=torch.int64, - device=self.storing_device, - ), - ) - self._final_rollout.refine_names(..., "time") - - def _set_truncated_keys(self): - self._truncated_keys = [] - if self.set_truncated: - if not any(_ends_with(key, "truncated") for key in self.env.done_keys): - raise RuntimeError( - "set_truncated was set to True but no truncated key could be found " - "in the environment. Make sure the truncated keys are properly set using " - "`env.add_truncated_keys()` before passing the env to the collector." - ) - self._truncated_keys = [ - key for key in self.env.done_keys if _ends_with(key, "truncated") - ] - - @classmethod - def _get_devices( - cls, - *, - storing_device: torch.device, - policy_device: torch.device, - env_device: torch.device, - device: torch.device, - ): - device = _make_ordinal_device(torch.device(device) if device else device) - storing_device = _make_ordinal_device( - torch.device(storing_device) if storing_device else device - ) - policy_device = _make_ordinal_device( - torch.device(policy_device) if policy_device else device - ) - env_device = _make_ordinal_device( - torch.device(env_device) if env_device else device - ) - if storing_device is None and (env_device == policy_device): - storing_device = env_device - return storing_device, policy_device, env_device - - # for RPC - def next(self): - return super().next() - - # for RPC - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - super().update_policy_weights_( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def set_seed(self, seed: int, static_seed: bool = False) -> int: - """Sets the seeds of the environments stored in the DataCollector. - - Args: - seed (int): integer representing the seed to be used for the environment. - static_seed(bool, optional): if ``True``, the seed is not incremented. - Defaults to False - - Returns: - Output seed. This is useful when more than one environment is contained in the DataCollector, as the - seed will be incremented for each of these. The resulting seed is the seed of the last environment. - - Examples: - >>> from torchrl.envs import ParallelEnv - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> env_fn = lambda: GymEnv("Pendulum-v1") - >>> env_fn_parallel = ParallelEnv(6, env_fn) - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) - >>> out_seed = collector.set_seed(1) # out_seed = 6 - - """ - out = self.env.set_seed(seed, static_seed=static_seed) - return out - - def _increment_frames(self, numel): - self._frames += numel - completed = self._frames >= self.total_frames - if completed: - self.env.close() - return completed - - def iterator(self) -> Iterator[TensorDictBase]: - """Iterates through the DataCollector. - - Yields: TensorDictBase objects containing (chunks of) trajectories - - """ - if ( - not self.no_cuda_sync - and self.storing_device - and self.storing_device.type == "cuda" - ): - stream = torch.cuda.Stream(self.storing_device, priority=-1) - event = stream.record_event() - streams = [stream] - events = [event] - elif not self.no_cuda_sync and self.storing_device is None: - streams = [] - events = [] - # this way of checking cuda is robust to lazy stacks with mismatching shapes - cuda_devices = set() - - def cuda_check(tensor: torch.Tensor): - if tensor.is_cuda: - cuda_devices.add(tensor.device) - - if not self._use_buffers: - # This may be a bit dangerous as `torch.device("cuda")` may not have a precise - # device associated, whereas `tensor.device` always has - for spec in self.env.specs.values(True, True): - if spec.device is not None and spec.device.type == "cuda": - if ":" not in str(spec.device): - raise RuntimeError( - "A cuda spec did not have a device associated. Make sure to " - "pass `'cuda:device_num'` to each spec device." - ) - cuda_devices.add(spec.device) - else: - self._final_rollout.apply(cuda_check, filter_empty=True) - for device in cuda_devices: - streams.append(torch.cuda.Stream(device, priority=-1)) - events.append(streams[-1].record_event()) - else: - streams = [] - events = [] - with contextlib.ExitStack() as stack: - for stream in streams: - stack.enter_context(torch.cuda.stream(stream)) - - while self._frames < self.total_frames: - self._iter += 1 - if self.verbose: - torchrl_logger.info("Collector: rollout.") - tensordict_out = self.rollout() - if tensordict_out is None: - # if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out - # frames are updated within the rollout function - if self.verbose: - torchrl_logger.info("Collector: No tensordict_out. Yielding.") - yield - continue - self._increment_frames(tensordict_out.numel()) - tensordict_out = self._postproc(tensordict_out) - if self.verbose: - torchrl_logger.info("Collector: postproc done.") - if self.return_same_td: - # This is used with multiprocessed collectors to use the buffers - # stored in the tensordict. - if events: - for event in events: - event.record() - event.synchronize() - yield tensordict_out - elif self.replay_buffer is not None and not self._ignore_rb: - self.replay_buffer.extend(tensordict_out) - if self.verbose: - torchrl_logger.info( - f"Collector: Added {tensordict_out.numel()} frames to replay buffer. " - "Buffer write count: {self.replay_buffer.write_count}. Yielding." - ) - yield - else: - # we must clone the values, as the tensordict is updated in-place. - # otherwise the following code may break: - # >>> for i, data in enumerate(collector): - # >>> if i == 0: - # >>> data0 = data - # >>> elif i == 1: - # >>> data1 = data - # >>> else: - # >>> break - # >>> assert data0["done"] is not data1["done"] - yield tensordict_out.clone() - - def start(self): - """Starts the collector in a separate thread for asynchronous data collection. - - The collected data is stored in the provided replay buffer. This method is useful when you want to decouple data - collection from training, allowing your training loop to run independently of the data collection process. - - Raises: - RuntimeError: If no replay buffer is defined during the collector's initialization. - - Example: - >>> import time - >>> from functools import partial - >>> - >>> import tqdm - >>> - >>> from torchrl.collectors import SyncDataCollector, RandomPolicy - >>> from torchrl.data import LazyTensorStorage, ReplayBuffer - >>> from torchrl.envs import GymEnv, set_gym_backend - >>> import ale_py - >>> - >>> # Set the gym backend to gymnasium - >>> set_gym_backend("gymnasium").set() - >>> - >>> if __name__ == "__main__": - ... # Create a random policy for the Pong environment - ... env = GymEnv("ALE/Pong-v5") - ... policy = RandomPolicy(env.action_spec) - ... - ... # Initialize a shared replay buffer - ... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True) - ... - ... # Create a synchronous data collector - ... collector = SyncDataCollector( - ... env, - ... policy=policy, - ... replay_buffer=rb, - ... frames_per_batch=256, - ... total_frames=-1, - ... ) - ... - ... # Progress bar to track the number of collected frames - ... pbar = tqdm.tqdm(total=100_000) - ... - ... # Start the collector asynchronously - ... collector.start() - ... - ... # Track the write count of the replay buffer - ... prec_wc = 0 - ... while True: - ... wc = rb.write_count - ... c = wc - prec_wc - ... prec_wc = wc - ... - ... # Update the progress bar - ... pbar.update(c) - ... pbar.set_description(f"Write Count: {rb.write_count}") - ... - ... # Check the write count every 0.5 seconds - ... time.sleep(0.5) - ... - ... # Stop when the desired number of frames is reached - ... if rb.write_count . 100_000: - ... break - ... - ... # Shut down the collector - ... collector.async_shutdown() - """ - if self.replay_buffer is None: - raise RuntimeError("Replay buffer must be defined for execution.") - if not self.is_running(): - self._stop = False - self._thread = threading.Thread(target=self._run_iterator) - self._thread.daemon = ( - True # So that the thread dies when the main program exits - ) - self._thread.start() - - def _run_iterator(self): - for _ in self: - if self._stop: - return - - def is_running(self): - return hasattr(self, "_thread") and self._thread.is_alive() - - def async_shutdown( - self, timeout: float | None = None, close_env: bool = True - ) -> None: - """Finishes processes started by ray.init() during async execution.""" - self._stop = True - if hasattr(self, "_thread") and self._thread.is_alive(): - self._thread.join(timeout=timeout) - self.shutdown(close_env=close_env) - - def _postproc(self, tensordict_out): - if self.split_trajs: - tensordict_out = split_trajectories(tensordict_out, prefix="collector") - if self.postproc is not None: - tensordict_out = self.postproc(tensordict_out) - if self._exclude_private_keys: - - def is_private(key): - if isinstance(key, str) and key.startswith("_"): - return True - if isinstance(key, tuple) and any(_key.startswith("_") for _key in key): - return True - return False - - excluded_keys = [ - key for key in tensordict_out.keys(True) if is_private(key) - ] - tensordict_out = tensordict_out.exclude(*excluded_keys, inplace=True) - return tensordict_out - - def _update_traj_ids(self, env_output) -> None: - # we can't use the reset keys because they're gone - traj_sop = _aggregate_end_of_traj( - env_output.get("next"), done_keys=self.env.done_keys - ) - if traj_sop.any(): - device = self.storing_device - - traj_ids = self._shuttle.get(("collector", "traj_ids")) - if device is not None: - traj_ids = traj_ids.to(device) - traj_sop = traj_sop.to(device) - elif traj_sop.device != traj_ids.device: - traj_sop = traj_sop.to(traj_ids.device) - - pool = self._traj_pool - new_traj = pool.get_traj_and_increment( - traj_sop.sum(), device=traj_sop.device - ) - traj_ids = traj_ids.masked_scatter(traj_sop, new_traj) - self._shuttle.set(("collector", "traj_ids"), traj_ids) - - @torch.no_grad() - def rollout(self) -> TensorDictBase: - """Computes a rollout in the environment using the provided policy. - - Returns: - TensorDictBase containing the computed rollout. - - """ - if self.reset_at_each_iter: - self._shuttle.update(self.env.reset()) - - # self._shuttle.fill_(("collector", "step_count"), 0) - if self._use_buffers: - self._final_rollout.fill_(("collector", "traj_ids"), -1) - else: - pass - tensordicts = [] - with set_exploration_type(self.exploration_type): - for t in range(self.frames_per_batch): - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - self.env.rand_action(self._shuttle) - if ( - self.policy_device is not None - and self.policy_device != self.env_device - ): - # TODO: This may break with exclusive / ragged lazy stacks - self._shuttle.apply( - lambda name, val: val.to( - device=self.policy_device, non_blocking=True - ) - if name in self._policy_output_keys - else val, - out=self._shuttle, - named=True, - nested_keys=True, - ) - else: - if self._cast_to_policy_device: - if self.policy_device is not None: - # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking - non_blocking = ( - not self.no_cuda_sync - or self.policy_device.type == "cuda" - ) - policy_input = self._shuttle.to( - self.policy_device, - non_blocking=non_blocking, - ) - if not self.no_cuda_sync: - self._sync_policy() - elif self.policy_device is None: - # we know the tensordict has a device otherwise we would not be here - # we can pass this, clear_device_ must have been called earlier - # policy_input = self._shuttle.clear_device_() - policy_input = self._shuttle - else: - policy_input = self._shuttle - # we still do the assignment for security - if self.compiled_policy: - cudagraph_mark_step_begin() - policy_output = self._wrapped_policy(policy_input) - if self.compiled_policy: - policy_output = policy_output.clone() - if self._shuttle is not policy_output: - # ad-hoc update shuttle - self._shuttle.update( - policy_output, keys_to_update=self._policy_output_keys - ) - - if self._cast_to_env_device: - if self.env_device is not None: - non_blocking = ( - not self.no_cuda_sync or self.env_device.type == "cuda" - ) - env_input = self._shuttle.to( - self.env_device, non_blocking=non_blocking - ) - if not self.no_cuda_sync: - self._sync_env() - elif self.env_device is None: - # we know the tensordict has a device otherwise we would not be here - # we can pass this, clear_device_ must have been called earlier - # env_input = self._shuttle.clear_device_() - env_input = self._shuttle - else: - env_input = self._shuttle - env_output, env_next_output = self.env.step_and_maybe_reset(env_input) - - if self._shuttle is not env_output: - # ad-hoc update shuttle - next_data = env_output.get("next") - if self._shuttle_has_no_device: - # Make sure - next_data.clear_device_() - self._shuttle.set("next", next_data) - - if self.verbose: - torchrl_logger.info( - f"Collector: Rollout step completed {self._iter=}." - ) - if ( - self.replay_buffer is not None - and not self._ignore_rb - and not self.extend_buffer - ): - if self.verbose: - torchrl_logger.info( - f"Collector: Adding {env_output.numel()} frames to replay buffer using add()." - ) - self.replay_buffer.add(self._shuttle) - if self._increment_frames(self._shuttle.numel()): - return - else: - if self.storing_device is not None: - if self.verbose: - torchrl_logger.info( - f"Collector: Moving to {self.storing_device} and adding to queue." - ) - non_blocking = ( - not self.no_cuda_sync or self.storing_device.type == "cuda" - ) - tensordicts.append( - self._shuttle.to( - self.storing_device, non_blocking=non_blocking - ) - ) - if not self.no_cuda_sync: - self._sync_storage() - else: - if self.verbose: - torchrl_logger.info( - "Collector: Adding to queue (no device)." - ) - tensordicts.append(self._shuttle) - - # carry over collector data without messing up devices - collector_data = self._shuttle.get("collector").copy() - self._shuttle = env_next_output - if self._shuttle_has_no_device: - self._shuttle.clear_device_() - self._shuttle.set("collector", collector_data) - self._update_traj_ids(env_output) - - if ( - self.interruptor is not None - and self.interruptor.collection_stopped() - ): - if self.verbose: - torchrl_logger.info("Collector: Interruptor stopped.") - if ( - self.replay_buffer is not None - and not self._ignore_rb - and not self.extend_buffer - ): - return - result = self._final_rollout - if self._use_buffers: - try: - torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout[..., : t + 1], - ) - except RuntimeError: - with self._final_rollout.unlock_(): - torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout[..., : t + 1], - ) - else: - result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) - break - else: - if self._use_buffers: - torchrl_logger.info("Returning final rollout within buffer.") - result = self._final_rollout - try: - result = torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout, - ) - - except RuntimeError: - with self._final_rollout.unlock_(): - result = torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout, - ) - elif ( - self.replay_buffer is not None - and not self._ignore_rb - and not self.extend_buffer - ): - return - else: - torchrl_logger.info( - "Returning final rollout with NO buffer (maybe_dense_stack)." - ) - result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) - result.refine_names(..., "time") - - return self._maybe_set_truncated(result) - - def _maybe_set_truncated(self, final_rollout): - last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,) - for truncated_key in self._truncated_keys: - truncated = final_rollout["next", truncated_key] - truncated[last_step] = True - final_rollout["next", truncated_key] = truncated - done = final_rollout["next", _replace_last(truncated_key, "done")] - final_rollout["next", _replace_last(truncated_key, "done")] = ( - done | truncated - ) - return final_rollout - - @torch.no_grad() - def reset(self, index=None, **kwargs) -> None: - """Resets the environments to a new initial state.""" - # metadata - collector_metadata = self._shuttle.get("collector").clone() - if index is not None: - # check that the env supports partial reset - if prod(self.env.batch_size) == 0: - raise RuntimeError("resetting unique env with index is not permitted.") - for reset_key, done_keys in zip( - self.env.reset_keys, self.env.done_keys_groups - ): - _reset = torch.zeros( - self.env.full_done_spec[done_keys[0]].shape, - dtype=torch.bool, - device=self.env.device, - ) - _reset[index] = 1 - self._shuttle.set(reset_key, _reset) - else: - _reset = None - self._shuttle.zero_() - - self._shuttle.update(self.env.reset(**kwargs), inplace=True) - collector_metadata["traj_ids"] = ( - collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min() - ) - self._shuttle["collector"] = collector_metadata - - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - """Shuts down all workers and/or closes the local environment. - - Args: - timeout (float, optional): The timeout for closing pipes between workers. - No effect for this class. - close_env (bool, optional): Whether to close the environment. Defaults to `True`. - raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. - """ - try: - if not self.closed: - self.closed = True - del self._shuttle - if self._use_buffers: - del self._final_rollout - if close_env and not self.env.is_closed: - self.env.close(raise_if_closed=raise_on_error) - del self.env - return - except Exception as e: - if raise_on_error: - raise e - else: - pass - - def __del__(self): - try: - self.shutdown() - except Exception: - # an AttributeError will typically be raised if the collector is deleted when the program ends. - # In the future, insignificant changes to the close method may change the error type. - # We excplicitely assume that any error raised during closure in - # __del__ will not affect the program. - pass - - def state_dict(self) -> OrderedDict: - """Returns the local state_dict of the data collector (environment and policy). - - Returns: - an ordered dictionary with fields :obj:`"policy_state_dict"` and - `"env_state_dict"`. - - """ - from torchrl.envs.batched_envs import BatchedEnvBase - - if isinstance(self.env, TransformedEnv): - env_state_dict = self.env.transform.state_dict() - elif isinstance(self.env, BatchedEnvBase): - env_state_dict = self.env.state_dict() - else: - env_state_dict = OrderedDict() - - if hasattr(self, "_policy_w_state_dict"): - policy_state_dict = self._policy_w_state_dict.state_dict() - state_dict = OrderedDict( - policy_state_dict=policy_state_dict, - env_state_dict=env_state_dict, - ) - else: - state_dict = OrderedDict(env_state_dict=env_state_dict) - - state_dict.update({"frames": self._frames, "iter": self._iter}) - - return state_dict - - def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: - """Loads a state_dict on the environment and policy. - - Args: - state_dict (OrderedDict): ordered dictionary containing the fields - `"policy_state_dict"` and :obj:`"env_state_dict"`. - - """ - strict = kwargs.get("strict", True) - if strict or "env_state_dict" in state_dict: - self.env.load_state_dict(state_dict["env_state_dict"], **kwargs) - if strict or "policy_state_dict" in state_dict: - if not hasattr(self, "_policy_w_state_dict"): - raise ValueError( - "Underlying policy does not have state_dict to load policy_state_dict into." - ) - self._policy_w_state_dict.load_state_dict( - state_dict["policy_state_dict"], **kwargs - ) - self._frames = state_dict["frames"] - self._iter = state_dict["iter"] - - def __repr__(self) -> str: - try: - env_str = indent(f"env={self.env}", 4 * " ") - policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ") - td_out_str = repr(getattr(self, "_final_rollout", None)) - if len(td_out_str) > 50: - td_out_str = td_out_str[:50] + "..." - td_out_str = indent(f"td_out={td_out_str}", 4 * " ") - string = ( - f"{self.__class__.__name__}(" - f"\n{env_str}," - f"\n{policy_str}," - f"\n{td_out_str}," - f"\nexploration={self.exploration_type})" - ) - return string - except Exception: - return f"{type(self).__name__}(not_init)" - - def increment_version(self): - """Increment the policy version.""" - if self.policy_version_tracker is not None: - if not hasattr(self.policy_version_tracker, "increment_version"): - raise RuntimeError( - "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." - ) - self.policy_version_tracker.increment_version() - - @property - def policy_version(self) -> str | int | None: - """The current policy version.""" - if not hasattr(self.policy_version_tracker, "version"): - return None - return self.policy_version_tracker.version - - def get_policy_version(self) -> str | int | None: - """Get the current policy version. - - This method exists to support remote calls in Ray actors, since properties - cannot be accessed directly through Ray's RPC mechanism. - - Returns: - The current version number (int) or UUID (str), or None if version tracking is disabled. - """ - return self.policy_version - - def getattr_policy(self, attr): - """Get an attribute from the policy.""" - # send command to policy to return the attr - return getattr(self._wrapped_policy, attr) - - def getattr_env(self, attr): - """Get an attribute from the environment.""" - # send command to env to return the attr - return getattr(self.env, attr) - - def getattr_rb(self, attr): - """Get an attribute from the replay buffer.""" - # send command to rb to return the attr - return getattr(self.replay_buffer, attr) - - def get_model(self, model_id: str): - """Get model instance by ID (for weight sync schemes). - - Args: - model_id: Model identifier (e.g., "policy", "value_net") - - Returns: - The model instance - - Raises: - ValueError: If model_id is not recognized - """ - if model_id == "policy": - # Return the unwrapped policy instance for weight synchronization - # The unwrapped policy has the same parameter structure as what's - # extracted in the main process, avoiding key mismatches when - # the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule) - if hasattr(self, "policy") and self.policy is not None: - return self.policy - else: - raise ValueError(f"No policy found for model_id '{model_id}'") - else: - # Try to resolve via attribute access - if hasattr(self, model_id): - return getattr(self, model_id) - else: - raise ValueError(f"Unknown model_id: {model_id}") - - -class _MultiDataCollector(DataCollectorBase): - """Runs a given number of DataCollectors on separate processes. - - Args: - create_env_fn (List[Callabled]): list of Callables, each returning an - instance of :class:`~torchrl.envs.EnvBase`. - policy (Callable): Policy to be executed in the environment. - Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. - If ``None`` is provided (default), the policy used will be a - :class:`~torchrl.collectors.RandomPolicy` instance with the environment - ``action_spec``. - Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. - This is the recommended usage of the collector. - Other callables are accepted too: - If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` - instances) it will be wrapped in a `nn.Module` first. - Then, the collector will try to assess if these - modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - - - If the policy forward signature matches any of ``forward(self, tensordict)``, - ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or - any typing with a single argument typed as a subclass of ``TensorDictBase``) - then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - - - In all other cases an attempt to wrap it will be undergone as such: - ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. - - .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / - pickled directly), the ``policy_factory`` should be used instead. - - Keyword Args: - policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable - (or list of callables) that returns a policy instance. This is exclusive with the `policy` argument. - - .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. - - .. warning:: `policy_factory` is currently not compatible with multiprocessed data - collectors. - - num_workers (int, optional): number of workers to use. If `create_env_fn` is a list, this will be ignored. - Defaults to `None` (workers determined by the `create_env_fn` length). - frames_per_batch (int, Sequence[int]): A keyword-only argument representing the - total number of elements in a batch. If a sequence is provided, represents the number of elements in a - batch per worker. Total number of elements in a batch is then the sum over the sequence. - total_frames (int, optional): A keyword-only argument representing the - total number of frames returned by the collector - during its lifespan. If the ``total_frames`` is not divisible by - ``frames_per_batch``, an exception is raised. - Endless collectors can be created by passing ``total_frames=-1``. - Defaults to ``-1`` (never ending collector). - device (int, str or torch.device, optional): The generic device of the - collector. The ``device`` args fills any non-specified device: if - ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or - ``env_device`` is not specified, its value will be set to ``device``. - Defaults to ``None`` (No default device). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - storing_device (int, str or torch.device, optional): The device on which - the output :class:`~tensordict.TensorDict` will be stored. - If ``device`` is passed and ``storing_device`` is ``None``, it will - default to the value indicated by ``device``. - For long trajectories, it may be necessary to store the data on a different - device than the one where the policy and env are executed. - Defaults to ``None`` (the output tensordict isn't on a specific device, - leaf tensors sit on the device where they were created). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - env_device (int, str or torch.device, optional): The device on which - the environment should be cast (or executed if that functionality is - supported). If not specified and the env has a non-``None`` device, - ``env_device`` will default to that value. If ``device`` is passed - and ``env_device=None``, it will default to ``device``. If the value - as such specified of ``env_device`` differs from ``policy_device`` - and one of them is not ``None``, the data will be cast to ``env_device`` - before being passed to the env (i.e., passing different devices to - policy and env is supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - policy_device (int, str or torch.device, optional): The device on which - the policy should be cast. - If ``device`` is passed and ``policy_device=None``, it will default - to ``device``. If the value as such specified of ``policy_device`` - differs from ``env_device`` and one of them is not ``None``, - the data will be cast to ``policy_device`` before being passed to - the policy (i.e., passing different devices to policy and env is - supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - create_env_kwargs (dict, optional): A dictionary with the - keyword arguments used to create an environment. If a list is - provided, each of its elements will be assigned to a sub-collector. - collector_class (Python class or constructor): a collector class to be remotely instantiated. Can be - :class:`~torchrl.collectors.SyncDataCollector`, - :class:`~torchrl.collectors.MultiSyncDataCollector`, - :class:`~torchrl.collectors.MultiaSyncDataCollector` - or a derived class of these. - Defaults to :class:`~torchrl.collectors.SyncDataCollector`. - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span across multiple batches (unless - ``reset_at_each_iter`` is set to ``True``, see below). - Once a trajectory reaches ``n_steps``, the environment is reset. - If the environment wraps multiple environments together, the number - of steps is tracked for each environment independently. Negative - values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e. no maximum number of steps). - init_random_frames (int, optional): Number of frames for which the - policy is ignored before it is called. This feature is mainly - intended to be used in offline/model-based settings, where a - batch of random trajectories can be used to initialize training. - If provided, it will be rounded up to the closest multiple of frames_per_batch. - Defaults to ``None`` (i.e. no random frames). - reset_at_each_iter (bool, optional): Whether environments should be reset - at the beginning of a batch collection. - Defaults to ``False``. - postproc (Callable, optional): A post-processing transform, such as - a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` - instance. - Defaults to ``None``. - split_trajs (bool, optional): Boolean indicating whether the resulting - TensorDict should be split according to the trajectories. - See :func:`~torchrl.collectors.utils.split_trajectories` for more - information. - Defaults to ``False``. - exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, - ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` - or ``torchrl.envs.utils.ExplorationType.MEAN``. - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. - update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` - will be called before (sync) or after (async) each data collection. - Defaults to ``False``. - preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers - that will be allowed to finished collecting their rollout before the rest are forced to end early. - num_threads (int, optional): number of threads for this process. - Defaults to the number of workers. - num_sub_threads (int, optional): number of threads of the subprocesses. - Should be equal to one plus the number of processes launched within - each subprocess (or one if a single process is launched). - Defaults to 1 for safety: if none is indicated, launching multiple - workers may charge the cpu load too much and harm performance. - cat_results (str, int or None): (:class:`~torchrl.collectors.MultiSyncDataCollector` exclusively). - If ``"stack"``, the data collected from the workers will be stacked along the - first dimension. This is the preferred behavior as it is the most compatible - with the rest of the library. - If ``0``, results will be concatenated along the first dimension - of the outputs, which can be the batched dimension if the environments are - batched or the time dimension if not. - A ``cat_results`` value of ``-1`` will always concatenate results along the - time dimension. This should be preferred over the default. Intermediate values - are also accepted. - Defaults to ``"stack"``. - - .. note:: From v0.5, this argument will default to ``"stack"`` for a better - interoperability with the rest of the library. - - set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding - ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of - a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. - Truncated keys can be set through ``env.add_truncated_keys``. - Defaults to ``False``. - use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. - This isn't compatible with environments with dynamic specs. Defaults to ``True`` - for envs without dynamic specs, ``False`` for others. - replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts - but populate the buffer instead. Defaults to ``None``. - extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not - with single steps. Defaults to `True` for multiprocessed data collectors. - local_init_rb (bool, optional): if ``False``, the collector will use fake data to initialize - the replay buffer in the main process (legacy behavior). If ``True``, the storage-level - coordination will handle initialization with real data from worker processes. - Defaults to ``None``, which maintains backward compatibility but shows a deprecation warning. - This parameter is deprecated and will be removed in v0.12. - trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be - assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules - and ``False`` otherwise. - compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled - using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it - will be used to compile the policy. - cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped - in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. - If a dictionary of kwargs is passed, it will be used to wrap the policy. - no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. - For environments running directly on CUDA (`IsaacLab `_ - or `ManiSkills `_) cuda synchronization may cause unexpected - crashes. - Defaults to ``False``. - weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` - or its subclass, responsible for updating the policy weights on remote inference workers. - If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, - which handles weight synchronization across multiple processes. - Consider using a constructor if the updater needs to be serialized. - weight_sync_schemes (dict[str, WeightSyncScheme], optional): A dictionary of weight sync schemes for the different models. - If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default. - track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. - This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. - Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track - the policy version. - Defaults to `False`. - - """ - - def __init__( - self, - create_env_fn: Sequence[Callable[[], EnvBase]], - policy: None - | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, - *, - num_workers: int | None = None, - policy_factory: Callable[[], Callable] - | list[Callable[[], Callable]] - | None = None, - frames_per_batch: int | Sequence[int], - total_frames: int | None = -1, - device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - create_env_kwargs: Sequence[dict] | None = None, - collector_class: type | Callable[[], DataCollectorBase] | None = None, - max_frames_per_traj: int | None = None, - init_random_frames: int | None = None, - reset_at_each_iter: bool = False, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - split_trajs: bool | None = None, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - reset_when_done: bool = True, - update_at_each_batch: bool = False, - preemptive_threshold: float | None = None, - num_threads: int | None = None, - num_sub_threads: int = 1, - cat_results: str | int | None = None, - set_truncated: bool = False, - use_buffers: bool | None = None, - replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = True, - replay_buffer_chunk: bool | None = None, - local_init_rb: bool | None = None, - trust_policy: bool | None = None, - compile_policy: bool | dict[str, Any] | None = None, - cudagraph_policy: bool | dict[str, Any] | None = None, - no_cuda_sync: bool = False, - weight_updater: WeightUpdaterBase - | Callable[[], WeightUpdaterBase] - | None = None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, - track_policy_version: bool = False, - ): - self.closed = True - - # Set up workers and environment functions - create_env_fn, total_frames_per_batch = self._setup_workers_and_env_fns( - create_env_fn, num_workers, frames_per_batch - ) - - # Set up basic configuration - self.set_truncated = set_truncated - self.num_sub_threads = num_sub_threads - self.num_threads = num_threads - self.create_env_fn = create_env_fn - self._read_compile_kwargs(compile_policy, cudagraph_policy) - - # Set up environment kwargs - self.create_env_kwargs = self._setup_env_kwargs(create_env_kwargs) - - # Set up devices - storing_devices, policy_devices, env_devices = self._get_devices( - storing_device=storing_device, - env_device=env_device, - policy_device=policy_device, - device=device, - ) - self.storing_device = storing_devices - self.policy_device = policy_devices - self.env_device = env_devices - self.collector_class = collector_class - del storing_device, env_device, policy_device, device - self.no_cuda_sync = no_cuda_sync - - # Set up replay buffer - self._use_buffers = use_buffers - self.replay_buffer = replay_buffer - self._setup_multi_replay_buffer( - local_init_rb, replay_buffer, replay_buffer_chunk, extend_buffer - ) - - # Set up policy and weights - if trust_policy is None: - trust_policy = policy is not None and isinstance(policy, CudaGraphModule) - self.trust_policy = trust_policy - - policy_factory = self._setup_policy_factory(policy_factory) - - # Set up weight synchronization - if ( - not any(policy_factory) - and not weight_sync_schemes - and weight_updater is None - ): - weight_sync_schemes = {"policy": SharedMemWeightSyncScheme()} - - self._setup_multi_policy_and_weights( - policy, policy_factory, weight_updater, weight_sync_schemes - ) - - self._setup_multi_weight_sync(weight_updater, weight_sync_schemes) - - # Set up policy version tracking - self._setup_multi_policy_version_tracking(track_policy_version) - - # Store policy and policy_factory - self.policy = policy - self.policy_factory = policy_factory - - # Set up fallback policy for weight extraction - self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) - - # Set up total frames and other parameters - self._setup_multi_total_frames( - total_frames, total_frames_per_batch, frames_per_batch - ) - self.reset_at_each_iter = reset_at_each_iter - self.postprocs = postproc - self.max_frames_per_traj = ( - int(max_frames_per_traj) if max_frames_per_traj is not None else 0 - ) - - # Set up split trajectories - self.requested_frames_per_batch = total_frames_per_batch - self.reset_when_done = reset_when_done - self._setup_split_trajs(split_trajs, reset_when_done) - - # Set up other parameters - self.init_random_frames = ( - int(init_random_frames) if init_random_frames is not None else 0 - ) - self.update_at_each_batch = update_at_each_batch - self.exploration_type = exploration_type - self.frames_per_worker = np.inf - - # Set up preemptive threshold - self._setup_preemptive_threshold(preemptive_threshold) - - # Run worker processes - try: - self._run_processes() - except Exception as e: - self.shutdown(raise_on_error=False) - raise e - - # Set up frame tracking and other options - self._exclude_private_keys = True - self._frames = 0 - self._iter = -1 - - # Validate cat_results - self._validate_cat_results(cat_results) - - def _setup_workers_and_env_fns( - self, - create_env_fn: Sequence[Callable] | Callable, - num_workers: int | None, - frames_per_batch: int | Sequence[int], - ) -> tuple[list[Callable], int]: - """Set up workers and environment functions.""" - if isinstance(create_env_fn, Sequence): - self.num_workers = len(create_env_fn) - else: - self.num_workers = num_workers - create_env_fn = [create_env_fn] * self.num_workers - - if ( - isinstance(frames_per_batch, Sequence) - and len(frames_per_batch) != self.num_workers - ): - raise ValueError( - "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker." - f"Got {len(frames_per_batch)} values for {self.num_workers} workers." - ) - - self._frames_per_batch = frames_per_batch - total_frames_per_batch = ( - sum(frames_per_batch) - if isinstance(frames_per_batch, Sequence) - else frames_per_batch - ) - - return create_env_fn, total_frames_per_batch - - def _setup_env_kwargs( - self, create_env_kwargs: Sequence[dict] | dict | None - ) -> list[dict]: - """Set up environment kwargs for each worker.""" - if isinstance(create_env_kwargs, Mapping): - create_env_kwargs = [create_env_kwargs] * self.num_workers - elif create_env_kwargs is None: - create_env_kwargs = [{}] * self.num_workers - elif isinstance(create_env_kwargs, (tuple, list)): - create_env_kwargs = list(create_env_kwargs) - if len(create_env_kwargs) != self.num_workers: - raise ValueError( - f"len(create_env_kwargs) must be equal to num_workers, got {len(create_env_kwargs)=} and {self.num_workers=}" - ) - return create_env_kwargs - - def _setup_multi_replay_buffer( - self, - local_init_rb: bool | None, - replay_buffer: ReplayBuffer | None, - replay_buffer_chunk: bool | None, - extend_buffer: bool, - ) -> None: - """Set up replay buffer for multi-process collector.""" - # Handle local_init_rb deprecation - if local_init_rb is None: - local_init_rb = False - if replay_buffer is not None and not local_init_rb: - warnings.warn( - "local_init_rb=False is deprecated and will be removed in v0.12. " - "The new storage-level initialization provides better performance.", - FutureWarning, - ) - self.local_init_rb = local_init_rb - - self._check_replay_buffer_init() - - if replay_buffer_chunk is not None: - if extend_buffer is None: - replay_buffer_chunk = extend_buffer - warnings.warn( - "The replay_buffer_chunk is deprecated and replaced by extend_buffer. This argument will disappear in v0.10.", - DeprecationWarning, - ) - elif extend_buffer != replay_buffer_chunk: - raise ValueError( - "conflicting values for replay_buffer_chunk and extend_buffer." - ) - self.extend_buffer = extend_buffer - - if ( - replay_buffer is not None - and hasattr(replay_buffer, "shared") - and not replay_buffer.shared - ): - torchrl_logger.warning("Replay buffer is not shared. Sharing it.") - replay_buffer.share() - - def _setup_policy_factory( - self, policy_factory: Callable | list[Callable] | None - ) -> list[Callable | None]: - """Set up policy factory for each worker.""" - if not isinstance(policy_factory, Sequence): - policy_factory = [policy_factory] * self.num_workers - return policy_factory - - def _setup_multi_policy_and_weights( - self, - policy: TensorDictModule | Callable | None, - policy_factory: list[Callable | None], - weight_updater: WeightUpdaterBase | Callable | None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up policy and extract weights for each device.""" - self._policy_weights_dict = {} - self._fallback_policy = None # Policy to use for weight extraction fallback - - if any(policy_factory) and policy is not None: - raise TypeError("policy_factory and policy are mutually exclusive") - elif not any(policy_factory): - for policy_device, env_maker, env_maker_kwargs in _zip_strict( - self.policy_device, self.create_env_fn, self.create_env_kwargs - ): - policy_new_device, get_weights_fn = self._get_policy_and_device( - policy=policy, - policy_device=policy_device, - env_maker=env_maker, - env_maker_kwargs=env_maker_kwargs, - ) - if type(policy_new_device) is not type(policy): - policy = policy_new_device - weights = ( - TensorDict.from_module(policy_new_device) - if isinstance(policy_new_device, nn.Module) - else TensorDict() - ) - # For multi-process collectors, ensure weights are in shared memory - if policy_device and policy_device.type == "cpu": - weights = weights.share_memory_() - self._policy_weights_dict[policy_device] = weights - # Store the first policy instance for fallback weight extraction - if self._fallback_policy is None: - self._fallback_policy = policy_new_device - self._get_weights_fn = get_weights_fn - if weight_updater is None: - # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default - if weight_sync_schemes is None: - weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} - elif weight_updater is None: - warnings.warn( - "weight_updater is None, but policy_factory is provided. This means that the server will " - "not know how to send the weights to the workers. If the workers can handle their weight synchronization " - "on their own (via some specialized worker type / constructor) this may well work, but make sure " - "your weight synchronization strategy is properly set. To suppress this warning, you can use " - "RemoteModuleWeightUpdater() which enforces explicit weight passing when calling update_policy_weights_(weights). " - "This will work whenever your inference and training policies are nn.Module instances with similar structures." - ) - - def _setup_multi_weight_sync( - self, - weight_updater: WeightUpdaterBase | Callable | None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up weight synchronization for multi-process collector.""" - if weight_sync_schemes is not None: - # Use new simplified weight synchronization system - self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - # Senders will be created in _run_processes when pipes are available - self.weight_updater = None # Don't use legacy system - else: - # Fall back to legacy weight updater system - self.weight_updater = weight_updater - self._weight_sync_schemes = None - self._weight_senders = {} - - def _setup_multi_policy_version_tracking( - self, track_policy_version: bool | PolicyVersion - ) -> None: - """Set up policy version tracking for multi-process collector.""" - self.policy_version_tracker = track_policy_version - if PolicyVersion is not None: - if isinstance(track_policy_version, bool) and track_policy_version: - self.policy_version_tracker = PolicyVersion() - elif hasattr(track_policy_version, "increment_version"): - self.policy_version_tracker = track_policy_version - else: - self.policy_version_tracker = None - else: - if track_policy_version: - raise ImportError( - "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." - ) - self.policy_version_tracker = None - - def _setup_fallback_policy( - self, - policy: TensorDictModule | Callable | None, - policy_factory: list[Callable | None], - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up fallback policy for weight extraction when using policy_factory.""" - # _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided - # If policy_factory was used, create a policy instance to use as fallback - if policy is None and any(policy_factory) and weight_sync_schemes is not None: - if not hasattr(self, "_fallback_policy") or self._fallback_policy is None: - first_factory = ( - policy_factory[0] - if isinstance(policy_factory, list) - else policy_factory - ) - if first_factory is not None: - # Create a policy instance for weight extraction - # This will be a reference to a policy with the same structure - # For shared memory, modifications to any policy will be visible here - self._fallback_policy = first_factory() - - def _setup_multi_total_frames( - self, - total_frames: int, - total_frames_per_batch: int, - frames_per_batch: int | Sequence[int], - ) -> None: - """Validate and set total frames for multi-process collector.""" - if total_frames is None or total_frames < 0: - total_frames = float("inf") - else: - remainder = total_frames % total_frames_per_batch - if remainder != 0 and rl_warnings(): - warnings.warn( - f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). " - f"This means {total_frames_per_batch - remainder} additional frames will be collected. " - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - self.total_frames = ( - int(total_frames) if total_frames != float("inf") else total_frames - ) - - def _setup_split_trajs( - self, split_trajs: bool | None, reset_when_done: bool - ) -> None: - """Set up split trajectories option.""" - if split_trajs is None: - split_trajs = False - elif not reset_when_done and split_trajs: - raise RuntimeError( - "Cannot split trajectories when reset_when_done is False." - ) - self.split_trajs = split_trajs - - def _setup_preemptive_threshold(self, preemptive_threshold: float | None) -> None: - """Set up preemptive threshold for early stopping.""" - if preemptive_threshold is not None: - if _is_osx: - raise NotImplementedError( - "Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform." - ) - self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0) - manager = _InterruptorManager() - manager.start() - self.interruptor = manager._Interruptor() - else: - self.preemptive_threshold = 1.0 - self.interruptor = None - - def _validate_cat_results(self, cat_results: str | int | None) -> None: - """Validate cat_results parameter.""" - if cat_results is not None and ( - not isinstance(cat_results, (int, str)) - or (isinstance(cat_results, str) and cat_results != "stack") - ): - raise ValueError( - "cat_results must be a string ('stack') " - f"or an integer representing the cat dimension. Got {cat_results}." - ) - if not isinstance(self, MultiSyncDataCollector) and cat_results not in ( - "stack", - None, - ): - raise ValueError( - "cat_results can only be used with ``MultiSyncDataCollector``." - ) - self.cat_results = cat_results - - def _check_replay_buffer_init(self): - if self.replay_buffer is None: - return - is_init = hasattr(self.replay_buffer, "_storage") and getattr( - self.replay_buffer._storage, "initialized", True - ) - if not is_init: - if self.local_init_rb: - # New behavior: storage handles all coordination itself - # Nothing to do here - the storage will coordinate during first write - self.replay_buffer.share() - return - - # Legacy behavior: fake tensordict initialization - if isinstance(self.create_env_fn[0], EnvCreator): - fake_td = self.create_env_fn[0].meta_data.tensordict - elif isinstance(self.create_env_fn[0], EnvBase): - fake_td = self.create_env_fn[0].fake_tensordict() - else: - fake_td = self.create_env_fn[0]( - **self.create_env_kwargs[0] - ).fake_tensordict() - fake_td["collector", "traj_ids"] = torch.zeros( - fake_td.shape, dtype=torch.long - ) - # Use extend to avoid time-related transforms to fail - self.replay_buffer.extend(fake_td.unsqueeze(-1)) - self.replay_buffer.empty() - - @classmethod - def _total_workers_from_env(cls, env_creators): - if isinstance(env_creators, (tuple, list)): - return sum( - cls._total_workers_from_env(env_creator) for env_creator in env_creators - ) - from torchrl.envs import ParallelEnv - - if isinstance(env_creators, ParallelEnv): - return env_creators.num_workers - return 1 - - def _get_devices( - self, - *, - storing_device: torch.device, - policy_device: torch.device, - env_device: torch.device, - device: torch.device, - ): - # convert all devices to lists - if not isinstance(storing_device, (list, tuple)): - storing_device = [ - storing_device, - ] * self.num_workers - if not isinstance(policy_device, (list, tuple)): - policy_device = [ - policy_device, - ] * self.num_workers - if not isinstance(env_device, (list, tuple)): - env_device = [ - env_device, - ] * self.num_workers - if not isinstance(device, (list, tuple)): - device = [ - device, - ] * self.num_workers - if not ( - len(device) - == len(storing_device) - == len(policy_device) - == len(env_device) - == self.num_workers - ): - raise RuntimeError( - f"THe length of the devices does not match the number of workers: {self.num_workers}." - ) - storing_device, policy_device, env_device = zip( - *[ - SyncDataCollector._get_devices( - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - device=device, - ) - for (storing_device, policy_device, env_device, device) in zip( - storing_device, policy_device, env_device, device - ) - ] - ) - return storing_device, policy_device, env_device - - def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: - raise NotImplementedError - - @property - def _queue_len(self) -> int: - raise NotImplementedError - - def _run_processes(self) -> None: - if self.num_threads is None: - total_workers = self._total_workers_from_env(self.create_env_fn) - self.num_threads = max( - 1, torch.get_num_threads() - total_workers - ) # 1 more thread for this proc - - # Weight senders will be initialized after workers are ready (via init_on_sender) - torch.set_num_threads(self.num_threads) - queue_out = mp.Queue(self._queue_len) # sends data from proc to main - self.procs = [] - self.pipes = [] - self._traj_pool = _TrajectoryPool(lock=True) - # Create a policy on the right device - policy_factory = self.policy_factory - if any(policy_factory): - policy_factory = [ - CloudpickleWrapper(_policy_factory) - for _policy_factory in policy_factory - ] - - for i, (env_fun, env_fun_kwargs) in enumerate( - zip(self.create_env_fn, self.create_env_kwargs) - ): - pipe_parent, pipe_child = mp.Pipe() # send messages to procs - if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( - env_fun, EnvBase - ): # to avoid circular imports - env_fun = CloudpickleWrapper(env_fun) - - policy_device = self.policy_device[i] - storing_device = self.storing_device[i] - env_device = self.env_device[i] - # We take the weights, the policy, and locally dispatch the weights to the policy - # while we send the policy to the remote process. - # This makes sure that a given set of shared weights for a given device are - # shared for all policies that rely on that device. - policy = self.policy - policy_weights = self._policy_weights_dict.get(policy_device) - if policy is not None and policy_weights is not None: - cm = policy_weights.to_module(policy) - else: - cm = contextlib.nullcontext() - with cm: - kwargs = { - "policy_factory": policy_factory[i], - "pipe_parent": pipe_parent, - "pipe_child": pipe_child, - "queue_out": queue_out, - "create_env_fn": env_fun, - "create_env_kwargs": env_fun_kwargs, - "policy": policy, - "max_frames_per_traj": self.max_frames_per_traj, - "frames_per_batch": self.frames_per_batch_worker(worker_idx=i), - "reset_at_each_iter": self.reset_at_each_iter, - "policy_device": policy_device, - "storing_device": storing_device, - "env_device": env_device, - "exploration_type": self.exploration_type, - "reset_when_done": self.reset_when_done, - "idx": i, - "interruptor": self.interruptor, - "set_truncated": self.set_truncated, - "use_buffers": self._use_buffers, - "replay_buffer": self.replay_buffer, - "extend_buffer": self.extend_buffer, - "traj_pool": self._traj_pool, - "trust_policy": self.trust_policy, - "compile_policy": self.compiled_policy_kwargs - if self.compiled_policy - else False, - "cudagraph_policy": self.cudagraphed_policy_kwargs - if self.cudagraphed_policy - else False, - "no_cuda_sync": self.no_cuda_sync, - "collector_class": self.collector_class, - "postproc": self.postprocs - if self.replay_buffer is not None - else None, - "weight_sync_schemes": self._weight_sync_schemes, - } - proc = _ProcessNoWarn( - target=_main_async_collector, - num_threads=self.num_sub_threads, - kwargs=kwargs, - ) - # proc.daemon can't be set as daemonic processes may be launched by the process itself - try: - proc.start() - except TypeError as err: - if "cannot pickle" in str(err): - raise RuntimeError( - "A non-serializable object was passed to the collector workers." - ) from err - except RuntimeError as err: - if "Cowardly refusing to serialize non-leaf tensor" in str(err): - raise RuntimeError( - "At least one of the tensors in the policy, replay buffer, environment constructor or postprocessor requires gradients. " - "This is not supported in multiprocessed data collectors.\n- For ReplayBuffer transforms, use a `transform_factory` instead with `delayed_init=True`.\n" - "- Make sure your environment constructor does not reference tensors already instantiated on the main process.\n" - "- Since no gradient can be propagated through the Collector pipes, the backward graph is never needed. Consider using detached tensors instead." - ) from err - else: - raise err - except _pickle.PicklingError as err: - if "" in str(err): - raise RuntimeError( - """Can't open a process with doubly cloud-pickled lambda function. -This error is likely due to an attempt to use a ParallelEnv in a -multiprocessed data collector. To do this, consider wrapping your -lambda function in an `torchrl.envs.EnvCreator` wrapper as follows: -`env = ParallelEnv(N, EnvCreator(my_lambda_function))`. -This will not only ensure that your lambda function is cloud-pickled once, but -also that the state dict is synchronised across processes if needed.""" - ) from err - pipe_child.close() - self.procs.append(proc) - self.pipes.append(pipe_parent) - - # Worker registration now handled by init_on_sender() after workers are ready - for i, pipe_parent in enumerate(self.pipes): - pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) - try: - msg = pipe_parent.recv() - except EOFError as e: - raise RuntimeError( - f"Worker {i} failed to initialize and closed the connection before sending status. " - f"This typically indicates that the worker process crashed during initialization. " - f"Check the worker process logs for the actual error." - ) from e - if msg != "instantiated": - # Check if it's an error dict from worker - if isinstance(msg, dict) and msg.get("error"): - # Reconstruct the exception from the worker - exc_type_name = msg["exception_type"] - exc_msg = msg["exception_msg"] - traceback_str = msg["traceback"] - - # Try to get the actual exception class - exc_class = None - exc_module = msg["exception_module"] - - if exc_module == "builtins": - # Get from builtins - import builtins - - exc_class = getattr(builtins, exc_type_name, None) - else: - # Try to import from the module - try: - import importlib - - mod = importlib.import_module(exc_module) - exc_class = getattr(mod, exc_type_name, None) - except Exception: - pass - - # Re-raise with original exception type if possible - if exc_class is not None: - raise exc_class( - f"{exc_msg}\n\nWorker traceback:\n{traceback_str}" - ) - else: - # Fall back to RuntimeError if we can't get the original type - raise RuntimeError( - f"Worker {i} raised {exc_type_name}: {exc_msg}\n\nWorker traceback:\n{traceback_str}" - ) - else: - # Legacy string error message - raise RuntimeError(msg) - - # Initialize all weight sync schemes now that workers are ready - # This calls init_on_sender() for each scheme which: - # 1. Creates transports for all workers - # 2. Creates and configures the sender - # 3. For SharedMemWeightSyncScheme, distributes buffer references to avoid deadlock - if self._weight_sync_schemes: - for model_id, scheme in self._weight_sync_schemes.items(): - # Check if scheme has new API or legacy API - if hasattr(scheme, "init_on_sender"): - scheme.init_on_sender(model_id=model_id, context=self) - # Get the initialized sender - self._weight_senders[model_id] = scheme.get_sender() - # else: keep using legacy _weight_senders initialization from before - - self.queue_out = queue_out - self.closed = False - - _running_free = False - - def start(self): - """Starts the collector(s) for asynchronous data collection. - - The collected data is stored in the provided replay buffer. This method initiates the background collection of - data across multiple processes, allowing for decoupling of data collection and training. - - Raises: - RuntimeError: If no replay buffer is defined during the collector's initialization. - - Example: - >>> import time - >>> from functools import partial - >>> - >>> import tqdm - >>> - >>> from torchrl.collectors import MultiaSyncDataCollector, RandomPolicy - >>> from torchrl.data import LazyTensorStorage, ReplayBuffer - >>> from torchrl.envs import GymEnv, set_gym_backend - >>> import ale_py - >>> - >>> # Set the gym backend to gymnasium - >>> set_gym_backend("gymnasium").set() - >>> - >>> if __name__ == "__main__": - ... # Create a random policy for the Pong environment - ... env_fn = partial(GymEnv, "ALE/Pong-v5") - ... policy = RandomPolicy(env_fn().action_spec) - ... - ... # Initialize a shared replay buffer - ... rb = ReplayBuffer(storage=LazyTensorStorage(10000), shared=True) - ... - ... # Create a multi-async data collector with 16 environments - ... num_envs = 16 - ... collector = MultiaSyncDataCollector( - ... [env_fn] * num_envs, - ... policy=policy, - ... replay_buffer=rb, - ... frames_per_batch=num_envs * 16, - ... total_frames=-1, - ... ) - ... - ... # Progress bar to track the number of collected frames - ... pbar = tqdm.tqdm(total=100_000) - ... - ... # Start the collector asynchronously - ... collector.start() - ... - ... # Track the write count of the replay buffer - ... prec_wc = 0 - ... while True: - ... wc = rb.write_count - ... c = wc - prec_wc - ... prec_wc = wc - ... - ... # Update the progress bar - ... pbar.update(c) - ... pbar.set_description(f"Write Count: {rb.write_count}") - ... - ... # Check the write count every 0.5 seconds - ... time.sleep(0.5) - ... - ... # Stop when the desired number of frames is reached - ... if rb.write_count . 100_000: - ... break - ... - ... # Shut down the collector - ... collector.async_shutdown() - """ - if self.replay_buffer is None: - raise RuntimeError("Replay buffer must be defined for execution.") - if self.init_random_frames is not None and self.init_random_frames > 0: - raise RuntimeError( - "Cannot currently start() a collector that requires random frames. Please submit a feature request on github." - ) - self._running_free = True - for pipe in self.pipes: - pipe.send((None, "run_free")) - - @contextlib.contextmanager - def pause(self): - """Context manager that pauses the collector if it is running free.""" - if self._running_free: - for pipe in self.pipes: - pipe.send((None, "pause")) - # Make sure all workers are paused - for _ in self.pipes: - idx, msg = self.queue_out.get() - if msg != "paused": - raise ValueError(f"Expected paused, but got {msg=}.") - torchrl_logger.info(f"Worker {idx} is paused.") - self._running_free = False - yield None - for pipe in self.pipes: - pipe.send((None, "restart")) - self._running_free = True - else: - raise RuntimeError("Collector cannot be paused.") - - def __del__(self): - try: - self.shutdown() - except Exception: - # an AttributeError will typically be raised if the collector is deleted when the program ends. - # In the future, insignificant changes to the close method may change the error type. - # We excplicitely assume that any error raised during closure in - # __del__ will not affect the program. - pass - - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - """Shuts down all processes. This operation is irreversible. - - Args: - timeout (float, optional): The timeout for closing pipes between workers. - close_env (bool, optional): Whether to close the environment. Defaults to `True`. - raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. - """ - if not close_env: - raise RuntimeError( - f"Cannot shutdown {type(self).__name__} collector without environment being closed." - ) - try: - self._shutdown_main(timeout) - except Exception as e: - if raise_on_error: - raise e - else: - pass - - def _shutdown_main(self, timeout: float | None = None) -> None: - if timeout is None: - timeout = 10 - try: - if self.closed: - return - _check_for_faulty_process(self.procs) - all_closed = [False] * self.num_workers - rep = 0 - for idx in range(self.num_workers): - if all_closed[idx]: - continue - if not self.procs[idx].is_alive(): - continue - self.pipes[idx].send((None, "close")) - - while not all(all_closed) and rep < 1000: - rep += 1 - for idx in range(self.num_workers): - if all_closed[idx]: - continue - if not self.procs[idx].is_alive(): - all_closed[idx] = True - continue - try: - if self.pipes[idx].poll(timeout / 1000 / self.num_workers): - msg = self.pipes[idx].recv() - if msg != "closed": - raise RuntimeError(f"got {msg} but expected 'close'") - all_closed[idx] = True - else: - continue - except BrokenPipeError: - all_closed[idx] = True - continue - self.closed = True - - self.queue_out.close() - for pipe in self.pipes: - pipe.close() - for proc in self.procs: - proc.join(1.0) - finally: - import torchrl - - num_threads = min( - torchrl._THREAD_POOL_INIT, - torch.get_num_threads() - + self._total_workers_from_env(self.create_env_fn), - ) - torch.set_num_threads(num_threads) - - for proc in self.procs: - if proc.is_alive(): - proc.terminate() - - def async_shutdown(self, timeout: float | None = None): - return self.shutdown(timeout=timeout) - - def set_seed(self, seed: int, static_seed: bool = False) -> int: - """Sets the seeds of the environments stored in the DataCollector. - - Args: - seed: integer representing the seed to be used for the environment. - static_seed (bool, optional): if ``True``, the seed is not incremented. - Defaults to False - - Returns: - Output seed. This is useful when more than one environment is - contained in the DataCollector, as the seed will be incremented for - each of these. The resulting seed is the seed of the last - environment. - - Examples: - >>> from torchrl.envs import ParallelEnv - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> env_fn = lambda: GymEnv("Pendulum-v1") - >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300) - >>> out_seed = collector.set_seed(1) # out_seed = 6 - - """ - _check_for_faulty_process(self.procs) - for idx in range(self.num_workers): - self.pipes[idx].send(((seed, static_seed), "seed")) - new_seed, msg = self.pipes[idx].recv() - if msg != "seeded": - raise RuntimeError(f"Expected msg='seeded', got {msg}") - seed = new_seed - self.reset() - return seed - - def reset(self, reset_idx: Sequence[bool] | None = None) -> None: - """Resets the environments to a new initial state. - - Args: - reset_idx: Optional. Sequence indicating which environments have - to be reset. If None, all environments are reset. - - """ - _check_for_faulty_process(self.procs) - - if reset_idx is None: - reset_idx = [True for _ in range(self.num_workers)] - for idx in range(self.num_workers): - if reset_idx[idx]: - self.pipes[idx].send((None, "reset")) - for idx in range(self.num_workers): - if reset_idx[idx]: - j, msg = self.pipes[idx].recv() - if msg != "reset": - raise RuntimeError(f"Expected msg='reset', got {msg}") - - def state_dict(self) -> OrderedDict: - """Returns the state_dict of the data collector. - - Each field represents a worker containing its own state_dict. - - """ - for idx in range(self.num_workers): - self.pipes[idx].send((None, "state_dict")) - state_dict = OrderedDict() - for idx in range(self.num_workers): - _state_dict, msg = self.pipes[idx].recv() - if msg != "state_dict": - raise RuntimeError(f"Expected msg='state_dict', got {msg}") - state_dict[f"worker{idx}"] = _state_dict - state_dict.update({"frames": self._frames, "iter": self._iter}) - - return state_dict - - def load_state_dict(self, state_dict: OrderedDict) -> None: - """Loads the state_dict on the workers. - - Args: - state_dict (OrderedDict): state_dict of the form - ``{"worker0": state_dict0, "worker1": state_dict1}``. - - """ - for idx in range(self.num_workers): - self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict")) - for idx in range(self.num_workers): - _, msg = self.pipes[idx].recv() - if msg != "loaded": - raise RuntimeError(f"Expected msg='loaded', got {msg}") - self._frames = state_dict["frames"] - self._iter = state_dict["iter"] - - def increment_version(self): - """Increment the policy version.""" - if self.policy_version_tracker is not None: - if not hasattr(self.policy_version_tracker, "increment_version"): - raise RuntimeError( - "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." - ) - self.policy_version_tracker.increment_version() - - @property - def policy_version(self) -> str | int | None: - """The current policy version.""" - if not hasattr(self.policy_version_tracker, "version"): - return None - return self.policy_version_tracker.version - - def get_policy_version(self) -> str | int | None: - """Get the current policy version. - - This method exists to support remote calls in Ray actors, since properties - cannot be accessed directly through Ray's RPC mechanism. - - Returns: - The current version number (int) or UUID (str), or None if version tracking is disabled. - """ - return self.policy_version - - def getattr_policy(self, attr): - """Get an attribute from the policy of the first worker. - - Args: - attr (str): The attribute name to retrieve from the policy. - - Returns: - The attribute value from the policy of the first worker. - - Raises: - AttributeError: If the attribute doesn't exist on the policy. - """ - _check_for_faulty_process(self.procs) - - # Send command to first worker (index 0) - self.pipes[0].send((attr, "getattr_policy")) - result, msg = self.pipes[0].recv() - if msg != "getattr_policy": - raise RuntimeError(f"Expected msg='getattr_policy', got {msg}") - - # If the worker returned an AttributeError, re-raise it - if isinstance(result, AttributeError): - raise result - - return result - - def getattr_env(self, attr): - """Get an attribute from the environment of the first worker. - - Args: - attr (str): The attribute name to retrieve from the environment. - - Returns: - The attribute value from the environment of the first worker. - - Raises: - AttributeError: If the attribute doesn't exist on the environment. - """ - _check_for_faulty_process(self.procs) - - # Send command to first worker (index 0) - self.pipes[0].send((attr, "getattr_env")) - result, msg = self.pipes[0].recv() - if msg != "getattr_env": - raise RuntimeError(f"Expected msg='getattr_env', got {msg}") - - # If the worker returned an AttributeError, re-raise it - if isinstance(result, AttributeError): - raise result - - return result - - def getattr_rb(self, attr): - """Get an attribute from the replay buffer.""" - return getattr(self.replay_buffer, attr) - - def get_model(self, model_id: str): - """Get model instance by ID (for weight sync schemes). - - Args: - model_id: Model identifier (e.g., "policy", "value_net") - - Returns: - The model instance - - Raises: - ValueError: If model_id is not recognized - """ - if model_id == "policy": - # Return the fallback policy instance - if hasattr(self, "_fallback_policy") and self._fallback_policy is not None: - return self._fallback_policy - elif hasattr(self, "policy") and self.policy is not None: - return self.policy - else: - raise ValueError(f"No policy found for model_id '{model_id}'") - else: - # Try to resolve via attribute access - if hasattr(self, model_id): - return getattr(self, model_id) - else: - raise ValueError(f"Unknown model_id: {model_id}") - - def get_cached_weights(self, model_id: str): - """Get cached shared memory weights if available (for weight sync schemes). - - Args: - model_id: Model identifier - - Returns: - Cached TensorDict weights or None if not available - """ - if model_id == "policy" and hasattr(self, "_policy_weights_dict"): - # Get the policy device (first device if list) - policy_device = self.policy_device - if isinstance(policy_device, (list, tuple)): - policy_device = policy_device[0] if len(policy_device) > 0 else None - - # Return cached weights for this device - return self._policy_weights_dict.get(policy_device) - return None - - -@accept_remote_rref_udf_invocation -class MultiSyncDataCollector(_MultiDataCollector): - """Runs a given number of DataCollectors on separate processes synchronously. - - .. aafig:: - - +----------------------------------------------------------------------+ - | "MultiSyncDataCollector" | | - |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | - | "Collector 1" | "Collector 2" | "Collector 3" | Main | - |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| - | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | - |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| - |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | - | | | | | | | | - | "actor" | | | "actor" | | - | | | | | | - | "step" | "step" | "actor" | | | - | | | | | | - | | | | "step" | "step" | | - | | | | | | | - | "actor" | "step" | "step" | "actor" | | - | | | | | | - | | "actor" | | | - | | | | | - | "yield batch of traj 1"------->"collect, train"| - | | | - | "step" | "step" | "step" | "step" | "step" | "step" | | - | | | | | | | | - | "actor" | "actor" | | | | - | | "step" | "step" | "actor" | | - | | | | | | - | "step" | "step" | "actor" | "step" | "step" | | - | | | | | | | - | "actor" | | "actor" | | - | "yield batch of traj 2"------->"collect, train"| - | | | - +----------------------------------------------------------------------+ - - Envs can be identical or different. - - The collection starts when the next item of the collector is queried, - and no environment step is computed in between the reception of a batch of - trajectory and the start of the next collection. - This class can be safely used with online RL sota-implementations. - - .. note:: - Python requires multiprocessed code to be instantiated within a main guard: - - >>> from torchrl.collectors import MultiSyncDataCollector - >>> if __name__ == "__main__": - ... # Create your collector here - ... collector = MultiSyncDataCollector(...) - - See https://docs.python.org/3/library/multiprocessing.html for more info. - - Examples: - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> from torchrl.collectors import MultiSyncDataCollector - >>> if __name__ == "__main__": - ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - ... collector = MultiSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... cat_results="stack", - ... ) - ... for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break - ... collector.shutdown() - ... del collector - TensorDict( - fields={ - action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - collector: TensorDict( - fields={ - traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False) - - """ - - __doc__ += _MultiDataCollector.__doc__ - - # for RPC - def next(self): - return super().next() - - # for RPC - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - if not close_env: - raise RuntimeError( - f"Cannot shutdown {type(self).__name__} collector without environment being closed." - ) - if hasattr(self, "out_buffer"): - del self.out_buffer - if hasattr(self, "buffers"): - del self.buffers - try: - return super().shutdown(timeout=timeout) - except Exception as e: - if raise_on_error: - raise e - else: - pass - - # for RPC - def set_seed(self, seed: int, static_seed: bool = False) -> int: - return super().set_seed(seed, static_seed) - - # for RPC - def state_dict(self) -> OrderedDict: - return super().state_dict() - - # for RPC - def load_state_dict(self, state_dict: OrderedDict) -> None: - return super().load_state_dict(state_dict) - - # for RPC - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - super().update_policy_weights_( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def frames_per_batch_worker(self, worker_idx: int | None) -> int: - if worker_idx is not None and isinstance(self._frames_per_batch, Sequence): - return self._frames_per_batch[worker_idx] - if self.requested_frames_per_batch % self.num_workers != 0 and rl_warnings(): - warnings.warn( - f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}," - f" this results in more frames_per_batch per iteration that requested." - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - frames_per_batch_worker = -( - -self.requested_frames_per_batch // self.num_workers - ) - return frames_per_batch_worker - - @property - def _queue_len(self) -> int: - return self.num_workers - - def iterator(self) -> Iterator[TensorDictBase]: - cat_results = self.cat_results - if cat_results is None: - cat_results = "stack" - - self.buffers = {} - dones = [False for _ in range(self.num_workers)] - workers_frames = [0 for _ in range(self.num_workers)] - same_device = None - self.out_buffer = None - preempt = self.interruptor is not None and self.preemptive_threshold < 1.0 - - while not all(dones) and self._frames < self.total_frames: - _check_for_faulty_process(self.procs) - if self.update_at_each_batch: - self.update_policy_weights_() - - for idx in range(self.num_workers): - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - msg = "continue_random" - else: - msg = "continue" - # Debug: sending 'continue' - self.pipes[idx].send((None, msg)) - - self._iter += 1 - - if preempt: - self.interruptor.start_collection() - while self.queue_out.qsize() < int( - self.num_workers * self.preemptive_threshold - ): - continue - self.interruptor.stop_collection() - # Now wait for stragglers to return - while self.queue_out.qsize() < int(self.num_workers): - continue - - recv = collections.deque() - t0 = time.time() - while len(recv) < self.num_workers and ( - (time.time() - t0) < (_TIMEOUT * _MAX_IDLE_COUNT) - ): - for _ in range(self.num_workers): - try: - new_data, j = self.queue_out.get(timeout=_TIMEOUT) - recv.append((new_data, j)) - except (TimeoutError, Empty): - _check_for_faulty_process(self.procs) - if (time.time() - t0) > (_TIMEOUT * _MAX_IDLE_COUNT): - try: - self.shutdown() - finally: - raise RuntimeError( - f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " - f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." - ) - - for _ in range(self.num_workers): - new_data, j = recv.popleft() - use_buffers = self._use_buffers - if self.replay_buffer is not None: - idx = new_data - workers_frames[idx] = workers_frames[ - idx - ] + self.frames_per_batch_worker(worker_idx=idx) - continue - elif j == 0 or not use_buffers: - try: - data, idx = new_data - self.buffers[idx] = data - if use_buffers is None and j > 0: - self._use_buffers = False - except TypeError: - if use_buffers is None: - self._use_buffers = True - idx = new_data - else: - raise - else: - idx = new_data - - if preempt: - # mask buffers if cat, and create a mask if stack - if cat_results != "stack": - buffers = {} - for worker_idx, buffer in self.buffers.items(): - valid = buffer.get(("collector", "traj_ids")) != -1 - if valid.ndim > 2: - valid = valid.flatten(0, -2) - if valid.ndim == 2: - valid = valid.any(0) - buffers[worker_idx] = buffer[..., valid] - else: - for buffer in self.buffers.values(): - with buffer.unlock_(): - buffer.set( - ("collector", "mask"), - buffer.get(("collector", "traj_ids")) != -1, - ) - buffers = self.buffers - else: - buffers = self.buffers - - # Skip frame counting if this worker didn't send data this iteration - # (happens when reusing buffers or on first iteration with some workers) - if idx not in buffers: - continue - - workers_frames[idx] = workers_frames[idx] + buffers[idx].numel() - - if workers_frames[idx] >= self.total_frames: - dones[idx] = True - - if self.replay_buffer is not None: - yield - self._frames += sum( - [ - self.frames_per_batch_worker(worker_idx) - for worker_idx in range(self.num_workers) - ] - ) - continue - - # we have to correct the traj_ids to make sure that they don't overlap - # We can count the number of frames collected for free in this loop - n_collected = 0 - for idx in buffers.keys(): - buffer = buffers[idx] - traj_ids = buffer.get(("collector", "traj_ids")) - if preempt: - if cat_results == "stack": - mask_frames = buffer.get(("collector", "traj_ids")) != -1 - n_collected += mask_frames.sum().cpu() - else: - n_collected += traj_ids.numel() - else: - n_collected += traj_ids.numel() - - if same_device is None: - prev_device = None - same_device = True - for item in self.buffers.values(): - if prev_device is None: - prev_device = item.device - else: - same_device = same_device and (item.device == prev_device) - - if cat_results == "stack": - stack = ( - torch.stack if self._use_buffers else TensorDict.maybe_dense_stack - ) - if same_device: - self.out_buffer = stack(list(buffers.values()), 0) - else: - self.out_buffer = stack( - [item.cpu() for item in buffers.values()], 0 - ) - else: - if self._use_buffers is None: - torchrl_logger.warning( - "use_buffer not specified and not yet inferred from data, assuming `True`." - ) - elif not self._use_buffers: - raise RuntimeError( - "Cannot concatenate results with use_buffers=False" - ) - try: - if same_device: - self.out_buffer = torch.cat(list(buffers.values()), cat_results) - else: - self.out_buffer = torch.cat( - [item.cpu() for item in buffers.values()], cat_results - ) - except RuntimeError as err: - if ( - preempt - and cat_results != -1 - and "Sizes of tensors must match" in str(err) - ): - raise RuntimeError( - "The value provided to cat_results isn't compatible with the collectors outputs. " - "Consider using `cat_results=-1`." - ) - raise - - # TODO: why do we need to do cat inplace and clone? - if self.split_trajs: - out = split_trajectories(self.out_buffer, prefix="collector") - else: - out = self.out_buffer - if cat_results in (-1, "stack"): - out.refine_names(*[None] * (out.ndim - 1) + ["time"]) - - self._frames += n_collected - - if self.postprocs: - self.postprocs = ( - self.postprocs.to(out.device) - if hasattr(self.postprocs, "to") - else self.postprocs - ) - out = self.postprocs(out) - if self._exclude_private_keys: - excluded_keys = [key for key in out.keys() if key.startswith("_")] - if excluded_keys: - out = out.exclude(*excluded_keys) - yield out - del out - - del self.buffers - self.out_buffer = None - # We shall not call shutdown just yet as user may want to retrieve state_dict - # self._shutdown_main() - - -@accept_remote_rref_udf_invocation -class MultiaSyncDataCollector(_MultiDataCollector): - """Runs a given number of DataCollectors on separate processes asynchronously. - - .. aafig:: - - - +----------------------------------------------------------------------+ - | "MultiConcurrentCollector" | | - |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | - | "Collector 1" | "Collector 2" | "Collector 3" | "Main" | - |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| - | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | - |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| - |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | - | | | | | | | | - | "actor" | | | "actor" | | - | | | | | | - | "step" | "step" | "actor" | | | - | | | | | | - | | | | "step" | "step" | | - | | | | | | | - | "actor | "step" | "step" | "actor" | | - | | | | | | - | "yield batch 1" | "actor" | |"collect, train"| - | | | | | - | "step" | "step" | | "yield batch 2" |"collect, train"| - | | | | | | - | | | "yield batch 3" | |"collect, train"| - | | | | | | - +----------------------------------------------------------------------+ - - Environment types can be identical or different. - - The collection keeps on occurring on all processes even between the time - the batch of rollouts is collected and the next call to the iterator. - This class can be safely used with offline RL sota-implementations. - - .. note:: Python requires multiprocessed code to be instantiated within a main guard: - - >>> from torchrl.collectors import MultiaSyncDataCollector - >>> if __name__ == "__main__": - ... # Create your collector here - - See https://docs.python.org/3/library/multiprocessing.html for more info. - - Examples: - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> from torchrl.collectors import MultiaSyncDataCollector - >>> if __name__ == "__main__": - ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - ... collector = MultiaSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... cat_results="stack", - ... ) - ... for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break - ... collector.shutdown() - ... del collector - TensorDict( - fields={ - action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - collector: TensorDict( - fields={ - traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False) - - """ - - __doc__ += _MultiDataCollector.__doc__ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.out_tensordicts = defaultdict(lambda: None) - self.running = False - - if self.postprocs is not None and self.replay_buffer is None: - postproc = self.postprocs - self.postprocs = {} - for _device in self.storing_device: - if _device not in self.postprocs: - if hasattr(postproc, "to"): - postproc = deepcopy(postproc).to(_device) - self.postprocs[_device] = postproc - - # for RPC - def next(self): - return super().next() - - # for RPC - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - if hasattr(self, "out_tensordicts"): - del self.out_tensordicts - if not close_env: - raise RuntimeError( - f"Cannot shutdown {type(self).__name__} collector without environment being closed." - ) - return super().shutdown(timeout=timeout, raise_on_error=raise_on_error) - - # for RPC - def set_seed(self, seed: int, static_seed: bool = False) -> int: - return super().set_seed(seed, static_seed) - - # for RPC - def state_dict(self) -> OrderedDict: - return super().state_dict() - - # for RPC - def load_state_dict(self, state_dict: OrderedDict) -> None: - return super().load_state_dict(state_dict) - - # for RPC - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - super().update_policy_weights_( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: - return self.requested_frames_per_batch - - def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]: - new_data, j = self.queue_out.get(timeout=timeout) - use_buffers = self._use_buffers - if self.replay_buffer is not None: - idx = new_data - elif j == 0 or not use_buffers: - try: - data, idx = new_data - self.out_tensordicts[idx] = data - if use_buffers is None and j > 0: - use_buffers = self._use_buffers = False - except TypeError: - if use_buffers is None: - use_buffers = self._use_buffers = True - idx = new_data - else: - raise - else: - idx = new_data - out = self.out_tensordicts[idx] - if not self.replay_buffer and (j == 0 or use_buffers): - # we clone the data to make sure that we'll be working with a fixed copy - out = out.clone() - return idx, j, out - - @property - def _queue_len(self) -> int: - return 1 - - def iterator(self) -> Iterator[TensorDictBase]: - if self.update_at_each_batch: - self.update_policy_weights_() - - for i in range(self.num_workers): - if self.init_random_frames is not None and self.init_random_frames > 0: - self.pipes[i].send((None, "continue_random")) - else: - self.pipes[i].send((None, "continue")) - self.running = True - - workers_frames = [0 for _ in range(self.num_workers)] - while self._frames < self.total_frames: - self._iter += 1 - counter = 0 - while True: - try: - idx, j, out = self._get_from_queue(timeout=_TIMEOUT) - break - except (TimeoutError, Empty): - counter += _TIMEOUT - _check_for_faulty_process(self.procs) - if counter > (_TIMEOUT * _MAX_IDLE_COUNT): - raise RuntimeError( - f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " - f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." - ) - if self.replay_buffer is None: - worker_frames = out.numel() - if self.split_trajs: - out = split_trajectories(out, prefix="collector") - else: - worker_frames = self.frames_per_batch_worker() - self._frames += worker_frames - workers_frames[idx] = workers_frames[idx] + worker_frames - if out is not None and self.postprocs: - out = self.postprocs[out.device](out) - - # the function blocks here until the next item is asked, hence we send the message to the - # worker to keep on working in the meantime before the yield statement - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - msg = "continue_random" - else: - msg = "continue" - self.pipes[idx].send((idx, msg)) - if out is not None and self._exclude_private_keys: - excluded_keys = [key for key in out.keys() if key.startswith("_")] - out = out.exclude(*excluded_keys) - yield out - - # We don't want to shutdown yet, the user may want to call state_dict before - # self._shutdown_main() - self.running = False - - def _shutdown_main(self, *args, **kwargs) -> None: - if hasattr(self, "out_tensordicts"): - del self.out_tensordicts - return super()._shutdown_main(*args, **kwargs) - - def reset(self, reset_idx: Sequence[bool] | None = None) -> None: - super().reset(reset_idx) - if self.queue_out.full(): - time.sleep(_TIMEOUT) # wait until queue is empty - if self.queue_out.full(): - raise Exception("self.queue_out is full") - if self.running: - for idx in range(self.num_workers): - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - self.pipes[idx].send((idx, "continue_random")) - else: - self.pipes[idx].send((idx, "continue")) - - -@accept_remote_rref_udf_invocation -class aSyncDataCollector(MultiaSyncDataCollector): - """Runs a single DataCollector on a separate process. - - This is mostly useful for offline RL paradigms where the policy being - trained can differ from the policy used to collect data. In online - settings, a regular DataCollector should be preferred. This class is - merely a wrapper around a MultiaSyncDataCollector where a single process - is being created. - - Args: - create_env_fn (Callabled): Callable returning an instance of EnvBase - policy (Callable): Policy to be executed in the environment. - Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. - If ``None`` is provided, the policy used will be a - :class:`~torchrl.collectors.RandomPolicy` instance with the environment - ``action_spec``. - Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. - This is the recommended usage of the collector. - Other callables are accepted too: - If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` - instances) it will be wrapped in a `nn.Module` first. - Then, the collector will try to assess if these - modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - - - If the policy forward signature matches any of ``forward(self, tensordict)``, - ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or - any typing with a single argument typed as a subclass of ``TensorDictBase``) - then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - - - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. - - .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / - pickled directly), the ``policy_factory`` should be used instead. - - Keyword Args: - policy_factory (Callable[[], Callable], optional): a callable that returns - a policy instance. This is exclusive with the `policy` argument. - - .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. - - frames_per_batch (int): A keyword-only argument representing the - total number of elements in a batch. - total_frames (int, optional): A keyword-only argument representing the - total number of frames returned by the collector - during its lifespan. If the ``total_frames`` is not divisible by - ``frames_per_batch``, an exception is raised. - Endless collectors can be created by passing ``total_frames=-1``. - Defaults to ``-1`` (never ending collector). - device (int, str or torch.device, optional): The generic device of the - collector. The ``device`` args fills any non-specified device: if - ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or - ``env_device`` is not specified, its value will be set to ``device``. - Defaults to ``None`` (No default device). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - storing_device (int, str or torch.device, optional): The device on which - the output :class:`~tensordict.TensorDict` will be stored. - If ``device`` is passed and ``storing_device`` is ``None``, it will - default to the value indicated by ``device``. - For long trajectories, it may be necessary to store the data on a different - device than the one where the policy and env are executed. - Defaults to ``None`` (the output tensordict isn't on a specific device, - leaf tensors sit on the device where they were created). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - env_device (int, str or torch.device, optional): The device on which - the environment should be cast (or executed if that functionality is - supported). If not specified and the env has a non-``None`` device, - ``env_device`` will default to that value. If ``device`` is passed - and ``env_device=None``, it will default to ``device``. If the value - as such specified of ``env_device`` differs from ``policy_device`` - and one of them is not ``None``, the data will be cast to ``env_device`` - before being passed to the env (i.e., passing different devices to - policy and env is supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - policy_device (int, str or torch.device, optional): The device on which - the policy should be cast. - If ``device`` is passed and ``policy_device=None``, it will default - to ``device``. If the value as such specified of ``policy_device`` - differs from ``env_device`` and one of them is not ``None``, - the data will be cast to ``policy_device`` before being passed to - the policy (i.e., passing different devices to policy and env is - supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - create_env_kwargs (dict, optional): A dictionary with the - keyword arguments used to create an environment. If a list is - provided, each of its elements will be assigned to a sub-collector. - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span across multiple batches (unless - ``reset_at_each_iter`` is set to ``True``, see below). - Once a trajectory reaches ``n_steps``, the environment is reset. - If the environment wraps multiple environments together, the number - of steps is tracked for each environment independently. Negative - values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e. no maximum number of steps). - init_random_frames (int, optional): Number of frames for which the - policy is ignored before it is called. This feature is mainly - intended to be used in offline/model-based settings, where a - batch of random trajectories can be used to initialize training. - If provided, it will be rounded up to the closest multiple of frames_per_batch. - Defaults to ``None`` (i.e. no random frames). - reset_at_each_iter (bool, optional): Whether environments should be reset - at the beginning of a batch collection. - Defaults to ``False``. - postproc (Callable, optional): A post-processing transform, such as - a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` - instance. - Defaults to ``None``. - split_trajs (bool, optional): Boolean indicating whether the resulting - TensorDict should be split according to the trajectories. - See :func:`~torchrl.collectors.utils.split_trajectories` for more - information. - Defaults to ``False``. - exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, - ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` - or ``torchrl.envs.utils.ExplorationType.MEAN``. - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. - update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` - will be called before (sync) or after (async) each data collection. - Defaults to ``False``. - preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers - that will be allowed to finished collecting their rollout before the rest are forced to end early. - num_threads (int, optional): number of threads for this process. - Defaults to the number of workers. - num_sub_threads (int, optional): number of threads of the subprocesses. - Should be equal to one plus the number of processes launched within - each subprocess (or one if a single process is launched). - Defaults to 1 for safety: if none is indicated, launching multiple - workers may charge the cpu load too much and harm performance. - set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding - ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of - a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. - Truncated keys can be set through ``env.add_truncated_keys``. - Defaults to ``False``. - track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. - This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. - Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track - the policy version. - Defaults to `False`. - - """ - - def __init__( - self, - create_env_fn: Callable[[], EnvBase], - policy: None - | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, - *, - policy_factory: Callable[[], Callable] | None = None, - frames_per_batch: int, - total_frames: int | None = -1, - device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - create_env_kwargs: Sequence[dict[str, Any]] | None = None, - max_frames_per_traj: int | None = None, - init_random_frames: int | None = None, - reset_at_each_iter: bool = False, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - split_trajs: bool | None = None, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - reset_when_done: bool = True, - update_at_each_batch: bool = False, - preemptive_threshold: float | None = None, - num_threads: int | None = None, - num_sub_threads: int = 1, - set_truncated: bool = False, - track_policy_version: bool = False, - **kwargs, - ): - super().__init__( - create_env_fn=[create_env_fn], - policy=policy, - policy_factory=policy_factory, - total_frames=total_frames, - create_env_kwargs=[create_env_kwargs] - if create_env_kwargs - else create_env_kwargs, - max_frames_per_traj=max_frames_per_traj, - frames_per_batch=frames_per_batch, - reset_at_each_iter=reset_at_each_iter, - init_random_frames=init_random_frames, - postproc=postproc, - split_trajs=split_trajs, - device=device, - policy_device=policy_device, - env_device=env_device, - storing_device=storing_device, - exploration_type=exploration_type, - reset_when_done=reset_when_done, - update_at_each_batch=update_at_each_batch, - preemptive_threshold=preemptive_threshold, - num_threads=num_threads, - num_sub_threads=num_sub_threads, - set_truncated=set_truncated, - track_policy_version=track_policy_version, - **kwargs, - ) - - # for RPC - def next(self): - return super().next() - - # for RPC - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - return super().shutdown( - timeout=timeout, close_env=close_env, raise_on_error=raise_on_error - ) - - # for RPC - def set_seed(self, seed: int, static_seed: bool = False) -> int: - return super().set_seed(seed, static_seed) - - # for RPC - def state_dict(self) -> OrderedDict: - return super().state_dict() - - # for RPC - def load_state_dict(self, state_dict: OrderedDict) -> None: - return super().load_state_dict(state_dict) - - -def _main_async_collector( - pipe_parent: connection.Connection, - pipe_child: connection.Connection, - queue_out: queues.Queue, - create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821 - create_env_kwargs: dict[str, Any], - policy: Callable[[TensorDictBase], TensorDictBase], - max_frames_per_traj: int, - frames_per_batch: int, - reset_at_each_iter: bool, - storing_device: torch.device | str | int | None, - env_device: torch.device | str | int | None, - policy_device: torch.device | str | int | None, - idx: int = 0, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - reset_when_done: bool = True, - verbose: bool = VERBOSE, - interruptor=None, - set_truncated: bool = False, - use_buffers: bool | None = None, - replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = True, - traj_pool: _TrajectoryPool = None, - trust_policy: bool = False, - compile_policy: bool = False, - cudagraph_policy: bool = False, - no_cuda_sync: bool = False, - policy_factory: Callable | None = None, - collector_class: type | Callable[[], DataCollectorBase] | None = None, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, -) -> None: - if collector_class is None: - collector_class = SyncDataCollector - pipe_parent.close() - # init variables that will be cleared when closing - collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None - - try: - collector_class._ignore_rb = extend_buffer - inner_collector = collector_class( - create_env_fn, - create_env_kwargs=create_env_kwargs, - policy=policy, - policy_factory=policy_factory, - total_frames=-1, - max_frames_per_traj=max_frames_per_traj, - frames_per_batch=frames_per_batch, - reset_at_each_iter=reset_at_each_iter, - postproc=postproc, - split_trajs=False, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - exploration_type=exploration_type, - reset_when_done=reset_when_done, - return_same_td=replay_buffer is None, - interruptor=interruptor, - set_truncated=set_truncated, - use_buffers=use_buffers, - replay_buffer=replay_buffer, - extend_buffer=False, - traj_pool=traj_pool, - trust_policy=trust_policy, - compile_policy=compile_policy, - cudagraph_policy=cudagraph_policy, - no_cuda_sync=no_cuda_sync, - weight_sync_schemes=weight_sync_schemes, - ) - - # Set up weight receivers for worker process - if weight_sync_schemes: - inner_collector._weight_receivers = {} - inner_collector.pipe = pipe_child # Add pipe attribute for context - for model_id, scheme in weight_sync_schemes.items(): - # Check if scheme has new API or legacy API - if hasattr(scheme, "init_on_worker"): - scheme.init_on_worker(model_id=model_id, context=inner_collector) - receiver = scheme.get_receiver() - else: - # Legacy API - receiver = scheme.create_receiver() - receiver.set_context(inner_collector) - receiver.register_worker_transport(pipe_child) - - model = _resolve_model(inner_collector, model_id) - receiver.register_model(model) - - inner_collector._weight_receivers[model_id] = receiver - else: - inner_collector._weight_receivers = {} - - use_buffers = inner_collector._use_buffers - if verbose: - torchrl_logger.info("Sync data collector created") - dc_iter = iter(inner_collector) - j = 0 - pipe_child.send("instantiated") - except Exception as e: - # Send error information to main process - # We send a dict with the exception info so we can recreate it in the main process - import traceback - - error_info = { - "error": True, - "exception_type": type(e).__name__, - "exception_module": type(e).__module__, - "exception_msg": str(e), - "traceback": traceback.format_exc(), - } - try: - pipe_child.send(error_info) - except Exception: - # If pipe is broken, nothing we can do - pass - return - - has_timed_out = False - counter = 0 - run_free = False - while True: - _timeout = _TIMEOUT if not has_timed_out else 1e-3 - if not run_free and pipe_child.poll(_timeout): - counter = 0 - data_in, msg = pipe_child.recv() - if verbose: - torchrl_logger.info(f"worker {idx} received {msg}") - elif not run_free: - if verbose: - torchrl_logger.info(f"poll failed, j={j}, worker={idx}") - # default is "continue" (after first iteration) - # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe - # in that case, the main process probably expects the worker to continue collect data - if has_timed_out: - counter = 0 - # has_timed_out is True if the process failed to send data, which will - # typically occur if main has taken another batch (i.e. the queue is Full). - # In this case, msg is the previous msg sent by main, which will typically be "continue" - # If it's not the case, it is not expected that has_timed_out is True. - if msg not in ("continue", "continue_random"): - raise RuntimeError(f"Unexpected message after time out: msg={msg}") - else: - # if has_timed_out is False, then the time out does not come from the fact that the queue is Full. - # this means that our process has been waiting for a command from main in vain, while main was not - # receiving data. - # This will occur if main is busy doing something else (e.g. computing loss etc). - - counter += _timeout - if verbose: - torchrl_logger.info(f"worker {idx} has counter {counter}") - if counter >= (_MAX_IDLE_COUNT * _TIMEOUT): - raise RuntimeError( - f"This process waited for {counter} seconds " - f"without receiving a command from main. Consider increasing the maximum idle count " - f"if this is expected via the environment variable MAX_IDLE_COUNT " - f"(current value is {_MAX_IDLE_COUNT})." - f"\nIf this occurs at the end of a function or program, it means that your collector has not been " - f"collected, consider calling `collector.shutdown()` before ending the program." - ) - continue - else: - # placeholder, will be checked after - if msg != "continue": - torchrl_logger.info(f"worker {idx} will reset {msg} to 'continue'") - msg = "continue" - if msg == "run_free": - run_free = True - msg = "continue" - if run_free: - # Capture shutdown / update / seed signal, but continue should not be expected - if pipe_child.poll(1e-4): - data_in, msg = pipe_child.recv() - torchrl_logger.info(f"worker {idx} received {msg} while running free") - if msg == "continue": - # Switch back to run_free = False - run_free = False - if msg == "pause": - queue_out.put((idx, "paused"), timeout=_TIMEOUT) - while not pipe_child.poll(1e-2): - continue - data_in, msg = pipe_child.recv() - if msg != "restart": - raise RuntimeError(f"Expected msg='restart', got {msg=}") - msg = "continue" - else: - data_in = None - # TODO: this does not work with random frames - msg = "continue" - # Note: The "continue" message handling has been moved below after update_weights handling - # to allow falling through from update_weights to continue - - if msg == "update": - torchrl_logger.info(f"worker {idx} updating the params...") - inner_collector.update_policy_weights_(policy_weights=data_in) - pipe_child.send((j, "updated")) - has_timed_out = False - continue - - if msg == "register_shared_weights": - # Shared memory lazy registration: main process sends buffer reference - if verbose: - torchrl_logger.info( - f"worker {idx} received shared memory buffer registration" - ) - model_id, shared_buffer = data_in - - # Store the shared buffer reference for this model - # The receiver will use this buffer for all future weight accesses - if ( - inner_collector._weight_receivers - and model_id in inner_collector._weight_receivers - ): - # Update receiver's buffer reference - receiver = inner_collector._weight_receivers[model_id] - # Store the shared buffer - the model's parameters should point to this - if hasattr(receiver, "_shared_weights"): - receiver._shared_weights[model_id] = shared_buffer - - # Apply the buffer to the model immediately - # Only apply if the model is an nn.Module (has learnable parameters) - try: - model = receiver._resolve_model_ref() - except (ValueError, AttributeError) as e: - # Model not registered or reference is invalid - if verbose: - torchrl_logger.warning( - f"worker {idx} could not resolve model '{model_id}': {e}" - ) - continue - - if isinstance(model, nn.Module): - receiver.apply_weights(shared_buffer) - else: - if verbose: - torchrl_logger.info( - f"worker {idx} skipping weight application for non-nn.Module model '{model_id}'" - ) - - if verbose: - torchrl_logger.info( - f"worker {idx} registered shared buffer for model '{model_id}'" - ) - else: - torchrl_logger.warning( - f"worker {idx} received shared buffer for unknown model '{model_id}'" - ) - - # Send acknowledgment back to main process - pipe_child.send((None, "registered")) - has_timed_out = False - continue - - if msg == "update_weights": - # New weight update protocol for simplified weight sync system - if verbose: - torchrl_logger.info( - f"worker {idx} received weight update via new protocol" - ) - model_id, weights = data_in - - # Apply weights using the appropriate receiver for this model - if ( - inner_collector._weight_receivers - and model_id in inner_collector._weight_receivers - ): - inner_collector._weight_receivers[model_id].apply_weights(weights) - else: - torchrl_logger.warning( - f"worker {idx} received weights for unknown model '{model_id}'" - ) - - # After applying weights, we continue collecting immediately as if we received - # a "continue" message. This ensures the worker keeps collecting data without - # waiting for an explicit continue from the main process. - has_timed_out = False - msg = "continue" - # Now check if we should continue collecting - - if msg in ("continue", "continue_random"): - # This block handles both explicit continue messages and implicit ones after weight updates - if msg == "continue_random": - inner_collector.init_random_frames = float("inf") - else: - inner_collector.init_random_frames = -1 - - # Note: For MultiProcessWeightSyncScheme, weight updates are handled by the - # main message loop above (msg == "update_weights" case). The receiver.receive() - # pattern is only used for schemes with separate communication channels like - # SharedMemWeightSyncScheme (shared memory) or DistributedWeightSyncScheme (TCPStore). - # Calling receiver.receive() here would interfere with the pipe-based message protocol. - - next_data = next(dc_iter) - if pipe_child.poll(_MIN_TIMEOUT): - # in this case, main send a message to the worker while it was busy collecting trajectories. - # In that case, we skip the collected trajectory and get the message from main. This is faster than - # sending the trajectory in the queue until timeout when it's never going to be received. - continue - - if replay_buffer is not None: - if extend_buffer: - next_data.names = None - replay_buffer.extend(next_data) - - if run_free: - continue - - try: - queue_out.put((idx, j), timeout=_TIMEOUT) - if verbose: - torchrl_logger.info(f"worker {idx} successfully sent data") - j += 1 - has_timed_out = False - continue - except queue.Full: - if verbose: - torchrl_logger.info(f"worker {idx} has timed out") - has_timed_out = True - continue - - if j == 0 or not use_buffers: - collected_tensordict = next_data - if ( - storing_device is not None - and collected_tensordict.device != storing_device - ): - raise RuntimeError( - f"expected device to be {storing_device} but got {collected_tensordict.device}" - ) - if use_buffers: - # If policy and env are on cpu, we put in shared mem, - # if policy is on cuda and env on cuda, we are fine with this - # If policy is on cuda and env on cpu (or opposite) we put tensors that - # are on cpu in shared mem. - MPS_ERROR = ( - "tensors on mps device cannot be put in shared memory. Make sure " - "the shared device (aka storing_device) is set to CPU." - ) - if collected_tensordict.device is not None: - # placeholder in case we need different behaviors - if collected_tensordict.device.type in ("cpu",): - collected_tensordict.share_memory_() - elif collected_tensordict.device.type in ("mps",): - raise RuntimeError(MPS_ERROR) - elif collected_tensordict.device.type == "cuda": - collected_tensordict.share_memory_() - else: - raise NotImplementedError( - f"Device {collected_tensordict.device} is not supported in multi-collectors yet." - ) - else: - # make sure each cpu tensor is shared - assuming non-cpu devices are shared - def cast_tensor(x, MPS_ERROR=MPS_ERROR): - if x.device.type in ("cpu",): - x.share_memory_() - if x.device.type in ("mps",): - RuntimeError(MPS_ERROR) - - collected_tensordict.apply(cast_tensor, filter_empty=True) - data = (collected_tensordict, idx) - else: - if next_data is not collected_tensordict: - raise RuntimeError( - "SyncDataCollector should return the same tensordict modified in-place." - ) - data = idx # flag the worker that has sent its data - try: - queue_out.put((data, j), timeout=_TIMEOUT) - if verbose: - torchrl_logger.info(f"worker {idx} successfully sent data") - j += 1 - has_timed_out = False - continue - except queue.Full: - if verbose: - torchrl_logger.info(f"worker {idx} has timed out") - has_timed_out = True - continue - - if msg == "seed": - data_in, static_seed = data_in - new_seed = inner_collector.set_seed(data_in, static_seed=static_seed) - torch.manual_seed(data_in) - np.random.seed(data_in) - pipe_child.send((new_seed, "seeded")) - has_timed_out = False - continue - - elif msg == "reset": - inner_collector.reset() - pipe_child.send((j, "reset")) - continue - - elif msg == "state_dict": - state_dict = inner_collector.state_dict() - # send state_dict to cpu first - state_dict = recursive_map_to_cpu(state_dict) - pipe_child.send((state_dict, "state_dict")) - has_timed_out = False - continue - - elif msg == "load_state_dict": - state_dict = data_in - inner_collector.load_state_dict(state_dict) - del state_dict - pipe_child.send((j, "loaded")) - has_timed_out = False - continue - - elif msg == "getattr_policy": - attr_name = data_in - try: - result = getattr(inner_collector.policy, attr_name) - pipe_child.send((result, "getattr_policy")) - except AttributeError as e: - pipe_child.send((e, "getattr_policy")) - has_timed_out = False - continue - - elif msg == "getattr_env": - attr_name = data_in - try: - result = getattr(inner_collector.env, attr_name) - pipe_child.send((result, "getattr_env")) - except AttributeError as e: - pipe_child.send((e, "getattr_env")) - has_timed_out = False - continue - - elif msg == "close": - del collected_tensordict, data, next_data, data_in - inner_collector.shutdown() - del inner_collector, dc_iter - pipe_child.send("closed") - if verbose: - torchrl_logger.info(f"collector {idx} closed") - break - - else: - raise Exception(f"Unrecognized message {msg}") - - -def _make_meta_params(param): - is_param = isinstance(param, Parameter) - - pd = param.detach().to("meta") - - if is_param: - pd = Parameter(pd, requires_grad=False) - return pd - - -class _TrajectoryPool: - def __init__(self, ctx=None, lock: bool = False): - self.ctx = ctx - self._traj_id = torch.zeros((), device="cpu", dtype=torch.int).share_memory_() - if ctx is None: - self.lock = contextlib.nullcontext() if not lock else mp.RLock() - else: - self.lock = contextlib.nullcontext() if not lock else ctx.RLock() - - def get_traj_and_increment(self, n=1, device=None): - with self.lock: - v = self._traj_id.item() - out = torch.arange(v, v + n).to(device) - self._traj_id.copy_(1 + out[-1].item()) - return out - - -def _map_weight( - weight, - policy_device, -): - - is_param = isinstance(weight, Parameter) - is_buffer = isinstance(weight, Buffer) - weight = weight.data - if weight.device != policy_device: - weight = weight.to(policy_device) - elif weight.device.type in ("cpu",): - weight = weight.share_memory_() - if is_param: - weight = Parameter(weight, requires_grad=False) - elif is_buffer: - weight = Buffer(weight) - return weight +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_base import _MultiDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._runner import _main_async_collector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors._single_async import aSyncDataCollector + +__all__ = [ + "MultiSyncDataCollector", + "MultiaSyncDataCollector", + "_MultiDataCollector", + "SyncDataCollector", + "_main_async_collector", + "aSyncDataCollector", + "DataCollectorBase", + # Constants + "_TIMEOUT", + "INSTANTIATE_TIMEOUT", + "_MIN_TIMEOUT", + "_MAX_IDLE_COUNT", + "DEFAULT_EXPLORATION_TYPE", + "_is_osx", + "_Interruptor", + "_InterruptorManager", + "cudagraph_mark_step_begin", +] diff --git a/torchrl/collectors/distributed/default_configs.py b/torchrl/collectors/distributed/default_configs.py index 8da69010242..cf0eb4da0a6 100644 --- a/torchrl/collectors/distributed/default_configs.py +++ b/torchrl/collectors/distributed/default_configs.py @@ -5,6 +5,13 @@ from __future__ import annotations import os +import random +import socket +from datetime import timedelta + +import torch.distributed + +from torchrl._utils import logger as torchrl_logger TCP_PORT = os.environ.get("TCP_PORT", "10003") IDLE_TIMEOUT = os.environ.get("RCP_IDLE_TIMEOUT", 10) @@ -32,3 +39,95 @@ "rpc_timeout": 10_000, "_transports": ["uv"], } + + +def _find_free_port() -> int: + """Find a free port by binding to port 0 and letting the OS choose.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", 0)) + return s.getsockname()[1] + + +def _create_tcpstore_with_retry( + host_name: str, + port: int | None, + world_size: int, + is_master: bool, + timeout: float = 10.0, + max_retries: int = 10, + wait_for_workers: bool = True, +) -> tuple[torch.distributed.TCPStore, int]: + """Create a TCPStore with retry logic for handling port conflicts. + + This function attempts to create a TCPStore, and if the port is already in use, + it will retry with different random ports up to max_retries times. + + Args: + host_name: The hostname for the TCPStore. + port: The initial port to try. If None, a random port will be chosen. + world_size: The world size for the TCPStore. + is_master: Whether this is the master (server) process. + timeout: Timeout in seconds for the TCPStore. + max_retries: Maximum number of retry attempts. + wait_for_workers: Whether the master should wait for workers. + Only used when is_master=True. + + Returns: + A tuple of (TCPStore, actual_port) where actual_port is the port + that was successfully bound. + + Raises: + RuntimeError: If unable to create a TCPStore after max_retries attempts. + """ + last_error = None + + for attempt in range(max_retries): + if port is None or attempt > 0: + # For the first attempt use provided port, for retries find a new free port + current_port = _find_free_port() + else: + current_port = int(port) + + try: + if is_master: + store = torch.distributed.TCPStore( + host_name=host_name, + port=current_port, + world_size=world_size, + is_master=True, + timeout=timedelta(seconds=timeout), + wait_for_workers=wait_for_workers, + ) + else: + store = torch.distributed.TCPStore( + host_name=host_name, + port=current_port, + is_master=False, + timeout=timedelta(seconds=timeout), + ) + torchrl_logger.debug( + f"TCPStore created successfully on {host_name}:{current_port} " + f"(attempt {attempt + 1}/{max_retries})" + ) + return store, current_port + + except (RuntimeError, OSError) as e: + error_msg = str(e).lower() + if "address already in use" in error_msg or "eaddrinuse" in error_msg: + torchrl_logger.debug( + f"Port {current_port} already in use, " + f"retrying ({attempt + 1}/{max_retries})..." + ) + last_error = e + # Add small random delay to reduce collision probability + import time + + time.sleep(random.uniform(0.01, 0.1)) + continue + # For other errors, re-raise immediately + raise + + raise RuntimeError( + f"Failed to create TCPStore after {max_retries} attempts. Last error: {last_error}" + ) diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 4839259e4ca..2fe678c07cc 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -20,23 +20,24 @@ from tensordict.nn import TensorDictModuleBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._base import DataCollectorBase +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_base import _MultiDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.distributed.default_configs import ( + _create_tcpstore_with_retry, DEFAULT_SLURM_CONF, MAX_TIME_TO_CONNECT, TCP_PORT, ) -from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories +from torchrl.collectors.utils import _cast, _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator +from torchrl.weight_update import DistributedWeightSyncScheme from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme SUBMITIT_ERR = None @@ -54,11 +55,11 @@ def _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose): os.environ["MASTER_PORT"] = str(tcpport) if verbose: - torchrl_logger.info( + torchrl_logger.debug( f"Rank0 IP address: '{rank0_ip}' \ttcp port: '{tcpport}', backend={backend}." ) - torchrl_logger.info( - f"node with rank {rank} with world_size {world_size} -- launching distributed" + torchrl_logger.debug( + f"RANK {rank} with world_size {world_size} -- launching distributed" ) torch.distributed.init_process_group( backend, @@ -68,11 +69,21 @@ def _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose): init_method=f"tcp://{rank0_ip}:{tcpport}", ) if verbose: - torchrl_logger.info(f"Connected!\nNode with rank {rank} -- creating store") + torchrl_logger.debug(f"Connected!\nRANK {rank} -- creating store") + + # Receive actual store port from master via broadcast (master may have used retry) + store_port_tensor = torch.zeros(1, dtype=torch.int64) + torch.distributed.broadcast(store_port_tensor, src=0) + actual_store_port = int(store_port_tensor.item()) + if verbose: + torchrl_logger.debug( + f"RANK {rank} -- received store port {actual_store_port} from master" + ) + # The store carries instructions for the node _store = torch.distributed.TCPStore( host_name=rank0_ip, - port=tcpport + 1, + port=actual_store_port, world_size=world_size, is_master=False, timeout=timedelta(10), @@ -108,19 +119,20 @@ def _distributed_init_delayed( frames_per_batch = output["frames_per_batch"] collector_kwargs = output["collector_kwargs"] _run_collector( - _store, - sync, - collector_class, - num_workers, - env_make, - policy, - frames_per_batch, - collector_kwargs, + _store=_store, + sync=sync, + collector_class=collector_class, + num_workers=num_workers, + env_make=env_make, + policy=policy, + frames_per_batch=frames_per_batch, + collector_kwargs=collector_kwargs, verbose=verbose, ) def _distributed_init_collection_node( + *, rank, rank0_ip, tcpport, @@ -134,24 +146,27 @@ def _distributed_init_collection_node( policy_factory, frames_per_batch, collector_kwargs, + weight_sync_schemes, verbose=True, ): _store = _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose) _run_collector( - _store, - sync, - collector_class, - num_workers, - env_make, - policy, - policy_factory, - frames_per_batch, - collector_kwargs, + _store=_store, + sync=sync, + collector_class=collector_class, + num_workers=num_workers, + env_make=env_make, + policy=policy, + policy_factory=policy_factory, + frames_per_batch=frames_per_batch, + weight_sync_schemes=weight_sync_schemes, + collector_kwargs=collector_kwargs, verbose=verbose, ) def _run_collector( + *, _store, sync, collector_class, @@ -161,12 +176,13 @@ def _run_collector( policy_factory, frames_per_batch, collector_kwargs, + weight_sync_schemes: dict[str, DistributedWeightSyncScheme], verbose=True, ): rank = torch.distributed.get_rank() if verbose: - torchrl_logger.info( - f"node with rank {rank} -- creating collector of type {collector_class}" + torchrl_logger.debug( + f"RANK {rank} -- creating collector of type {collector_class}" ) if not issubclass(collector_class, SyncDataCollector): env_make = [env_make] * num_workers @@ -177,9 +193,21 @@ def _run_collector( "SyncDataCollector and subclasses can only support a single environment." ) + if issubclass(collector_class, _MultiDataCollector) and ( + (not isinstance(policy_factory, Sequence) and policy_factory is not None) + or (isinstance(policy_factory, Sequence) and any(policy_factory)) + ): + # We build an intermediate policy to get the weights from for weight updates. This is slow + # (main -> dist worker -> mp worker), but in some cases there is no alternative + policy = ( + policy_factory[0]() + if isinstance(policy_factory, Sequence) + else policy_factory() + ) + if isinstance(policy, nn.Module): policy_weights = TensorDict.from_module(policy) - policy_weights = policy_weights.data.lock_() + policy_weights = policy_weights.data.apply(_cast, policy_weights).lock_() else: if collector_kwargs.get("weight_updater") is None and ( policy_factory is None @@ -188,50 +216,119 @@ def _run_collector( warnings.warn(_NON_NN_POLICY_WEIGHTS) policy_weights = TensorDict(lock=True) + torchrl_logger.debug(f"RANK {rank} -- init collector") + # NOTE: + # - `weight_sync_schemes` here are the *distributed* schemes used to send + # weights from the main process to this node. + # - Inner multi-process collectors (e.g., MultiSyncDataCollector) should + # manage their own local weight sync schemes (SharedMem / MP) for their + # sub-workers. + # Therefore, we do NOT pass `weight_sync_schemes` down into + # `collector_class` so that it can set up its own local schemes. collector = collector_class( env_make, - policy, + policy=policy, policy_factory=policy_factory, frames_per_batch=frames_per_batch, total_frames=-1, split_trajs=False, **collector_kwargs, ) + + if weight_sync_schemes is not None: + for model_id, scheme in weight_sync_schemes.items(): + torchrl_logger.debug(f"RANK {rank} -- init receiver for model '{model_id}'") + # Provide both collector context and distributed store / rank so the + # scheme can wire its transport correctly. + scheme.init_on_receiver( + model_id=model_id, + context=collector, + # store=_store, + worker_idx=rank, + ) + torchrl_logger.debug(f"RANK {rank} -- initial weight sync (if any)") + scheme.connect() + torchrl_logger.debug( + f"RANK {rank} -- initial weight sync for '{model_id}' completed" + ) + else: + torchrl_logger.debug( + f"RANK {rank} -- {collector_class.__name__} without weight_sync_schemes \n\n" + ) + total_frames = 0 - if verbose: - torchrl_logger.info(f"node with rank {rank} -- loop") while True: + if verbose: + torchrl_logger.debug(f"RANK {rank} -- waiting for instructions") instruction = _store.get(f"NODE_{rank}_in") if verbose: - torchrl_logger.info( - f"node with rank {rank} -- new instruction: {instruction}" - ) + torchrl_logger.debug(f"RANK {rank} -- new instruction: {instruction}") _store.delete_key(f"NODE_{rank}_in") if instruction == b"continue": _store.set(f"NODE_{rank}_status", b"busy") if verbose: - torchrl_logger.info(f"node with rank {rank} -- new data") + torchrl_logger.debug(f"RANK {rank} -- collecting new data") data = collector.next() total_frames += data.numel() if verbose: - torchrl_logger.info(f"got data, total frames = {total_frames}") - torchrl_logger.info(f"node with rank {rank} -- sending {data}") + torchrl_logger.debug( + f"RANK {rank} -- got data, total frames = {total_frames}" + ) + torchrl_logger.debug( + f"RANK {rank} -- data batch_size={data.batch_size}, " + f"keys={list(data.keys(False, True))}" + ) + torchrl_logger.debug( + f"RANK {rank} -- sending TensorDict payload to rank 0" + ) + torchrl_logger.debug(f"RANK {rank} -- {data=}") + if _store.get("TRAINER_status") == b"alive": data.isend(dst=0) if verbose: - torchrl_logger.info(f"node with rank {rank} -- setting to 'done'") + torchrl_logger.debug(f"RANK {rank} -- setting to 'done'") if not sync: _store.set(f"NODE_{rank}_status", b"done") + if verbose: + torchrl_logger.debug(f"RANK {rank} -- set to 'done'") + elif instruction == b"shutdown": if verbose: - torchrl_logger.info(f"node with rank {rank} -- shutting down") + torchrl_logger.debug(f"RANK {rank} -- shutting down") try: collector.shutdown() except Exception: pass _store.set(f"NODE_{rank}_out", b"down") break + elif instruction == b"update_weights": + if verbose: + torchrl_logger.debug(f"RANK {rank} -- updating weights") + + if weight_sync_schemes is not None: + if verbose: + torchrl_logger.debug( + f"RANK {rank} -- using weight sync schemes for update" + ) + # Receive fresh weights from the main process for each model. + # scheme.receive() handles both applying weights locally and + # cascading to sub-collectors via context.update_policy_weights_(). + for model_id, scheme in weight_sync_schemes.items(): + if verbose: + torchrl_logger.debug( + f"RANK {rank} -- receiving weights for model '{model_id}'" + ) + scheme.receive() + if verbose: + torchrl_logger.debug( + f"RANK {rank} -- received and cascaded weights for model '{model_id}'" + ) + + # Acknowledgment is handled by the transport (send_ack in the + # WeightReceiver), so we can continue without touching the + # TCPStore here. + continue if sync: policy_weights.recv(0) else: @@ -424,6 +521,16 @@ class DistributedDataCollector(DataCollectorBase): If not provided, a :class:`~torchrl.collectors.distributed.DistributedWeightUpdater` will be used by default, which handles weight synchronization across distributed workers. Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + SENDING weights to distributed worker collectors. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to send weights via torch.distributed. + If not provided, a :class:`~torchrl.weight_update.DistributedWeightSyncScheme` will be used by default. + This is for propagating weights from the main process to distributed workers. + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from a parent process or training loop. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This is typically used when DistributedDataCollector is itself a worker in a larger distributed setup. + Defaults to ``None``. """ @@ -463,8 +570,12 @@ def __init__( | Callable[[], WeightUpdaterBase] | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, ): + if self._VERBOSE: + torchrl_logger.setLevel("DEBUG") + if collector_class == "async": collector_class = MultiaSyncDataCollector elif collector_class == "sync": @@ -473,7 +584,6 @@ def __init__( collector_class = SyncDataCollector self.collector_class = collector_class self.env_constructors = create_env_fn - self.policy = policy if not isinstance(policy_factory, Sequence): policy_factory = [policy_factory for _ in range(len(self.env_constructors))] self.policy_factory = policy_factory @@ -482,14 +592,12 @@ def __init__( policy_weights = policy_weights.data.lock_() elif any(policy_factory): policy_weights = None - if weight_updater is None: - raise RuntimeError( - "weight_updater must be passed along with " "a policy_factory." - ) else: if not any(policy_factory): warnings.warn(_NON_NN_POLICY_WEIGHTS) policy_weights = TensorDict(lock=True) + self.policy = policy + self._policy_to_send = policy if not any(policy_factory) else None self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch @@ -564,54 +672,22 @@ def __init__( self.backend = backend - # os.environ['TP_SOCKET_IFNAME'] = 'lo' - - self._init_workers() - self._make_container() - # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to Distributed weight sync scheme for distributed collectors - from torchrl.weight_update.weight_sync_schemes import ( - DistributedWeightSyncScheme, - ) + from torchrl.weight_update import DistributedWeightSyncScheme weight_sync_schemes = { "policy": DistributedWeightSyncScheme(backend=backend, sync=self._sync) } if weight_sync_schemes is not None: + torchrl_logger.debug("RANK 0 -- Using weight sync schemes") # Use new weight synchronization system self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - - # Set up weight senders now that remote collectors exist - for model_id, scheme in self._weight_sync_schemes.items(): - sender = scheme.create_sender() - sender._model_id = model_id - - # Create transports for each remote collector - for i in range(self.num_workers): - rank = i + 1 # Workers are 1-indexed in distributed - transport = scheme.create_transport((self._store, rank)) - sender._transports[i] = transport - - # Set context and register model - if hasattr(sender, "set_context"): - sender.set_context(self, model_id) - - # Store reference to source model for automatic extraction - if ( - model_id == "policy" - and hasattr(self, "policy") - and self.policy is not None - ): - sender._source_model = self.policy - - self._weight_senders[model_id] = sender - self.weight_updater = None else: + torchrl_logger.debug("RANK 0 -- Using weight updater") # Fall back to legacy weight updater system if weight_updater is None: weight_updater = DistributedWeightUpdater( @@ -622,7 +698,25 @@ def __init__( ) self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} + + if self._weight_sync_schemes is not None: + # Initialize schemes on the sender (main process) side now that + # worker processes and the store have been created. + for model_id, scheme in self._weight_sync_schemes.items(): + scheme.init_on_sender( + num_workers=self.num_workers, context=self, model_id=model_id + ) + + self._init_workers() + + # Set up weight receivers if provided + if weight_recv_schemes is not None: + self.register_scheme_receiver(weight_recv_schemes) + + self._make_container() + if self._weight_sync_schemes is not None: + for scheme in self._weight_sync_schemes.values(): + scheme.connect() @property def device(self) -> list[torch.device]: @@ -689,11 +783,10 @@ def _init_master_dist( world_size, backend, ): - if self._VERBOSE: - torchrl_logger.info( - f"launching main node with tcp port '{self.tcp_port}' and " - f"IP '{self.IPAddr}'. rank: 0, world_size: {world_size}, backend={backend}." - ) + torchrl_logger.debug( + f"RANK 0 -- launching main node with tcp port '{self.tcp_port}' and " + f"IP '{self.IPAddr}'. rank: 0, world_size: {world_size}, backend={backend}." + ) os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -705,27 +798,39 @@ def _init_master_dist( timeout=timedelta(MAX_TIME_TO_CONNECT), init_method=f"tcp://{self.IPAddr}:{TCP_PORT}", ) - if self._VERBOSE: - torchrl_logger.info("main initiated! Launching store...") - self._store = torch.distributed.TCPStore( + torchrl_logger.debug("RANK 0 -- main initiated! Launching store...") + # Use retry logic to handle port conflicts + self._store, self._store_port = _create_tcpstore_with_retry( host_name=self.IPAddr, port=int(TCP_PORT) + 1, world_size=self.num_workers + 1, is_master=True, - timeout=timedelta(10), + timeout=10.0, + wait_for_workers=False, # Don't wait - we need to broadcast port first ) - if self._VERBOSE: - torchrl_logger.info("done. Setting status to 'alive'") + torchrl_logger.debug( + f"RANK 0 -- store created on port {self._store_port}. Broadcasting to workers..." + ) + # Broadcast actual store port to all workers + store_port_tensor = torch.tensor([self._store_port], dtype=torch.int64) + torch.distributed.broadcast(store_port_tensor, src=0) + torchrl_logger.debug("RANK 0 -- done. Setting status to 'alive'") self._store.set("TRAINER_status", b"alive") def _make_container(self): - if self._VERBOSE: - torchrl_logger.info("making container") + torchrl_logger.debug("RANK 0 -- making container") env_constructor = self.env_constructors[0] - kwargs = self.collector_kwargs[0] + kwargs = self.collector_kwargs[ + 0 + ].copy() # Create a copy to avoid modifying the original + # Mirror the SyncDataCollector configuration used on the workers so + # that the dummy batch structure matches what remote ranks will send. + # _run_collector always sets return_same_td=True for SyncDataCollector, + # so we must do the same here to ensure structural consistency. + kwargs["return_same_td"] = True pseudo_collector = SyncDataCollector( env_constructor, - policy=self.policy, + policy=self.policy if not self.policy_factory[0] else None, policy_factory=self.policy_factory[0], frames_per_batch=self._frames_per_batch_corrected, total_frames=-1, @@ -734,12 +839,15 @@ def _make_container(self): ) for _data in pseudo_collector: break - if self._VERBOSE: - torchrl_logger.info(f"got data {_data}") - torchrl_logger.info("expanding...") - self._tensordict_out = _data.expand((self.num_workers, *_data.shape)) - if self._VERBOSE: - torchrl_logger.info("locking") + torchrl_logger.debug(f"RANK 0 -- got dummy batch: {_data}") + torchrl_logger.debug("RANK 0 -- expanding...") + self._tensordict_out = ( + _data.expand((self.num_workers, *_data.shape)).clone().to_lazystack(0) + ) + torchrl_logger.debug( + f"RANK 0 -- expanded recv buffer spec: {self._tensordict_out}" + ) + torchrl_logger.debug("RANK 0 -- locking") if self._sync: self._tensordict_out.lock_() self._tensordict_out_unbind = self._tensordict_out.unbind(0) @@ -749,12 +857,10 @@ def _make_container(self): self._tensordict_out = self._tensordict_out.unbind(0) for td in self._tensordict_out: td.lock_() - if self._VERBOSE: - torchrl_logger.info("storage created:") - torchrl_logger.info("shutting down...") + torchrl_logger.debug("RANK 0 -- storage created:") + torchrl_logger.debug("RANK 0 -- shutting down...") pseudo_collector.shutdown() - if self._VERBOSE: - torchrl_logger.info("dummy collector shut down!") + torchrl_logger.debug("RANK 0 -- dummy collector shut down!") del pseudo_collector def _init_worker_dist_submitit(self, executor, i): @@ -764,20 +870,21 @@ def _init_worker_dist_submitit(self, executor, i): TCP_PORT = self.tcp_port job = executor.submit( _distributed_init_collection_node, - i + 1, - self.IPAddr, - int(TCP_PORT), - self._sync, - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self._VERBOSE, + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + sync=self._sync, + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self._policy_to_send, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + weight_sync_schemes=self._weight_sync_schemes, + collector_kwargs=self.collector_kwargs[i], + verbose=self._VERBOSE, ) return job @@ -812,21 +919,22 @@ def _init_worker_dist_mp(self, i): TCP_PORT = self.tcp_port job = _ProcessNoWarn( target=_distributed_init_collection_node, - args=( - i + 1, - self.IPAddr, - int(TCP_PORT), - self._sync, - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self._VERBOSE, + kwargs=dict( # noqa: C408 + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + sync=self._sync, + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self._policy_to_send, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + collector_kwargs=self.collector_kwargs[i], + weight_sync_schemes=self._weight_sync_schemes, + verbose=self._VERBOSE, ), ) job.start() @@ -839,8 +947,7 @@ def _init_workers(self): IPAddr = socket.gethostbyname(hostname) else: IPAddr = "localhost" - if self._VERBOSE: - torchrl_logger.info(f"Server IP address: {IPAddr}") + torchrl_logger.debug(f"RANK 0 -- Server IP address: {IPAddr}") self.IPAddr = IPAddr os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -855,21 +962,20 @@ def _init_workers(self): self._init_worker_dist_submitit_delayed() else: for i in range(self.num_workers): - if self._VERBOSE: - torchrl_logger.info("Submitting job") + torchrl_logger.debug("RANK 0 -- Submitting job") if self.launcher == "submitit": job = self._init_worker_dist_submitit( executor, i, ) - if self._VERBOSE: - torchrl_logger.info(f"job id {job.job_id}") # ID of your job + torchrl_logger.debug( + f"RANK 0 -- job id {job.job_id}" + ) # ID of your job elif self.launcher == "mp": job = self._init_worker_dist_mp( i, ) - if self._VERBOSE: - torchrl_logger.info("job launched") + torchrl_logger.debug("RANK 0 -- job launched") self.jobs.append(job) self._init_master_dist(self.num_workers + 1, self.backend) @@ -877,21 +983,21 @@ def iterator(self): yield from self._iterator_dist() def _iterator_dist(self): - if self._VERBOSE: - torchrl_logger.info("iterating...") + torchrl_logger.debug("RANK 0 -- iterating...") total_frames = 0 if not self._sync: for rank in range(1, self.num_workers + 1): - if self._VERBOSE: - torchrl_logger.info(f"sending 'continue' to {rank}") + torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers = [] for i in range(self.num_workers): rank = i + 1 + torchrl_logger.debug(f"RANK 0 -- receiving {rank=}") trackers.append( self._tensordict_out[i].irecv(src=rank, return_premature=True) ) + torchrl_logger.debug(f"RANK 0 -- trackers: {trackers}") while total_frames < self.total_frames: if self._sync: @@ -912,19 +1018,22 @@ def _iterator_dist(self): self._batches_since_weight_update[j] > self.max_weight_update_interval ): + torchrl_logger.debug(f"RANK 0 -- updating weights for {rank=}") self.update_policy_weights_( policy_weights=None, worker_ids=rank ) for i in range(self.num_workers): rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"shutting down rank {rank}.") + torchrl_logger.debug(f"RANK 0 -- shutting down rank {rank}.") self._store.set(f"NODE_{rank}_in", b"shutdown") def _next_sync(self, total_frames): # in the 'sync' case we should update before collecting the data if self.update_after_each_batch: + torchrl_logger.debug( + f"RANK 0 -- updating weights for {total_frames=} in _next_sync." + ) self.update_policy_weights_() else: for j in range(self.num_workers): @@ -932,12 +1041,12 @@ def _next_sync(self, total_frames): if total_frames < self.total_frames: for rank in range(1, self.num_workers + 1): - if self._VERBOSE: - torchrl_logger.info(f"sending 'continue' to {rank}") + torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers = [] for i in range(self.num_workers): rank = i + 1 + torchrl_logger.debug(f"RANK 0 -- receiving {rank=} in _next_sync.") trackers.append( self._tensordict_out_unbind[i].irecv(src=rank, return_premature=True) ) @@ -958,16 +1067,21 @@ def _next_async(self, total_frames, trackers): while data is None: for i in range(self.num_workers): rank = i + 1 + torchrl_logger.debug(f"RANK 0 -- checking {rank=} in _next_async.") if self._store.get(f"NODE_{rank}_status") == b"done": + torchrl_logger.debug(f"RANK 0 -- receiving {rank=} in _next_async.") for _tracker in trackers[i]: _tracker.wait() + torchrl_logger.debug(f"RANK 0 -- received {rank=} in _next_async.") data = self._tensordict_out[i].clone() if self.update_after_each_batch: + torchrl_logger.debug( + f"RANK 0 -- updating weights for {rank=} in _next_async." + ) self.update_policy_weights_(worker_ids=rank) total_frames += data.numel() if total_frames < self.total_frames: - if self._VERBOSE: - torchrl_logger.info(f"sending 'continue' to {rank}") + torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers[i] = self._tensordict_out[i].irecv( src=i + 1, return_premature=True @@ -977,34 +1091,6 @@ def _next_async(self, total_frames, trackers): break return data, total_frames - def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: - """Extract weights from a model if needed. - - For distributed collectors, when weights is None and we have a weight sync scheme, - extract fresh weights from the tracked policy model. - """ - scheme = ( - self._weight_sync_schemes.get(model_id) - if self._weight_sync_schemes - else None - ) - - if weights is None and scheme is not None: - # Extract fresh weights from the source model - sender = self._weight_senders.get(model_id) - if ( - sender - and hasattr(sender, "_source_model") - and sender._source_model is not None - ): - # For distributed collectors, we need TensorDict format for isend/irecv - from tensordict import TensorDict - - return TensorDict.from_module(sender._source_model).data.lock_() - - # Fall back to base class implementation - return super()._extract_weights_if_needed(weights, model_id) - def set_seed(self, seed: int, static_seed: bool = False) -> int: for i in range(self.num_workers): rank = i + 1 @@ -1028,13 +1114,11 @@ def shutdown(self, timeout: float | None = None) -> None: self._store.set("TRAINER_status", b"shutdown") for i in range(self.num_workers): rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"shutting down node with rank={rank}") + torchrl_logger.debug(f"shutting down node with rank={rank}") self._store.set(f"NODE_{rank}_in", b"shutdown") for i in range(self.num_workers): rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"getting status of node {rank}") + torchrl_logger.debug(f"getting status of node {rank}") status = self._store.get(f"NODE_{rank}_out") if status != b"down": raise RuntimeError(f"Expected 'down' but got status {status}.") @@ -1048,13 +1132,16 @@ def shutdown(self, timeout: float | None = None) -> None: self.jobs[i].result() elif self.launcher == "submitit_delayed": pass - if self._VERBOSE: - torchrl_logger.info("collector shut down") + torchrl_logger.debug("collector shut down") class DistributedWeightUpdater(WeightUpdaterBase): """A remote weight updater for synchronizing policy weights across distributed workers. + .. warning:: + This class has been deprecated in favor of the :class:`~torchrl.weight_update.DistributedWeightSyncScheme` + API. + The `DistributedWeightUpdater` class provides a mechanism for updating the weights of a policy across distributed inference workers. It is designed to work with the :class:`~torchrl.collectors.distributed.DistributedDataCollector` to ensure that each worker receives the latest policy weights. @@ -1090,7 +1177,7 @@ class DistributedWeightUpdater(WeightUpdaterBase): """ - _VERBOSE = True + _VERBOSE = False def __init__( self, @@ -1135,8 +1222,7 @@ def _push_weights( ) for i in workers: rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"updating weights of {rank}") + torchrl_logger.debug(f"updating weights of {rank}") self._store.set(f"NODE_{rank}_in", b"update_weights") if self._sync: weights.send(rank) diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index b8b28345872..9fefe6fd29c 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -16,13 +16,11 @@ from tensordict import TensorDict, TensorDictBase from torchrl._utils import as_remote, logger as torchrl_logger -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._base import DataCollectorBase +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.collectors.weight_update import RayWeightUpdater, WeightUpdaterBase from torchrl.data import ReplayBuffer @@ -74,7 +72,7 @@ def print_remote_collector_info(self): f"{get_node_ip_address()} using gpus {ray.get_gpu_ids()}" ) # torchrl_logger.warning(s) - torchrl_logger.info(s) + torchrl_logger.debug(s) class RayCollector(DataCollectorBase): @@ -267,10 +265,25 @@ class RayCollector(DataCollectorBase): If not provided, a :class:`~torchrl.collectors.RayWeightUpdater` will be used by default, leveraging Ray's distributed capabilities. Consider using a constructor if the updater needs to be serialized. - weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary mapping model identifiers to - :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` instances. - This is the recommended way to configure weight synchronization. If not provided, + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + SENDING weights to remote collector workers. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to send weights via Ray. + This is the recommended way to configure weight synchronization for propagating weights + from the main process to remote collectors. If not provided, defaults to ``{"policy": RayWeightSyncScheme()}``. + + .. note:: Weight synchronization is lazily initialized. When using ``policy_factory`` + without a central ``policy``, weight sync is deferred until the first call to + :meth:`~torchrl.collectors.DataCollector.update_policy_weights_` with actual weights. + This allows sub-collectors to each have their own independent policies created via + the factory. If you have a central policy and want to sync its weights to remote + collectors, call ``update_policy_weights_(policy)`` before starting iteration. + + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from a parent process or training loop. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This is typically used when RayCollector is itself a worker in a larger distributed setup. + Defaults to ``None``. use_env_creator (bool, optional): if ``True``, the environment constructor functions will be wrapped in :class:`~torchrl.envs.EnvCreator`. This is useful for multiprocessed settings where shared memory needs to be managed, but Ray has its own object storage mechanism, so this is typically not needed. @@ -340,6 +353,7 @@ def __init__( | Callable[[], WeightUpdaterBase] | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, use_env_creator: bool = False, no_cuda_sync: bool | None = None, ): @@ -541,31 +555,42 @@ def check_list_length_consistency(*lists): # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to Ray weight sync scheme for Ray collectors - from torchrl.weight_update.weight_sync_schemes import RayWeightSyncScheme + from torchrl.weight_update import RayWeightSyncScheme weight_sync_schemes = {"policy": RayWeightSyncScheme()} if weight_sync_schemes is not None: + torchrl_logger.debug("RayCollector: Using weight sync schemes") # Use new weight synchronization system self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - # Set up weight senders using the new simplified API + # Initialize schemes on the sender (main process) side + # Pass remote collectors as the "workers" for Ray schemes for model_id, scheme in self._weight_sync_schemes.items(): - # Initialize the scheme on the sender (main process) side - # Pass remote collectors as the "workers" for Ray schemes + torchrl_logger.debug( + f"RayCollector: Initializing sender for model '{model_id}'" + ) scheme.init_on_sender( model_id=model_id, remote_collectors=self.remote_collectors, - source_model=self.policy if model_id == "policy" else None, + model=self.policy if model_id == "policy" else None, + context=self, ) - # Get the configured sender from the scheme - sender = scheme.get_sender() - self._weight_senders[model_id] = sender + # Set up receiver schemes on remote collectors + # This enables the remote collectors to receive weight updates + for remote_collector in self.remote_collectors: + torchrl_logger.debug( + f"RayCollector: Registering scheme receiver for remote collector {remote_collector}" + ) + fut = remote_collector.register_scheme_receiver.remote( + self._weight_sync_schemes, synchronize_weights=False + ) + ray.get(fut) self.weight_updater = None # Don't use legacy system else: + torchrl_logger.debug("RayCollector: Using legacy weight updater system") # Fall back to legacy weight updater system if weight_updater is None: weight_updater = RayWeightUpdater( @@ -575,12 +600,113 @@ def check_list_length_consistency(*lists): ) self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} + + # Always initialize this flag - legacy system doesn't need lazy init + # but we set it for consistency + self._weight_sync_initialized = False + + # Set up weight receivers if provided + if weight_recv_schemes is not None: + torchrl_logger.debug("RayCollector: Setting up weight receivers...") + self.register_scheme_receiver(weight_recv_schemes) + + if not self._weight_sync_initialized: + self._lazy_initialize_weight_sync() # Print info of all remote workers (fire and forget - no need to wait) for e in self.remote_collectors: e.print_remote_collector_info.remote() + def _lazy_initialize_weight_sync(self) -> None: + """Initialize weight synchronization lazily on first update_policy_weights_() call. + + This method performs the initial weight synchronization that was deferred from __init__. + It must be called before collection begins if weights need to be synced from a central policy. + + The synchronization is done here (not in __init__) because: + 1. When using policy_factory, there may be no central policy to sync from + 2. Users may want to train the policy first before syncing weights + 3. Different sub-collectors may have different policies via policy_factory + """ + if self._weight_sync_initialized: + return + + if self._weight_sync_schemes is None: + # Legacy weight updater system doesn't use lazy init + self._weight_sync_initialized = True + return + + torchrl_logger.debug("RayCollector: Performing lazy weight synchronization") + + # Cascade synchronize_weights to remote collectors + torchrl_logger.debug( + "RayCollector: Cascading synchronize_weights to remote collectors" + ) + self._sync_futures = [] + for remote_collector in self.remote_collectors: + for model_id in self._weight_sync_schemes: + self._sync_futures.append( + remote_collector.cascade_execute.remote( + f"_receiver_schemes['{model_id}'].connect" + ) + ) + + # Synchronize weights for each scheme + for model_id, scheme in self._weight_sync_schemes.items(): + torchrl_logger.debug( + f"RayCollector: Synchronizing weights for model '{model_id}'" + ) + scheme.connect() + + # Block sync + torchrl_logger.debug( + "RayCollector: Waiting for weight synchronization to finish" + ) + ray.get(self._sync_futures) + self._weight_sync_initialized = True + torchrl_logger.debug("RayCollector: Weight synchronization complete") + + def _weight_update_impl( + self, + policy_or_weights: TensorDictBase | nn.Module | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + """Override to trigger lazy weight sync initialization on first call. + + When using policy_factory without a central policy, weight synchronization + is deferred until this method is called with actual weights. + """ + # Trigger lazy initialization if not already done + if not self._weight_sync_initialized: + self._lazy_initialize_weight_sync() + + # Call parent implementation + return super()._weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + + # def _send_weights_scheme(self, *, scheme, processed_weights, worker_ids, model_id): + # if not worker_ids: + # worker_ids = list(range(self.num_collectors)) + # futures = [] + # for worker_id in worker_ids: + # torchrl_logger.debug(f"RayCollector: Sending weights to remote worker {worker_id}") + # # Call irecv + # fut = self.remote_collectors[worker_id].cascade_execute.remote(f"_receiver_schemes['{model_id}'].receive") + # futures.append(fut) + # torchrl_logger.debug(f"RayCollector: calling isend") + # scheme.send(weights=processed_weights, worker_ids=worker_ids) + # torchrl_logger.debug(f"RayCollector: Waiting for {len(futures)} irecv calls to finish") + # ray.get(futures) + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: """Extract weights from a model if needed. @@ -594,17 +720,13 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: ) if weights is None and scheme is not None: - # Extract fresh weights from the source model - sender = self._weight_senders.get(model_id) - if ( - sender - and hasattr(sender, "_source_model") - and sender._source_model is not None - ): + # Extract fresh weights from the scheme's model + model = scheme.model + if model is not None: from torchrl.weight_update.weight_sync_schemes import WeightStrategy - strategy = WeightStrategy(extract_as=scheme.strategy) - return strategy.extract_weights(sender._source_model) + strategy = WeightStrategy(extract_as=scheme.strategy_str) + return strategy.extract_weights(model) # Fall back to base class behavior return super()._extract_weights_if_needed(weights, model_id) @@ -678,9 +800,13 @@ def add_collectors( remote_configs, ): """Creates and adds a number of remote collectors to the set.""" - for env_maker, other_params, remote_config in zip( - create_env_fn, collector_kwargs, remote_configs + for i, (env_maker, other_params, remote_config) in enumerate( + zip(create_env_fn, collector_kwargs, remote_configs) ): + # Add worker_idx to params so remote collectors know their index + other_params = dict(other_params) # Make a copy to avoid mutating original + other_params["worker_idx"] = i + cls = self.collector_class.as_remote(remote_config).remote collector = self._make_collector( cls, @@ -715,6 +841,17 @@ def stop_remote_collectors(self): ) # This will interrupt any running tasks on the actor, causing them to fail immediately def iterator(self): + # Warn if weight sync wasn't initialized before collection starts + if not self._weight_sync_initialized and self._weight_sync_schemes is not None: + warnings.warn( + "RayCollector iteration started before weight synchronization was initialized. " + "Call update_policy_weights_(policy_or_weights) before iterating to sync weights " + "from a central policy to remote collectors. If using policy_factory with " + "independent policies on each collector, you can ignore this warning.", + UserWarning, + stacklevel=2, + ) + def proc(data): # When using RayReplayBuffer, sub-collectors write directly to buffer # and return None, so skip processing @@ -757,7 +894,7 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]: self.collected_frames < self.total_frames and not self._stop_event.is_set() ): if self.update_after_each_batch or self.max_weight_update_interval > -1: - torchrl_logger.info("Updating weights on all workers") + torchrl_logger.debug("Updating weights on all workers") self.update_policy_weights_() # Ask for batches to all remote workers. @@ -874,7 +1011,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: yield out_td if self.update_after_each_batch or self.max_weight_update_interval > -1: - torchrl_logger.info(f"Updating weights on worker {collector_index}") + torchrl_logger.debug(f"Updating weights on worker {collector_index}") self.update_policy_weights_(worker_ids=collector_index + 1) # Schedule a new collection task diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 3d86bbc5422..67d0be05046 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -23,14 +23,12 @@ from torch.distributed import rpc from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE +from torchrl.collectors._base import DataCollectorBase -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.distributed import DEFAULT_SLURM_CONF from torchrl.collectors.distributed.default_configs import ( DEFAULT_TENSORPIPE_OPTIONS, @@ -61,11 +59,23 @@ def _rpc_init_collection_node( world_size, visible_device, tensorpipe_options, + backend="gloo", verbose=VERBOSE, ): os.environ["MASTER_ADDR"] = str(rank0_ip) os.environ["MASTER_PORT"] = str(tcp_port) + # Initialize torch.distributed process group for efficient weight transfer + if verbose: + torchrl_logger.debug( + f"init distributed with rank={rank}, world_size={world_size}, backend={backend}" + ) + torch.distributed.init_process_group( + backend=backend, + rank=rank, + world_size=world_size, + ) + if isinstance(visible_device, list): pass elif isinstance(visible_device, (str, int, torch.device)): @@ -80,7 +90,7 @@ def _rpc_init_collection_node( **tensorpipe_options, ) if verbose: - torchrl_logger.info( + torchrl_logger.debug( f"init rpc with master addr: {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" ) rpc.init_rpc( @@ -91,6 +101,7 @@ def _rpc_init_collection_node( world_size=world_size, ) rpc.shutdown() + torch.distributed.destroy_process_group() class RPCDataCollector(DataCollectorBase): @@ -260,6 +271,9 @@ class RPCDataCollector(DataCollectorBase): https://github.com/facebookincubator/submitit Defaults to "submitit". tcp_port (int, optional): the TCP port to be used. Defaults to 10003. + backend (str, optional): the torch.distributed backend to use for weight synchronization. + Must be one of ``"gloo"``, ``"mpi"``, ``"nccl"`` or ``"ucc"``. See the torch.distributed + documentation for more information. Defaults to ``"gloo"``. visible_devices (list of Union[int, torch.device, str], optional): a list of the same length as the number of nodes containing the device used to pass data to main. @@ -270,6 +284,16 @@ class RPCDataCollector(DataCollectorBase): If not provided, an :class:`~torchrl.collectors.distributed.RPCWeightUpdater` will be used by default, which handles weight synchronization via RPC. Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + SENDING weights to remote collector workers. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to send weights via RPC. + If not provided, an :class:`~torchrl.weight_update.RPCWeightSyncScheme` will be used by default. + This is for propagating weights from the main process to remote collectors. + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from a parent process or training loop. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This is typically used when RPCDataCollector is itself a worker in a larger distributed setup. + Defaults to ``None``. """ @@ -304,13 +328,19 @@ def __init__( max_weight_update_interval: int = -1, launcher: str = "submitit", tcp_port: str | None = None, + backend: str = "gloo", visible_devices: list[torch.device] | None = None, tensorpipe_options: dict[str, Any] | None = None, weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, ): + + if self._VERBOSE: + torchrl_logger.setLevel("DEBUG") + if collector_class == "async": collector_class = MultiaSyncDataCollector elif collector_class == "sync": @@ -407,6 +437,7 @@ def __init__( self.postproc = postproc self.split_trajs = split_trajs + self.backend = backend if tensorpipe_options is None: self.tensorpipe_options = copy(DEFAULT_TENSORPIPE_OPTIONS) @@ -414,50 +445,17 @@ def __init__( self.tensorpipe_options = copy(DEFAULT_TENSORPIPE_OPTIONS).update( tensorpipe_options ) - self._init() # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to RPC weight sync scheme for RPC collectors - from torchrl.weight_update.weight_sync_schemes import RPCWeightSyncScheme + from torchrl.weight_update import RPCWeightSyncScheme weight_sync_schemes = {"policy": RPCWeightSyncScheme()} if weight_sync_schemes is not None: # Use new weight synchronization system self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - - # Set up weight senders now that remote collectors exist - for model_id, scheme in self._weight_sync_schemes.items(): - sender = scheme.create_sender() - sender._model_id = model_id - - # Create transports for each remote collector - for i in range(self.num_workers): - transport = scheme.create_transport( - ( - self.collector_infos[i], - self.collector_rrefs[i], - self.collector_class, - ) - ) - sender._transports[i] = transport - - # Set context and register model - if hasattr(sender, "set_context"): - sender.set_context(self, model_id) - - # Store reference to source model for automatic extraction - if ( - model_id == "policy" - and hasattr(self, "policy") - and self.policy is not None - ): - sender._source_model = self.policy - - self._weight_senders[model_id] = sender - self.weight_updater = None else: # Fall back to legacy weight updater system @@ -471,7 +469,22 @@ def __init__( ) self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} + + self._init() + + if weight_sync_schemes is not None: + # Set up weight senders now that remote collectors exist + for model_id, scheme in self._weight_sync_schemes.items(): + scheme.init_on_sender( + model_id=model_id, + num_workers=self.num_workers, + context=self, + ) + scheme.connect() + + # Set up weight receivers if provided + if weight_recv_schemes is not None: + self.register_scheme_receiver(weight_recv_schemes) @property def device(self) -> list[torch.device]: @@ -537,7 +550,18 @@ def _init_master_rpc( self, world_size, ): - """Init RPC on main node.""" + """Init torch.distributed and RPC on main node.""" + # Initialize torch.distributed process group for efficient weight transfer + torchrl_logger.debug( + f"init distributed with rank=0, world_size={world_size}, backend={self.backend}" + ) + torch.distributed.init_process_group( + backend=self.backend, + rank=0, + world_size=world_size, + ) + + # Initialize RPC for control/signaling options = rpc.TensorPipeRpcBackendOptions(**self.tensorpipe_options) if torch.cuda.is_available(): if self.visible_devices: @@ -546,8 +570,7 @@ def _init_master_rpc( options.set_device_map( f"COLLECTOR_NODE_{rank}", {0: self.visible_devices[i]} ) - if self._VERBOSE: - torchrl_logger.info("init rpc") + torchrl_logger.debug("init rpc") rpc.init_rpc( "TRAINER_NODE", rank=0, @@ -578,10 +601,7 @@ def _start_workers( counter += 1 time.sleep(time_interval) try: - if self._VERBOSE: - torchrl_logger.info( - f"trying to connect to collector node {i + 1}" - ) + torchrl_logger.debug(f"trying to connect to collector node {i + 1}") collector_info = rpc.get_worker_info(f"COLLECTOR_NODE_{i + 1}") break except RuntimeError as err: @@ -595,8 +615,22 @@ def _start_workers( env_make = env_constructors[i] if not isinstance(env_make, (EnvBase, EnvCreator)): env_make = CloudpickleWrapper(env_make) - if self._VERBOSE: - torchrl_logger.info("Making collector in remote node") + torchrl_logger.debug("Making collector in remote node") + # When using weight sync schemes together with a policy_factory, the + # main-node `policy` should be used only as a weight source on the + # trainer, and NOT sent to remote collectors (which will build their + # own policies from the factory). This mirrors the behaviour of + # `DistributedDataCollector` with multi-process collectors. + policy_to_send = ( + None + if ( + policy is not None + and policy_factory[i] is not None + and getattr(self, "_weight_sync_schemes", None) is not None + ) + else policy + ) + collector_rref = rpc.remote( collector_infos[i], collector_class, @@ -604,29 +638,40 @@ def _start_workers( [env_make] * num_workers_per_collector if collector_class is not SyncDataCollector else env_make, - policy, + policy_to_send, ), kwargs={ "policy_factory": policy_factory[i], "frames_per_batch": frames_per_batch, "total_frames": -1, "split_trajs": False, + "weight_recv_schemes": self._weight_sync_schemes, + "worker_idx": i, **collector_kwargs[i], }, ) collector_rrefs.append(collector_rref) + # Set up receiver schemes on remote collectors (if using new weight sync system) + # This enables cascading: RPC -> MultiSync -> Sync + if getattr(self, "_weight_sync_schemes", None) is not None: + for i in range(num_workers): + torchrl_logger.debug( + f"Setting up receiver schemes on remote collector {i}" + ) + # Call register_scheme_receiver on the remote collector using rref.rpc_sync() + # This properly dereferences the rref and calls the instance method + collector_rrefs[i].rpc_sync().register_scheme_receiver( + self._weight_sync_schemes + ) + futures = collections.deque(maxlen=self.num_workers) if not self._sync: for i in range(num_workers): - if self._VERBOSE: - torchrl_logger.info("Asking for the first batch") - future = rpc.rpc_async( - collector_infos[i], - collector_class.next, - args=(collector_rrefs[i],), - ) + torchrl_logger.debug("Asking for the first batch") + # Use rref.rpc_async() to properly call instance method + future = collector_rrefs[i].rpc_async().next() futures.append((future, i)) self.futures = futures self.collector_rrefs = collector_rrefs @@ -648,10 +693,10 @@ def _init_worker_rpc(self, executor, i): self.num_workers + 1, visible_device, self.tensorpipe_options, + self.backend, self._VERBOSE, ) - if self._VERBOSE: - torchrl_logger.info(f"job id {job.job_id}") # ID of your job + torchrl_logger.debug(f"job id {job.job_id}") # ID of your job return job elif self.launcher == "mp": job = _ProcessNoWarn( @@ -663,6 +708,7 @@ def _init_worker_rpc(self, executor, i): self.num_workers + 1, visible_device, self.tensorpipe_options, + self.backend, self._VERBOSE, ), ) @@ -694,8 +740,7 @@ def _init(self): self.jobs = [] for i in range(self.num_workers): - if self._VERBOSE: - torchrl_logger.info(f"Submitting job {i}") + torchrl_logger.debug(f"Submitting job {i}") job = self._init_worker_rpc( executor, i, @@ -737,10 +782,9 @@ def iterator(self): self._batches_since_weight_update[j] > self.max_weight_update_interval ): - if self._VERBOSE: - torchrl_logger.info( - f"Updating policy of worker {j} with wait=False" - ) + torchrl_logger.debug( + f"Updating policy of worker {j} with wait=False" + ) self.update_policy_weights_(worker_ids=[j], wait=False) elif self.max_weight_update_interval > -1: ranks = [ @@ -749,15 +793,13 @@ def iterator(self): if self._batches_since_weight_update[j] > self.max_weight_update_interval ] - if self._VERBOSE: - torchrl_logger.info( - f"Updating policy of workers {ranks} with wait=True" - ) + torchrl_logger.debug( + f"Updating policy of workers {ranks} with wait=True" + ) self.update_policy_weights_(worker_ids=ranks, wait=True) def _next_async_rpc(self): - if self._VERBOSE: - torchrl_logger.info("next async") + torchrl_logger.debug("next async") if not len(self.futures): raise StopIteration( f"The queue is empty, the collector has ran out of data after {self._collected_frames} collected frames." @@ -767,31 +809,23 @@ def _next_async_rpc(self): if future.done(): if self.update_after_each_batch: self.update_policy_weights_(worker_ids=(i,), wait=False) - if self._VERBOSE: - torchrl_logger.info(f"future {i} is done") + torchrl_logger.debug(f"future {i} is done") data = future.value() self._collected_frames += data.numel() if self._collected_frames < self.total_frames: - future = rpc.rpc_async( - self.collector_infos[i], - self.collector_class.next, - args=(self.collector_rrefs[i],), - ) + # Use rref.rpc_async() to properly call instance method + future = self.collector_rrefs[i].rpc_async().next() self.futures.append((future, i)) return data self.futures.append((future, i)) def _next_sync_rpc(self): - if self._VERBOSE: - torchrl_logger.info("next sync: futures") + torchrl_logger.debug("next sync: futures") if self.update_after_each_batch: self.update_policy_weights_() for i in range(self.num_workers): - future = rpc.rpc_async( - self.collector_infos[i], - self.collector_class.next, - args=(self.collector_rrefs[i],), - ) + # Use rref.rpc_async() to properly call instance method + future = self.collector_rrefs[i].rpc_async().next() self.futures.append((future, i)) data = [] while len(self.futures): @@ -799,10 +833,9 @@ def _next_sync_rpc(self): # the order is NOT guaranteed: should we change that? if future.done(): data += [future.value()] - if self._VERBOSE: - torchrl_logger.info( - f"got data from {i} // data has len {len(data)} / {self.num_workers}" - ) + torchrl_logger.debug( + f"got data from {i} // data has len {len(data)} / {self.num_workers}" + ) else: self.futures.append((future, i)) data = torch.cat(data) @@ -814,34 +847,6 @@ def _next_sync_rpc(self): self._collected_frames += data.numel() return data - def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: - """Extract weights from a model if needed. - - For RPC collectors, when weights is None and we have a weight sync scheme, - extract fresh weights from the tracked policy model. - """ - scheme = ( - self._weight_sync_schemes.get(model_id) - if self._weight_sync_schemes - else None - ) - - if weights is None and scheme is not None: - # Extract fresh weights from the source model - sender = self._weight_senders.get(model_id) - if ( - sender - and hasattr(sender, "_source_model") - and sender._source_model is not None - ): - from torchrl.weight_update.weight_sync_schemes import WeightStrategy - - strategy = WeightStrategy(extract_as=scheme.strategy) - return strategy.extract_weights(sender._source_model) - - # Fall back to base class implementation - return super()._extract_weights_if_needed(weights, model_id) - def set_seed(self, seed: int, static_seed: bool = False) -> int: for worker in self.collector_infos: seed = rpc.rpc_sync(worker, self.collector_class.set_seed, args=(seed,)) @@ -858,25 +863,23 @@ def shutdown(self, timeout: float | None = None) -> None: return if self._shutdown: return - if self._VERBOSE: - torchrl_logger.info("shutting down") + torchrl_logger.debug("shutting down") for future, i in self.futures: # clear the futures while future is not None and not future.done(): - torchrl_logger.info(f"waiting for proc {i} to clear") + torchrl_logger.debug(f"waiting for proc {i} to clear") future.wait() for i in range(self.num_workers): - if self._VERBOSE: - torchrl_logger.info(f"shutting down {i}") - rpc.rpc_sync( - self.collector_infos[i], - self.collector_class.shutdown, - args=(self.collector_rrefs[i],), - timeout=int(IDLE_TIMEOUT), - ) - if self._VERBOSE: - torchrl_logger.info("rpc shutdown") + torchrl_logger.debug(f"shutting down {i}") + # Use rref.rpc_sync() to properly call instance method + self.collector_rrefs[i].rpc_sync(timeout=int(IDLE_TIMEOUT)).shutdown() + torchrl_logger.debug("rpc shutdown") rpc.shutdown(timeout=int(IDLE_TIMEOUT)) + + # Destroy torch.distributed process group + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + if self.launcher == "mp": for job in self.jobs: job.join(int(IDLE_TIMEOUT)) @@ -971,19 +974,13 @@ def push_weights( futures = [] weights = self.policy_weights if weights is None else weights for i in workers: - if self._VERBOSE: - torchrl_logger.info(f"calling update on worker {i}") + torchrl_logger.debug(f"calling update on worker {i}") + # Use rref.rpc_async() to properly call instance method futures.append( - rpc.rpc_async( - self.collector_infos[i], - self.collector_class.update_policy_weights_, - args=(self.collector_rrefs[i], weights), - ) + self.collector_rrefs[i].rpc_async().update_policy_weights_(weights) ) if kwargs.get("wait", True): for i in workers: - if self._VERBOSE: - torchrl_logger.info(f"waiting for worker {i}") + torchrl_logger.debug(f"waiting for worker {i}") futures[i].wait() - if self._VERBOSE: - torchrl_logger.info("got it!") + torchrl_logger.debug("got it!") diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 980b3a4b489..95723782865 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -19,13 +19,11 @@ from tensordict import TensorDict, TensorDictBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._base import DataCollectorBase +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.distributed.default_configs import ( DEFAULT_SLURM_CONF, MAX_TIME_TO_CONNECT, @@ -46,6 +44,7 @@ def _distributed_init_collection_node( + *, rank, rank0_ip, tcpport, @@ -60,13 +59,14 @@ def _distributed_init_collection_node( collector_kwargs, update_interval, total_frames, + weight_sync_schemes=None, verbose=VERBOSE, ): os.environ["MASTER_ADDR"] = str(rank0_ip) os.environ["MASTER_PORT"] = str(tcpport) if verbose: - torchrl_logger.info( + torchrl_logger.debug( f"node with rank {rank} -- creating collector of type {collector_class}" ) if not issubclass(collector_class, SyncDataCollector): @@ -78,53 +78,55 @@ def _distributed_init_collection_node( "SyncDataCollector and subclasses can only support a single environment." ) - if isinstance(policy, nn.Module): - policy_weights = TensorDict.from_module(policy) - policy_weights = policy_weights.data.lock_() + torchrl_logger.debug(f"IP address: {rank0_ip} \ttcp port: {tcpport}") + + # Pass weight_recv_schemes to the collector - it will handle init_on_receiver and connect + # The scheme's connect() will call init_process_group as a collective operation + if weight_sync_schemes is not None: + collector_kwargs["weight_recv_schemes"] = weight_sync_schemes else: - if collector_kwargs.get("weight_updater") is None and ( - policy_factory is None - or (isinstance(policy_factory, Sequence) and not any(policy_factory)) - ): - warnings.warn(_NON_NN_POLICY_WEIGHTS) - policy_weights = TensorDict(lock=True) + # No schemes - init process group manually for data.isend to work + if verbose: + torchrl_logger.debug( + f"node with rank {rank} -- launching distributed (no weight schemes)" + ) + torch.distributed.init_process_group( + backend, + rank=rank, + world_size=world_size, + timeout=timedelta(MAX_TIME_TO_CONNECT), + ) + # When policy_factory is provided, the child collector should use it + # instead of the policy (which is only used as a weight source for the parent) collector = collector_class( env_make, - policy, + policy if policy_factory is None else None, frames_per_batch=frames_per_batch, split_trajs=False, total_frames=total_frames, policy_factory=policy_factory, + worker_idx=rank, **collector_kwargs, ) - torchrl_logger.info(f"IP address: {rank0_ip} \ttcp port: {tcpport}") - if verbose: - torchrl_logger.info(f"node with rank {rank} -- launching distributed") - torch.distributed.init_process_group( - backend, - rank=rank, - world_size=world_size, - timeout=timedelta(MAX_TIME_TO_CONNECT), - # init_method=f"tcp://{rank0_ip}:{tcpport}", - ) - if verbose: - torchrl_logger.info(f"node with rank {rank} -- creating store") if verbose: - torchrl_logger.info(f"node with rank {rank} -- loop") - policy_weights.irecv(0) - frames = 0 + torchrl_logger.debug(f"node with rank {rank} -- loop") + + # Collection loop - weight updates are handled by the background thread in the scheme for i, data in enumerate(collector): + torchrl_logger.debug( + f"Sending batch {i} from sync distributed collector on rank {rank}" + ) data.isend(dst=0) - frames += data.numel() - if ( - frames < total_frames - and (i + 1) % update_interval == 0 - and not policy_weights.is_empty() - ): - policy_weights.irecv(0) + torchrl_logger.debug( + f"Sent batch {i} from distributed collector on rank {rank}" + ) + # Cleanup + if weight_sync_schemes is not None: + for scheme in weight_sync_schemes.values(): + scheme.shutdown() if not collector.closed: collector.shutdown() del collector @@ -339,6 +341,7 @@ def __init__( if not isinstance(policy_factory, Sequence): policy_factory = [policy_factory] * len(create_env_fn) self.policy_factory = policy_factory + self._policy_to_send = policy if not any(policy_factory) else None self.policy_weights = policy_weights self.num_workers = len(create_env_fn) self.frames_per_batch = frames_per_batch @@ -402,6 +405,28 @@ def __init__( self.backend = backend + # Create weight sync schemes for distributed weight updates + # The scheme creates its own TCPStore for coordination + self._weight_sync_schemes = None + if isinstance(policy, nn.Module): + from torchrl.weight_update import DistributedWeightSyncScheme + + self._weight_sync_schemes = { + "policy": DistributedWeightSyncScheme(backend=backend, sync=False) + } + # Initialize schemes on sender BEFORE starting workers so the store + # exists when workers try to connect + for model_id, scheme in self._weight_sync_schemes.items(): + torchrl_logger.debug( + f"DistributedSyncDataCollector: Initializing scheme for '{model_id}' on sender" + ) + scheme.init_on_sender( + model_id=model_id, + context=self, + num_workers=self.num_workers, + model=policy, + ) + # os.environ['TP_SOCKET_IFNAME'] = 'lo' self._init_workers() @@ -473,7 +498,7 @@ def _init_master_dist( backend, ): TCP_PORT = self.tcp_port - torchrl_logger.info("init master...") + torchrl_logger.debug("init master...") torch.distributed.init_process_group( backend, rank=0, @@ -481,7 +506,7 @@ def _init_master_dist( timeout=timedelta(MAX_TIME_TO_CONNECT), init_method=f"tcp://{self.IPAddr}:{TCP_PORT}", ) - torchrl_logger.info("done") + torchrl_logger.debug("done") def _make_container(self): env_constructor = self.env_constructors[0] @@ -507,20 +532,22 @@ def _init_worker_dist_submitit(self, executor, i): env_make = CloudpickleWrapper(env_make) job = executor.submit( _distributed_init_collection_node, - i + 1, - self.IPAddr, - int(TCP_PORT), - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self.update_interval, - self.total_frames_per_collector, + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self._policy_to_send, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + collector_kwargs=self.collector_kwargs[i], + update_interval=self.update_interval, + total_frames=self.total_frames_per_collector, + weight_sync_schemes=self._weight_sync_schemes, + verbose=VERBOSE, ) return job @@ -531,21 +558,23 @@ def _init_worker_dist_mp(self, i): env_make = CloudpickleWrapper(env_make) job = _ProcessNoWarn( target=_distributed_init_collection_node, - args=( - i + 1, - self.IPAddr, - int(TCP_PORT), - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self.update_interval, - self.total_frames_per_collector, + kwargs=dict( # noqa: C408 + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self._policy_to_send, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + collector_kwargs=self.collector_kwargs[i], + update_interval=self.update_interval, + total_frames=self.total_frames_per_collector, + weight_sync_schemes=self._weight_sync_schemes, + verbose=VERBOSE, ), ) job.start() @@ -555,7 +584,7 @@ def _init_workers(self): hostname = socket.gethostname() IPAddr = socket.gethostbyname(hostname) - torchrl_logger.info(f"Server IP address: {IPAddr}") + torchrl_logger.debug(f"Server IP address: {IPAddr}") self.IPAddr = IPAddr os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -567,20 +596,35 @@ def _init_workers(self): executor = submitit.AutoExecutor(folder="log_test") executor.update_parameters(**self.slurm_kwargs) for i in range(self.num_workers): - torchrl_logger.info("Submitting job") + torchrl_logger.debug("Submitting job") if self.launcher == "submitit": job = self._init_worker_dist_submitit( executor, i, ) - torchrl_logger.info(f"job id {job.job_id}") # ID of your job + torchrl_logger.debug(f"job id {job.job_id}") # ID of your job elif self.launcher == "mp": job = self._init_worker_dist_mp( i, ) - torchrl_logger.info("job launched") + torchrl_logger.debug("job launched") self.jobs.append(job) - self._init_master_dist(self.num_workers + 1, self.backend) + + # Initialize process group and weight sync + # If we have schemes, they handle init_process_group in connect() + # Otherwise, we need to init manually for data.irecv to work + if self._weight_sync_schemes is not None: + for model_id, scheme in self._weight_sync_schemes.items(): + torchrl_logger.debug( + f"DistributedSyncDataCollector: Connecting scheme '{model_id}' (will init process group)" + ) + scheme.connect() + torchrl_logger.debug( + "DistributedSyncDataCollector: Initial weight sync completed" + ) + else: + # No schemes - init process group manually + self._init_master_dist(self.num_workers + 1, self.backend) def iterator(self): yield from self._iterator_dist() @@ -591,17 +635,20 @@ def _iterator_dist(self): j = -1 while total_frames < self.total_frames: j += 1 - if j % self.update_interval == 0 and not self.policy_weights.is_empty(): - for i in range(self.num_workers): - rank = i + 1 - self.policy_weights.isend(rank) + if j % self.update_interval == 0 and self._weight_sync_schemes is not None: + # Send weight updates via the schemes + # Each scheme handles extracting weights from the policy and sending + for scheme in self._weight_sync_schemes.values(): + scheme.send() trackers = [] for i in range(self.num_workers): rank = i + 1 + torchrl_logger.debug(f"Receiving from rank {rank} on main") trackers.append( self._single_tds[i].irecv(src=rank, return_premature=True) ) + torchrl_logger.debug(f"Received from rank {rank} on main") for tracker in trackers: for _tracker in tracker: _tracker.wait() @@ -639,4 +686,8 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: raise NotImplementedError def shutdown(self, timeout: float | None = None) -> None: - pass + # Clean up weight sync schemes + if self._weight_sync_schemes is not None: + for scheme in self._weight_sync_schemes.values(): + scheme.shutdown() + self._weight_sync_schemes = None diff --git a/torchrl/collectors/distributed/utils.py b/torchrl/collectors/distributed/utils.py index bc72bda6a4a..457164a4199 100644 --- a/torchrl/collectors/distributed/utils.py +++ b/torchrl/collectors/distributed/utils.py @@ -58,8 +58,7 @@ class submitit_delayed_launcher: >>> num_jobs=2 >>> @submitit_delayed_launcher(num_jobs=num_jobs) ... def main(): - ... from torchrl.envs.utils import RandomPolicy - from torchrl.envs.libs.gym import GymEnv + ... from torchrl.modules.utils.utils import RandomPolicyfrom torchrl.envs.libs.gym import GymEnv ... from torchrl.data import BoundedContinuous ... collector = DistributedDataCollector( ... [EnvCreator(lambda: GymEnv("Pendulum-v1"))] * num_jobs, @@ -103,7 +102,7 @@ def exec_fun(): executor.update_parameters(**self.submitit_main_conf) main_job = executor.submit(main_func) # listen to output file looking for IP address - torchrl_logger.info(f"job id: {main_job.job_id}") + torchrl_logger.debug(f"job id: {main_job.job_id}") time.sleep(2.0) node = None while not node: @@ -114,11 +113,11 @@ def exec_fun(): except ValueError: time.sleep(0.5) continue - torchrl_logger.info(f"node: {node}") + torchrl_logger.debug(f"node: {node}") # by default, sinfo will truncate the node name at char 20, we increase this to 200 cmd = f"sinfo -n {node} -O nodeaddr:200 | tail -1" rank0_ip = subprocess.check_output(cmd, shell=True, text=True).strip() - torchrl_logger.info(f"IP: {rank0_ip}") + torchrl_logger.debug(f"IP: {rank0_ip}") world_size = self.num_jobs + 1 # submit jobs diff --git a/torchrl/collectors/llm/base.py b/torchrl/collectors/llm/base.py index e9ba6e9bcdf..408a6ec5e6a 100644 --- a/torchrl/collectors/llm/base.py +++ b/torchrl/collectors/llm/base.py @@ -14,7 +14,7 @@ from torchrl._utils import as_remote, logger as torchrl_logger -from torchrl.collectors import SyncDataCollector +from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.llm.utils import _QueueAsRB from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer @@ -308,7 +308,7 @@ def _rollout_all(self) -> TensorDictBase: # A simplified version of rollout policy_input = self._shuttle while collected_steps < self.dialog_turns_per_batch: if self.verbose: - torchrl_logger.info( + torchrl_logger.debug( f"LLMCollector: Collected {collected_steps} steps over {self.dialog_turns_per_batch} requested." ) env_input = self.policy(policy_input) @@ -341,7 +341,7 @@ def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rol if self._result_numel >= self.dialog_turns_per_batch: break elif self.verbose: - torchrl_logger.info( + torchrl_logger.debug( f"LLMCollector: Collected {collected_steps} steps with {self._result_numel} elements in the resulting batch, over {self.dialog_turns_per_batch} requested." ) env_input = self.policy(next_output) @@ -385,7 +385,7 @@ def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rol self._result_numel -= result[-1].numel() result = torch.cat(result, -1) if self.verbose: - torchrl_logger.info( + torchrl_logger.debug( f"LLMCollector: Yielding completed trajectory with shape {result.shape}." ) return result @@ -447,7 +447,7 @@ def _rollout_yield_trajs_async( result = self._trajectory_queue.popleft() if self.verbose: - torchrl_logger.info( + torchrl_logger.debug( f"LLMCollector: Yielding completed trajectory with shape {result.shape}." ) return result diff --git a/torchrl/collectors/llm/weight_update/vllm.py b/torchrl/collectors/llm/weight_update/vllm.py index 9b2fe144b0f..ae1161ec77f 100644 --- a/torchrl/collectors/llm/weight_update/vllm.py +++ b/torchrl/collectors/llm/weight_update/vllm.py @@ -17,7 +17,7 @@ from torchrl._utils import logger as torchrl_logger -from torchrl.collectors import WeightUpdaterBase +from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.modules.llm.backends import stateless_init_process_group _has_vllm = importlib.util.find_spec("vllm") is not None @@ -103,7 +103,7 @@ def __init__( model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None, vllm_tp_size: int | None = None, ): - torchrl_logger.info(f"=> in {type(self).__name__}.__init__") + torchrl_logger.debug(f"=> in {type(self).__name__}.__init__") self.master_address = master_address self.master_port = master_port self.model_metadata = model_metadata @@ -171,23 +171,23 @@ def _get_model_ref(self): def _init_group(self): import ray - torchrl_logger.info(f"=> in {type(self).__name__}._init_group") + torchrl_logger.debug(f"=> in {type(self).__name__}._init_group") weight_sync_world_size = self.vllm_tp_size + 1 - torchrl_logger.info(f"initializing group with {weight_sync_world_size=}...") - torchrl_logger.info(f"vllm_tp_size={self.vllm_tp_size}") + torchrl_logger.debug(f"initializing group with {weight_sync_world_size=}...") + torchrl_logger.debug(f"vllm_tp_size={self.vllm_tp_size}") model_ref = self._get_model_ref() - torchrl_logger.info(f"model_ref: {model_ref}") + torchrl_logger.debug(f"model_ref: {model_ref}") # Initialize the weight update group - torchrl_logger.info("Calling init_weight_update_group...") + torchrl_logger.debug("Calling init_weight_update_group...") init_weight_update_group_getter = model_ref.collective_rpc.remote( "init_weight_update_group", args=(self.master_address, self.master_port, 1, weight_sync_world_size), ) - torchrl_logger.info("init_weight_update_group remote call succeeded") + torchrl_logger.debug("init_weight_update_group remote call succeeded") - torchrl_logger.info("Calling stateless_init_process_group within updater...") + torchrl_logger.debug("Calling stateless_init_process_group within updater...") self.vllm_comm_group = stateless_init_process_group( self.master_address, self.master_port, @@ -197,9 +197,9 @@ def _init_group(self): ) ray.get(init_weight_update_group_getter) - torchrl_logger.info("init_weight_update_group getter succeeded") + torchrl_logger.debug("init_weight_update_group getter succeeded") - torchrl_logger.info("group initialized") + torchrl_logger.debug("group initialized") self.initialized_group = True def maybe_init_group(self): @@ -239,7 +239,7 @@ def _sync_weights_with_worker( model_ref = self._get_model_ref() # First broadcast metadata - torchrl_logger.info("broadcasting with update_weight_broadcast") + torchrl_logger.debug("broadcasting with update_weight_broadcast") remotes = [] for k, (dtype, shape) in self.model_metadata.items(): remotes.append( @@ -257,7 +257,7 @@ def _sync_weights_with_worker( # # ray.get(remotes) # if self.vllm_comm_group is not True: - torchrl_logger.info("broadcasting...") + torchrl_logger.debug("broadcasting...") for k in self.model_metadata: val = server_weights[k].to(torch.device("cuda:0")) self.vllm_comm_group.broadcast( @@ -269,7 +269,7 @@ def _sync_weights_with_worker( import ray ray.get(remotes) - torchrl_logger.info("done broadcasting") + torchrl_logger.debug("done broadcasting") torch.cuda.synchronize() def _get_server_weights(self) -> TensorDictBase | None: diff --git a/torchrl/collectors/llm/weight_update/vllm_v2.py b/torchrl/collectors/llm/weight_update/vllm_v2.py index 0792d7e7de6..cb4b4d6183b 100644 --- a/torchrl/collectors/llm/weight_update/vllm_v2.py +++ b/torchrl/collectors/llm/weight_update/vllm_v2.py @@ -12,7 +12,7 @@ import torch from tensordict import TensorDictBase from torchrl._utils import logger as torchrl_logger -from torchrl.collectors import WeightUpdaterBase +from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.modules.llm.backends.vllm import RLvLLMEngine try: @@ -44,7 +44,7 @@ def __init__(self, vllm_engine: RLvLLMEngine): f"vllm_engine must implement RLvLLMEngine interface, got {type(vllm_engine)}" ) - torchrl_logger.info(f"=> in {type(self).__name__}.__init__") + torchrl_logger.debug(f"=> in {type(self).__name__}.__init__") self.vllm_engine = vllm_engine self.initialized_group = None @@ -54,7 +54,7 @@ def __init__(self, vllm_engine: RLvLLMEngine): self.master_port = vllm_engine.get_master_port() self.model_metadata = vllm_engine.get_model_metadata() - torchrl_logger.info( + torchrl_logger.debug( f"Initialized vLLMUpdaterV2 with tp_size={self.vllm_tp_size}" ) @@ -76,7 +76,7 @@ def init( # Initialize the engine's weight update group self.vllm_engine.init_weight_update_group() self.initialized_group = True - torchrl_logger.info("Weight update group initialized") + torchrl_logger.debug("Weight update group initialized") def push_weights( self, weights: Iterator[tuple[str, torch.Tensor]] | TensorDictBase @@ -94,12 +94,12 @@ def push_weights( # Delegate to the engine's update_weights method self.vllm_engine.update_weights(weights) - torchrl_logger.info("Weight update completed") + torchrl_logger.debug("Weight update completed") # Call post-hooks to increment policy version - torchrl_logger.info("Calling post-hooks...") + torchrl_logger.debug("Calling post-hooks...") self._call_post_hooks() - torchrl_logger.info("Post-hooks completed") + torchrl_logger.debug("Post-hooks completed") def push_weights_from_transformers(self, transformers_model): """Push weights from a transformers model. @@ -134,11 +134,11 @@ def push_weights_from_transformers(self, transformers_model): ) t1 = time.time() - torchrl_logger.info(f"Time to extract state_dict: {t1 - t0}") + torchrl_logger.debug(f"Time to extract state_dict: {t1 - t0}") # Convert to iterator for memory efficiency weights_iter = iter(state_dict.items()) self.push_weights(weights_iter) - torchrl_logger.info(f"Time to push weights: {time.time() - t1}") + torchrl_logger.debug(f"Time to push weights: {time.time() - t1}") def push_weights_from_transformers_optimized( self, transformers_model, batch_size=50 @@ -181,7 +181,7 @@ def push_weights_from_transformers_optimized( ) t1 = time.time() - torchrl_logger.info(f"Time to extract state_dict: {t1 - t0:.3f}s") + torchrl_logger.debug(f"Time to extract state_dict: {t1 - t0:.3f}s") # Pre-load all weights to GPU for faster transfer gpu_weights = {} @@ -195,7 +195,7 @@ def push_weights_from_transformers_optimized( # Synchronize to ensure all transfers are complete torch.cuda.synchronize() t2 = time.time() - torchrl_logger.info(f"Time to move weights to GPU: {t2 - t1:.3f}s") + torchrl_logger.debug(f"Time to move weights to GPU: {t2 - t1:.3f}s") # Transfer weights (optionally in batches) if batch_size > 0: @@ -203,7 +203,7 @@ def push_weights_from_transformers_optimized( for i in range(0, len(weight_items), batch_size): batch = weight_items[i : i + batch_size] self.push_weights(iter(batch)) - torchrl_logger.info( + torchrl_logger.debug( f"Transferred batch {i // batch_size + 1}/{(len(weight_items) + batch_size - 1) // batch_size}" ) else: @@ -211,7 +211,7 @@ def push_weights_from_transformers_optimized( self.push_weights(iter(gpu_weights.items())) t3 = time.time() - torchrl_logger.info( + torchrl_logger.debug( f"Time to push weights: {t3 - t2:.3f}s, total time: {t3 - t0:.3f}s" ) @@ -252,14 +252,14 @@ def register_collector(self, collector): # noqa: F821 # This avoids N^2 complexity where each weight update calls increment_version # on all collectors N times (once per registered collector) if len(self.post_hooks) == 0: - torchrl_logger.info("Registering policy version increment post-hook") + torchrl_logger.debug("Registering policy version increment post-hook") self.register_post_hook(self._increment_all_collector_versions) return result def _increment_all_collector_versions(self): """Increment version for all registered collectors efficiently.""" - torchrl_logger.info( + torchrl_logger.debug( f"Incrementing policy version for {len(self.collectors)} collectors..." ) for i, collector in enumerate(self.collectors): @@ -272,7 +272,7 @@ def _increment_all_collector_versions(self): torchrl_logger.warning( f"Failed to increment version for collector {i + 1}: {e}" ) - torchrl_logger.info("All collector versions incremented") + torchrl_logger.debug("All collector versions incremented") @classmethod def get_model_metadata(cls, model) -> dict[str, tuple[torch.dtype, torch.Size]]: diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 1f8b2668938..ef6aa60aad2 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -4,12 +4,16 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from collections.abc import Callable +import contextlib +from collections.abc import Callable, Sequence import torch +from pyvers import implement_for -from tensordict import NestedKey, pad, set_lazy_legacy, TensorDictBase - +from tensordict import NestedKey, pad, set_lazy_legacy, TensorDict, TensorDictBase +from tensordict.utils import Buffer +from torch import multiprocessing as mp, nn as nn +from torch.nn import Parameter _NON_NN_POLICY_WEIGHTS = ( "The policy is not an nn.Module. TorchRL will assume that the parameter set is empty and " @@ -257,3 +261,142 @@ def nest(*x): [pad(out_split, [0, MAX - out_split.shape[0]]) for out_split in out_splits], 0 ) return td + + +@implement_for("torch", "2.5.0") +def _cast( + p: nn.Parameter | torch.Tensor, + param_maybe_buffer: nn.Parameter | torch.Tensor | None = None, +) -> nn.Parameter | torch.Tensor: + if param_maybe_buffer is None: + param_maybe_buffer = p + p = p.data + if isinstance(param_maybe_buffer, Parameter): + # Create parameter without gradients to avoid serialization issues + return Parameter(p, requires_grad=False) + if isinstance(param_maybe_buffer, Buffer): + return Buffer(p) + if p.requires_grad: + raise RuntimeError(f"Cannot cast tensor {p} with gradients") + return p + + +def _make_meta_policy(policy: nn.Module): + """Create context manager that temporarily puts policy parameters on meta device. + + This is used with weight sync schemes to send policy structure without weights. + The actual weights are distributed by the schemes. + + Args: + policy: Policy module to temporarily modify. + + Returns: + A context manager that temporarily replaces policy parameters with meta device versions. + On exit, the original parameters are restored to the policy. + """ + param_and_buf = TensorDict.from_module(policy, as_module=True) + return param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy) + + +@implement_for("torch", None, "2.5.0") +def _cast( # noqa + p: nn.Parameter | torch.Tensor, + param_maybe_buffer: nn.Parameter | torch.Tensor | None = None, +) -> nn.Parameter | torch.Tensor: + if param_maybe_buffer is None: + param_maybe_buffer = p + p = p.data + if isinstance(param_maybe_buffer, Parameter): + # Create parameter without gradients to avoid serialization issues + return Parameter(p, requires_grad=False) + if p.requires_grad: + raise RuntimeError(f"Cannot cast tensor {p} with gradients") + return p + + +def _map_to_cpu_if_needed(x): + """Map tensors on exotic devices (MPS, NPU, etc.) to CPU. + + CPU and CUDA tensors are kept as-is since they can be shared across processes. + Only exotic devices that don't support multiprocessing are mapped to CPU. + """ + if isinstance(x, torch.Tensor): + # CPU and CUDA can be shared across processes + if x.device.type in ("cpu", "cuda"): + return x + # Exotic devices (MPS, NPU, etc.) need to be mapped to CPU + return x.cpu() + return x + + +def _make_meta_params(param): + is_param = isinstance(param, Parameter) + + pd = param.detach().to("meta") + + if is_param: + pd = Parameter(pd, requires_grad=False) + return pd + + +class _TrajectoryPool: + def __init__(self, ctx=None, lock: bool = False): + self.ctx = ctx + self._traj_id = torch.zeros((), device="cpu", dtype=torch.int).share_memory_() + if ctx is None: + self.lock = contextlib.nullcontext() if not lock else mp.RLock() + else: + self.lock = contextlib.nullcontext() if not lock else ctx.RLock() + + def get_traj_and_increment(self, n=1, device=None): + with self.lock: + v = self._traj_id.item() + out = torch.arange(v, v + n).to(device) + self._traj_id.copy_(1 + out[-1].item()) + return out + + +def _map_weight( + weight, + policy_device, +): + + is_param = isinstance(weight, Parameter) + is_buffer = isinstance(weight, Buffer) + weight = weight.data + if weight.device != policy_device: + weight = weight.to(policy_device) + elif weight.device.type in ("cpu",): + weight = weight.share_memory_() + if is_param: + weight = Parameter(weight, requires_grad=False) + elif is_buffer: + weight = Buffer(weight) + return weight + + +def _make_policy_factory( + *, policy: Callable, policy_factory, weight_sync_scheme, worker_idx, pipe=None +): + has_policy_factory = policy_factory is not None and ( + (isinstance(policy_factory, Sequence) and any(policy_factory)) + or not isinstance(policy_factory, Sequence) + ) + if policy is not None and has_policy_factory: + raise ValueError("policy cannot be used with policy_factory") + elif has_policy_factory: + if isinstance(policy_factory, Sequence): + return policy_factory + else: + policy = policy_factory() + + if weight_sync_scheme is not None: + # Initialize the receiver on the worker side + weight_sync_scheme.init_on_receiver( + model=policy, + model_id="policy", + worker_idx=worker_idx, + ) + # Synchronize initial weights + weight_sync_scheme.connect(worker_idx=worker_idx) + return policy diff --git a/torchrl/collectors/weight_update.py b/torchrl/collectors/weight_update.py index 97fa62d6a2b..82f0e6e52ca 100644 --- a/torchrl/collectors/weight_update.py +++ b/torchrl/collectors/weight_update.py @@ -578,7 +578,7 @@ def _maybe_map_weights(self, server_weights: Any) -> Any: return server_weights def _sync_weights_with_worker(self, worker_id: int, server_weights: Any) -> Any: - torchrl_logger.info(f"syncing weights with worker {worker_id}") + torchrl_logger.debug(f"syncing weights with worker {worker_id}") c = self.remote_collectors[worker_id] c.update_policy_weights_.remote(policy_weights=server_weights) self._batches_since_weight_update[worker_id] = 0 diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index a0ded99c892..421dc53df69 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -103,7 +103,7 @@ class MultiStep(nn.Module): within the replay buffer instead. Examples: - >>> from torchrl.collectors import SyncDataCollector, RandomPolicy + >>> from torchrl.modules import RandomPolicy >>> >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.data.postprocs import MultiStep >>> from torchrl.envs import GymEnv, TransformedEnv, StepCounter >>> env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter()) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 4669ce103ff..7f5f32dd20e 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -899,7 +899,6 @@ def set( self._init(tree_map(lambda x: x[0], data)) else: self._init(data) - assert self.initialized if is_tensor_collection(data): self._storage[cursor] = data @@ -944,7 +943,6 @@ def set( # noqa: F811 self._init(data[0]) else: self._init(data) - assert self.initialized if not isinstance(cursor, (*INT_CLASSES, slice)): if not isinstance(cursor, torch.Tensor): diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index d4cbbd71db4..320e870980a 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -123,7 +123,6 @@ get_available_libraries, make_composite_from_td, MarlGroupMapType, - RandomPolicy, set_exploration_type, step_mdp, terminated_or_truncated, @@ -208,7 +207,6 @@ "PinMemoryTransform", "R3MTransform", "RandomCropTensorDict", - "RandomPolicy", "RemoveEmptySpecs", "RenameTransform", "Resize", diff --git a/torchrl/envs/async_envs.py b/torchrl/envs/async_envs.py index f93fc5fe2cd..263704dd621 100644 --- a/torchrl/envs/async_envs.py +++ b/torchrl/envs/async_envs.py @@ -474,7 +474,6 @@ def _setup(self) -> None: self._current_step_reset = 0 num_threads = self.num_envs - assert num_threads > 0 self.threads = [] for i in range(num_threads): # thread = threading.Thread(target=_env_exec, kwargs={"i": i, "env_or_factory": self.env_maker[i], "input_queue": self.input_queue[i], "step_queue": self.step_queue, "reset_queue": self.reset_queue}) @@ -541,7 +540,6 @@ def async_step_recv(self, min_get: int = 1) -> TensorDictBase: ) r = self._wait_for_one_and_get(self.step_queue, min_get) self._current_step = self._current_step - len(r) - assert self._current_step >= 0 r, idx = self._sort_results(r) self._busy.difference_update(idx) return self._stack_func(r) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b0993c12242..0ba2c019303 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -2701,7 +2701,6 @@ def _run_worker_pipe_direct( if event is not None: event.record() event.synchronize() - mp_event.set() if consolidate: try: child_pipe.send( @@ -2713,6 +2712,9 @@ def _run_worker_pipe_direct( raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err else: child_pipe.send(cur_td) + # Set event after successfully sending through pipe to avoid race condition + # where event is set but pipe send fails (BrokenPipeError) + mp_event.set() del cur_td @@ -2726,7 +2728,6 @@ def _run_worker_pipe_direct( if event is not None: event.record() event.synchronize() - mp_event.set() if consolidate: try: next_td = next_td.consolidate( @@ -2735,6 +2736,9 @@ def _run_worker_pipe_direct( except Exception as err: raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err child_pipe.send(next_td) + # Set event after successfully sending through pipe to avoid race condition + # where event is set but pipe send fails (BrokenPipeError) + mp_event.set() del next_td diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py index 6a17125b1d4..94c9bfa2aed 100644 --- a/torchrl/envs/llm/transforms/tools.py +++ b/torchrl/envs/llm/transforms/tools.py @@ -906,9 +906,9 @@ def execute(self, prompt: str) -> dict[str, Any]: except queue.Empty: pass - if not start_found: - timeout_val -= 0.1 - time.sleep(0.1) + # Always sleep a bit to avoid busy-waiting and give subprocess time + timeout_val -= 0.01 + time.sleep(0.01) except Exception as e: return { @@ -1007,8 +1007,10 @@ def __init__(self, pool_size: int = 32, timeout: float = 10.0): self.processes = [ PersistentPythonProcess(timeout=timeout) for _ in range(pool_size) ] + # Create a lock for each process to prevent concurrent access + self.process_locks = [threading.Lock() for _ in range(pool_size)] self.next_idx = 0 - self._lock = threading.Lock() + self._selection_lock = threading.Lock() def execute(self, code: str) -> dict: """Execute Python code using next available process (round-robin). @@ -1019,12 +1021,14 @@ def execute(self, code: str) -> dict: Returns: dict: Execution result with keys 'success', 'stdout', 'stderr', 'returncode'. """ - # Simple round-robin - Ray handles the queuing via max_concurrency - with self._lock: - process = self.processes[self.next_idx] + # Select a process using round-robin + with self._selection_lock: + process_idx = self.next_idx self.next_idx = (self.next_idx + 1) % self.pool_size - return process.execute(code) + # Lock the selected process for the duration of execution + with self.process_locks[process_idx]: + return self.processes[process_idx].execute(code) def cleanup(self): """Cleanup all processes in the pool.""" diff --git a/torchrl/envs/transforms/module.py b/torchrl/envs/transforms/module.py index 288af9054cc..c6038a91032 100644 --- a/torchrl/envs/transforms/module.py +++ b/torchrl/envs/transforms/module.py @@ -6,16 +6,19 @@ from collections.abc import Callable from contextlib import nullcontext -from typing import overload +from typing import overload, TYPE_CHECKING import torch from tensordict import TensorDictBase from tensordict.nn import TensorDictModuleBase +from torchrl._utils import logger as torchrl_logger from torchrl.data.tensor_specs import TensorSpec from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform from torchrl.envs.transforms.transforms import Transform +if TYPE_CHECKING: + from torchrl.weight_update import WeightSyncScheme __all__ = ["ModuleTransform", "RayModuleTransform"] @@ -25,8 +28,46 @@ class RayModuleTransform(RayTransform): This transform creates a Ray actor that wraps a ModuleTransform, allowing module execution in a separate Ray worker process. + + Args: + weight_sync_scheme: Optional weight synchronization scheme for updating + the module's weights from a parent collector. When provided, the scheme + is initialized on the receiver side (the Ray actor) and can receive + weight updates via torch.distributed. + **kwargs: Additional arguments passed to RayTransform and ModuleTransform. + + Example: + >>> from torchrl.weight_update import RayModuleTransformScheme + >>> scheme = RayModuleTransformScheme() + >>> transform = RayModuleTransform(module=my_module, weight_sync_scheme=scheme) + >>> # The scheme can then be registered with a collector for weight updates """ + def __init__(self, *, weight_sync_scheme=None, **kwargs): + self._weight_sync_scheme = weight_sync_scheme + super().__init__(**kwargs) + + # After actor is created, initialize the scheme on the receiver side + if weight_sync_scheme is not None: + # Store transform reference in the scheme for sender initialization + weight_sync_scheme._set_transform(self) + + weight_sync_scheme.init_on_sender() + + # Initialize receiver in the actor + torchrl_logger.debug( + "Setting up weight sync scheme on sender -- sender will do the remote call" + ) + weight_sync_scheme.connect() + + @property + def in_keys(self): + return self._ray.get(self._actor._getattr.remote("in_keys")) + + @property + def out_keys(self): + return self._ray.get(self._actor._getattr.remote("out_keys")) + def _create_actor(self, **kwargs): import ray @@ -240,6 +281,25 @@ def _update_weights_tensordict(self, params: TensorDictBase) -> None: def _update_weights_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None: self.module.load_state_dict(state_dict) + def _init_weight_sync_scheme(self, scheme: WeightSyncScheme, model_id: str) -> None: + """Initialize weight sync scheme on the receiver side (called in Ray actor). + + This method is called by RayModuleTransform after the actor is created + to set up the receiver side of the weight synchronization scheme. + + Args: + scheme: The weight sync scheme instance (e.g., RayModuleTransformScheme). + model_id: Identifier for the model being synchronized. + """ + torchrl_logger.debug(f"Initializing weight sync scheme for {model_id=}") + scheme.init_on_receiver(model_id=model_id, context=self) + torchrl_logger.debug(f"Setup weight sync scheme for {model_id=}") + scheme.connect() + self._weight_sync_scheme = scheme + + def _receive_weights_scheme(self): + self._weight_sync_scheme.receive() + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: if self.observation_spec_transform is not None: if isinstance(self.observation_spec_transform, TensorSpec): diff --git a/torchrl/envs/transforms/ray_service.py b/torchrl/envs/transforms/ray_service.py index 0da863863fa..5b3c91fce84 100644 --- a/torchrl/envs/transforms/ray_service.py +++ b/torchrl/envs/transforms/ray_service.py @@ -200,9 +200,7 @@ def __init__( actor_name: Name of the Ray actor (for reuse) **kwargs: Additional arguments passed to Transform """ - super().__init__( - in_keys=kwargs.get("in_keys", []), out_keys=kwargs.get("out_keys", []) - ) + super().__init__(in_keys=kwargs.get("in_keys"), out_keys=kwargs.get("out_keys")) self._num_cpus = num_cpus self._num_gpus = num_gpus diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ca9ab70f184..0abf29ff5dd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -241,21 +241,43 @@ class Transform(nn.Module): def __init__( self, - in_keys: Sequence[NestedKey] = None, + in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, ): super().__init__() - self.in_keys = in_keys - self.out_keys = out_keys - self.in_keys_inv = in_keys_inv - self.out_keys_inv = out_keys_inv + if in_keys is not None: + self.in_keys = in_keys + if out_keys is not None: + self.out_keys = out_keys + if in_keys_inv is not None: + self.in_keys_inv = in_keys_inv + if out_keys_inv is not None: + self.out_keys_inv = out_keys_inv self._missing_tolerance = False # we use __dict__ to avoid having nn.Module placing these objects in the module list self.__dict__["_container"] = None self.__dict__["_parent"] = None + def _getattr(self, val, *args, **kwargs): + if args: + if len(args) > 1: + raise TypeError( + f"Expected at most 1 positional argument, got {len(args)}" + ) + default = args[0] + return getattr(self, val, default) + if kwargs: + try: + default = kwargs.pop("default") + except KeyError: + raise TypeError("Only 'default' keyword argument is supported") + if args: + raise TypeError("Got two values for keyword argument 'default'") + return getattr(self, val, default) + return getattr(self, val) + def _ready(self): # Used to block ray until the actor is ready, see RayTransform return True @@ -3501,7 +3523,7 @@ class CatFrames(ObservationTransform): gives the complete picture, together with the usage of a :class:`torchrl.data.ReplayBuffer`: Examples: - >>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.envs import UnsqueezeTransform, CatFrames + >>> from torchrl.modules import RandomPolicy >>> >>> >>> from torchrl.envs import UnsqueezeTransform, CatFrames >>> from torchrl.collectors import SyncDataCollector >>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension >>> env = TransformedEnv( @@ -8800,7 +8822,7 @@ class Reward2GoTransform(Transform): append the `inv` method of the transform. Examples: - >>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.collectors import SyncDataCollector + >>> from torchrl.modules import RandomPolicy >>> >>> >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs.libs.gym import GymEnv >>> t = Reward2GoTransform(gamma=0.99, out_keys=["reward_to_go"]) >>> env = GymEnv("Pendulum-v1") diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 2c02399c6e7..6bf247f2ce5 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -47,6 +47,7 @@ Unbounded, ) from torchrl.data.utils import check_no_exclusive_keys, CloudpickleWrapper +from torchrl.modules.tensordict_module.exploration import RandomPolicy # noqa __all__ = [ "exploration_type", @@ -59,7 +60,6 @@ "check_marl_grouping", ] - ACTION_MASK_ERROR = RuntimeError( "An out-of-bounds actions has been provided to an env with an 'action_mask' output. " "If you are using a custom policy, make sure to take the action mask into account when computing the output. " @@ -1672,34 +1672,6 @@ def is_compatible(policy): ) -class RandomPolicy: - """A random policy for data collectors. - - This is a wrapper around the action_spec.rand method. - - Args: - action_spec: TensorSpec object describing the action specs - - Examples: - >>> from tensordict import TensorDict - >>> from torchrl.data.tensor_specs import Bounded - >>> action_spec = Bounded(-torch.ones(3), torch.ones(3)) - >>> actor = RandomPolicy(action_spec=action_spec) - >>> td = actor(TensorDict()) # selects a random action in the cube [-1; 1] - """ - - def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): - super().__init__() - self.action_spec = action_spec.clone() - self.action_key = action_key - - def __call__(self, td: TensorDictBase) -> TensorDictBase: - if isinstance(self.action_spec, Composite): - return td.update(self.action_spec.rand()) - else: - return td.set(self.action_key, self.action_spec.rand()) - - class _PolicyMetaClass(abc.ABCMeta): def __call__(cls, *args, **kwargs): # no kwargs diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index a349aba6635..dc8b213d492 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -92,6 +92,7 @@ VmapModule, WorldModelWrapper, ) +from .tensordict_module.exploration import RandomPolicy from .utils import get_primers_from_module from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip @@ -183,4 +184,5 @@ "recurrent_mode", "reset_noise", "set_recurrent_mode", + "RandomPolicy", ] diff --git a/torchrl/modules/llm/backends/vllm/vllm_async.py b/torchrl/modules/llm/backends/vllm/vllm_async.py index 39b808cebf6..1647fe37d87 100644 --- a/torchrl/modules/llm/backends/vllm/vllm_async.py +++ b/torchrl/modules/llm/backends/vllm/vllm_async.py @@ -15,6 +15,7 @@ import asyncio import os import random +import time import uuid from collections.abc import Iterator, Sequence from concurrent.futures import ThreadPoolExecutor, wait @@ -1257,8 +1258,6 @@ def _update_weights_with_nccl_broadcast_simple( Args: weights_dict: Dictionary of parameter names to weight tensors """ - import time - if not hasattr(self, "_nccl_master_group") or self._nccl_master_group is None: raise RuntimeError( "NCCL master group not initialized. This is a bug in the setup process." diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index 9739ce5e592..ad73278955c 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -4,12 +4,15 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +from typing import TYPE_CHECKING + import torch from tensordict import TensorDict, TensorDictBase - -from torchrl.envs.common import EnvBase from torchrl.modules.planners.common import MPCPlannerBase +if TYPE_CHECKING: + from torchrl.envs.common import EnvBase + class CEMPlanner(MPCPlannerBase): """CEMPlanner Module. diff --git a/torchrl/modules/planners/common.py b/torchrl/modules/planners/common.py index 35703e6cad7..cc97838ece5 100644 --- a/torchrl/modules/planners/common.py +++ b/torchrl/modules/planners/common.py @@ -5,13 +5,16 @@ from __future__ import annotations import abc +from typing import TYPE_CHECKING import torch from tensordict import TensorDictBase -from torchrl.envs.common import EnvBase from torchrl.modules import SafeModule +if TYPE_CHECKING: + from torchrl.envs.common import EnvBase + class MPCPlannerBase(SafeModule, metaclass=abc.ABCMeta): """MPCPlannerBase abstract Module. diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index e4b33ced697..77d65e16849 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -4,13 +4,17 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +from typing import TYPE_CHECKING + import torch from tensordict import TensorDict, TensorDictBase from torch import nn -from torchrl.envs.common import EnvBase from torchrl.modules.planners.common import MPCPlannerBase +if TYPE_CHECKING: + from torchrl.envs.common import EnvBase + class MPPIPlanner(MPCPlannerBase): """MPPI Planner Module. diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index add36202bba..75c3edec9a5 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -29,6 +29,7 @@ EGreedyWrapper, OrnsteinUhlenbeckProcessModule, OrnsteinUhlenbeckProcessWrapper, + RandomPolicy, ) from torchrl.modules.tensordict_module.probabilistic import ( SafeProbabilisticModule, @@ -70,6 +71,7 @@ "AdditiveGaussianWrapper", "EGreedyModule", "EGreedyWrapper", + "RandomPolicy", "OrnsteinUhlenbeckProcessModule", "OrnsteinUhlenbeckProcessWrapper", "SafeProbabilisticModule", diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 1a0520466db..4f8abaa225e 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -8,13 +8,13 @@ import numpy as np import torch -from tensordict import TensorDictBase +from tensordict import NestedKey, TensorDictBase from tensordict.nn import ( TensorDictModule, TensorDictModuleBase, TensorDictModuleWrapper, ) -from tensordict.utils import expand_as_right, expand_right, NestedKey +from tensordict.utils import expand_as_right, expand_right from torch import nn from torchrl.data.tensor_specs import Composite, TensorSpec @@ -743,3 +743,31 @@ def add_sample( def current_sigma(self, n_steps: torch.Tensor) -> torch.Tensor: sigma = (self.m * n_steps + self.c).clamp_min(self.sigma_min) return sigma + + +class RandomPolicy: + """A random policy for data collectors. + + This is a wrapper around the action_spec.rand method. + + Args: + action_spec: TensorSpec object describing the action specs + + Examples: + >>> from tensordict import TensorDict + >>> from torchrl.data.tensor_specs import Bounded + >>> action_spec = Bounded(-torch.ones(3), torch.ones(3)) + >>> actor = RandomPolicy(action_spec=action_spec) + >>> td = actor(TensorDict()) # selects a random action in the cube [-1; 1] + """ + + def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): + super().__init__() + self.action_spec = action_spec.clone() + self.action_key = action_key + + def __call__(self, td: TensorDictBase) -> TensorDictBase: + if isinstance(self.action_spec, Composite): + return td.update(self.action_spec.rand()) + else: + return td.set(self.action_key, self.action_spec.rand()) diff --git a/torchrl/testing/modules.py b/torchrl/testing/modules.py new file mode 100644 index 00000000000..84dffae8485 --- /dev/null +++ b/torchrl/testing/modules.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import torch +from torch import nn + + +class BiasModule(nn.Module): + """Simple bias module to check weight synchronization correctness.""" + + def __init__(self, value: float = 0.0): + super().__init__() + self.bias = nn.Parameter(torch.tensor(value, dtype=torch.float)) + + def forward(self, x): + return x + self.bias + + +class NonSerializableBiasModule(BiasModule): + """Bias module that intentionally fails to serialize. + + This is used in tests to simulate a policy that cannot be pickled. + """ + + def __getstate__(self): + # Simulate a non-serializable policy by raising on pickling + raise RuntimeError("NonSerializableBiasModule cannot be pickled") diff --git a/torchrl/trainers/algorithms/configs/weight_sync_schemes.py b/torchrl/trainers/algorithms/configs/weight_sync_schemes.py index 4417e5c2cb3..ed128429d76 100644 --- a/torchrl/trainers/algorithms/configs/weight_sync_schemes.py +++ b/torchrl/trainers/algorithms/configs/weight_sync_schemes.py @@ -48,17 +48,12 @@ class SharedMemWeightSyncSchemeConfig(ConfigBase): Weight synchronization using shared memory for in-place weight updates. Workers automatically see weight updates without explicit message passing. - - By default, uses lazy registration (auto_register=True) which makes it seamless - to use with Hydra configs - models are automatically registered on first weight send. """ _target_: str = "torchrl.weight_update.SharedMemWeightSyncScheme" _partial_: bool = False - policy_weights: Any = None # dict[str, TensorDictBase] | None strategy: str = "tensordict" # "tensordict" or "state_dict" - auto_register: bool = True # Enable lazy registration by default def __post_init__(self) -> None: """Post-initialization hook for shared memory weight sync scheme configurations.""" diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 25f3ffa6357..4ae2e81de75 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -1881,7 +1881,7 @@ def _update_with_map(self): and destination in self.collector._weight_sync_schemes ): scheme = self.collector._weight_sync_schemes[destination] - strategy = WeightStrategy(extract_as=scheme.strategy) + strategy = WeightStrategy(extract_as=scheme.strategy_str) weights = strategy.extract_weights(source_module) else: # Fallback: use TensorDict extraction if no scheme found diff --git a/torchrl/weight_update/__init__.py b/torchrl/weight_update/__init__.py index 556064a6113..e50d9752721 100644 --- a/torchrl/weight_update/__init__.py +++ b/torchrl/weight_update/__init__.py @@ -3,43 +3,26 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .weight_sync_schemes import ( - DistributedTransport, - DistributedWeightSyncScheme, - MPTransport, - MultiProcessWeightSyncScheme, - NoWeightSyncScheme, - RayActorTransport, - RayModuleTransformReceiver, - RayModuleTransformScheme, - RayModuleTransformSender, - RayTransport, - RayWeightSyncScheme, - RPCTransport, - RPCWeightSyncScheme, - SharedMemTransport, - SharedMemWeightSyncScheme, - TransportBackend, - WeightReceiver, - WeightSender, - WeightStrategy, - WeightSyncScheme, -) +from ._distributed import DistributedTransport, DistributedWeightSyncScheme +from ._mp import MPTransport, MultiProcessWeightSyncScheme +from ._noupdate import NoWeightSyncScheme +from ._ray import RayModuleTransformScheme, RayTransport, RayWeightSyncScheme +from ._rpc import RPCTransport, RPCWeightSyncScheme +from ._shared import SharedMemTransport, SharedMemWeightSyncScheme +from .weight_sync_schemes import TransportBackend, WeightStrategy, WeightSyncScheme __all__ = [ + # Base classes "TransportBackend", + "WeightStrategy", + "WeightSyncScheme", + # Transports "MPTransport", "SharedMemTransport", "RayTransport", - "RayActorTransport", "RPCTransport", "DistributedTransport", - "WeightStrategy", - "WeightSender", - "WeightReceiver", - "RayModuleTransformSender", - "RayModuleTransformReceiver", - "WeightSyncScheme", + # Schemes "MultiProcessWeightSyncScheme", "SharedMemWeightSyncScheme", "NoWeightSyncScheme", diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py new file mode 100644 index 00000000000..c9cad578c53 --- /dev/null +++ b/torchrl/weight_update/_distributed.py @@ -0,0 +1,847 @@ +from __future__ import annotations + +import random +import socket +import time +import weakref +from datetime import timedelta +from typing import Any + +import torch +from tensordict import TensorDictBase +from torchrl._utils import logger as torchrl_logger + +from torchrl.weight_update.utils import _resolve_model + +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightStrategy, + WeightSyncScheme, +) + + +class DistributedWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for torch.distributed. + + This scheme uses torch.distributed primitives (send/recv) to synchronize + weights across distributed workers. Each worker gets its own transport, + following the same pattern as multiprocess collectors. + + The scheme can create its own TCPStore for coordination if one is not provided. + Use `get_store_info()` after `init_on_sender()` to get connection details for workers. + + Args: + backend (str): The distributed backend ("gloo", "nccl", etc.) + sync (bool): If True, weight updates are synchronous (blocking receive). + If False, a background thread monitors the store and applies weight + updates automatically. Defaults to True. + timeout (float): Timeout in seconds for TCPStore operations. + Defaults to 3600.0 (1 hour). + """ + + def __init__( + self, + backend: str = "gloo", + sync: bool = True, + timeout: float = 3600.0, + ): + super().__init__() + self.backend = backend + self.sync = sync + self._timeout = timeout + self._store = None + self._store_info = None + self._num_workers = None + + def __getstate__(self): + """Custom serialization - exclude non-picklable objects.""" + state = super().__getstate__() + # TCPStore cannot be pickled - remove it but keep _store_info + state["_store"] = None + + # Thread and Event cannot be pickled + state["_background_thread"] = None + state["_stop_event"] = None + + # Transports contain references to store/groups - exclude them + # The receiver will create its own transport in init_on_receiver + state["_sender_transports"] = {} + state["_receiver_transport"] = None + return state + + def __setstate__(self, state): + """Custom deserialization.""" + super().__setstate__(state) + + def _init_on_sender_impl( + self, + *, + model_id: str, + context: Any = None, + num_workers: int, + model: Any = None, + weights: Any = None, + **kwargs, + ) -> None: + if kwargs: + raise RuntimeError(f"Unexpected kwargs: {kwargs.keys()}") + self.model_id = model_id + self._num_workers = num_workers + + # Attach context so we can resolve the model and prepare + # weights on demand via scheme.prepare_weights(). + weights_buffer = None + if context is not None: + self.context = context + if weights is not None: + self.weights = weights + weights_buffer = weights + if model is not None: + self.model = model + else: + # resolve model + try: + model = self.model + except (AttributeError, ValueError): + pass + + if weights_buffer is None and model is not None: + weights_buffer = self._get_weights_buffer_from_model(model) + + # Get base tcp_port from context if available to avoid port conflicts. + # The DistributedDataCollector uses tcp_port for init and tcp_port+1 for its store, + # so we use tcp_port+2 for the weight sync scheme's store. + base_tcp_port = ( + getattr(context, "tcp_port", None) if context is not None else None + ) + self._store = self._make_store( + is_master=True, num_workers=num_workers, base_tcp_port=base_tcp_port + ) + + for i in range(num_workers): + rank = i + 1 # Workers are 1-indexed in distributed + transport = self.create_transport( + store=self._store, + rank=rank, + weights_buffer=weights_buffer, + sync=self.sync, + ) + self._register_worker_sender(worker_idx=i, transport=transport) + + def _make_store( + self, + is_master: bool, + num_workers: int | None = None, + store_info: dict | None = None, + base_tcp_port: int | str | None = None, + max_retries: int = 10, + ) -> torch.distributed.TCPStore: + """Create a TCPStore for weight synchronization. + + Args: + is_master: If True, creates the store as master (server). + If False, connects as client. + num_workers: Number of workers (required for master). + store_info: Dictionary with 'host' and 'port' keys (required for client). + base_tcp_port: Base TCP port from the collector. If provided, the store + will use base_tcp_port + 2 to avoid conflicts with the collector's + stores (which use base_tcp_port and base_tcp_port + 1). + max_retries: Maximum number of retry attempts for handling port conflicts. + + Returns: + The created TCPStore. + """ + if is_master: + # Create as master (server) + if num_workers is None: + raise ValueError( + "num_workers is required when creating store as master" + ) + + hostname = socket.gethostname() + host = socket.gethostbyname(hostname) + + # Use base_tcp_port + 2 if available (to avoid conflicts with collector's + # tcp_port and tcp_port + 1), otherwise find a free port dynamically. + initial_port = int(base_tcp_port) + 2 if base_tcp_port is not None else None + + last_error = None + for attempt in range(max_retries): + if initial_port is None or attempt > 0: + # Find a free port dynamically + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", 0)) + self._store_port = s.getsockname()[1] + else: + self._store_port = initial_port + + try: + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Creating TCPStore on {host}:{self._store_port} " + f"(attempt {attempt + 1}/{max_retries})" + ) + store = torch.distributed.TCPStore( + host_name=host, + port=self._store_port, + is_master=True, + timeout=timedelta(seconds=self._timeout), + wait_for_workers=False, # Don't block - workers may not be started yet + ) + self._store_info = {"host": host, "port": self._store_port} + torchrl_logger.debug( + f"DistributedWeightSyncScheme: TCPStore created successfully: {self._store_info}" + ) + return store + except (RuntimeError, OSError) as e: + error_msg = str(e).lower() + if ( + "address already in use" in error_msg + or "eaddrinuse" in error_msg + ): + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Port {self._store_port} already in use, " + f"retrying ({attempt + 1}/{max_retries})..." + ) + last_error = e + # Add small random delay to reduce collision probability + time.sleep(random.uniform(0.01, 0.1)) + continue + # For other errors, re-raise immediately + raise + + raise RuntimeError( + f"DistributedWeightSyncScheme: Failed to create TCPStore after {max_retries} attempts. " + f"Last error: {last_error}" + ) + else: + # Connect as client + if store_info is None: + raise ValueError("store_info is required when connecting as client") + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Connecting to TCPStore at " + f"{store_info['host']}:{store_info['port']}" + ) + store = torch.distributed.TCPStore( + host_name=store_info["host"], + port=store_info["port"], + is_master=False, + timeout=timedelta(seconds=self._timeout), + ) + return store + + def _init_on_receiver_impl( + self, + *, + model_id: str, + context: Any = None, + store_info: dict | None = None, + worker_idx: int | None = None, + **kwargs, + ) -> None: + """Initialize scheme on the worker (receiver) side. + + Expected kwargs (as provided by collectors): + - model_id: str # e.g. "policy" + - context: Any # collector / inner collector + - store: TCPStore | None # distributed TCP store + - store_info: dict | None # {"host": ..., "port": ...} to create store + - rank: int | None # worker rank (1-indexed) + """ + if context is None: + raise ValueError( + "DistributedWeightSyncScheme.init_on_receiver requires a 'context' " + "providing access to the model to be synchronized." + ) + if worker_idx is None: + raise RuntimeError("rank was not provided.") + if kwargs: + raise RuntimeError(f"Unexpected kwargs: {kwargs.keys()}") + + # Store model_id and context on scheme + self.model_id = model_id + self.context = context + + # Get or create store + # Priority: provided store > provided store_info > self._store_info (from serialization) + # Connect to master's TCPStore as client + info = self._store_info + if info is None: + raise RuntimeError( + "TCPStore info not available. init_on_sender() must be called first on the sender side, before passing the scheme to the receiver." + ) + self._store = self._make_store(is_master=False, store_info=info) + + if (model := getattr(self, "model", None)) is not None: + self.model = model + weights_buffer = self._get_weights_buffer_from_model(model) + else: + raise RuntimeError("Couldn't find weights") + self._receiver_transport = self.create_transport( + store=self._store, + rank=worker_idx, + weights_buffer=weights_buffer, + sync=self.sync, + ) + + # Store worker_idx for synchronize_weights + self._worker_idx = worker_idx + # Note: Background thread for async mode is started in connect() after init_process_group + + def _wait_for_instruction(self, timeout: float | None = None) -> str | None: + """Block until an instruction arrives via TCPStore. + + Args: + timeout: Maximum time to wait for instruction (seconds). + None means block indefinitely. + + Returns: + The instruction string (e.g., "receive", "stop"), or None if + stop event is set or timeout expires. + """ + key = f"NODE_{self._worker_idx}_in" + start_time = time.monotonic() + + while True: + if self._stop_event is not None and self._stop_event.is_set(): + return None + + try: + instruction = self._store.get(key) + self._store.delete_key(key) + # Decode bytes to string + return ( + instruction.decode() + if isinstance(instruction, bytes) + else instruction + ) + except RuntimeError: + # Key doesn't exist yet, continue polling + pass + + # Check timeout + if timeout is not None: + elapsed = time.monotonic() - start_time + if elapsed >= timeout: + return None + + time.sleep(0.01) + + def _send_instruction( + self, + instruction: str = "receive", + worker_ids: int | list[int] | None = None, + ) -> None: + """Send instruction to receiver(s) via TCPStore. + + Args: + instruction: The instruction to send (default: "receive"). + worker_ids: Which workers to send to (None = all workers). + """ + if self._store is None: + raise RuntimeError( + "Store not initialized. init_on_sender() must be called first." + ) + + if worker_ids is None: + target_workers = list(range(self._num_workers)) if self._num_workers else [] + elif isinstance(worker_ids, int): + target_workers = [worker_ids] + else: + target_workers = list(worker_ids) + + # Map instruction to TCPStore format + store_instruction = ( + b"update_weights" if instruction == "receive" else instruction.encode() + ) + + for worker_idx in target_workers: + rank = worker_idx + 1 # Workers are 1-indexed in distributed + self._store.set(f"NODE_{rank}_in", store_instruction) + + def _send_ack(self, message: str = "updated") -> None: + """Send acknowledgment back to sender via TCPStore. + + Args: + message: The acknowledgment message (default: "updated"). + """ + if self._store is None or self._worker_idx is None: + return + self._store.set(f"NODE_{self._worker_idx}_out", message.encode()) + + def _wait_for_ack( + self, + worker_ids: int | list[int] | None = None, + timeout: float | None = None, + ) -> None: + """Wait for acknowledgment from receiver(s) via TCPStore. + + Args: + worker_ids: Which workers to wait for (None = all workers). + timeout: Maximum time to wait (seconds). None means block indefinitely. + """ + if self._store is None: + return + + if worker_ids is None: + target_workers = list(range(self._num_workers)) if self._num_workers else [] + elif isinstance(worker_ids, int): + target_workers = [worker_ids] + else: + target_workers = list(worker_ids) + + for worker_idx in target_workers: + rank = worker_idx + 1 + key = f"NODE_{rank}_out" + try: + status = self._store.get(key) + if status != b"updated": + torchrl_logger.warning( + f"Unexpected ack from worker {worker_idx}: {status}" + ) + self._store.delete_key(key) + except Exception as e: + torchrl_logger.warning( + f"Timeout waiting for ack from worker {worker_idx}: {e}" + ) + + def _background_receive_loop(self): + """Background thread loop that waits for instructions and receives weights. + + This loop: + 1. Waits for an instruction via TCPStore + 2. Receives weights via torch.distributed + 3. Sends an acknowledgment back + 4. Repeats until stop event is set + """ + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Background receiver started for worker {self._worker_idx}" + ) + while not self._stop_event.is_set(): + try: + instruction = self._wait_for_instruction() + if instruction is None: + continue + if instruction in ("receive", "update_weights"): + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Worker {self._worker_idx} " + "received 'receive' instruction" + ) + + # Receive weights via torch.distributed + weights = self._receiver_transport.receive_weights( + model=self.model, + strategy=self._strategy, + ) + + if weights is not None: + # Cascade weight update to sub-collectors if context supports it + model_id = self._model_id or "policy" + if self.context is not None and hasattr( + self.context, "update_policy_weights_" + ): + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Cascading weight update to sub-collectors for {model_id=}" + ) + self.context.update_policy_weights_( + model_id=model_id, policy_or_weights=weights + ) + + # Send acknowledgment + self._send_ack("updated") + + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Worker {self._worker_idx} " + "received and applied weights" + ) + + elif instruction == "stop": + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Worker {self._worker_idx} received 'stop' instruction" + ) + break + else: + torchrl_logger.warning( + f"DistributedWeightSyncScheme: Unknown instruction: {instruction}" + ) + + except Exception as e: + if not self._stop_event.is_set(): + torchrl_logger.warning( + f"DistributedWeightSyncScheme: Background receiver error: {e}" + ) + + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Background receiver stopped for worker {self._worker_idx}" + ) + + def _setup_connection_and_weights_on_sender_impl( + self, *, worker_idx: int | None = None, weights: Any | None = None + ) -> None: + """Send initial weights to all workers during connect(). + + If the sender has a stateful model (weights available), send them + to all workers so they start with the correct weights. + + Note: This uses direct torch.distributed send/recv without TCPStore + signaling to avoid interfering with the main collection loop. + """ + # Initialize torch.distributed process group if not already done + # This is a collective operation - all workers must call it + if not torch.distributed.is_initialized(): + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Initializing process group on sender " + f"(world_size={self._num_workers + 1})" + ) + torch.distributed.init_process_group( + backend=self.backend, + rank=0, # Sender is always rank 0 + world_size=self._num_workers + 1, + timeout=timedelta(seconds=self._timeout), + ) + + # Check if we have weights to send + if weights is None and getattr(self, "model", None) is None: + torchrl_logger.debug( + "DistributedWeightSyncScheme: No model on sender, skipping initial weight sync" + ) + self._store.set("STATELESS_MODEL", b"1") + return + + self._store.set("STATELESS_MODEL", b"0") + # Prepare weights from model + weights = self._get_weights_buffer_from_model(self.model) + if weights is None or weights.is_empty(): + torchrl_logger.debug( + "DistributedWeightSyncScheme: Empty weights, skipping initial weight sync" + ) + return + + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Sending initial weights to {self._num_workers} workers" + ) + + # Send to all workers using direct torch.distributed (no TCPStore signaling) + for i, transport in enumerate(self._iterate_transports()): + if worker_idx is not None and i != worker_idx: + continue + transport.send_initial_weights(weights) + + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None + ) -> None: + """Receive initial weights from sender during connect(). + + The receiver always has a model that needs weights, so we block + waiting for the initial weights from the sender. + """ + # Use stored worker_idx if not provided + if worker_idx is None: + worker_idx = self._worker_idx + + # Initialize torch.distributed process group if not already done + # This is a collective operation - sender and all workers must call it + if not torch.distributed.is_initialized(): + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Initializing process group on worker {worker_idx} " + f"(world_size={self._num_workers + 1})" + ) + torch.distributed.init_process_group( + backend=self.backend, + rank=worker_idx, + world_size=self._num_workers + 1, + timeout=timedelta(seconds=self._timeout), + ) + + if self._receiver_transport is None: + torchrl_logger.warning( + "DistributedWeightSyncScheme: No receiver transport, skipping initial weight sync" + ) + return + + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Worker {worker_idx} waiting for STATELESS_MODEL key" + ) + stateless_model = self.receiver_transport._store.get("STATELESS_MODEL") + if stateless_model not in (b"0", b"1"): + raise RuntimeError(f"Invalid STATELESS_MODEL value: {stateless_model}") + if stateless_model == b"1": + torchrl_logger.debug( + "DistributedWeightSyncScheme: Skipping initial weight sync on receiver because of stateless model." + ) + else: + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Worker {worker_idx} waiting for initial weights" + ) + + # Receive initial weights (blocking, no TCPStore coordination) + weights = self._receiver_transport.receive_initial_weights() + if weights is not None and self.model is not None: + self._strategy.apply_weights(self.model, weights, inplace=False) + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Worker {worker_idx} received and applied initial weights" + ) + + # Start background receiver thread AFTER initial weight sync is complete + # This prevents the background thread from consuming the initial sync messages + if self._background_thread is None: + self._start_background_receiver() + + def shutdown(self) -> None: + """Stop background receiver thread and clean up.""" + if self._stop_event is not None: + self._stop_event.set() + if self._background_thread is not None: + self._background_thread.join(timeout=5.0) + if self._background_thread.is_alive(): + torchrl_logger.warning( + "DistributedWeightSyncScheme: Background thread did not stop gracefully" + ) + self._background_thread = None + self._stop_event = None + + @property + def model(self) -> Any | None: + """Get the model associated with this scheme. + + Returns: + The model if set, None otherwise. + """ + if self._model_ref is not None: + return self._model_ref() + if self._model_id is not None: + model = _resolve_model(self.context, self._model_id) + if model is None: + if self._model_id == "policy": + torchrl_logger.debug( + f"Creating policy from factory and setting in collector {type(self.context)}" + ) + model = self.context.policy_factory[0]() + self.context.policy = model + torchrl_logger.debug(f"{self.context.policy=}") + else: + raise AttributeError( + f"Model {self._model_id} was `None` in context {self.context}" + ) + self._model_ref = weakref.ref(model) + return model + + @model.setter + def model(self, value: Any): + """Set the model for this scheme. + + Args: + value: The model to set. If None, the setter is a no-op. + """ + if value is None: + return + self._model_ref = weakref.ref(value) + + def create_transport(self, **kwargs) -> TransportBackend: + """Create distributed transport for a specific worker.""" + return DistributedTransport(**kwargs) + + +class DistributedTransport: + """torch.distributed transport for communicating with a single distributed worker. + + This transport handles weight updates for ONE specific distributed worker via + torch.distributed send/recv. Multiple transports are created for multiple workers, + following the same pattern as multiprocess collectors. + """ + + def __init__( + self, + *, + weights_buffer: TensorDictBase, + store: torch.distributed.Store = None, + rank: int | None = None, + sync: bool = True, + ): + """Initialize the DistributedTransport. + + Args: + weights_buffer (TensorDictBase): a tensor buffer of weights. + store (torch.distributed.Store): A (TCP)Store for communication. + rank (int): Worker rank (1-indexed). + sync (bool): Whether to use synchronous weight updates. + """ + self._store = store + self._rank = rank + self._sync = sync + self._weights_buffer = weights_buffer + + def send_weights(self, weights: Any) -> None: + """Send weights to the distributed worker.""" + if self._store is None or self._rank is None: + return + + # Instruct worker to expect weight update + torchrl_logger.debug("RANK 0 -- Setting weight sync instructions to store") + self._store.set(f"NODE_{self._rank}_in", b"update_weights") + + # Send weights via torch.distributed + torchrl_logger.debug(f"RANK 0 -- Send {type(weights)=} to rank {self._rank}") + if self._sync: + weights.send(self._rank) + else: + weights.isend(self._rank) + + # Wait for acknowledgment + torchrl_logger.debug("RANK 0 -- Receiving acknowledgement from store") + status = self._store.get(f"NODE_{self._rank}_out") + if status != b"updated": + raise RuntimeError(f"Expected 'updated' but got status {status}.") + self._store.delete_key(f"NODE_{self._rank}_out") + + def send_weights_async(self, weights: Any) -> None: + """Send weights to distributed worker without waiting for acknowledgment. + + Use wait_ack() to wait for acknowledgment after sending to all workers. + """ + if self._store is None or self._rank is None: + return + + # Instruct worker to expect weight update + torchrl_logger.debug( + f"RANK 0 -- Setting weight sync instructions to store for rank {self._rank}" + ) + self._store.set(f"NODE_{self._rank}_in", b"update_weights") + + # Send weights via torch.distributed + torchrl_logger.debug( + f"RANK 0 -- Send {type(weights)=} to rank {self._rank} with sync={self._sync}" + ) + if self._sync: + weights.send(self._rank) + else: + weights.isend(self._rank) + torchrl_logger.debug(f"RANK 0 -- Weights successfully sent to {self._rank}") + + def wait_ack(self) -> None: + """Wait for acknowledgment from distributed worker.""" + if self._store is None or self._rank is None: + return + + status = self._store.get(f"NODE_{self._rank}_out") + if status != b"updated": + raise RuntimeError(f"Expected 'updated' but got status {status}.") + self._store.delete_key(f"NODE_{self._rank}_out") + + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy | None = None, + ) -> Any | None: + r"""Receive weights via torch.distributed and apply them to the model. + + The surrounding collector loop is responsible for checking the TCPStore + for the \"update_weights\" instruction. When this method is called we + assume that a weight update has been requested and the sender has + already performed the corresponding ``send()``. + + Args: + timeout: Maximum time to wait for weights (seconds). If None, + blocks until weights are received. + weights: Pre-allocated weight buffer to receive into. + model: The model to apply weights to. + strategy: Strategy for applying weights to the model. + + Returns: + The received weights, or None if timeout expires. + """ + if self._store is None or self._rank is None: + return None + + # Use provided weights buffer or fallback to stored one + weights_buffer = weights if weights is not None else self._weights_buffer + + # Receive weights via torch.distributed into the buffer + if self._sync or timeout is None: + # Blocking receive - no timeout support + if self._sync: + torchrl_logger.debug(f"Rank {self._rank} -- calling recv") + weights_buffer.recv(src=0) + else: + torchrl_logger.debug(f"Rank {self._rank} -- calling irecv") + weights_buffer.irecv(src=0) + else: + # Non-blocking receive with timeout support + torchrl_logger.debug( + f"Rank {self._rank} -- calling irecv with premature return" + ) + futures = weights_buffer.irecv(src=0, return_premature=True) + if futures: + start_time = time.monotonic() + while True: + # Check if all futures are complete + all_complete = all(f.is_completed() for f in futures) + if all_complete: + break + # Check timeout + elapsed = time.monotonic() - start_time + if elapsed >= timeout: + # Timeout expired before receiving all weights + return None + # Small sleep to avoid busy-waiting + time.sleep(0.001) + + # Apply weights if model and strategy provided + if model is not None and strategy is not None: + strategy.apply_weights(model, weights_buffer) + + torchrl_logger.debug(f"Rank {self._rank} -- closing receive_weights") + return weights_buffer + + def send_initial_weights(self, weights: Any) -> None: + """Send initial weights during connect() without TCPStore signaling. + + This is used for the initial weight sync during connect() to avoid + interfering with the main collection loop's TCPStore-based coordination. + """ + if self._rank is None: + return + + torchrl_logger.debug( + f"DistributedTransport: Sending initial weights to rank {self._rank}" + ) + # Note: No TCPStore signaling for initial sync - just direct send/recv + if self._sync: + weights.send(self._rank) + else: + weights.isend(self._rank) + + def receive_initial_weights(self) -> Any: + """Receive initial weights during connect() without TCPStore signaling. + + This is used for the initial weight sync during connect() to avoid + interfering with the main collection loop's TCPStore-based coordination. + + Returns: + The received weights TensorDict. + """ + torchrl_logger.debug( + "DistributedTransport: Receiving initial weights from rank 0" + ) + if self._sync: + self._weights_buffer.recv(src=0) + else: + self._weights_buffer.irecv(src=0) + return self._weights_buffer + + def setup_connection_and_weights_on_sender(self) -> None: + """No-op for DistributedTransport - handled by scheme.""" + + def setup_connection_and_weights_on_receiver( + self, + *, + worker_idx: int, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy | None = None, + ) -> Any: + """No-op for DistributedTransport - handled by scheme.""" + return None diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py new file mode 100644 index 00000000000..9a56ee1c06c --- /dev/null +++ b/torchrl/weight_update/_mp.py @@ -0,0 +1,651 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import torch +from tensordict import TensorDictBase +from torch import multiprocessing as mp, nn +from torchrl.weight_update._shared import SharedMemWeightSyncScheme +from torchrl.weight_update.utils import _resolve_model + +from torchrl.weight_update.weight_sync_schemes import TransportBackend + + +class MultiProcessWeightSyncScheme(SharedMemWeightSyncScheme): + """Weight synchronization for multiprocess operations using queues. + + This scheme creates transports that communicate via multiprocessing queues. + Unlike the parent SharedMemWeightSyncScheme which uses shared memory for in-place + updates, this scheme sends actual weight copies through queues to workers. + + A background thread on the receiver side listens for "receive" instructions + from the sender. When an instruction arrives, the thread receives the weights + from the weight queue and applies them to the model. + + It follows the same two-phase pattern as SharedMemWeightSyncScheme: + + 1. **init_on_sender()**: Stores the recipe for creating device-specific weights + (model reference, devices, mapping functions) without creating actual copies + 2. **synchronize_weights()**: Creates device-specific weight copies on-demand, + sends them sequentially to workers via queues, allowing garbage collection + between workers to minimize memory usage + + This approach avoids holding multiple weight copies in memory simultaneously, + which is especially beneficial for large models with many workers. + + Synchronization flow: + - **init_on_sender()**: Store configuration and register worker queues + - **synchronize_weights()**: Create and send initial weights on-demand + - **init_on_receiver()**: Create receiver that reads from queue + - **send()**: Extract and send weight updates, wait for acknowledgments + + Args: + strategy: The weight transmission strategy (default: "tensordict"). + Can be "tensordict" or "state_dict". + sync: If True (default), send() blocks until receiver acknowledges. + If False, send() returns immediately (use send_async/wait_async). + + Example: + >>> # Basic usage with collector + >>> scheme = MultiProcessWeightSyncScheme() + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3, + ... policy=policy, + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.collect() is called automatically by collector + >>> # Weights are created on-demand and sent to workers efficiently + + Note: + The on-demand weight creation means that synchronize_weights() will be + slower than if weights were pre-computed, but memory usage is significantly + reduced, especially when workers use different devices or when the model + is large. + """ + + def __init__(self, strategy: str = "tensordict", sync: bool = True): + """Initialize the MultiProcessWeightSyncScheme. + + Args: + strategy: The weight transmission strategy (default: "tensordict"). + sync: If True (default), send() blocks until receiver acknowledges. + """ + super().__init__(strategy, sync=sync) + # Override parent's shared transport - we don't use shared memory + self._shared_transport = None + + def _init_on_sender_impl( + self, + *, + model_id: str | None = None, + context: Any = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + This method stores the configuration needed to create device-specific weight + copies during synchronization. Weight copies are created on-demand during + `synchronize_weights()` to reduce memory usage. + + Similar to `SharedMemWeightSyncScheme`, this follows a two-phase pattern: + 1. `init_on_sender()`: Store the recipe for creating weights + 2. `synchronize_weights()`: Create and send weights on-demand + + Args: + model_id: Identifier for the model being synchronized (e.g., "policy"). + Required when using context. + context: Optional context object (e.g., collector) providing: + - num_workers: Number of worker processes + - policy_device: List of devices for each worker + When provided, model_id is used to resolve the model from context. + weights: Pre-extracted weights as TensorDict. Mutually exclusive with + model and context. Used when weights are already available. + model: Model to extract weights from. Mutually exclusive with weights + and context. + params_map: Pre-computed mapping of worker_idx to device-specific weights. + Most explicit option. When provided, all other parameters must be None. + devices: List of devices for each worker. Used with weights or model to + automatically create device-specific copies. Length must equal num_workers. + device_map_fn: Custom function (worker_idx, weights) -> device_weights. + Allows full control over device mapping. Requires num_workers. + num_workers: Number of workers. Required with device_map_fn, inferred + from devices length otherwise. + **kwargs: Reserved for future use. + + Examples: + Simple usage with collector context (most common): + + >>> scheme = MultiProcessWeightSyncScheme() + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3, + ... policy=policy, + ... frames_per_batch=100, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.init_on_sender() is called automatically by collector + + Direct initialization with explicit devices: + + >>> scheme = MultiProcessWeightSyncScheme() + >>> weights = TensorDict.from_module(policy) + >>> scheme.init_on_sender( + ... weights=weights, + ... devices=[torch.device("cpu"), torch.device("cuda:0")], + ... num_workers=2, + ... ) + + Advanced: Pre-computed params_map: + + >>> weights_cpu = TensorDict.from_module(policy) + >>> weights_cuda = weights_cpu.to("cuda") + >>> scheme.init_on_sender( + ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, + ... num_workers=3, + ... ) + """ + # Get params_map from parent class logic + params_map_result = self._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) + + # Store the mapping recipe for later use in synchronize_weights + # Don't store params_map directly to save memory - we'll recompute on demand + # Note: We don't store context directly to avoid pickle issues - + # it's available via _context_ref + self._device_mapping_info = { + "model_id": model_id, + "weights": weights, + "model": model, + "params_map": params_map, + "devices": devices, + "device_map_fn": device_map_fn, + "num_workers": num_workers + if num_workers is not None + else len(params_map_result), + } + + # Create per-worker queues for weight distribution + # Each worker gets its own queue for receiving weights + all_workers = list(params_map_result.keys()) + if not hasattr(self, "_weight_init_queues"): + self._weight_init_queues = {} + + for worker_idx in all_workers: + if worker_idx not in self._weight_init_queues: + self._weight_init_queues[worker_idx] = mp.Queue() + # Create instruction queues for background receiver + if worker_idx not in self._instruction_queues: + self._instruction_queues[worker_idx] = mp.Queue() + # Create ack queues for synchronous mode + if worker_idx not in self._ack_queues: + self._ack_queues[worker_idx] = mp.Queue() + + # Store model_id and context on scheme + self.model_id = model_id + if context is not None: + self.context = context + + # Register workers with their queues + for worker_idx in all_workers: + queue = self._weight_init_queues[worker_idx] + ack_queue = self._ack_queues[worker_idx] + # Create MPTransport for this worker with ack queue + transport = MPTransport(weight_queue=queue, ack_queue=ack_queue) + self._register_worker_sender(worker_idx=worker_idx, transport=transport) + + def _init_on_receiver_impl( + self, + *, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing worker_idx and model + **kwargs: Alternative to context (worker_idx, model, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + worker_idx = getattr(context, "worker_idx", None) + if hasattr(context, "get_model"): + model = context.get_model(model_id) + else: + model = _resolve_model(context, model_id) + else: + worker_idx = kwargs.get("worker_idx") + model = kwargs.get("model") + + if worker_idx is None: + raise ValueError("worker_idx must be provided via context or kwargs") + + # Get the queue for this worker + if worker_idx not in self._weight_init_queues: + raise ValueError( + f"Worker {worker_idx} not registered. init_on_sender() must be called first." + ) + + queue = self._weight_init_queues[worker_idx] + ack_queue = self._ack_queues.get(worker_idx) + + # Store on scheme directly + self.model_id = model_id + if context is not None: + self.context = context + + # Store instruction and ack queue references for this worker + if worker_idx in self._instruction_queues: + self._receiver_instruction_queue = self._instruction_queues[worker_idx] + if worker_idx in self._ack_queues: + self._receiver_ack_queue = self._ack_queues[worker_idx] + + # Create transport with the worker's queue and ack queue + transport = MPTransport(weight_queue=queue, ack_queue=ack_queue) + self._register_transport_receiver(transport=transport) + + if model is not None: + self.model = model + + # Store worker_idx for synchronize_weights + self.worker_idx = worker_idx + + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights synchronously to workers. + + This method: + 1. Prepares weights (extracts from model if weights=None) + 2. Sends weights to the weight queue + 3. Sends "receive" instruction to workers' background threads + 4. If sync=True, waits for acknowledgments from those workers + + Args: + weights: Weights to send. Can be: + - None: Extract from model via context.get_model(model_id) + - nn.Module: Extract weights from module + - TensorDict: Use directly + - dict: Convert to TensorDict + worker_ids: Which workers to send to: + - None: Send to all workers (default) + - int: Send to single worker + - list[int]: Send to specific workers + + Note: If sync=True (default), this is a blocking call that ensures + specified workers are updated before returning. + """ + from torchrl._utils import logger as torchrl_logger + + if not self.initialized_on_sender: + raise RuntimeError("Must be initialized on sender before sending weights") + if not self.synchronized_on_sender: + raise RuntimeError("Must be synchronized on sender before sending weights") + + model_id = self.model_id + context = self.context + + # Let the scheme prepare the weights + prepared_weights = self.prepare_weights( + weights=weights, + model_id=model_id, + strategy=self._strategy, + context=context, + ) + + transports = list(self._iterate_transports(worker_ids)) + + # Send weights to all workers first via queue (non-blocking) + torchrl_logger.debug("Sending weights to queues") + for transport in transports: + if hasattr(transport, "send_weights_async"): + # For MPTransport, pass model_id; other transports don't need it + transport.send_weights_async(prepared_weights, model_id=model_id) + else: + # Fallback for transports that don't support async send + transport.send_weights(prepared_weights) + + # Send instruction to workers' background threads to receive the weights + torchrl_logger.debug("Sending 'receive' instruction to workers") + self._send_instruction(instruction="receive", worker_ids=worker_ids) + + # Wait for all acknowledgments if in synchronous mode + if self.sync: + torchrl_logger.debug("Waiting for acknowledgments from workers") + self._wait_for_ack(worker_ids=worker_ids) + + def _setup_connection_and_weights_on_sender_impl( + self, + *, + worker_idx: int | None = None, + weights: Any | None = None, + ) -> None: + """Synchronize weights with workers before collection starts. + + Computes device-specific weight copies on-demand and sends them to workers + sequentially via queues. This is called once after workers are initialized + but before they start collecting data. + + Unlike send(), this does not wait for acknowledgments since workers are still + in their initialization phase. + + This approach creates weight copies on-demand and sends them sequentially, + allowing garbage collection between workers to reduce memory usage. + + Raises: + RuntimeError: If init_on_sender() was not called first. + """ + # Get the device mapping info stored during init_on_sender + if not hasattr(self, "_device_mapping_info"): + raise RuntimeError( + "synchronize_weights() requires init_on_sender() to be called first" + ) + + mapping_info = self._device_mapping_info + + # Get context from weakref + context = self.context + + # Compute params_map on-demand + # Extract with explicit type casting for type checker + model_id = mapping_info["model_id"] + weights = mapping_info["weights"] + model = mapping_info["model"] + params_map_arg = mapping_info["params_map"] + devices = mapping_info["devices"] + device_map_fn = mapping_info["device_map_fn"] + num_workers = mapping_info["num_workers"] + + params_map = self._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map_arg, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) + + # Send to workers sequentially via queues (no ACK - workers are still initializing) + # This allows GC to clean up each worker's weights before creating the next + for i, transport in enumerate(self._iterate_transports()): + if worker_idx is not None and i != worker_idx: + continue + worker_weights = params_map[i] + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(worker_weights, model_id=self._model_id) + else: + raise RuntimeError( + f"Transport {type(transport)} does not support async send for synchronization" + ) + + # Clean up the mapping info after synchronization + delattr(self, "_device_mapping_info") + + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None + ) -> None: + """Receive initial weights and start background receiver thread. + + This method: + 1. Receives initial weights from the sender via queue + 2. Applies them to the model + 3. Starts a background thread that listens for "receive" instructions + + Args: + worker_idx: The worker index. + """ + from torchrl._utils import logger as torchrl_logger + + # Use stored worker_idx if not provided + if worker_idx is None: + worker_idx = self._worker_idx + + if worker_idx is None: + raise RuntimeError( + "worker_idx must be provided for _setup_connection_and_weights_on_receiver_impl." + ) + + # Receive initial weights from queue via transport + if self._receiver_transport is None: + raise RuntimeError("Receiver transport not set.") + + weights = self._receiver_transport.setup_connection_and_weights_on_receiver( + worker_idx=worker_idx, + weights=self.weights, + model=self.model, + strategy=self._strategy, + ) + + # Store received weights for later use + if weights is not None: + self._receiver_weights = weights + + # Apply weights to model + if weights is not None and self.model is not None: + self._strategy.apply_weights(self.model, weights, inplace=False) + torchrl_logger.debug( + f"MultiProcessWeightSyncScheme: Worker {worker_idx} applied initial weights" + ) + + # Start background receiver thread + self._start_background_receiver() + + def _background_receive_loop(self): + """Background thread loop that waits for instructions and receives weights. + + This loop: + 1. Waits for a "receive" instruction from the sender + 2. Receives weights from the weight queue + 3. Applies them to the model + 4. Sends an acknowledgment back to the sender + 5. Repeats until stop event is set or "stop" instruction received + """ + from torchrl._utils import logger as torchrl_logger + + torchrl_logger.debug( + f"MultiProcessWeightSyncScheme: Background receiver started for worker {self._worker_idx}" + ) + while not self._stop_event.is_set(): + try: + instruction = self._wait_for_instruction() + if instruction is None: + # Stop event was set or timeout + continue + if instruction == "receive": + torchrl_logger.debug( + f"MultiProcessWeightSyncScheme: Worker {self._worker_idx} received 'receive' instruction" + ) + + # Receive weights from transport (blocking) + if self._receiver_transport is not None: + weights = self._receiver_transport.receive_weights( + model=self.model, + strategy=self._strategy, + ) + + if weights is not None: + torchrl_logger.debug( + f"MultiProcessWeightSyncScheme: Worker {self._worker_idx} received and applied weights" + ) + + # Cascade weight update to sub-collectors if context supports it + model_id = self._model_id or "policy" + if self.context is not None and hasattr( + self.context, "update_policy_weights_" + ): + torchrl_logger.debug( + f"MultiProcessWeightSyncScheme: Cascading weight update to sub-collectors for {model_id=}" + ) + self.context.update_policy_weights_( + model_id=model_id, policy_or_weights=weights + ) + + # Send acknowledgment + self._send_ack("updated") + + elif instruction == "stop": + torchrl_logger.debug( + f"MultiProcessWeightSyncScheme: Worker {self._worker_idx} received 'stop' instruction" + ) + break + else: + torchrl_logger.warning( + f"MultiProcessWeightSyncScheme: Unknown instruction: {instruction}" + ) + except Exception as e: + if not self._stop_event.is_set(): + torchrl_logger.warning( + f"MultiProcessWeightSyncScheme: Background receiver error: {e}" + ) + + torchrl_logger.debug( + f"MultiProcessWeightSyncScheme: Background receiver stopped for worker {self._worker_idx}" + ) + + def create_transport(self, **kwargs) -> TransportBackend: + """Create an MPTransport using the provided queue. + + Note: + This is used internally by init_on_sender/init_on_receiver. + """ + queue = kwargs.get("queue") + return MPTransport(weight_queue=queue, ack_queue=None) + + +class MPTransport: + """Multiprocessing transport using queues. + + This transport uses queues for weight distribution and synchronization. + Similar to SharedMemTransport's queue-based approach, MPTransport uses + queues to send initial weights to workers during synchronization. + + Initialization flow: + - synchronize_weights() extracts weights and sends to all workers via queues + - Workers receive the initial weights via setup_connection_and_weights_on_receiver() + - Subsequent updates use send_weights_async() followed by acknowledgments + + Args: + weight_queue (mp.Queue): The queue to use for sending weights. + ack_queue (mp.Queue): The queue to use for receiving acknowledgments. + timeout (float): The timeout for waiting for acknowledgment. Default is 10 seconds. + """ + + def __init__(self, weight_queue, ack_queue=None, timeout: float = 10.0): + self.timeout = timeout + self.weight_queue = weight_queue + self.ack_queue = ack_queue + + def send_weights_async(self, weights: Any, model_id: str = "policy") -> None: + """Send weights through the queue without waiting for acknowledgment. + + Use wait_ack() to wait for acknowledgment after sending to all workers. + """ + # Send in format expected by worker loop: ((model_id, weights), "update_weights") + self.weight_queue.put(((model_id, weights), "update_weights")) + + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> Any | None: + """Receive weights from the queue (used in worker process). + + This method only handles weight update messages. Other messages + (like "close", "continue", etc.) are ignored and should be handled + by the main worker loop. + + Args: + timeout: Maximum time to wait for weights (seconds). + None means use the transport's default timeout. + weights: Ignored (weights come from queue). + model: The model to apply weights to. + strategy: Strategy for applying weights to the model. + + Returns: + The received weights, or None if no data available. + """ + # Use transport's default timeout if not specified + if timeout is None: + timeout = self.timeout + data_in, msg = self.weight_queue.get(timeout=timeout) + if msg == "update_weights": + # data_in is (model_id, weights) - we ignore model_id, scheme knows it + _model_id, received_weights = data_in + + # Apply weights to model if provided + if model is not None and strategy is not None: + strategy.apply_weights(model, received_weights) + + return received_weights + else: + raise ValueError(f"Expected 'update_weights' but got {msg}") + + def setup_connection_and_weights_on_sender(self) -> None: + """No-op for MPTransport - weights are sent via scheme's synchronize_weights(). + + The actual sending happens in MultiProcessWeightSyncScheme._setup_connection_and_weights_on_sender_impl(), which: + 1. Extracts weights from the context (e.g., collector.policy) + 2. Calls send_weights_async() on all worker transports + 3. Sends initial weights through queues to all workers + + This is similar to SharedMemTransport.setup_connection_and_weights_on_sender() which + sends shared memory buffer references via queues. + """ + + def setup_connection_and_weights_on_receiver( + self, + *, + worker_idx: int, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> Any: + """Receive initial weights from sender during worker initialization. + + This method blocks waiting for the initial weights to be sent from the main process + via queue. Similar to SharedMemTransport.setup_connection_and_weights_on_receiver() which receives + shared memory buffer references via queues, this receives the actual weights via queues. + + The received weights are then applied to the worker's model by the scheme's synchronize_weights(). + + Args: + worker_idx: The worker index (used for logging/debugging). + weights: Ignored (weights come from queue). + model: Ignored. + strategy: Ignored. + + Returns: + The received weights if available, None otherwise (weights will come later via receive()). + """ + # Wait for initial weights (blocking) + data_in, msg = self.weight_queue.get(timeout=self.timeout) + if msg == "update_weights": + # data_in is (model_id, weights), extract just the weights + _, received_weights = data_in + return received_weights + else: + raise ValueError(f"Expected 'update_weights' but got {msg}") diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py new file mode 100644 index 00000000000..8fe5625cb27 --- /dev/null +++ b/torchrl/weight_update/_noupdate.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from typing import Any + +from torchrl.weight_update.weight_sync_schemes import TransportBackend, WeightSyncScheme + + +class NoWeightSyncScheme(WeightSyncScheme): + """No-op weight synchronization scheme. + + This scheme disables weight synchronization entirely. + """ + + def _init_on_sender_impl( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (not used) + **kwargs: Optional parameters (not used) + """ + # Store model_id directly on scheme (no-op) + self.model_id = model_id + + def _init_on_receiver_impl( + self, + *, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (not used) + **kwargs: Optional parameters (not used) + """ + # Store model_id directly on scheme (no-op) + self.model_id = model_id + + def create_transport(self, **kwargs) -> TransportBackend: + """Create a no-op transport. + + Note: + This is used internally by init_on_sender/init_on_receiver. + """ + # Return a dummy transport that does nothing + class NoOpTransport: + def send_weights(self, weights: Any) -> None: + pass + + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> Any | None: + return None + + def check_connection(self) -> bool: + return True + + def setup_connection_and_weights_on_sender(self) -> None: + pass + + def setup_connection_and_weights_on_receiver( + self, + *, + worker_idx: int, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> Any: + return None + + return NoOpTransport() + + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """No-op send - does nothing.""" + + def receive(self, timeout: float | None = None) -> None: + """No-op receive - always returns None.""" + return None + + def connect(self, *, worker_idx: int | None = None) -> None: + """No-op synchronize - does nothing.""" + if self._initialized_on_sender: + self.synchronized_on_sender = True + elif self._initialized_on_receiver: + self.synchronized_on_receiver = True diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py new file mode 100644 index 00000000000..73c3a6894c8 --- /dev/null +++ b/torchrl/weight_update/_ray.py @@ -0,0 +1,1104 @@ +from __future__ import annotations + +import os +import socket + +import time +import weakref +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Literal + +import torch +from tensordict import TensorDict +from tensordict.base import TensorDictBase + +from torchrl._utils import logger as torchrl_logger +from torchrl.weight_update.utils import _resolve_model +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightStrategy, + WeightSyncScheme, +) + +# Default timeout for torch.distributed operations +_DIST_TIMEOUT = timedelta(seconds=60) + + +@dataclass +class ConnectionInfo: + """Connection info for Ray distributed computing. + + Uses dataclass instead of UserDict to avoid Ray signature introspection + issues with UserDict's __class_getitem__ in Python 3.11+ + (ValueError: no signature found for builtin type GenericAlias). + """ + + master_addr: str + master_port: int + world_size: int + stateful_model: bool + + def get(self, key: str, default: Any = None) -> Any: + """Get a connection info value by key name. + + Args: + key (str): The attribute name to retrieve. + default: The default value if the attribute does not exist. + Defaults to None. + + Returns: + The value of the attribute, or the default if not found. + """ + return getattr(self, key, default) + + +class RayTransport: + """Ray transport for communicating with a single Ray actor. + + This transport handles weight updates for ONE specific remote actor + using torch.distributed for efficient weight transfer. Ray is used for + signaling/coordination, while the actual weight data is transferred via + torch.distributed send/recv operations. + + Multiple transports are created for multiple actors, following the + same pattern as multiprocess collectors. + + Args: + remote_actor: The Ray actor handle for the remote collector/transform. + worker_idx (int, optional): The worker index for this remote actor. + Defaults to 0. + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Defaults to "gloo". + connection_info_name (str): Name of the Ray actor storing connection info. + Defaults to "connection_info". + model_id (str, optional): The model identifier for weight synchronization. + """ + + def __init__( + self, + *, + remote_actor=None, + worker_idx: int | None = None, + backend: str = "gloo", + connection_info_name: str = "connection_info", + model_id: str | None = None, + ): + """Initialize the RayTransport. + + Args: + remote_actor: The Ray actor handle for the remote collector/transform. + worker_idx (int, optional): The worker index for this remote actor. + Defaults to 0. + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Defaults to "gloo". + connection_info_name (str): Name of the Ray actor storing connection info. + Defaults to "connection_info". + model_id (str, optional): The model identifier for weight synchronization. + """ + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayTransport") + self._remote_actor = remote_actor + self._worker_idx = worker_idx if worker_idx is not None else 0 + self._backend = backend + self._connection_info_name = connection_info_name + self._model_id = model_id + + # Distributed state + self._dist_initialized = False + self._weights_buffer: TensorDictBase | None = None + self._stateful_model: bool = True + + # Async operation state + self._pending_future = None + self._pending_isend = None + + # Model reference (set by scheme on receiver side) + self._model = None + + @property + def _rank(self) -> int: + """Get the torch.distributed rank for this worker. + + Returns: + int: The rank (worker_idx + 1, since sender is rank 0). + """ + return self._worker_idx + 1 # Sender is rank 0, workers are 1-indexed + + def set_model(self, model: Any) -> None: + """Set the model for receiving weights. + + Args: + model: The model to receive weights into. + """ + self._model = model + + # ======================================================================== + # Sending Weights (Sender Side) + # ======================================================================== + + def send_weights(self, weights: Any) -> None: + """Send weights to the remote actor via torch.distributed. + + This method: + 1. Signals the remote actor to start receiving via Ray remote call + 2. Sends weights via torch.distributed.isend + 3. Waits for both to complete + + Args: + weights: The weights to send (typically a TensorDict). + """ + if self._remote_actor is None: + return + + # Step 1: Signal the remote actor via Ray to start receiving (async) + future = self._remote_actor._receive_weights_scheme.remote() + + # Step 2: Send weights via torch.distributed (async) + torchrl_logger.debug(f"RayTransport: Sending weights to rank {self._rank}") + weights.isend(dst=self._rank) + + # Step 3: Wait for the Ray call to complete (receiver has applied weights) + self.ray.get(future) + + def send_weights_async(self, weights: Any) -> None: + """Send weights to Ray actor without waiting for completion. + + Use :meth:`wait_ack` to wait for completion after sending to all actors. + + Args: + weights: The weights to send (typically a TensorDict). + """ + if self._remote_actor is None: + return + + # Step 1: Signal the actor via Ray to start receiving (async) + torchrl_logger.debug( + f"RayTransport: Sending weights async to rank {self._rank}" + ) + self._pending_future = self._remote_actor._receive_weights_scheme.remote() + + # Step 2: Send weights via torch.distributed (async) + self._pending_isend = weights.isend(dst=self._rank, return_early=True) + torchrl_logger.debug("RayTransport: Async send initiated") + + def wait_ack(self) -> None: + """Wait for Ray actor to finish applying weights. + + Raises: + RuntimeError: If no pending future exists (i.e., :meth:`send_weights_async` + was not called before this method). + """ + if self._pending_future is not None: + torchrl_logger.debug( + f"RayTransport: Waiting for ack from rank {self._rank}" + ) + self.ray.get(self._pending_future) + torchrl_logger.debug( + f"RayTransport: Ack received from rank {self._rank}. Waiting for isend to complete." + ) + if self._pending_isend is not None: + for fut in self._pending_isend: + fut.wait() + self._pending_future = None + self._pending_isend = None + else: + raise RuntimeError("No pending future. Did you call send_weights_async?") + + # ======================================================================== + # Receiving Weights (Receiver Side) + # ======================================================================== + + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy | None = None, + ) -> Any | None: + """Receive weights from sender via torch.distributed. + + Args: + timeout: Maximum time to wait for weights (seconds). If None, + blocks until weights are received. + weights: Pre-allocated weight buffer to receive into. + model: The model to apply weights to. + strategy: Strategy for applying weights to the model. + + Returns: + The received weights, or None if timeout expires. + """ + from torchrl.collectors.utils import _cast + + # Use provided weights buffer or fallback to stored one + weights_buffer = weights if weights is not None else self._weights_buffer + if weights_buffer is None: + if model is None: + raise RuntimeError("No model available to receive weights") + if isinstance(model, torch.nn.Module): + weights_buffer = TensorDict.from_module(model) + weights_buffer = weights_buffer.data.apply(_cast, weights_buffer) + else: + weights_buffer = TensorDict(lock=True) + + # Cache the weights buffer for future use + if self._weights_buffer is None: + self._weights_buffer = weights_buffer + + # Receive weights from rank 0 + torchrl_logger.debug( + f"RayTransport: Receiving weights from rank 0: {type(weights_buffer)=}" + ) + + if timeout is None: + # Blocking receive + weights_buffer.irecv(src=0) + else: + # Non-blocking receive with timeout support + futures = weights_buffer.irecv(src=0, return_premature=True) + if futures: + start_time = time.monotonic() + while True: + # Check if all futures are complete + all_complete = all(f.is_completed() for f in futures) + if all_complete: + break + # Check timeout + elapsed = time.monotonic() - start_time + if elapsed >= timeout: + # Timeout expired before receiving all weights + torchrl_logger.debug( + f"RayTransport: Timeout ({timeout}s) expired waiting for weights" + ) + return None + # Small sleep to avoid busy-waiting + time.sleep(0.001) + + # Apply weights to model + if not isinstance(model, torch.nn.Module): + if not weights_buffer.is_empty(): + raise RuntimeError( + f"Cannot cast weights to model type: {type(model)} with weights: {weights_buffer}." + ) + torchrl_logger.debug("RayTransport: No weights to apply to model") + return None + + if strategy is not None: + strategy.apply_weights(model, weights_buffer) + else: + weights_buffer.to_module(model) + + torchrl_logger.debug("RayTransport: Weights applied to model") + return weights_buffer + + # ======================================================================== + # Connection Setup + # ======================================================================== + + def setup_connection_and_weights_on_sender(self) -> None: + """Initialize torch.distributed on sender side for this worker's rank. + + This is called by the scheme after it has created the connection info + Ray actor. The actual ``init_process_group`` happens in the scheme since + it's a collective operation that needs to happen for rank 0. + + Note: + This method exists for interface compatibility but the real work + happens in the scheme's :meth:`_setup_distributed_connection_sender`. + """ + # The scheme handles the collective init_process_group for rank 0. + # This method exists for interface compatibility but the real work + # happens in the scheme's _setup_distributed_connection_sender. + + def setup_connection_and_weights_on_receiver( + self, + *, + worker_idx: int, + strategy: WeightStrategy | None = None, + model: Any | None = None, + weights: Any | None = None, + ) -> Any: + """Join torch.distributed process group and receive initial weights. + + This method: + 1. Retrieves connection info from the shared Ray actor + 2. Initializes torch.distributed process group with rank=worker_idx+1 + 3. Receives weights if model is stateful + + Args: + worker_idx (int): The worker index for this transport. + strategy (WeightStrategy, optional): The weight transmission strategy. + model (nn.Module or compatible, optional): The model to receive weights for. + weights (TensorDict, optional): Pre-allocated buffer for receiving weights. + + Returns: + The received weights (TensorDict) if model is stateful, None otherwise. + """ + if self._dist_initialized: + # Already initialized, just receive weights if stateful + if self._stateful_model: + result = self.receive_weights( + weights=weights, model=model, strategy=strategy + ) + return result[1] if result else None + return None + + self._worker_idx = worker_idx + rank = self._rank + + # Wait for connection info actor to be available + i = 0 + while True: + try: + remote_connection_info = self.ray.get_actor(self._connection_info_name) + except ValueError: + i += 1 + time.sleep(0.1) + if i % 50 == 0: + torchrl_logger.debug( + f"RayTransport: Waiting for connection info (attempt {i}) on {worker_idx=}/{rank=}" + ) + continue + break + + master_addr = self.ray.get(remote_connection_info.get.remote("master_addr")) + master_port = self.ray.get(remote_connection_info.get.remote("master_port")) + world_size = self.ray.get(remote_connection_info.get.remote("world_size")) + stateful_model = self.ray.get( + remote_connection_info.get.remote("stateful_model") + ) + self._stateful_model = stateful_model + + torchrl_logger.debug( + f"RayTransport: Worker {worker_idx} joining process group with " + f"rank={rank}, master_addr={master_addr}, master_port={master_port} -- blocking" + ) + + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + # Initialize process group on receiver + torch.distributed.init_process_group( + backend=self._backend, + rank=rank, + world_size=world_size, + ) + torchrl_logger.debug(f"RayTransport: Worker {worker_idx} joined process group") + self._dist_initialized = True + + # Receive initial weights if model is stateful + if self._stateful_model: + return self.receive_weights(model=model, weights=weights, strategy=strategy) + return None + + +class RayWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for Ray distributed computing. + + This scheme uses torch.distributed to synchronize weights across distributed + workers (Ray actors). The process group is initialized during the first + ``synchronize_weights()`` call, with the sender as rank 0 and workers as + rank ``worker_idx + 1``. + + Each remote collector gets its own transport, following the same pattern + as multiprocess collectors. + + Args: + strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). + Defaults to "tensordict". + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Defaults to "gloo". + """ + + @property + def connection_info_name(self) -> str: + """Get the name of the Ray actor storing connection info. + + Returns a unique name based on model_id to avoid collisions when + multiple schemes are used with different models. + + Returns: + The connection info actor name. + """ + if self._model_id is not None: + return f"connection_info_{self._model_id}" + return "connection_info" + + def __init__( + self, + strategy: Literal["tensordict", "state_dict"] = "tensordict", + backend: str = "gloo", + ): + """Initialize the RayWeightSyncScheme. + + Args: + strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). + Defaults to "tensordict". + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Defaults to "gloo". + """ + super().__init__(strategy) + self._backend = backend + self._dist_initialized = False + self._remote_collectors: list | None = None + self._num_workers: int = 0 + + @property + def model(self) -> Any | None: + """Get the model associated with this scheme. + + Returns: + The model if set, None otherwise. + """ + if self._model_ref is not None: + return self._model_ref() + if self._model_id is not None: + model = _resolve_model(self.context, self._model_id) + if model is None: + if self._model_id == "policy": + torchrl_logger.debug( + f"Creating policy from factory and setting in collector {type(self.context)}" + ) + model = self.context.policy_factory[0]() + self.context.policy = model + torchrl_logger.debug(f"{self.context.policy=}") + else: + raise AttributeError( + f"Model {self._model_id} was `None` in context {self.context}" + ) + self._model_ref = weakref.ref(model) + return model + + @model.setter + def model(self, value: Any): + """Set the model for this scheme. + + Args: + value: The model to set. If None, the setter is a no-op. + """ + if value is None: + return + self._model_ref = weakref.ref(value) + + def create_transport( + self, + *, + remote_actor=None, + worker_idx: int | None = None, + # Legacy parameter name for backwards compatibility + remote_collector=None, + **kwargs, + ) -> TransportBackend: + """Create Ray-based transport for a specific remote actor. + + Args: + remote_actor: The Ray actor handle for the remote collector/transform. + worker_idx: The worker index for this remote actor. + remote_collector: Legacy alias for remote_actor. + **kwargs: Additional transport configuration. + + Returns: + RayTransport configured for this specific remote actor. + """ + # Support legacy parameter name + if remote_actor is None: + remote_actor = remote_collector + + return RayTransport( + remote_actor=remote_actor, + worker_idx=worker_idx, + backend=self._backend, + connection_info_name=self.connection_info_name, + model_id=self._model_id, + ) + + def _init_on_sender_impl( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + This method sets up the torch.distributed connection info and shares it + with all remote collectors so they can join the process group. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing remote_collectors + **kwargs: Alternative to context (remote_collectors, source_model, etc.) + """ + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayWeightSyncScheme") + + # Extract parameters from context or kwargs + if context is not None: + remote_collectors = getattr(context, "remote_collectors", None) + num_workers = getattr(context, "num_workers", None) or getattr( + context, "num_collectors", None + ) + else: + remote_collectors = kwargs.get("remote_collectors") + num_workers = kwargs.get("num_workers") or kwargs.get("num_collectors") + + if remote_collectors is None: + raise ValueError("remote_collectors must be provided via context or kwargs") + if num_workers is None: + num_workers = len(remote_collectors) if remote_collectors else 0 + + # Store model_id and context on scheme + self.model_id = model_id + + # Store remote collectors and num_workers for synchronize_weights + self._remote_collectors = list(remote_collectors) + self._num_workers = int(num_workers) + + # Register each Ray actor with explicit transport kwargs + for worker_idx, remote_collector in enumerate(remote_collectors): + transport = self.create_transport( + remote_actor=remote_collector, + worker_idx=worker_idx, + ) + self._register_worker_sender( + worker_idx=worker_idx, + transport=transport, + ) + + # Set context with weak reference to avoid circular refs + if context is not None: + self.context = context + + # Store source model reference if provided for automatic weight extraction + model = kwargs.get("model") + if model is not None: + self.model = model + + # Note: Distributed connection setup is deferred to synchronize_weights + # because _receiver_schemes on workers won't exist until register_scheme_receiver is called + + def _init_on_receiver_impl( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (typically the remote collector) + **kwargs: Optional parameters (worker_idx, model, etc.) + """ + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayWeightSyncScheme") + + # Store model_id and context on scheme + self.model_id = model_id + self.context = context + + # Extract worker_idx from context or kwargs + if context is not None: + worker_idx = getattr(context, "worker_idx", None) + else: + worker_idx = kwargs.get("worker_idx") + + self._worker_idx = worker_idx + + # Resolve the target model on this worker + model = kwargs.get("model") + if model is not None: + self.model = model + # get the weights to possibly instantiate a copy of the model (policy factory with multi-collector) + self.weights # noqa + + # Create and register transport for receiver side + # Note: create_transport returns TransportBackend but we know it's RayTransport + transport = self.create_transport( + remote_actor=None, # Receiver doesn't need actor handle + worker_idx=worker_idx, + ) + if isinstance(transport, RayTransport): + transport.set_model(model) + self._register_transport_receiver(transport=transport) + + def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: + """Set up torch.distributed connection info and share with remote collectors. + + This method: + 1. Gets master address and finds an available port + 2. Stores connection info in Ray's object store as a named actor + 3. Initializes torch.distributed process group with rank=0 + + Args: + timeout: Maximum time in seconds to wait for workers to be ready. + Default is 300 seconds (5 minutes). + """ + if self._dist_initialized: + return + + if self._remote_collectors is None or self._num_workers == 0: + raise RuntimeError( + "_setup_distributed_connection() requires remote_collectors to be set" + ) + + # Get master address (hostname/IP) + hostname = socket.gethostname() + try: + master_addr = socket.gethostbyname(hostname) + except socket.gaierror: + master_addr = "127.0.0.1" + + # Find an available port + master_port = self._find_free_port() + world_size = self._num_workers + 1 # +1 for the sender (rank 0) + + torchrl_logger.debug( + f"RayWeightSyncScheme: Setting up distributed connection with " + f"master_addr={master_addr}, master_port={master_port}, world_size={world_size}" + ) + + try: + self.weights + stateful_model = True + except (AttributeError, RuntimeError, ValueError): + stateful_model = False + self._stateful_model = stateful_model + + # Connection info to share with workers via named Ray actor + RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options( + name=self.connection_info_name + ) + self._connection_info_actor = RemoteConnectionInfo.remote( + master_addr=master_addr, + master_port=master_port, + world_size=world_size, + stateful_model=stateful_model, + ) + + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + # Initialize process group on sender (rank 0) + # Note: Workers will call init_process_group in their transport's + # setup_connection_and_weights_on_receiver. The init_process_group is + # a collective operation, so all ranks must call it together. + torchrl_logger.debug( + "RayWeightSyncScheme: Initializing process group on sender (rank 0) -- blocking." + ) + torch.distributed.init_process_group( + backend=self._backend, + rank=0, + world_size=world_size, + timeout=_DIST_TIMEOUT, + ) + self._dist_initialized = True + + torchrl_logger.debug( + "RayWeightSyncScheme: Distributed connection setup complete -- all workers at rendez-vous" + ) + + def _setup_connection_and_weights_on_sender_impl( + self, + *, + worker_idx: int | None = None, + weights: Any | None = None, + ) -> None: + """Set up distributed connection and send initial weights to all workers. + + This method: + 1. Sets up torch.distributed process group (waits for workers if needed) + 2. Sends initial weights to all workers via their transports + + The distributed setup is done here (not in ``init_on_sender``) because + workers need to have ``register_scheme_receiver`` called first. + + Args: + worker_idx (int, optional): Not used in this implementation. + weights (optional): Not used in this implementation (weights are + extracted from the model). + """ + # Set up distributed connection (with wait for workers to be ready) + if not self._dist_initialized: + torchrl_logger.debug( + "RayWeightSyncScheme: Setting up distributed connection (sender)" + ) + self._setup_distributed_connection_sender() + + # Send the initial weights + if self._stateful_model: + self._send_weights_distributed() + + def _send_weights_distributed(self) -> None: + """Send weights to all workers via torch.distributed. + + Raises: + RuntimeError: If no weights are available to send. + """ + # Extract weights from model + weights = self.weights + if weights is None: + raise RuntimeError("No weights available to send") + + # Send weights to each worker (ranks 1 to num_workers) + futures = [] + for worker_idx in range(self._num_workers): + rank = worker_idx + 1 + torchrl_logger.debug(f"RayWeightSyncScheme: Sending weights to rank {rank}") + futures.extend(weights.isend(dst=rank, return_early=True)) + # Wait for all sends to complete + for future in futures: + future.wait() + + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None + ) -> None: + """Join torch.distributed process group and receive initial weights. + + Delegates to the transport's :meth:`~RayTransport.setup_connection_and_weights_on_receiver`. + + Args: + worker_idx (int, optional): The worker index. If None, uses the stored + ``_worker_idx`` or defaults to 0. + """ + if worker_idx is None: + worker_idx = self._worker_idx + if worker_idx is None: + worker_idx = 0 # Default to worker 0 + + transport = self.receiver_transport + if transport is not None: + # Transport handles joining process group and receiving weights + transport.setup_connection_and_weights_on_receiver( + worker_idx=worker_idx, + model=self.model, + weights=self.weights, + strategy=self._strategy, + ) + self._dist_initialized = True + + @staticmethod + def _find_free_port() -> int: + """Find a free port on the local machine. + + Returns: + int: An available port number. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + +class RayModuleTransformScheme(RayWeightSyncScheme): + """Weight synchronization for RayModuleTransform. + + This scheme uses torch.distributed to synchronize weights between + a trainer/collector and a RayModuleTransform actor. The sender is rank 0, + the transform's actor is rank 1. + + This enables updating the weights of a module running inside a RayModuleTransform + from a parent collector or training loop. + + Args: + strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). + Default is "tensordict". + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Default is "gloo". + + Example: + >>> # Create scheme and transform + >>> scheme = RayModuleTransformScheme() + >>> transform = RayModuleTransform(module=my_module, weight_sync_scheme=scheme) + >>> + >>> # Create env with transform + >>> env = TransformedEnv(base_env, transform) + >>> + >>> # Pass scheme to parent collector + >>> collector = SomeCollector( + ... env, policy, + ... weight_sync_schemes={"transform_module": scheme} + ... ) + >>> + >>> # Update weights + >>> collector.update_policy_weights_(model_id="transform_module") + """ + + def __init__( + self, + strategy: Literal["tensordict", "state_dict"] = "tensordict", + backend: str = "gloo", + ): + """Initialize the RayModuleTransformScheme. + + Args: + strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). + Defaults to "tensordict". + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Defaults to "gloo". + """ + super().__init__(strategy, backend) + self._ray_transform = None + + def _set_transform(self, ray_transform) -> None: + """Store reference to the RayModuleTransform. + + Called by RayModuleTransform when the scheme is passed to it. + + Args: + ray_transform: The RayModuleTransform instance. + """ + self._ray_transform = ray_transform + + def _init_on_sender_impl( + self, + model_id: str | None = None, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Uses the stored transform reference (set via _set_transform) to + create transport for the transform's actor. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (typically the collector) + **kwargs: Optional parameters (ray_transform, model, etc.) + """ + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayModuleTransformScheme") + + # Get transform reference - either stored via _set_transform or from kwargs + ray_transform = self._ray_transform + if ray_transform is None: + ray_transform = kwargs.get("ray_transform") + if ray_transform is None: + raise ValueError( + "ray_transform must be set via _set_transform() or provided in kwargs. " + "Pass the scheme to RayModuleTransform constructor to set it automatically." + ) + + # Store model_id + self.model_id = model_id + + # Single worker (the transform's actor) + self._num_workers = 1 + + # Create transport for the transform's actor + # The actor handle is ray_transform._actor + transport = self.create_transport( + remote_actor=ray_transform._actor, + worker_idx=0, + ) + self._register_worker_sender( + worker_idx=0, + transport=transport, + ) + + # Set context if provided + if context is not None: + self.context = context + + # Store source model reference if provided for automatic weight extraction + model = kwargs.get("model") + if model is not None: + self.model = model + + def _init_on_receiver_impl( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the transform's actor (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: The ModuleTransform instance (the actor's underlying class) + **kwargs: Optional parameters (worker_idx, model, etc.) + """ + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayModuleTransformScheme") + + # Store model_id and context + self.model_id = model_id + self.context = context + + # Single transform actor is always worker_idx=0 + self._worker_idx = kwargs.get("worker_idx", 0) + + # Resolve the target model from context (ModuleTransform has a .module attribute) + model = kwargs.get("model") + if model is None and context is not None: + model = getattr(context, "module", None) + if model is not None: + self.model = model + + # Create and register transport for receiver side + # Note: create_transport returns TransportBackend but we know it's RayTransport + transport = self.create_transport( + remote_actor=None, + worker_idx=self._worker_idx, + ) + if isinstance(transport, RayTransport): + transport.set_model(model) + self._register_transport_receiver(transport=transport) + + def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: + """Set up torch.distributed for the single transform actor. + + Overrides parent to work with a single RayModuleTransform instead of + multiple remote collectors. + + Args: + timeout (float): Maximum time in seconds to wait for connection setup. + Defaults to 300.0 (5 minutes). + + Raises: + RuntimeError: If ``ray_transform`` is not set. + """ + if self._dist_initialized: + return + + if self._ray_transform is None: + raise RuntimeError( + "_setup_distributed_connection() requires ray_transform to be set. " + "Did you pass the scheme to RayModuleTransform?" + ) + + # Get master address (hostname/IP) + hostname = socket.gethostname() + try: + master_addr = socket.gethostbyname(hostname) + except socket.gaierror: + master_addr = "127.0.0.1" + + # Find an available port + master_port = self._find_free_port() + world_size = 2 # Sender (rank 0) + Transform (rank 1) + + torchrl_logger.debug( + f"RayModuleTransformScheme: Setting up distributed connection with " + f"master_addr={master_addr}, master_port={master_port}, world_size={world_size}" + ) + + # Check if model has weights + try: + w = self.weights + stateful_model = w is not None + except (AttributeError, RuntimeError, ValueError): + stateful_model = False + self._stateful_model = stateful_model + + # Connection info to share with the transform's actor + RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options( + name=self.connection_info_name + ) + self._connection_info_actor = RemoteConnectionInfo.remote( + master_addr=master_addr, + master_port=master_port, + world_size=world_size, + stateful_model=stateful_model, + ) + + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + # Now initialize process group on sender (rank 0) + # The receiver is concurrently joining via the Ray call above + torchrl_logger.debug( + "RayModuleTransformScheme: Initializing process group on sender (rank 0) -- blocking." + ) + torch.distributed.init_process_group( + backend=self._backend, + rank=0, + world_size=world_size, + timeout=_DIST_TIMEOUT, + ) + self._dist_initialized = True + + torchrl_logger.debug( + "RayModuleTransformScheme: Distributed connection setup complete" + ) + + def _setup_connection_and_weights_on_sender_impl( + self, + *, + worker_idx: int | None = None, + weights: Any | None = None, + ) -> None: + """Set up distributed connection and send initial weights. + + Args: + worker_idx (int, optional): The worker index. Not used for + RayModuleTransformScheme as there is only one transform actor. + weights (optional): Pre-extracted weights to send. If None, weights + are extracted from the model. + """ + torchrl_logger.debug( + "RayModuleTransformScheme: Signaling receiver to join process group" + ) + receiver_future = self._ray_transform._actor._init_weight_sync_scheme.remote( + scheme=self, model_id=self.model_id + ) + + if not self._dist_initialized: + torchrl_logger.debug( + "RayModuleTransformScheme: Setting up distributed connection (sender)" + ) + self._setup_distributed_connection_sender() + + if self._stateful_model: + torchrl_logger.debug( + "RayModuleTransformScheme: Sending first batch of weights (sender)" + ) + self._send_weights_distributed(weights=weights) + + torchrl_logger.debug("Waiting for receiver to join process group...") + self.ray.get(receiver_future) + + def _send_weights_distributed(self, weights: Any | None = None) -> None: + """Send weights to the transform actor via torch.distributed. + + Args: + weights (optional): Pre-extracted weights to send. If None, weights + are extracted from the model via :attr:`weights`. + + Raises: + RuntimeError: If no weights are available to send. + """ + if weights is None: + weights = self.weights + if weights is None: + raise RuntimeError("No weights available to send") + + # Send weights to the transform (rank 1) + torchrl_logger.debug("RayModuleTransformScheme: Sending weights to rank 1") + futures = weights.isend(dst=1, return_early=True) + for future in futures: + future.wait() diff --git a/torchrl/weight_update/_rpc.py b/torchrl/weight_update/_rpc.py new file mode 100644 index 00000000000..7bc829599c5 --- /dev/null +++ b/torchrl/weight_update/_rpc.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +import time +import weakref +from typing import Any + +from torchrl._utils import logger as torchrl_logger + +from torchrl.weight_update.utils import _resolve_model +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightStrategy, + WeightSyncScheme, +) + + +class RPCWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for torch.distributed.rpc. + + This scheme uses RPC calls to synchronize weights across distributed + workers. Each remote collector gets its own transport, following the + same pattern as multiprocess collectors. + """ + + def _init_on_sender_impl( + self, + *, + model_id: str, + context: Any = None, + num_workers: int, + ) -> None: + # Store model_id and context on scheme + self.model_id = model_id + if context is not None: + self.context = context + else: + raise RuntimeError(f"Expected a context for {type(self).__name__}.") + collector_infos = getattr(self.context, "collector_infos", None) + collector_rrefs = getattr(self.context, "collector_rrefs", None) + collector_class = getattr(self.context, "collector_class", None) + if ( + collector_infos is None + or collector_rrefs is None + or collector_class is None + ): + raise RuntimeError( + "RPCWeightSyncScheme requires a context with the following attributes: " + "(context.collector_infos, context.collector_rrefs, context.collector_class)" + ) + + # Create transports for each remote collector + # worker_rank is i+1 because rank 0 is the main/trainer process + for i in range(num_workers): + worker_rank = i + 1 + transport = self.create_transport( + collector_info=collector_infos[i], + collector_rref=collector_rrefs[i], + collector_class=collector_class, + worker_rank=worker_rank, + ) + self._register_worker_sender(worker_idx=i, transport=transport) + + def _init_on_receiver_impl( + self, *, model_id: str, context: Any = None, worker_idx: int | None = None + ) -> None: + """Initialize scheme on the worker (receiver) side. + + Expected kwargs (as provided by collectors): + - model_id: str # e.g. "policy" + - context: Any # collector / inner collector + - worker_idx: int | None # worker index (optional) + """ + if context is None: + raise ValueError( + "RPCWeightSyncScheme.init_on_receiver requires a 'context' " + "providing access to the model to be synchronized." + ) + + # Store model_id and context on scheme + self.model_id = model_id + self.worker_idx = worker_idx + self.context = context + # Access weights to set up missing elements + self.weights # noqa + + self._receiver_transport = RPCTransport(worker_rank=worker_idx) + + @property + def model(self) -> Any | None: + if self._model_ref is not None: + return self._model_ref() + if self._model_id is not None: + model = _resolve_model(self.context, self._model_id) + if model is None: + if self._model_id == "policy": + torchrl_logger.debug( + f"Creating policy from factory and setting in collector {type(self.context)}" + ) + model = self.context.policy_factory[0]() + self.context.policy = model + torchrl_logger.debug(f"{self.context.policy=}") + else: + raise AttributeError( + f"Model {self._model_id} was `None` in context {self.context}" + ) + self._model_ref = weakref.ref(model) + return model + + @model.setter + def model(self, value: Any): + if value is None: + return + self._model_ref = weakref.ref(value) + + def create_transport( + self, + *, + collector_info=None, + collector_rref=None, + collector_class=None, + worker_rank=None, + **kwargs, + ) -> TransportBackend: + """Create RPC-based transport for a specific remote collector. + + Args: + collector_info: RPC worker info for the remote collector. + collector_rref: RPC remote reference to the collector. + collector_class: Class of the remote collector. + worker_rank: The torch.distributed rank of the remote worker. + **kwargs: Additional transport configuration. + + Returns: + RPCTransport configured for this specific remote collector. + """ + return RPCTransport( + collector_info=collector_info, + collector_rref=collector_rref, + collector_class=collector_class, + worker_rank=worker_rank, + ) + + +class RPCTransport: + """RPC transport for communicating with a single RPC remote collector. + + This transport handles weight updates for ONE specific remote collector via + torch.distributed primitives (send/recv) with RPC used for signaling. + Multiple transports are created for multiple collectors, following the same + pattern as the DistributedDataCollector. + """ + + def __init__( + self, + collector_info=None, + collector_rref=None, + collector_class=None, + worker_rank=None, + ): + self._collector_info = collector_info + self._collector_rref = collector_rref + self._collector_class = collector_class + self._worker_rank = worker_rank # The torch.distributed rank of this worker + self._pending_future = None + self._pending_send = None + + def send_weights(self, weights: Any) -> None: + """Send weights to the remote collector using torch.distributed. + + Uses torch.distributed.send() for the actual weight transfer and RPC + for signaling the remote collector to receive. + + Order is critical to avoid deadlock: + 1. Signal receiver via RPC to start recv() (non-blocking) + 2. Send weights via torch.distributed (blocking until recv completes) + """ + if self._collector_info is None or self._collector_rref is None: + return + if self._worker_rank is None: + raise RuntimeError("worker_rank must be set for RPC transport") + + # Step 1: Signal the remote collector via RPC to start receiving (async) + # Use rref.rpc_async() to properly call the instance method on the remote object + future = self._collector_rref.rpc_async()._receive_weights_scheme() + + # Step 2: Send weights via torch.distributed (blocks until receiver calls recv()) + weights.send(self._worker_rank) + + # Step 3: Wait for RPC to complete (receiver has applied weights) + future.wait() + + def send_weights_async(self, weights: Any) -> None: + """Send weights to remote collector asynchronously. + + Uses torch.distributed.isend() for the actual weight transfer and RPC + for signaling. Use wait_ack() to wait for completion. + + Order is critical to avoid deadlock: + 1. Signal receiver via RPC to start recv() (non-blocking) + 2. Send weights via torch.distributed.isend() (non-blocking) + 3. wait_ack() waits for both to complete + """ + if self._collector_info is None or self._collector_rref is None: + return + if self._worker_rank is None: + raise RuntimeError("worker_rank must be set for RPC transport") + + # Step 1: Signal the remote collector via RPC to start receiving (async) + # Use rref.rpc_async() to properly call the instance method on the remote object + self._pending_future = ( + self._collector_rref.rpc_async()._receive_weights_scheme() + ) + + # Step 2: Send weights asynchronously via torch.distributed + # Store the Work handle for wait_ack() + weights.isend(self._worker_rank) + + def wait_ack(self) -> None: + """Wait for both the RPC call and the distributed send to complete.""" + # Wait for the RPC call to complete + if hasattr(self, "_pending_future") and self._pending_future is not None: + self._pending_future.wait() + del self._pending_future + + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy | None = None, + ) -> Any | None: + """Receive weights from sender using torch.distributed. + + Args: + timeout: Maximum time to wait for weights (seconds). If None, + blocks until weights are received. + weights: Pre-allocated weight buffer to receive into. + model: The model to apply weights to. + strategy: Strategy for applying weights to the model. + + Returns: + The received weights, or None if timeout expires. + """ + if weights is None: + return None + + if timeout is None: + # Blocking receive + weights.recv(0) + else: + # Non-blocking receive with timeout support + futures = weights.irecv(src=0, return_premature=True) + if futures: + start_time = time.monotonic() + while True: + # Check if all futures are complete + all_complete = all(f.is_completed() for f in futures) + if all_complete: + break + # Check timeout + elapsed = time.monotonic() - start_time + if elapsed >= timeout: + # Timeout expired before receiving all weights + return None + # Small sleep to avoid busy-waiting + time.sleep(0.001) + + # Apply the received weights to the model + if model is not None and strategy is not None: + strategy.apply_weights(model, weights) + + return weights + + def setup_connection_and_weights_on_sender(self) -> None: + """No-op for RPCTransport - weights are sent via send_weights().""" + + def setup_connection_and_weights_on_receiver( + self, + *, + worker_idx: int, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy | None = None, + ) -> Any: + """No-op for RPCTransport - weights are received via receive().""" + return None diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py new file mode 100644 index 00000000000..d7b5c50898b --- /dev/null +++ b/torchrl/weight_update/_shared.py @@ -0,0 +1,907 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import torch +import torch.distributed + +from tensordict import TensorDict, TensorDictBase + +from torch import multiprocessing as mp, nn + +from torchrl._utils import logger as torchrl_logger + +from torchrl.weight_update.utils import _resolve_model +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightStrategy, + WeightSyncScheme, +) + + +class SharedMemTransport: + """Shared memory transport for in-place weight updates. + + This transport uses queue-based buffer distribution for initialization, then + updates shared memory tensors directly for subsequent weight updates. + Workers automatically see weight updates without explicit communication. + + Initialization flow: + - Shared memory buffers are created and sent to workers via per-worker queues + - Workers receive the buffer reference and apply weights to their models + - Subsequent updates are pure in-place shared memory (zero-copy) + + Both CPU and CUDA tensors maintain shared references when sent through mp.Queue. + + """ + + def __init__(self): + self._params_map = None # a dict[worker_idx, TensorDictBase] map + self._weight_queues = ( + None # Dict of per-worker queues for distributing shared weights + ) + self._unique_weights = None + + @property + def unique_weights(self) -> list[TensorDictBase]: + """Get the unique weights. + + Returns: + The unique weights. + """ + if self._unique_weights is None: + raise RuntimeError("Unique weights not set. Call register_weights() first.") + return self._unique_weights + + def register_weights( + self, params_map: dict[int, mp.Queue], init_queues: dict[int, mp.Queue] + ) -> None: + """Initialize per-worker queues for shared memory buffer distribution.""" + from torchrl.collectors.utils import _cast + + self._weight_queues = init_queues + self._params_map = params_map + # Create set of the unique weights + self._unique_weights = [] + for weights in params_map.values(): + if id(weights) in [id(w) for w in self._unique_weights]: + continue + weights = weights.data.apply(_cast, weights) + self._unique_weights.append(weights) + + def setup_connection_and_weights_on_sender(self) -> None: + """Send shared memory buffer reference to workers via their per-worker queues. + + Both CPU and CUDA tensors maintain shared references through queues. + Each worker reads from its own dedicated queue, to avoid race conditions. + + """ + torchrl_logger.debug("Sending shared memory weights to workers.") + if self._weight_queues is None: + raise RuntimeError("Queues not created yet. Call init_on_sender() first.") + + for worker_idx, queue in self._weight_queues.items(): + weights = self._params_map[worker_idx] + queue.put(weights) + + def setup_connection_and_weights_on_receiver( + self, + *, + worker_idx: int | None = None, + weights: Any = None, + model: Any = None, + strategy: Any = None, + timeout: float = 10.0, + ) -> TensorDictBase: + """Receive shared memory buffer reference from sender via their per-worker queues. + + Each worker reads from its own dedicated queue, to avoid race conditions. + + Args: + worker_idx: The worker index. + weights: Ignored (weights come from queue). + model: Ignored. + strategy: Ignored. + timeout: Timeout for reading from queue. + + Returns: + The shared memory weights TensorDict. + """ + torchrl_logger.debug( + f"Receiving shared memory weights from worker {worker_idx}." + ) + if self._weight_queues is None: + raise RuntimeError("Queues not created yet. Call init_on_sender() first.") + + if worker_idx not in self._weight_queues: + raise RuntimeError(f"Worker {worker_idx} not registered in queues.") + + # Read from dedicated queue for this worker + worker_queue = self._weight_queues[worker_idx] + received_weights = worker_queue.get(timeout=timeout) + return received_weights + + def send_weights(self, weights: Any) -> None: + """Update weights in-place in shared memory. + + Args: + weights: New weights to send. Can be a TensorDictBase or dict. + + Raises: + ValueError: If weights type is unsupported. + """ + # Update shared memory in-place (workers see this automatically) + if isinstance(weights, dict): + weights = TensorDict(weights) + if not isinstance(weights, TensorDictBase): + raise ValueError(f"Unsupported weights type: {type(weights)=}") + # Unflatten if needed to match shared buffer structure + weights_to_update = weights + if any("." in key for key in weights.keys()): + weights_to_update = weights.unflatten_keys(".") + + # Detach weights to allow in-place updates (gradients are not needed for weight sync) + weights_to_update = weights_to_update.detach() + + if self._unique_weights is None: + raise RuntimeError("Unique weights not set. Call register_weights() first.") + for buffer in self._unique_weights: + if buffer.requires_grad: + raise RuntimeError( + "Gradients should not be required for shared memory buffers." + ) + if weights_to_update.requires_grad: + raise RuntimeError("Gradients should not be required for weights.") + buffer.update_(weights_to_update, non_blocking=True) + if torch.cuda.is_available(): + torch.cuda.synchronize() + + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> Any | None: + """Apply shared memory weights to the model. + + For shared memory, weights are already available (passed via the weights arg). + This method applies them to the model, matching the pattern of other transports. + + Args: + timeout: Ignored (shared memory access is instant). + weights: The shared memory buffer containing current weights. + model: The model to apply weights to. + strategy: Strategy for applying weights. + + Returns: + The applied weights, or None if not applied. + """ + # Apply weights to model if provided (same pattern as other transports) + if model is not None and strategy is not None and weights is not None: + torchrl_logger.debug( + f"Applying shared memory weights {type(weights)=} to model {model} with {strategy=}." + ) + strategy.apply_weights(model, weights) + return weights + torchrl_logger.debug( + f"Not applying shared memory weights {type(weights)=} to model {model} with {strategy=}." + ) + return None + + def send_ack(self, message: str = "updated") -> None: + """No-op for shared memory - no acknowledgment needed.""" + + +class SharedMemWeightSyncScheme(WeightSyncScheme): + """Weight synchronization using shared memory. + + This scheme uses shared memory for in-place weight updates. Workers + automatically see weight updates without explicit message passing. + + A background thread on the receiver side listens for "receive" instructions + from the sender. When an instruction arrives, the thread applies the current + shared memory weights to the model and sends an acknowledgment. + + Args: + strategy: The weight transmission strategy (default: "tensordict"). + sync: If True (default), send() blocks until receiver acknowledges. + If False, send() returns immediately (use send_async/wait_async). + + Example: + >>> # Basic usage + >>> scheme = SharedMemWeightSyncScheme() + >>> # Weights are initialized via init_on_sender() + """ + + def __init__( + self, + strategy: str = "tensordict", + sync: bool = True, + ): + super().__init__(strategy) + self.sync = sync + # Create a single shared transport for all workers + self.shared_transport = SharedMemTransport() + + # Create per-worker queues to avoid race conditions + # Each worker gets its own queue for weight initialization + self._weight_init_queues = {} # worker_idx -> Queue + + # Instruction queues: sender puts "receive" instruction, receiver's background thread reads + self._instruction_queues: dict[int, mp.Queue] = {} # worker_idx -> Queue + + # Acknowledgment queues: receiver puts "updated" ack, sender reads for sync mode + self._ack_queues: dict[int, mp.Queue] = {} # worker_idx -> Queue + + # Receiver's instruction queue reference (set during init_on_receiver) + self._receiver_instruction_queue: mp.Queue | None = None + self._receiver_ack_queue: mp.Queue | None = None + + def _init_on_sender_impl( + self, + *, + model_id: str | None = None, + context: Any = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + ) -> None: + """Initialize on the main process (sender side). + + We create a map dict[worker_idx, weights_on_device]. Each model will be assigned a device. If two workers + share the same device, the entry in the dict will be the same. + To do this, we need to know the number of workers, their assigned device, and have access to the parameters. + If a context is provided, we read the devices from it. If not, the dict[worker_idx, device] map must be provided + explicitly. + + In some cases, the policy on the worker side will be on multiple devices which may or may not be the same as the + devices on the main process. In this case, init_on_sender() needs to receive a mapping function as argument that + will take as input the worker_idx and the parameters and return a new set of parameters on the desired devices. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing device_to_workers mapping and model access + weights: Pre-extracted weights as TensorDict (for policy factory usage) + model: Model to extract weights from + params_map: Direct mapping of worker_idx to weights on device (most explicit) + devices: List of devices for each worker + device_map_fn: Custom function to map worker_idx and weights to device-specific weights + num_workers: Number of workers (required with device_map_fn) + + Examples: + Simple usage with collector context (stateful policy): + + >>> policy = make_stateful_policy() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy=policy, + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.init_on_sender() is called automatically by collector + + Pre-initialized usage (policy factory): + + >>> policy_on_main = make_stateful_policy() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> # Must initialize before collector creation when using policy_factory + >>> scheme.init_on_sender( + ... model_id="policy", + ... weights=TensorDict.from_module(policy_on_main), + ... devices=[torch.device("cuda:0"), torch.device("cuda:1")], + ... num_workers=2, + ... ) + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy_factory=[make_stateful_policy], + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + + Direct params_map usage (advanced): + + >>> weights_cpu = TensorDict.from_module(policy).share_memory_() + >>> weights_cuda = weights_cpu.to("cuda").share_memory_() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> scheme.init_on_sender( + ... model_id="policy", + ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, + ... ) + """ + # Plan: the goal of this init is to obtain a map dict[worker_idx, weights_on_device] that we can use to init + # the weights on the workers. + # Scenarios: + # - Easiest scenario: the user provides the map directly (params_map). Nothing to do other than creating + # the transport and registering the workers etc. + # - The user provides a model or its params and a device map. We need to create the map from the params + # explicitly. + # - The user provides a context (e.g. a Collector) and a model_id. Same as above, except that we need + # to collect the model from the context. + params_map = self._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) + + # Create per-worker queues if not already created + # Collect all unique worker indices + all_workers = list(params_map.keys()) + + for worker_idx in all_workers: + if worker_idx not in self._weight_init_queues: + self._weight_init_queues[worker_idx] = mp.Queue() + # Create instruction queues for background receiver + if worker_idx not in self._instruction_queues: + self._instruction_queues[worker_idx] = mp.Queue() + # Create ack queues for synchronous mode + if worker_idx not in self._ack_queues: + self._ack_queues[worker_idx] = mp.Queue() + + # Set worker info in transport + self.shared_transport.register_weights(params_map, self._weight_init_queues) + + # Store model_id and context on scheme + self.model_id = model_id + if context is not None: + self.context = context + + def _get_params_map( + self, + context: Any = None, + model_id: str | None = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + ): + """Get the params_map for init_on_sender().""" + # Import _cast locally to avoid circular imports + from torchrl.collectors.utils import _cast + + if params_map is not None: + # Sanity check: params_map must be a dict[int, TensorDictBase] + # All other args must be None + if ( + not isinstance(params_map, dict) + or not all(isinstance(v, int) for v in params_map.keys()) + or not all(isinstance(v, TensorDictBase) for v in params_map.values()) + ): + raise ValueError("params_map must be a dict[int, TensorDictBase]") + if model_id is not None or weights is not None or model is not None: + raise ValueError( + "model_id, weights, and model cannot be provided if params_map is provided" + ) + if context is not None: + raise ValueError("context cannot be provided if params_map is provided") + if devices is not None: + raise ValueError("devices cannot be provided if params_map is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if params_map is provided" + ) + if num_workers is not None: + raise ValueError( + "num_workers cannot be provided if params_map is provided" + ) + return params_map + elif context is not None: + if devices is not None: + raise ValueError("devices cannot be provided if context is provided") + # Sanity check: model_id must be provided if context is provided + # All other args must be None + if model_id is None: + raise ValueError("model_id must be provided if context is provided") + if model is not None: + raise ValueError("model cannot be provided if context is provided") + if weights is not None: + raise ValueError("weights cannot be provided if context is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if context is provided" + ) + # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later + devices = context.policy_device + if num_workers is not None and num_workers != len(devices): + raise ValueError( + "num_workers cannot be provided if context is provided" + ) + # Get the weights + model = _resolve_model(context, model_id) + if model is None: + if model_id == "policy": + # we need to get a copy of the weights from the factory + model = context.policy_factory[0]() + weights = TensorDict.from_module(model) + elif model is not None: + if weights is not None: + raise ValueError("weights cannot be provided if model is provided") + weights = TensorDict.from_module(model) + if weights is not None: + weights = weights.data.apply(_cast, weights) + # To make the map, we need the list of devices, or the map fn + if devices is not None: + # Get the unique devices + devices_set = set(devices) + weights_devices = ( + {p.device for p in weights.values(True, True)} + if weights is not None + else set() + ) + if len(weights_devices) == 1: + weights_device = weights_devices.pop() + else: + weights_device = None + + # Create device map with proper Parameter handling using _cast + # _cast ensures Parameters stay as Parameters (with requires_grad=False) + device_map = {} + for d in devices_set: + if d != weights_device: + # Move to device and apply _cast to preserve Parameter/Buffer types + weights_on_device = weights.to(d) + weights_on_device = weights_on_device.apply(_cast, weights) + device_map[d] = weights_on_device + else: + # Already on correct device, just apply _cast + device_map[d] = weights.apply(_cast, weights) + + # Create the map + params_map = { + worker_idx: device_map[device] + for worker_idx, device in enumerate(devices) + } + return params_map + if device_map_fn is not None: + return { + worker_idx: device_map_fn(worker_idx, weights) + for worker_idx in range(num_workers) + } + raise ValueError( + "Either params_map, model_id + context or model/weights + devices must be provided." + ) + + def _init_on_receiver_impl( + self, + *, + model_id: str | None = None, + context: Any = None, + model: Any = None, + worker_idx: int | None = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Reads from the worker's dedicated queue to receive shared weights, + then registers them in the transport. The receiver then applies these weights + to the model. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing model and worker_idx + model: Model being synchronized + worker_idx: Worker index + **kwargs: Alternative to context (model, worker_idx, timeout, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + if model_id is None: + raise ValueError("model_id is required when context is provided") + if hasattr(context, "get_model"): + model = context.get_model(model_id) + elif model is None: + model = _resolve_model(context, model_id) + worker_idx = getattr(context, "worker_idx", worker_idx) + + # Store on scheme directly + self.model_id = model_id + if context is not None: + self.context = context + + # Register the model + if model is not None: + self.model = model + + # Store worker_idx for synchronize_weights + self.worker_idx = worker_idx + + # Store references to instruction and ack queues for this worker + # These are created by init_on_sender and passed via pickle + if worker_idx is not None: + if worker_idx in self._instruction_queues: + self._receiver_instruction_queue = self._instruction_queues[worker_idx] + if worker_idx in self._ack_queues: + self._receiver_ack_queue = self._ack_queues[worker_idx] + + self.create_transport() + + def _wait_for_instruction(self, timeout: float | None = None) -> str | None: + """Block until an instruction arrives from the sender. + + Args: + timeout: Maximum time to wait for instruction (seconds). + None means block indefinitely. + + Returns: + The instruction string (e.g., "receive", "stop"), or None if + stop event is set or timeout expires. + """ + if self._receiver_instruction_queue is None: + raise RuntimeError( + "Instruction queue not set. init_on_receiver() must be called first." + ) + + try: + # Check stop event periodically while waiting + while True: + if self._stop_event is not None and self._stop_event.is_set(): + return None + try: + # Use short timeout to allow checking stop event + instruction = self._receiver_instruction_queue.get(timeout=0.1) + return instruction + except Exception: + # Queue.Empty - continue waiting + if timeout is not None: + timeout -= 0.1 + if timeout <= 0: + return None + except Exception as e: + torchrl_logger.warning(f"Error waiting for instruction: {e}") + return None + + def _send_instruction( + self, + instruction: str = "receive", + worker_ids: int | list[int] | None = None, + ) -> None: + """Send instruction to receiver(s) to trigger weight reception. + + Args: + instruction: The instruction to send (default: "receive"). + worker_ids: Which workers to send to (None = all workers). + """ + if not self._instruction_queues: + raise RuntimeError( + "Instruction queues not created. init_on_sender() must be called first." + ) + + if worker_ids is None: + target_workers = list(self._instruction_queues.keys()) + elif isinstance(worker_ids, int): + target_workers = [worker_ids] + else: + target_workers = list(worker_ids) + + for worker_idx in target_workers: + if worker_idx not in self._instruction_queues: + raise ValueError(f"Worker {worker_idx} not registered") + self._instruction_queues[worker_idx].put(instruction) + + def _send_ack(self, message: str = "updated") -> None: + """Send acknowledgment back to sender after receiving weights. + + Args: + message: The acknowledgment message (default: "updated"). + """ + if self._receiver_ack_queue is not None: + self._receiver_ack_queue.put(message) + + def _wait_for_ack( + self, + worker_ids: int | list[int] | None = None, + timeout: float | None = None, + ) -> None: + """Wait for acknowledgment from receiver(s). + + Args: + worker_ids: Which workers to wait for (None = all workers). + timeout: Maximum time to wait (seconds). None means block indefinitely. + """ + if not self._ack_queues: + return # No ack queues, nothing to wait for + + if worker_ids is None: + target_workers = list(self._ack_queues.keys()) + elif isinstance(worker_ids, int): + target_workers = [worker_ids] + else: + target_workers = list(worker_ids) + + for worker_idx in target_workers: + if worker_idx not in self._ack_queues: + raise ValueError(f"Worker {worker_idx} not registered") + try: + ack = self._ack_queues[worker_idx].get(timeout=timeout) + if ack != "updated": + torchrl_logger.warning( + f"Unexpected ack from worker {worker_idx}: {ack}" + ) + except Exception as e: + torchrl_logger.warning( + f"Timeout waiting for ack from worker {worker_idx}: {e}" + ) + + def create_transport(self, **kwargs) -> TransportBackend: + """Create shared memory transport. + + Returns the shared transport instance that all workers will use. + Since this is shared memory, there's only one transport shared by all workers. + + Note: + This is used internally by init_on_sender/init_on_receiver. + """ + return self.shared_transport + + def prepare_weights( + self, + weights: Any, + model_id: str, + strategy: WeightStrategy, + context: Any = None, + ) -> Any: + """Prepare weights for SharedMemWeightSyncScheme. + + When weights=None, we extract fresh weights from the model and update + the shared memory buffer in-place so workers see the change. + + Args: + weights: Raw weights input + model_id: The model identifier + strategy: WeightStrategy for extracting/converting weights + context: Optional context (e.g., collector) for cache lookup + + Returns: + Shared memory weights ready to send + """ + # If weights are explicitly provided, use them directly + if weights is not None: + fresh_weights = super().prepare_weights( + weights, model_id, strategy, context + ) + else: + # Extract fresh weights from the model (base class handles this) + fresh_weights = super().prepare_weights(None, model_id, strategy, context) + + if fresh_weights is None: + return None + + # Update the shared memory buffer in-place so workers see the change + if self._shared_transport is not None and self.shared_transport.unique_weights: + torchrl_logger.debug("Updating shared memory buffer in-place") + shared_weights = self.shared_transport.unique_weights[0] + # In-place update of shared memory buffer with fresh weights + shared_weights.data.update_(fresh_weights.data) + return shared_weights + + torchrl_logger.debug("No shared transport, returning fresh weights") + # If no shared transport, just return the fresh weights + return fresh_weights + + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights via shared memory (in-place update). + + For SharedMemWeightSyncScheme: + 1. prepare_weights() updates the shared memory buffer in-place + 2. _send_instruction() tells workers to apply the new weights + 3. If sync=True, waits for acknowledgments from all workers + + Args: + weights: Weights to send (can be None to extract from model). + worker_ids: Which workers to notify (None = all workers). + """ + if not self.initialized_on_sender: + raise RuntimeError("Must be initialized on sender before sending weights") + if not self.synchronized_on_sender: + raise RuntimeError("Must be synchronized on sender before sending weights") + + # prepare_weights updates the shared buffer in-place + torchrl_logger.debug( + "Sending weights via shared memory -- calling prepare_weights()" + ) + self.prepare_weights( + weights=weights, + model_id=self._model_id, + strategy=self._strategy, + context=self.context, + ) + + # Send instruction to workers' background threads to apply the weights + torchrl_logger.debug("Sending 'receive' instruction to workers") + self._send_instruction(instruction="receive", worker_ids=worker_ids) + + # Wait for acknowledgments if in synchronous mode + if self.sync: + torchrl_logger.debug("Waiting for acknowledgments from workers") + self._wait_for_ack(worker_ids=worker_ids) + + @property + def weights(self) -> Any | None: + """Get the current weights from shared memory. + + For SharedMemWeightSyncScheme: + - On sender side: weights are in transport's _unique_weights + - On receiver side: weights are in _receiver_shared_weights (stored during connect()) + + Returns: + The weights TensorDict if available, None otherwise. + """ + # On receiver side, use the stored shared buffer reference + if ( + hasattr(self, "_receiver_shared_weights") + and self._receiver_shared_weights is not None + ): + return self._receiver_shared_weights + + # On sender side, get from the shared transport + if self._shared_transport is not None and self.shared_transport.unique_weights: + return self.shared_transport.unique_weights[0] + + # Fall back to parent implementation + return super().weights + + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None + ) -> None: + """Synchronize weights on receiver side for shared memory. + + Reads the shared memory buffer from the queue and applies it to the model. + Then starts a background thread that listens for "receive" instructions + from the sender and applies weights when instructed. + + If a receiver_transport is set (e.g., for MultiProcessWeightSyncScheme), + defers to the base class implementation. + """ + # If receiver_transport is set (e.g., MultiProcess subclass), use base behavior + if self._receiver_transport is not None: + return super()._setup_connection_and_weights_on_receiver_impl( + worker_idx=worker_idx + ) + + # SharedMem-specific: use shared_transport + if self._shared_transport is None: + raise RuntimeError( + "SharedMemWeightSyncScheme requires shared_transport to be set." + ) + + # Use stored worker_idx if not provided + if worker_idx is None: + worker_idx = self.worker_idx + + if worker_idx is None: + raise RuntimeError( + "worker_idx must be provided for _setup_connection_and_weights_on_receiver_impl." + ) + + # Read shared memory buffer from queue + weights = self._shared_transport.setup_connection_and_weights_on_receiver( + worker_idx=worker_idx + ) + + # Store the shared buffer reference for later receive() calls + # This is the actual shared memory buffer that the sender updates + self._receiver_shared_weights = weights + + # Apply weights to model + if weights is not None and self.model is not None: + self._strategy.apply_weights(self.model, weights, inplace=False) + + # Start background receiver thread that listens for instructions + self._start_background_receiver() + + def _background_receive_loop(self): + """Background thread loop that waits for instructions and applies weights. + + This loop: + 1. Waits for a "receive" instruction from the sender + 2. Applies the current shared memory weights to the model + 3. Sends an acknowledgment back to the sender + 4. Repeats until stop event is set or "stop" instruction received + """ + torchrl_logger.debug( + f"SharedMemWeightSyncScheme: Background receiver started for worker {self._worker_idx}" + ) + while not self._stop_event.is_set(): + try: + instruction = self._wait_for_instruction() + if instruction is None: + # Stop event was set or timeout + continue + if instruction == "receive": + torchrl_logger.debug( + f"SharedMemWeightSyncScheme: Worker {self._worker_idx} received 'receive' instruction" + ) + # Apply the current shared memory weights to the model + # The weights are already updated in shared memory by the sender + if ( + self._receiver_shared_weights is not None + and self.model is not None + ): + self._strategy.apply_weights( + self.model, self._receiver_shared_weights, inplace=True + ) + torchrl_logger.debug( + f"SharedMemWeightSyncScheme: Worker {self._worker_idx} applied weights" + ) + + # Cascade weight update to sub-collectors if context supports it + model_id = self._model_id or "policy" + if self.context is not None and hasattr( + self.context, "update_policy_weights_" + ): + torchrl_logger.debug( + f"SharedMemWeightSyncScheme: Cascading weight update to sub-collectors for {model_id=}" + ) + self.context.update_policy_weights_( + model_id=model_id, + policy_or_weights=self._receiver_shared_weights, + ) + + # Send acknowledgment + self._send_ack("updated") + elif instruction == "stop": + torchrl_logger.debug( + f"SharedMemWeightSyncScheme: Worker {self._worker_idx} received 'stop' instruction" + ) + break + else: + torchrl_logger.warning( + f"SharedMemWeightSyncScheme: Unknown instruction: {instruction}" + ) + except Exception as e: + if not self._stop_event.is_set(): + torchrl_logger.warning( + f"SharedMemWeightSyncScheme: Background receiver error: {e}" + ) + + torchrl_logger.debug( + f"SharedMemWeightSyncScheme: Background receiver stopped for worker {self._worker_idx}" + ) + + def __getstate__(self): + """Prepare the scheme for pickling.""" + state = super().__getstate__() + # mp.Queue objects can be pickled and shared across processes + # Keep them in state so workers have access + return state + + def shutdown(self) -> None: + """Stop the background receiver thread and clean up.""" + # Signal all workers to stop + if self._instruction_queues: + for worker_idx in self._instruction_queues: + try: + self._instruction_queues[worker_idx].put("stop") + except Exception: + pass + + # Stop local background thread if running + if self._stop_event is not None: + self._stop_event.set() + if self._background_thread is not None: + self._background_thread.join(timeout=5.0) + if self._background_thread.is_alive(): + torchrl_logger.warning( + "SharedMemWeightSyncScheme: Background thread did not stop gracefully" + ) + self._background_thread = None + self._stop_event = None diff --git a/torchrl/weight_update/llm/vllm_double_buffer.py b/torchrl/weight_update/llm/vllm_double_buffer.py index 2482f250d0e..e8435352f43 100644 --- a/torchrl/weight_update/llm/vllm_double_buffer.py +++ b/torchrl/weight_update/llm/vllm_double_buffer.py @@ -117,20 +117,31 @@ def send_weights(self, model_id: str, weights: Any) -> None: weights.memmap(self.remote_addr, num_threads=self.num_threads) logger.info(f"Weights written successfully to {self.remote_addr}") - def receive_weights(self, timeout: float = 1.0) -> TensorDict: + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> Any | None: """Reads the weights from the shared directory. Args: - timeout: Not used for file-based transport (kept for API compatibility). + timeout: Ignored (file-based transport is instant). + weights: Ignored. + model: Ignored. + strategy: Ignored. Returns: TensorDict with flattened keys containing the weights. """ + # Timeout is ignored since file-based transport doesn't involve waiting logger.info(f"Reading weights from {self.local_addr}") - weights = TensorDict.load_memmap(self.local_addr) - weights = weights.flatten_keys(".") + received_weights = TensorDict.load_memmap(self.local_addr) + received_weights = received_weights.flatten_keys(".") logger.info(f"Weights read successfully from {self.local_addr}") - return weights + return received_weights def check_connection(self) -> bool: """Check if the transport is ready. @@ -187,13 +198,11 @@ def __init__( self.num_threads = num_threads self.strategy_name = strategy - def create_transport( - self, pipe_or_context: Any = None - ) -> VLLMDoubleBufferTransport: + def create_transport(self, **kwargs) -> VLLMDoubleBufferTransport: """Create transport for double-buffered storage. Args: - pipe_or_context: Not used for file-based transport (kept for API compatibility). + **kwargs: Not used for file-based transport (kept for API compatibility). Returns: A VLLMDoubleBufferTransport instance. @@ -301,7 +310,7 @@ def __init__(self, scheme: VLLMDoubleBufferSyncScheme, vllm_engine): f"Initialized double-buffer receiver reading from {self._scheme.local_addr}" ) - def apply_weights(self, weights: TensorDict) -> None: + def apply_weights(self, weights: TensorDict, inplace: bool = True) -> None: """Apply weights to vLLM engine using RPC. This method uses RPC to tell all vLLM workers to load weights from @@ -310,7 +319,10 @@ def apply_weights(self, weights: TensorDict) -> None: Args: weights: TensorDict with flattened keys containing weights. + inplace: Whether to apply weights in place. Default is `True`. """ + if not inplace: + raise ValueError("Cannot apply weights out of place for vLLM double-buffer") logger.info("Applying weights to vLLM engine via RPC") # Convert TensorDict to list of (name, tensor) tuples @@ -357,6 +369,7 @@ def poll_and_apply(self, timeout: float = 180.0) -> bool: Returns: True if weights were successfully read and applied, False otherwise. """ - weights = self._transport.receive_weights(timeout=timeout) + # timeout is not used by file-based transport but kept for API compatibility + weights = self._transport.receive_weights() self.apply_weights(weights) return True diff --git a/torchrl/weight_update/llm/vllm_nccl.py b/torchrl/weight_update/llm/vllm_nccl.py index 840a9883d14..c9907b8f17a 100644 --- a/torchrl/weight_update/llm/vllm_nccl.py +++ b/torchrl/weight_update/llm/vllm_nccl.py @@ -71,8 +71,6 @@ def init_all_workers_group(self, metadata): **Current Implementation (Ray Backend)** -The test suite in ``test_weightsync.py`` demonstrates the Ray-based RPC: - .. code-block:: python # Trainer actor (provides RPC endpoint) @@ -101,6 +99,8 @@ def init_all_workers_group(self, metadata): from __future__ import annotations +import time + from typing import Any, Literal import torch @@ -189,13 +189,13 @@ def init_all_workers_group( if self.rank == 0: # Trainer side - initialize process group - torchrl_logger.info( + torchrl_logger.debug( f"Initializing trainer collective group: rank={self.rank}, world_size={self.world_size}, device={self.device}" ) # Ray sets CUDA_VISIBLE_DEVICES, so we always use device 0 # Set CUDA device before initializing NCCL to avoid segfaults torch.cuda.set_device(self.device) - torchrl_logger.info(f"Set CUDA device to {self.device}") + torchrl_logger.debug(f"Set CUDA device to {self.device}") self._comm_group = stateless_init_process_group( self.master_address, @@ -204,13 +204,13 @@ def init_all_workers_group( self.world_size, device=self.device, ) - torchrl_logger.info("Trainer collective group initialized successfully") + torchrl_logger.debug("Trainer collective group initialized successfully") else: # vLLM worker side - initialize through engine if self.vllm_engine is None: raise ValueError("vllm_engine must be provided for worker ranks") - torchrl_logger.info( + torchrl_logger.debug( "Initializing vLLM worker collective group through engine" ) # Call vLLM engine's init method - it returns futures for all workers @@ -224,18 +224,17 @@ def init_all_workers_group( import ray ray.get(refs) - torchrl_logger.info( + torchrl_logger.debug( f"All {len(refs)} vLLM workers have dispatched NCCL init RPCs" ) # Small delay to ensure worker background threads have entered the NCCL collective # This prevents a race where the trainer starts NCCL before workers are ready - import time time.sleep(0.2) self._comm_group = True # Mark as initialized - torchrl_logger.info( + torchrl_logger.debug( "vLLM workers should now be blocked in NCCL collective, ready for trainer" ) @@ -283,7 +282,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: else: weights_dict = weights - torchrl_logger.info( + torchrl_logger.debug( f"Broadcasting {len(weights_dict)} weights for model '{model_id}'" ) @@ -314,14 +313,27 @@ def send_weights(self, model_id: str, weights: Any) -> None: del tensor torch.cuda.synchronize() - torchrl_logger.info(f"Broadcast complete for model '{model_id}'") + torchrl_logger.debug(f"Broadcast complete for model '{model_id}'") - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> Any | None: """Receive weights from broadcaster. This should only be called from worker ranks (rank > 0). This method is called by vLLM engine internally through collective operations. + Args: + timeout: Ignored (vLLM handles synchronization internally). + weights: Ignored. + model: Ignored. + strategy: Ignored. + Returns: None - vLLM handles weight application internally via collectives. """ @@ -441,7 +453,7 @@ def __init__( s.bind(("", 0)) self.master_port = s.getsockname()[1] - def create_transport(self, pipe_or_context: Any) -> VLLMCollectiveTransport: + def create_transport(self, **kwargs) -> VLLMCollectiveTransport: """Create transport for collective communication. For vLLM, this creates a transport but requires additional setup via init_all_workers_group(). @@ -449,7 +461,7 @@ def create_transport(self, pipe_or_context: Any) -> VLLMCollectiveTransport: is more complex and typically handled by sender/receiver initialization. Args: - pipe_or_context: Not used for vLLM (kept for API compatibility). + **kwargs: Not used for vLLM (kept for API compatibility). Returns: A VLLMCollectiveTransport instance (needs init_all_workers_group() to be called). @@ -546,7 +558,7 @@ def init_all_workers_group( device=self._scheme.device, vllm_engine=vllm_engine, ) - torchrl_logger.info( + torchrl_logger.debug( f"Initializing transport from sender with world_size={world_size}" ) self._transport.init_all_workers_group(model_metadata) @@ -642,14 +654,18 @@ def init_all_workers_group( device=self._scheme.device, vllm_engine=self._vllm_engine, ) - torchrl_logger.info( + torchrl_logger.debug( f"Initializing transport from receiver with world_size={world_size}." ) self._transport.init_all_workers_group(model_metadata) - def apply_weights(self, weights: Any) -> None: + def apply_weights(self, weights: Any, inplace: bool = True) -> None: """Apply weights to vLLM engine. + Args: + weights: The weights to apply. + inplace: Whether to apply weights in place. Default is `True`. + Note: For vLLM, weights are applied automatically during the collective broadcast operation. This method is a no-op but kept for API consistency. """ diff --git a/torchrl/weight_update/utils.py b/torchrl/weight_update/utils.py new file mode 100644 index 00000000000..ebfe9739474 --- /dev/null +++ b/torchrl/weight_update/utils.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import re +from typing import Any + + +def _resolve_attr(context: Any, attr_path: str) -> Any: + """Resolve an attribute path like 'policy' or 'env.value_net' to actual object. + + Also processes getitem notation like 'env.transform[0]' or '_receiver_schemes["model_id"]' + to actual object. + + Args: + context: The context object (collector or inner_collector). + attr_path: A string address like "policy", "env.value_net", or + "_receiver_schemes['model_id']". + + Returns: + The object at the specified address. + + Examples: + >>> _resolve_attr(collector, "policy") # -> collector.policy + >>> _resolve_attr(collector, "env.value_net") # -> collector.env.value_net + >>> _resolve_attr(collector, "_receiver_schemes['model_id']") # -> collector._receiver_schemes['model_id'] + """ + # Pattern to match subscript access: attr[key] or attr["key"] or attr['key'] or attr[0] + subscript_pattern = re.compile(r"^([^\[]+)(.*)$") + + parts = attr_path.split(".") + obj = context + for i, part in enumerate(parts): + if "[" in part: + match = subscript_pattern.match(part) + if match: + key = match.group(1) + subscripts_str = match.group(2) + + # Get the base attribute + if key: + try: + obj = getattr(obj, key) + except AttributeError: + raise AttributeError( + f"Attribute {key} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" + ) + + # Parse and apply all subscripts + # Match each [xxx] where xxx can be int, 'string', or "string" + subscript_matches = re.findall(r"\[([^\]]+)\]", subscripts_str) + for subscript in subscript_matches: + # Try to parse as int first + try: + index = int(subscript) + obj = obj[index] + except ValueError: + # It's a string key - remove quotes if present + if (subscript.startswith("'") and subscript.endswith("'")) or ( + subscript.startswith('"') and subscript.endswith('"') + ): + subscript = subscript[1:-1] + obj = obj[subscript] + else: + try: + obj = getattr(obj, part) + except AttributeError: + raise AttributeError( + f"Attribute {part} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" + ) + return obj + + +# Alias for backwards compatibility +_resolve_model = _resolve_attr diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 42d13108a0f..75ab16563b4 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -5,38 +5,28 @@ from __future__ import annotations import abc - +import threading +import warnings import weakref -from collections.abc import Iterator -from typing import Any, Literal, Protocol +from collections import defaultdict +from collections.abc import Callable, Iterator +from typing import Any, Literal, overload, Protocol -from tensordict import TensorDict, TensorDictBase +import torch +from tensordict import TensorDict, TensorDictBase from torch import nn +from torchrl._utils import logger as torchrl_logger __all__ = [ "TransportBackend", - "MPTransport", - "SharedMemTransport", - "RayTransport", - "RayActorTransport", - "RPCTransport", - "DistributedTransport", "WeightStrategy", - "WeightSender", - "WeightReceiver", - "RayModuleTransformSender", - "RayModuleTransformReceiver", "WeightSyncScheme", - "MultiProcessWeightSyncScheme", - "SharedMemWeightSyncScheme", - "NoWeightSyncScheme", - "RayWeightSyncScheme", - "RayModuleTransformScheme", - "RPCWeightSyncScheme", - "DistributedWeightSyncScheme", ] +from torchrl.weight_update.utils import _resolve_model + + # ============================================================================ # Transport Layer Abstraction # ============================================================================ @@ -45,641 +35,66 @@ class TransportBackend(Protocol): """Abstract interface for different communication mechanisms.""" - def send_weights(self, model_id: str, weights: Any) -> None: + def send_weights(self, weights: Any) -> None: """Send weights to the receiver.""" ... - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights from the sender. Returns (model_id, weights) or None if timeout.""" - ... - - def check_connection(self) -> bool: - """Check if the connection is still alive.""" - ... - - -class MPTransport: - """Multiprocessing transport using pipes. - - Args: - pipe_connection (mp.Pipe): The pipe connection to use for communication. - timeout (float): The timeout for waiting for acknowledgment. Default is 10 seconds. - """ - - def __init__(self, pipe_connection, timeout: float = 10.0): - self.timeout = timeout - self.pipe = pipe_connection - - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights through the pipe. - - Sends weights and waits for acknowledgment to ensure delivery. - """ - self.send_weights_async(model_id, weights) - self.wait_ack() - - def send_weights_async(self, model_id: str, weights: Any) -> None: - """Send weights through the pipe without waiting for acknowledgment. - - Use wait_ack() to wait for acknowledgment after sending to all workers. - """ - self.pipe.send(((model_id, weights), "update_weights")) - - def wait_ack(self) -> None: - """Wait for acknowledgment from worker.""" - self.check_ack("updated") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights from the pipe (used in worker process). - - This method only handles weight update messages. Other messages - (like "close", "continue", etc.) are ignored and should be handled - by the main worker loop. - - Returns: - Tuple of (model_id, weights) if weights were received, None if no data available - or if a non-weight message was received. - """ - if self.pipe.poll(timeout): - data_in, msg = self.pipe.recv() - if msg == "update_weights": - model_id, weights = data_in - return model_id, weights - else: - # Not a weight update message - put it back and return None - # This allows the main worker loop to handle other messages - # Note: We can't actually "put it back", so we'll just return None - # and the message is lost. This is why receive() should only be called - # when we're expecting weight updates, not in the main message loop. - return None - # No data available - return None instead of raising TimeoutError - # This allows non-blocking checks in the worker loop - return None - - def send_ack(self, message: str = "updated") -> None: - """Send acknowledgment back to sender.""" - self.pipe.send((None, message)) - - def check_ack(self, message: str = "updated") -> None: - """Check for acknowledgment.""" - _, msg = self.pipe.recv() - if msg != message: - raise RuntimeError(f"Expected acknowledgment '{message}', got '{msg}'") - - def check_connection(self) -> bool: - return not self.pipe.closed - - -class SharedMemTransport: - """Shared memory transport for in-place weight updates. - - This transport updates shared memory tensors directly without message passing. - Workers automatically see weight updates without explicit communication. - - The transport supports lazy registration with pipe-based buffer distribution: - - On first weight send for a model, creates shared memory and sends buffer via pipes - - Workers receive the buffer reference and update their local references - - Subsequent updates are pure in-place shared memory (zero-copy) - - This hybrid approach solves the chicken-and-egg problem: workers can start before - weights are available, and they'll receive the shared buffer references when ready. - - Args: - policy_weights: Dictionary mapping model_id to shared TensorDict weights. - Can be empty if using lazy registration. - auto_register: Whether to automatically register models on first weight send. - Default is True. Set to `False` to require explicit registration via - register_weights(). - """ - - def __init__( + def receive_weights( self, - policy_weights: dict[str, TensorDictBase] | None = None, - auto_register: bool = True, - ): - self._policy_weights = policy_weights if policy_weights is not None else {} - self._auto_register = auto_register - self._pipes = [] # List of pipes to send initial buffer references - # Track which model_ids have been sent to workers - self._registered_with_workers = set() - - def register_pipe(self, pipe: Any) -> None: - """Register a pipe for sending buffer references on first weight send. - - Args: - pipe: Pipe connection to a worker process. - """ - if pipe not in self._pipes: - self._pipes.append(pipe) - - def register_weights(self, model_id: str, weights: TensorDictBase) -> None: - """Register a shared memory weights TensorDict for a model. - - This method allows explicit registration of shared weights. It's optional - when auto_register=True (the default), but required when auto_register=False. - - If pipes are registered and this model hasn't been sent to workers yet, - this will trigger sending the buffer reference to all workers. If pipes - aren't registered yet, weights are stored and will be sent when pipes - become available (during init_on_sender). - """ - if not isinstance(weights, TensorDictBase): - raise ValueError(f"Weights must be a TensorDictBase, got {type(weights)}") - - is_new_registration = model_id not in self._policy_weights - if is_new_registration: - self._policy_weights[model_id] = weights - else: - raise RuntimeError("Re-registering weights is not supported.") - - # If this is a new registration and we have pipes, send buffer to workers - # If pipes aren't available yet, defer sending until init_on_sender is called - if self._pipes: - if model_id not in self._registered_with_workers: - self._send_buffer_to_workers(model_id, weights) - else: - raise RuntimeError( - f"Model '{model_id}' has already been registered with workers." - ) - - def _send_buffer_to_workers( - self, model_id: str, buffer: TensorDictBase, timeout: float = 10.0 - ) -> None: - """Send shared memory buffer reference to all workers via pipes. - - This is called once per model_id when lazy registration occurs. - Workers receive the buffer and update their local references. - - Note: We send buffer.data to avoid gradient tracking issues when crossing - process boundaries. The .data attribute gives us the underlying tensors - without autograd metadata. - """ - for pipe in self._pipes: - # Send special registration message with the shared buffer - # Use .data to strip gradient information (can't serialize non-leaf tensors with requires_grad) - pipe.send(((model_id, buffer.data), "register_shared_weights")) - - # Wait for acknowledgments from all workers - for pipe in self._pipes: - if not pipe.poll(timeout): - raise TimeoutError("Timeout waiting for acknowledgment from worker") - _, msg = pipe.recv() - if msg != "registered": - raise RuntimeError(f"Expected 'registered' acknowledgment, got '{msg}'") - - self._registered_with_workers.add(model_id) - - def send_weights(self, model_id: str, weights: Any) -> None: - """Update weights in-place in shared memory. - - If the model is not registered and auto_register=True, it will be automatically - registered by creating a shared memory copy of the provided weights. The shared - buffer reference is sent to all workers via pipes on first registration, then - subsequent updates are pure in-place shared memory. + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy | None = None, + ) -> Any | None: + """Receive weights from the sender and apply them to the model. Args: - model_id: Identifier for the model whose weights to update. - weights: New weights to send. Can be a TensorDictBase or dict. - - Raises: - KeyError: If model is not registered and auto_register=False. - ValueError: If weights type is unsupported for auto-registration. - """ - if model_id not in self._policy_weights: - if not self._auto_register: - raise KeyError( - f"Model '{model_id}' not registered in SharedMemTransport. " - f"Available models: {list(self._policy_weights.keys())}. " - f"Either register the model using register_weights() or enable auto_register." - ) - - # Auto-register on first send - if isinstance(weights, dict): - weights = TensorDict(weights) - if not isinstance(weights, TensorDictBase): - raise ValueError( - f"Cannot auto-register model '{model_id}' with weights type: {type(weights)}. " - f"Supported types for auto-registration: TensorDictBase, dict. " - f"Please manually register shared weights using register_weights()." - ) - # Unflatten keys if they're flat (e.g., 'module.0.weight' -> nested structure) - # This is necessary for to_module() to work properly - weights_to_share = weights - # Check if keys are flattened by looking for dots in key names - if any("." in key for key in weights_to_share.keys()): - weights_to_share = weights_to_share.unflatten_keys(".") - shared_buffer = weights_to_share.share_memory_() - - self._policy_weights[model_id] = shared_buffer - - # Send buffer reference to all workers if we have pipes - if self._pipes and model_id not in self._registered_with_workers: - self._send_buffer_to_workers(model_id, shared_buffer) - - shared_weights = self._policy_weights[model_id] - - # Update shared memory in-place (workers see this automatically) - if isinstance(weights, dict): - weights = TensorDict(weights) - if not isinstance(weights, TensorDictBase): - raise ValueError(f"Unsupported weights type: {type(weights)}") - # Unflatten if needed to match shared buffer structure - weights_to_update = weights - if any("." in key for key in weights.keys()): - weights_to_update = weights.unflatten_keys(".") - shared_weights.data.update_(weights_to_update.data) - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """No-op for shared memory - weights are already visible.""" - return None - - def send_ack(self, message: str = "updated") -> None: - """No-op for shared memory - no acknowledgment needed.""" - - def check_ack(self, message: str = "updated") -> None: - """No-op for shared memory - no acknowledgment needed.""" - - def check_connection(self) -> bool: - """Shared memory is always 'connected'.""" - return True - - -class RayTransport: - """Ray transport for communicating with a single Ray collector actor. - - This transport handles weight updates for ONE specific remote collector. - Multiple transports are created for multiple collectors, following the - same pattern as multiprocess collectors. - """ - - def __init__( - self, - remote_collector=None, - tensor_transport: Literal["object_store", "nixl"] = "object_store", - ): - try: - import ray - - self.ray = ray - except ImportError: - raise ImportError("Ray is required for RayTransport") - self._remote_collector = remote_collector - self._tensor_transport = tensor_transport - - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the remote collector via Ray. + timeout: Maximum time to wait for weights (seconds). + None means no timeout (blocking). Some transports may not + support timeout and will raise ValueError if specified. + weights: Pre-allocated weight buffer to receive into. + model: The model to apply weights to. + strategy: Strategy for applying weights to the model. - Note: We don't pass model_id to the remote collector because remote - collectors don't have weight senders - they apply weights directly to - their local policy. + Returns: + The received/applied weights, or None if timeout/no weights available. """ - if self._remote_collector is None: - return - - # Put weights in Ray's object store for efficient distribution - # Ray will automatically deduplicate if the same weights are sent to multiple actors - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - - # Send to the remote collector and wait for completion - # This ensures weights are applied before we continue - future = self._remote_collector.update_policy_weights_.remote( - policy_or_weights=weights_ref - ) - self.ray.wait([future], num_returns=1) + ... - def send_weights_async(self, model_id: str, weights: Any) -> None: - """Send weights to remote collector without waiting for completion. + def setup_connection_and_weights_on_sender(self) -> None: + """Synchronize weights on sender side before collection starts. - Use wait_ack() to wait for completion after sending to all workers. + This is called once after workers are initialized to send the initial + weights. This can be a no-op (weights are sent via + send_weights). """ - if self._remote_collector is None: - return - - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - self._pending_future = self._remote_collector.update_policy_weights_.remote( - policy_or_weights=weights_ref - ) - - def wait_ack(self) -> None: - """Wait for the remote collector to finish applying weights.""" - if hasattr(self, "_pending_future"): - self.ray.wait([self._pending_future], num_returns=1) - del self._pending_future - else: - raise RuntimeError("No pending future. Did you call send_weights_async?") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Ray workers typically don't receive weights through this transport.""" - return None - - def check_connection(self) -> bool: - """Check if Ray is initialized.""" - return self.ray.is_initialized() - - -class RayActorTransport: - """Ray transport for communicating with Ray actors (not collectors). - - This transport is designed for updating models hosted within Ray actors, - such as RayModuleTransform instances. It directly calls the actor's - update_weights method rather than going through collector update methods. - """ + ... - def __init__( + def setup_connection_and_weights_on_receiver( self, - actor_ref=None, - update_method: str = "tensordict", - tensor_transport: Literal["object_store", "nixl"] = "object_store", - ): - try: - import ray - - self.ray = ray - except ImportError: - raise ImportError("Ray is required for RayActorTransport") - - self._actor_ref = actor_ref - self._update_method = update_method - self._tensor_transport = tensor_transport - - def set_actor(self, actor_ref): - """Set the Ray actor reference to communicate with.""" - self._actor_ref = actor_ref - - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the Ray actor.""" - if self._actor_ref is None: - return - - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - - if self._update_method == "tensordict": - self.ray.get( - self._actor_ref._update_weights_tensordict.remote(params=weights_ref) - ) - elif self._update_method == "state_dict": - self.ray.get( - self._actor_ref._update_weights_state_dict.remote( - state_dict=weights_ref - ) - ) - else: - raise ValueError(f"Unknown update method: {self._update_method}") - - def send_weights_async(self, model_id: str, weights: Any) -> None: - """Send weights to Ray actor without waiting for completion. - - Use wait_ack() to wait for completion after sending to all actors. - """ - if self._actor_ref is None: - return - - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - - if self._update_method == "tensordict": - self._pending_future = self._actor_ref._update_weights_tensordict.remote( - params=weights_ref - ) - elif self._update_method == "state_dict": - self._pending_future = self._actor_ref._update_weights_state_dict.remote( - state_dict=weights_ref - ) - else: - raise ValueError(f"Unknown update method: {self._update_method}") - - def wait_ack(self) -> None: - """Wait for Ray actor to finish applying weights.""" - if hasattr(self, "_pending_future"): - self.ray.get(self._pending_future) - del self._pending_future - else: - raise RuntimeError("No pending future. Did you call send_weights_async?") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Ray actor workers receive weights through direct method calls.""" - return None - - def send_ack(self, message: str = "updated") -> None: - """No acknowledgment needed for Ray actors.""" - - def check_ack(self, message: str = "updated") -> None: - """No acknowledgment needed for Ray actors.""" - - def check_connection(self) -> bool: - """Check if Ray is initialized and actor exists.""" - if not self.ray.is_initialized(): - return False - if self._actor_ref is None: - return False - return True - - -class RPCTransport: - """RPC transport for communicating with a single RPC remote collector. - - This transport handles weight updates for ONE specific remote collector via - torch.distributed.rpc. Multiple transports are created for multiple collectors, - following the same pattern as multiprocess collectors. - """ - - def __init__(self, collector_info=None, collector_rref=None, collector_class=None): - self._collector_info = collector_info - self._collector_rref = collector_rref - self._collector_class = collector_class - - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the remote collector via RPC. - - Note: We don't pass model_id to the remote collector because remote - collectors don't have weight senders - they apply weights directly to - their local policy. - """ - if self._collector_info is None or self._collector_rref is None: - return - - from torch.distributed import rpc - - # Send weights to the remote collector and wait for completion - rpc.rpc_sync( - self._collector_info, - self._collector_class.update_policy_weights_, - args=(self._collector_rref, weights), - ) - - def send_weights_async(self, model_id: str, weights: Any) -> None: - """Send weights to remote collector without waiting for completion. - - Use wait_ack() to wait for completion after sending to all workers. - """ - if self._collector_info is None or self._collector_rref is None: - return - - from torch.distributed import rpc - - # Send weights asynchronously - self._pending_future = rpc.rpc_async( - self._collector_info, - self._collector_class.update_policy_weights_, - args=(self._collector_rref, weights), - ) - - def wait_ack(self) -> None: - """Wait for the RPC call to complete.""" - if hasattr(self, "_pending_future"): - self._pending_future.wait() - del self._pending_future - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """RPC workers typically don't receive weights through this transport.""" - return None - - def check_connection(self) -> bool: - """Check if RPC is initialized.""" - from torch.distributed import rpc - - return rpc.is_initialized() if hasattr(rpc, "is_initialized") else True - - -class DistributedTransport: - """torch.distributed transport for communicating with a single distributed worker. - - This transport handles weight updates for ONE specific distributed worker via - torch.distributed send/recv. Multiple transports are created for multiple workers, - following the same pattern as multiprocess collectors. - """ - - def __init__(self, store=None, rank=None, sync=True): - """Initialize the DistributedTransport. - - Args: - store: TCPStore for communication. - rank: Worker rank (1-indexed). - sync: Whether to use synchronous weight updates. - """ - self._store = store - self._rank = rank - self._sync = sync - self._weights_buffer = None # TensorDict buffer for receiving weights - - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the distributed worker. - - Note: We don't pass model_id to the remote collector because remote - collectors don't have weight senders - they apply weights directly to - their local policy. - """ - if self._store is None or self._rank is None: - return - - # Instruct worker to expect weight update - self._store.set(f"NODE_{self._rank}_in", b"update_weights") - - # Send weights via torch.distributed - if self._sync: - weights.send(self._rank) - else: - weights.isend(self._rank) - - # Wait for acknowledgment - status = self._store.get(f"NODE_{self._rank}_out") - if status != b"updated": - raise RuntimeError(f"Expected 'updated' but got status {status}.") - self._store.delete_key(f"NODE_{self._rank}_out") - - def send_weights_async(self, model_id: str, weights: Any) -> None: - """Send weights to distributed worker without waiting for acknowledgment. - - Use wait_ack() to wait for acknowledgment after sending to all workers. - """ - if self._store is None or self._rank is None: - return - - # Instruct worker to expect weight update - self._store.set(f"NODE_{self._rank}_in", b"update_weights") - - # Send weights via torch.distributed - if self._sync: - weights.send(self._rank) - else: - weights.isend(self._rank) - - def wait_ack(self) -> None: - """Wait for acknowledgment from distributed worker.""" - if self._store is None or self._rank is None: - return - - status = self._store.get(f"NODE_{self._rank}_out") - if status != b"updated": - raise RuntimeError(f"Expected 'updated' but got status {status}.") - self._store.delete_key(f"NODE_{self._rank}_out") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights via torch.distributed, using TCPStore for signaling. + *, + worker_idx: int, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy | None = None, + ) -> Any: + """Synchronize weights on worker side before collection starts. - This implements the RPC-like pattern: - 1. Check TCPStore for signal (non-blocking) - 2. If signal present, receive weights via torch.distributed - 3. Clean up signal and send acknowledgment + This is called once in each worker after initialization to receive + the initial weights. This is a no-op (weights are received via + receive_weights). Args: - timeout: Timeout for receiving (currently not used for TCPStore check) + worker_idx: The worker index. + weights: Pre-allocated weight buffer to receive into. + model: The model to apply weights to. + strategy: Strategy for applying weights to the model. Returns: - Tuple of (model_id, weights) if weights were received, None otherwise. + The received weights (for SharedMemTransport) or None. """ - if self._store is None or self._rank is None: - return None - - try: - # Non-blocking check of TCPStore "mailbox" for signal - msg = self._store.get(f"NODE_{self._rank}_in") - - if msg == b"update_weights": - # Initialize weights buffer on first use - if self._weights_buffer is None: - self._weights_buffer = TensorDict() - - # Receive weights via torch.distributed - # recv() and irecv() update the TensorDict in place - if self._sync: - self._weights_buffer.recv(src=0) - else: - # irecv() blocks until weights are received - self._weights_buffer.irecv(src=0) - - # Clean up the signal - self._store.delete_key(f"NODE_{self._rank}_in") - - # Note: Acknowledgment is sent separately via send_ack() if transport supports it - # This matches the pattern in WeightReceiver.receive() - - # Return model_id and received weights - # For distributed transport, we use "policy" as default model_id - return ("policy", self._weights_buffer) - else: - raise ValueError(f"Expected 'update_weights' but got {msg}") - except KeyError: - # No message in store - no weights available - return None - - return None - - def send_ack(self, message: str = "updated") -> None: - """Send acknowledgment back to sender via TCPStore. - - Args: - message: Acknowledgment message to send (default: "updated") - """ - if self._store is None or self._rank is None: - return - - self._store.set(f"NODE_{self._rank}_out", message.encode()) - - def check_connection(self) -> bool: - """Check if torch.distributed is initialized.""" - import torch.distributed - - return torch.distributed.is_initialized() + ... # ============================================================================ @@ -703,13 +118,18 @@ class WeightStrategy: """ def __init__(self, extract_as: Literal["tensordict", "state_dict"] = "tensordict"): + if extract_as == "state_dict": + warnings.warn( + "state_dict strategy is experimental. Use tensordict strategy for safer weight updates.", + UserWarning, + ) if extract_as not in ("tensordict", "state_dict"): raise ValueError( f"extract_as must be 'tensordict' or 'state_dict', got {extract_as}" ) self.extract_as = extract_as - def extract_weights(self, source: Any) -> Any: + def extract_weights(self, source: Any) -> TensorDictBase | dict | None: """Extract weights from source model in the specified format. Args: @@ -731,10 +151,11 @@ def extract_weights(self, source: Any) -> Any: # Convert state_dict to TensorDict return TensorDict(source, batch_size=[]) else: - raise ValueError( + torchrl_logger.warning( f"Unsupported source type for TensorDict extraction: {type(source)}" ) - else: # state_dict + return TensorDict(lock=True) + elif self.extract_as == "state_dict": # state_dict # Extract as state_dict if isinstance(source, nn.Module): return source.state_dict() @@ -742,13 +163,20 @@ def extract_weights(self, source: Any) -> Any: return source elif isinstance(source, TensorDictBase): # Convert TensorDict to state_dict - return source.to_dict() + return source.flatten_keys().to_dict() else: - raise ValueError( - f"Unsupported source type for state_dict extraction: {type(source)}" + torchrl_logger.warning( + f"Unsupported source type for TensorDict extraction: {type(source)}" ) + return {} + else: + raise ValueError( + f"Unknown extract_as: {self.extract_as}. Must be 'tensordict' or 'state_dict'." + ) - def apply_weights(self, destination: Any, weights: Any) -> None: + def apply_weights( + self, destination: Any, weights: Any, inplace: bool = True + ) -> None: """Apply weights to destination model. The format is automatically detected from the weights type: @@ -761,6 +189,7 @@ def apply_weights(self, destination: Any, weights: Any) -> None: - TensorDictBase: TensorDict - dict: State dictionary weights: The weights to apply (dict or TensorDictBase). + inplace: Whether to apply weights in place. """ if weights is None: return @@ -771,29 +200,40 @@ def apply_weights(self, destination: Any, weights: Any) -> None: if any("." in key for key in weights.keys()): weights = weights.unflatten_keys(".") if isinstance(destination, nn.Module): - destination = TensorDict.from_module(destination) + # Do not update in-place + if not inplace: + weights.to_module(destination) + return + else: + destination = TensorDict.from_module(destination) elif isinstance(destination, dict): + if not inplace: + raise ValueError("Cannot update state_dict out of place") destination = TensorDict(destination) if any(isinstance(key, str) and "." in key for key in destination.keys()): destination = destination.unflatten_keys(".") - if isinstance(weights, TensorDictBase): - # Apply TensorDict format - if isinstance(destination, TensorDictBase): - try: - destination.data.update_(weights.data) - except Exception as e: - raise KeyError( - f"Error updating destination: {e}. Destination keys: {destination.keys(True, True)}, weights keys: {weights.keys(True, True)}" - ) - else: - raise ValueError( - f"Unsupported destination type for TensorDict: {type(destination)}" - ) - else: + if not isinstance(weights, TensorDictBase): raise ValueError( - f"Unsupported weights type: {type(weights)}. Expected dict or TensorDictBase." + f"Unsupported weights type: {type(weights)}. Must be dict or TensorDictBase." ) + if not isinstance(destination, TensorDictBase): + if not weights.is_empty(): + raise ValueError( + "Non-empty weights are associated with a non-dict, non-td, non-Module destination." + ) + return + + try: + if not inplace: + destination.update(weights) + else: + destination.data.update_(weights.data) + except Exception as e: + raise KeyError( + f"Error updating destination. Destination keys: {destination.keys(True, True)}, weights keys: {weights.keys(True, True)}" + ) from e + return def _get_strategy(strategy: Literal["tensordict", "state_dict"]) -> WeightStrategy: @@ -813,917 +253,574 @@ def _get_strategy(strategy: Literal["tensordict", "state_dict"]) -> WeightStrate # ============================================================================ -# Sender (Trainer/Main Process Side) +# Weight Synchronization Schemes # ============================================================================ -class WeightSender: - """Sends weights for ONE model to workers. +class WeightSyncScheme(metaclass=abc.ABCMeta): + """Configuration for how to synchronize ONE model across workers. + + A scheme manages synchronization of ONE model across workers. + The collector maintains a dict of {model_id: scheme} pairs. - A single sender can broadcast to all workers or send to specific workers. - Created and managed by WeightSyncScheme. Users should not instantiate directly. + This class directly handles both sender and receiver functionality, + with behavior determined by whether init_on_sender() or init_on_receiver() + was called. """ - _transport: TransportBackend | None - _transports: dict[int, TransportBackend] + _model_id: str | None = None - def __init__(self, scheme: WeightSyncScheme): - self._scheme = scheme - self._transports: dict[int, TransportBackend] = {} # worker_idx -> transport - self._transport: TransportBackend | None = None - self._model_id = "policy" # Default model ID - self._strategy = _get_strategy(scheme.strategy) - self._context_ref = None # weakref to collector for model resolution - self._pending_async = False # Track if async send is pending + # Transport management + _sender_transports: dict[int, TransportBackend] | None + _receiver_transport: TransportBackend | None + _shared_transport: TransportBackend | None - def _set_context(self, context: Any, model_id: str | None = None) -> None: - """Set the context object (collector) for model resolution (internal). + # Context and model references + _context_ref: weakref.ReferenceType[Any] | None + _model_ref: weakref.ReferenceType[Any] | None - This is now handled by init_on_sender(). Only kept for internal use. + # Strategy + _strategy: WeightStrategy - Args: - context: The collector instance. - model_id: Optional model identifier (for compatibility with RayModuleTransformSender). - """ - self._context_ref = weakref.ref(context) - if model_id is not None: - self._model_id = model_id + # Worker index (for receiver side) + _worker_idx: int | None - def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: - """Register a worker's communication pipe (internal). + # Background thread + _background_thread = None + _stop_event = None - This is now handled by init_on_sender(). Only kept for internal use. + def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict"): + self.strategy_str = strategy + self._strategy = _get_strategy(strategy) + self._initialized_on_sender = False + self._initialized_on_receiver = False - Args: - worker_idx: The worker index. - pipe_or_context: The pipe connection for this worker. - """ - if worker_idx not in self._transports: - self._transports[worker_idx] = self._scheme.create_transport( - pipe_or_context - ) + # Transport management + self._sender_transports = None # worker_idx -> transport + self._receiver_transport = None + self._shared_transport = None - def _iterate_transports( - self, worker_ids: int | list[int] | None = None - ) -> Iterator[TransportBackend]: - """Iterate over transports for specified workers.""" - if worker_ids is None: - # All workers - if not self._transports: - yield self._transport - else: - yield from self._transports.values() - else: - # Specific workers - if isinstance(worker_ids, int): - worker_ids = [worker_ids] - for worker_id in worker_ids: - if worker_id in self._transports: - yield self._transports[worker_id] - else: - raise ValueError(f"Worker {worker_id} not registered") - - def send( - self, - weights: Any = None, - worker_ids: int | list[int] | None = None, - ) -> None: - """Send weights synchronously to workers. - - This method: - 1. Prepares weights (extracts from model if weights=None) - 2. Sends to specified workers (or all if worker_ids=None) - 3. Waits for acknowledgments from those workers - 4. Returns when workers have applied the weights - - Args: - weights: Weights to send. Can be: - - None: Extract from model via context.get_model(model_id) - - nn.Module: Extract weights from module - - TensorDict: Use directly - - dict: Convert to TensorDict - worker_ids: Which workers to send to: - - None: Send to all workers (default) - - int: Send to single worker - - list[int]: Send to specific workers - - Note: This is a blocking call that ensures specified workers are updated - before returning. - """ - if self._pending_async: - raise RuntimeError( - "Cannot call send() while an async send is pending. Call wait_async() first." - ) + # Context and model references + self._context_ref = None + self._model_ref = None - model_id = getattr(self, "_model_id", "policy") - context = self._context_ref() if self._context_ref is not None else None + # Worker index + self._worker_idx = None - # Let the scheme prepare the weights - prepared_weights = self._scheme.prepare_weights( - weights=weights, - model_id=model_id, - strategy=self._strategy, - context=context, - ) + # ======================================================================== + # Initialization + # ======================================================================== - transports = list(self._iterate_transports(worker_ids)) + @property + def strategy(self) -> WeightStrategy: + return self._strategy - # Send to all workers first (non-blocking if transport supports it) - for transport in transports: - if hasattr(transport, "send_weights_async"): - transport.send_weights_async(model_id, prepared_weights) - else: - # Fallback for transports that don't support async send - transport.send_weights(model_id, prepared_weights) - - # Wait for all acknowledgments - for transport in transports: - if hasattr(transport, "wait_ack"): - transport.wait_ack() + @strategy.setter + def strategy(self, value: WeightStrategy) -> None: + self._strategy = value - def send_async( + @overload + def init_on_sender( self, - weights: Any = None, - worker_ids: int | list[int] | None = None, + *, + model_id: str, + context: Any, ) -> None: - """Send weights asynchronously to workers (non-blocking). - - This initiates the send but returns immediately without waiting - for workers to acknowledge. You must call wait_async() before - the next send_async() or send() call. - - Args: - weights: Same as send() - worker_ids: Same as send() - - Raises: - RuntimeError: If a previous send_async() is still pending - """ - if self._pending_async: - raise RuntimeError( - "Cannot call send_async() again while a previous send is pending. Call wait_async() first." - ) - - model_id = getattr(self, "_model_id", "policy") - context = self._context_ref() if self._context_ref is not None else None - - # Let the scheme prepare the weights - prepared_weights = self._scheme.prepare_weights( - weights=weights, - model_id=model_id, - strategy=self._strategy, - context=context, - ) + ... - # Store transports for wait_async - self._pending_transports = list(self._iterate_transports(worker_ids)) + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + model_id: str | None = None, + ) -> None: + ... - # Send to all workers (non-blocking) - for transport in self._pending_transports: - if hasattr(transport, "send_weights_async"): - transport.send_weights_async(model_id, prepared_weights) - else: - raise RuntimeError( - f"transport of type {type(transport)} does not support async send." - ) + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + ) -> None: + ... - self._pending_async = True + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + ) -> None: + ... - def wait_async(self) -> None: - """Wait for a pending async send to complete. + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... - Blocks until all workers have acknowledged the previous send_async(). - This must be called after send_async() before any subsequent sends. + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + ) -> None: + ... - Raises: - RuntimeError: If no async send is pending - """ - if not self._pending_async: - raise RuntimeError("No async send is pending. Call send_async() first.") + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... - # Wait for all acknowledgments - for transport in self._pending_transports: - if hasattr(transport, "wait_ack"): - transport.wait_ack() + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + ) -> None: + ... - self._pending_async = False - self._pending_transports = None + @overload + def init_on_sender( + self, + *, + model: nn.Module, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + model_id: str | None = None, + ) -> None: + ... - # Legacy method - kept for backward compatibility - def update_weights(self, weights: Any) -> None: - """Send weights to ALL workers for this model (legacy). + @overload + def init_on_sender(self): + ... - Args: - weights: Weights to send (can be None, nn.Module, TensorDict, etc.). + def init_on_sender( + self, + *args, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). - Note: - This is the legacy method. Use send() instead. + This method is called once in the collector's _run_processes() method, + after workers have been started and are ready to receive messages. """ - self.send(weights=weights) - - def __getstate__(self): - """Pickle support: discard context weakref.""" - state = self.__dict__.copy() - state["_context_ref"] = None - state["_pending_async"] = False - state["_pending_transports"] = None - return state - - def __setstate__(self, state): - """Pickle support: restore state without context.""" - self.__dict__.update(state) - - -# ============================================================================ -# Receiver (Worker Process Side) -# ============================================================================ - - -class WeightReceiver: - """Receives weights for ONE model in ONE worker. - - Created and managed by WeightSyncScheme. Users should not instantiate directly. - """ - - def __init__(self, scheme: WeightSyncScheme): - self._scheme = scheme - self._context_ref = None # weakref to inner_collector - self._transport = None # lazy - self._model_ref = None - self._strategy = _get_strategy(scheme.strategy) - - def _set_context(self, context: Any) -> None: - """Set the context object (inner_collector) for resolving references (internal). + self._initialized_on_sender = True + try: + result = self._init_on_sender_impl(*args, **kwargs) + except Exception: + self._initialized_on_sender = False + raise + return result - This is now handled by init_on_worker(). Only kept for internal use. + def _init_on_sender_impl(self, *args, **kwargs): + raise NotImplementedError - Args: - context: The inner collector instance in the worker process. - """ - self._context_ref = weakref.ref(context) + @property + def initialized_on_sender(self): + return getattr(self, "_initialized_on_sender", False) - def _register_model(self, model_ref: Any) -> None: - """Register the model to apply weights to (internal). + @property + def initialized_on_receiver(self): + return getattr(self, "_initialized_on_receiver", False) - This is now handled by init_on_worker(). Only kept for internal use. + @overload + def init_on_receiver( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... - Args: - model_ref: Either a direct object reference or a string path like 'policy' or 'env.value_net'. - """ - self._model_ref = model_ref + @overload + def init_on_receiver( + self, + model_id: str, + context: None = None, + *, + worker_idx: int = ..., + model: Any | None = None, + **kwargs, + ) -> None: + ... - def _register_worker_transport(self, pipe: Any) -> None: - """Register this worker's communication pipe (internal). + def init_on_receiver( + self, + *, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). - This is now handled by init_on_worker(). Only kept for internal use. + This method is called once in each worker's initialization. Args: - pipe: The pipe connection for this worker. + model_id: Identifier for the model being synchronized + context: Optional context object (e.g., inner collector) + **kwargs: Alternative to context (model, etc.) """ - self._transport = self._scheme.create_transport(pipe) + self._initialized_on_receiver = True + try: + result = self._init_on_receiver_impl( + model_id=model_id, context=context, **kwargs + ) + except Exception: + self._initialized_on_receiver = False + raise + return result - def receive(self, timeout: float = 0.001) -> bool: - """Check for and apply new weights (non-blocking). + def _init_on_receiver_impl( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + raise NotImplementedError - This method is called in the worker's main loop to check if - new weights have been sent. If weights are available, they - are applied to the registered model immediately. + # ======================================================================== + # Context and Model Management + # ======================================================================== - Args: - timeout: Maximum time to wait for weights (seconds). - Use 0 for immediate return. + @property + def context(self) -> Any | None: + """Get the context object (e.g., collector), if available. Returns: - True if weights were received and applied - False if no weights were available - - Note: For SharedMemWeightSyncScheme, this always returns False - since workers automatically see updates via shared memory. + The context object if available, None otherwise. """ - if self._transport is None: - return False - - # Try to receive weights - result = self._transport.receive_weights(timeout=timeout) - if result is None: - return False - - model_id, weights = result - - # Apply weights to the model - if self._model_ref is None: - raise ValueError("No model registered") - - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights) - - # Send acknowledgment if transport supports it - if hasattr(self._transport, "send_ack"): - self._transport.send_ack("updated") - - return True + if self._context_ref is not None: + return self._context_ref() + return None - def apply_weights(self, weights: Any) -> None: - """Apply received weights to registered model (legacy). + @context.setter + def context(self, context: Any) -> None: + """Set the context object for resolving references. Args: - weights: The weights to apply. - - Note: - This is the legacy method. Use receive() in the worker loop instead. + context: The context object to resolve references from. """ - if self._model_ref is None: - raise ValueError("No model registered") - - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights) - - # Send acknowledgment if transport supports it - if hasattr(self._transport, "send_ack"): - self._transport.send_ack("updated") - - def _resolve_model_ref(self) -> Any: - """Resolve model reference to actual object.""" - if isinstance(self._model_ref, str): - if self._context_ref is None: - raise ValueError("Context is required to resolve string references") - context = self._context_ref() - if context is None: - raise ValueError("Context has been garbage collected") - return _resolve_model(context, self._model_ref) - return self._model_ref - - def __getstate__(self): - """Pickle support: discard context weakref.""" - state = self.__dict__.copy() - state["_context_ref"] = None - return state - - def __setstate__(self, state): - """Pickle support: restore state without context.""" - self.__dict__.update(state) - - -class RayModuleTransformSender(WeightSender): - """Specialized sender for :class:`~torchrl.envs.transforms.module.RayModuleTransform` actors. - - This sender handles weight updates for models hosted within Ray actors. - Unlike the base WeightSender which uses pipes for multiprocessing, - this sender directly communicates with Ray actors via their remote methods. - - For Ray actors, there is typically only one shared actor instance, so we - store a single transport rather than per-worker transports. - """ + if context is not None: + self._context_ref = weakref.ref(context) + else: + self._context_ref = None - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) - self._actor_ref = None - self._single_transport = None - self._context_ref = None - self._model_id_str = None + @property + def model_id(self) -> str | None: + """Get the model ID for this scheme. - def _set_context(self, context: Any, model_id: str) -> None: - """Set context for lazy actor resolution (internal). + Returns: + The model ID if set, None otherwise. + """ + return self._model_id - This is now handled by init_on_sender(). Only kept for internal use. + @model_id.setter + def model_id(self, model_id: str) -> None: + """Set the model ID for this scheme. Args: - context: The collector instance. - model_id: String path to the Ray actor (e.g., "env.transform[0]"). + model_id: The model ID to set. """ - self._context_ref = weakref.ref(context) - self._model_id_str = model_id + self._model_id = model_id - def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: - """For Ray actors, worker registration is a no-op (internal). + @property + def worker_idx(self) -> int | None: + """Get the worker index for this scheme. - Ray actors are shared across all workers, so we don't need per-worker - transports. The actor reference is resolved lazily on first use. + Returns: + The worker index if set, None otherwise. """ + return self._worker_idx - def update_weights(self, weights: Any) -> None: - """Send weights to the Ray actor. + @worker_idx.setter + def worker_idx(self, worker_idx: int | None) -> None: + """Set the worker index for this scheme. Args: - weights: Weights to send. + worker_idx: The worker index to set. """ - if self._single_transport is None: - self._initialize_transport() - - if self._single_transport is not None: - model_id = getattr(self, "_model_id", "policy") - self._single_transport.send_weights(model_id, weights) - - def _initialize_transport(self) -> None: - """Lazily initialize the transport by resolving the actor reference.""" - if self._context_ref is None or self._model_id_str is None: - return - - context = self._context_ref() - if context is None: - return - - model = _resolve_model(context, self._model_id_str) - if hasattr(model, "_actor"): - self._actor_ref = model._actor - self._single_transport = self._scheme.create_transport(model) - elif type(model).__name__ == "ActorHandle": - self._actor_ref = model - self._single_transport = self._scheme.create_transport(model) - - -class RayModuleTransformReceiver(WeightReceiver): - """Specialized receiver for RayModuleTransform actors. - - This receiver handles weight updates within Ray actors. - Since Ray actors receive weights through direct method calls, - this receiver primarily validates and applies weights locally. - """ - - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) - - def _register_worker_transport(self, actor_or_context: Any) -> None: - """Register the Ray actor's transport (internal). + if self.initialized_on_sender and worker_idx is not None: + raise RuntimeError( + "Worker index cannot be set after initialization on sender" + ) + self._worker_idx = worker_idx - This is now handled by init_on_worker(). Only kept for internal use. + @property + def model(self) -> Any | None: + """Get the model object, if available. - Args: - actor_or_context: Either a Ray actor reference or a context object. + Returns: + The model object if available, None otherwise. """ - self._transport = self._scheme.create_transport(actor_or_context) - - def apply_weights(self, weights: Any) -> None: - """Apply received weights to registered model. + if self._model_ref is not None: + return self._model_ref() + if self._model_id is not None: + model = _resolve_model(self.context, self._model_id) + if model is None: + raise AttributeError( + f"Model {self._model_id} was `None` in context {self.context}" + ) + self._model_ref = weakref.ref(model) + return model - For Ray actors, weights are applied directly to the module - within the actor's process space. + @model.setter + def model(self, model: Any) -> None: + """Set the model object for applying weights. Args: - weights: The weights to apply. + model: The model object to apply weights to. """ - if self._model_ref is None: - raise ValueError("No model registered") - - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights) + if model is not None: + self._model_ref = weakref.ref(model) + else: + self._model_ref = None + @property + def weights(self) -> Any | None: + """Get the current weights, if available. -# ============================================================================ -# Weight Synchronization Schemes -# ============================================================================ + Returns: + The weights as TensorDict if available, None otherwise. + """ + if (weights := getattr(self, "_weights", None)) is not None: + return weights + model = self.model + if model is not None: + return self._strategy.extract_weights(model) + return None + @weights.setter + def weights(self, value: Any): + self._weights = value -class WeightSyncScheme(metaclass=abc.ABCMeta): - """Configuration for how to synchronize ONE model across workers. + def _get_weights_buffer_from_model(self, model: nn.Module | Any) -> TensorDictBase: + from torchrl.collectors.utils import _cast - A scheme manages synchronization of ONE model across workers. - The collector maintains a dict of {model_id: scheme} pairs. - """ + if isinstance(model, torch.nn.Module): + td = TensorDict.from_module(model) + td = td.data.apply(_cast, td) + return td + # Return an empty TD + return TensorDict() - def __init__(self, strategy: Literal["state_dict", "tensordict"] = "state_dict"): - self.strategy = strategy - self._sender = None - self._receiver = None - self._initialized_on_sender = False - self._initialized_on_worker = False + # ======================================================================== + # Transport Management + # ======================================================================== - def init_on_sender( + def _register_worker_sender( self, - model_id: str, - context: Any = None, - **kwargs, + *, + worker_idx: int, + transport: TransportBackend | None = None, + **transport_kwargs, ) -> None: - """Initialize on the main process (sender side). - - This method is called once in the collector's _run_processes() method, - after workers have been started and are ready to receive messages. + """Register a worker's communication. Args: - model_id: Identifier for the model being synchronized - context: Optional context object (e.g., collector) providing: - - .pipes: list[mp.Connection] - - .get_model(model_id: str) -> nn.Module - - .get_cached_weights(model_id: str) -> TensorDict | None - - .num_workers: int - **kwargs: Alternative to context (pipes, num_workers, model, cached_weights, etc.) + worker_idx: The worker index. + transport: Optional pre-created transport. + **transport_kwargs: Transport-specific configuration. """ - raise NotImplementedError + if self._sender_transports is None: + if self._shared_transport is not None: + raise RuntimeError( + "Cannot register transports on sender after shared transport is set" + ) + self._sender_transports = {} + if worker_idx not in self._sender_transports: + if transport is not None: + self._sender_transports[worker_idx] = transport + else: + self._sender_transports[worker_idx] = self.create_transport( + **transport_kwargs + ) - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, + def _register_transport_receiver( + self, transport: TransportBackend | None = None, **transport_kwargs ) -> None: - """Initialize on worker process (receiver side). - - This method is called once in each worker's initialization. + """Register a single transport (for receiver side). Args: - model_id: Identifier for the model being synchronized - context: Optional context object (e.g., inner collector) providing: - - .pipe: mp.Connection - - .get_model(model_id: str) -> nn.Module - **kwargs: Alternative to context (pipe, model, etc.) - """ - raise NotImplementedError - - def get_sender(self) -> WeightSender: - """Get the sender instance. - - Returns: - Sender instance for sending weights to workers - - Raises: - RuntimeError: If init_on_sender() hasn't been called yet + transport: Optional pre-created transport. + **transport_kwargs: Transport-specific configuration. """ - if not self._initialized_on_sender or self._sender is None: - raise RuntimeError( - f"Must call init_on_sender() before get_sender() on {type(self).__name__}" - ) - return self._sender - - def get_receiver(self) -> WeightReceiver: - """Get the receiver instance. - - Returns: - Receiver instance for receiving weights in this worker - - Raises: - RuntimeError: If init_on_worker() hasn't been called yet - """ - if not self._initialized_on_worker or self._receiver is None: - raise RuntimeError( - f"Must call init_on_worker() before get_receiver() on {type(self).__name__}" - ) - return self._receiver - - def __getstate__(self): - """Prepare the scheme for pickling by excluding non-serializable runtime state. - - Sender and receiver objects contain pipes, weak references, and other - non-serializable resources that should not be pickled. These will be - re-initialized when needed after unpickling. - """ - state = self.__dict__.copy() - # Remove non-serializable runtime state - state["_sender"] = None - state["_receiver"] = None - state["_initialized_on_sender"] = False - state["_initialized_on_worker"] = False - return state + if transport is not None: + self._receiver_transport = transport + else: + self._receiver_transport = self.create_transport(**transport_kwargs) - def __setstate__(self, state): - """Restore the scheme from pickling.""" - self.__dict__.update(state) + def _iterate_transports( + self, worker_ids: int | list[int] | None = None + ) -> Iterator[TransportBackend]: + """Iterate over transports for specified workers.""" + if worker_ids is None: + # All workers + if not self.sender_transports: + if self.receiver_transport is not None: + yield self.receiver_transport + else: + # Make sure transports are sorted + for k in sorted(self.sender_transports.keys()): + yield self.sender_transports[k] + else: + # Specific workers + if isinstance(worker_ids, int): + worker_ids = [worker_ids] + for worker_id in worker_ids: + if worker_id in self.sender_transports: + yield self.sender_transports[worker_id] + else: + raise ValueError(f"Worker {worker_id} not registered") - # Legacy methods - kept for backward compatibility @abc.abstractmethod - def create_transport(self, pipe_or_context: Any) -> TransportBackend: + def create_transport(self, **kwargs) -> TransportBackend: """Create transport for communication. Args: - pipe_or_context: Either a pipe connection or context object to extract pipe from. + **kwargs: Transport-specific configuration parameters. Returns: A transport backend instance. + + Note: + This is used internally by init_on_sender/init_on_receiver. """ ... - def create_sender(self) -> WeightSender: - """Create a sender for this scheme (legacy). + @property + def sender_transports(self) -> dict[int, TransportBackend]: + """Get the sender transports. Returns: - WeightSender instance configured for this scheme. + The sender transports. """ - return WeightSender(self) + if self._shared_transport is not None: + return defaultdict(lambda: self._shared_transport) + return self._sender_transports - def create_receiver(self) -> WeightReceiver: - """Create a receiver for this scheme (legacy). + @property + def receiver_transport(self) -> TransportBackend | None: + """Get the receiver transport. Returns: - WeightReceiver instance configured for this scheme. + The receiver transport. """ - return WeightReceiver(self) - - def prepare_weights( - self, - weights: Any, - model_id: str, - strategy: WeightStrategy, - context: Any = None, - ) -> Any: - """Prepare weights for sending. - - This method handles weight extraction, conversion, and any scheme-specific - preparation (e.g., cache lookups for SharedMemWeightSyncScheme). + if self._shared_transport is not None: + return self._shared_transport + return self._receiver_transport - Args: - weights: Raw weights input (can be None, nn.Module, TensorDict, dict, str reference, etc.) - model_id: The model identifier (e.g., "policy") - strategy: WeightStrategy for extracting/converting weights - context: Optional context (e.g., collector) for model resolution + @property + def shared_transport(self) -> TransportBackend | None: + """Get the shared transport. Returns: - Prepared weights ready to send via transport + The shared transport. """ - # Default implementation: extract from model or pass through - if weights is None and context is not None: - # Try to resolve and extract from model in context - try: - model = _resolve_model(context, model_id) - return strategy.extract_weights(model) - except (AttributeError, KeyError): - pass - # Try fallback policy - if model_id == "policy" and hasattr(context, "_fallback_policy"): - if context._fallback_policy is not None: - return strategy.extract_weights(context._fallback_policy) - return None - - if isinstance(weights, nn.Module): - return strategy.extract_weights(weights) - elif isinstance(weights, str): - # String reference to model - if context is not None: - model = _resolve_model(context, weights) - return strategy.extract_weights(model) - raise ValueError( - f"Cannot resolve string reference '{weights}' without context" + if self._receiver_transport is not None: + raise RuntimeError( + "Receiver transport and shared transport cannot be used together" ) - else: - # Already extracted weights (TensorDict, dict, etc.) - return weights - - -class MultiProcessWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for multiprocess operations using pipes. - - This scheme creates transports that communicate via multiprocessing pipes. - """ - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipes and num_workers - **kwargs: Alternative to context (pipes, num_workers, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - pipes = getattr(context, "pipes", None) - num_workers = getattr(context, "num_workers", None) - else: - pipes = kwargs.get("pipes") - num_workers = kwargs.get("num_workers") - - if pipes is None: - raise ValueError("pipes must be provided via context or kwargs") - if num_workers is None: - num_workers = len(pipes) if pipes else 0 - - # Create sender and register all workers - sender = WeightSender(self) - sender._model_id = model_id - if context is not None: - sender._context_ref = weakref.ref(context) - - for worker_idx, pipe in enumerate(pipes): - sender._register_worker(worker_idx, pipe) - - self._sender = sender - self._initialized_on_sender = True + if self._sender_transports is not None: + raise RuntimeError( + "Sender transports and shared transport cannot be used together" + ) + return self._shared_transport - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). + @shared_transport.setter + def shared_transport(self, shared_transport: TransportBackend | None) -> None: + """Set the shared transport. Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipe and model - **kwargs: Alternative to context (pipe, model, etc.) + shared_transport: The shared transport to set. """ - # Extract parameters from context or kwargs - if context is not None: - pipe = getattr(context, "pipe", None) - if hasattr(context, "get_model"): - model = context.get_model(model_id) - else: - model = None - else: - pipe = kwargs.get("pipe") - model = kwargs.get("model") - - if pipe is None: - raise ValueError("pipe must be provided via context or kwargs") - - # Create receiver and register model - receiver = WeightReceiver(self) - if context is not None: - receiver._context_ref = weakref.ref(context) - receiver._register_worker_transport(pipe) - if model is not None: - receiver._register_model(model) - else: - # Register by model_id for later resolution - receiver._register_model(model_id) + self._shared_transport = shared_transport - self._receiver = receiver - self._initialized_on_worker = True + # ======================================================================== + # Sending Weights (Sender Side) + # ======================================================================== - def create_transport(self, pipe: Any) -> TransportBackend: - """Create an MPTransport using the provided pipe (legacy).""" - return MPTransport(pipe) - - -class SharedMemWeightSyncScheme(WeightSyncScheme): - """Weight synchronization using shared memory. - - This scheme mimics the old WeightUpdater behavior by using shared memory - for in-place weight updates. Workers automatically see weight updates - without explicit message passing. - - By default, this scheme uses lazy registration: models are automatically - registered on the first weight send. This makes it seamless to use with - configuration systems like Hydra where schemes are created before models - are available. - - Args: - policy_weights: Dictionary mapping model_id to shared TensorDict weights. - Can be empty if using lazy registration (auto_register=True). - strategy: The weight transmission strategy (default: "tensordict"). - auto_register: Whether to automatically register models on first weight send. - Default is True. Set to False to require explicit registration via - register_shared_weights(). - - Example: - >>> # With auto-registration (default) - works with Hydra configs - >>> scheme = SharedMemWeightSyncScheme() - >>> # Models are auto-registered on first weight send - - >>> # With explicit registration - >>> scheme = SharedMemWeightSyncScheme(auto_register=False) - >>> shared_weights = TensorDict.from_module(model).share_memory_() - >>> scheme.register_shared_weights("policy", shared_weights) - """ - - def __init__( - self, - policy_weights: dict[str, TensorDictBase] | None = None, - strategy: str = "tensordict", - auto_register: bool = True, - ): - super().__init__(strategy) - self.policy_weights = policy_weights if policy_weights is not None else {} - self.auto_register = auto_register - # Create a single shared transport for all workers - self._shared_transport = SharedMemTransport( - self.policy_weights, auto_register=auto_register - ) - - def register_shared_weights(self, model_id: str, weights: TensorDictBase) -> None: - """Register shared memory weights for a model. - - This method allows explicit registration of shared weights. It's optional - when auto_register=True (the default), but required when auto_register=False. - - Args: - model_id: Identifier for the model. - weights: Shared memory TensorDict containing the model's weights. - """ - # Don't set self.policy_weights[model_id] here - register_weights does that - # (self.policy_weights and transport._policy_weights are the same dict) - self._shared_transport.register_weights(model_id, weights) - - def init_on_sender( + def send( self, - model_id: str, - context: Any = None, - **kwargs, + weights: Any = None, + worker_ids: int | list[int] | None = None, ) -> None: - """Initialize on the main process (sender side). - - For SharedMemWeightSyncScheme, this handles: - 1. Getting cached shared memory weights from context - 2. Pre-registering the weights with the transport - 3. Distributing buffer references to all workers (avoiding later deadlock) - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipes, cached_weights - **kwargs: Alternative to context (pipes, cached_weights, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - pipes = getattr(context, "pipes", None) - num_workers = getattr(context, "num_workers", None) - # Try to get cached shared memory weights - if hasattr(context, "get_cached_weights"): - cached_weights = context.get_cached_weights(model_id) - else: - cached_weights = None - else: - pipes = kwargs.get("pipes") - num_workers = kwargs.get("num_workers") - cached_weights = kwargs.get("cached_weights") - - if pipes is None: - raise ValueError("pipes must be provided via context or kwargs") - if num_workers is None: - num_workers = len(pipes) if pipes else 0 - - # Register pipes with shared transport for lazy buffer distribution - for pipe in pipes: - self._shared_transport.register_pipe(pipe) - - # If we have cached shared memory weights, pre-register them - if cached_weights is not None: - # Check if already registered to avoid re-registration error - if model_id not in self.policy_weights: - self.register_shared_weights(model_id, cached_weights) - - # Send buffer references for any weights that were pre-registered - # before pipes were available (e.g., via explicit register_shared_weights call) - if model_id in self.policy_weights: - if model_id not in self._shared_transport._registered_with_workers: - self._shared_transport._send_buffer_to_workers( - model_id, self.policy_weights[model_id] - ) - - # Create sender with the shared transport - sender = WeightSender(self) - sender._model_id = model_id - sender._transport = self._shared_transport # Use shared transport - if context is not None: - sender._context_ref = weakref.ref(context) - - self._sender = sender - self._initialized_on_sender = True + """Send weights synchronously to workers. - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). + This method: + 1. Prepares weights (extracts from model if weights=None) + 2. Sends to specified workers (or all if worker_ids=None) + 3. Waits for acknowledgments from those workers + 4. Returns when workers have applied the weights Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipe and model - **kwargs: Alternative to context (pipe, model, etc.) + weights: Weights to send. Can be: + - None: Extract from model via context.get_model(model_id) + - nn.Module: Extract weights from module + - TensorDict: Use directly + - dict: Convert to TensorDict + worker_ids: Which workers to send to: + - None: Send to all workers (default) + - int: Send to single worker + - list[int]: Send to specific workers + + Note: This is a blocking call that ensures specified workers are updated + before returning. """ - # Extract parameters from context or kwargs - if context is not None: - getattr(context, "pipe", None) - if hasattr(context, "get_model"): - model = context.get_model(model_id) - else: - model = None - else: - model = kwargs.get("model") + if not self.initialized_on_sender: + raise RuntimeError("Must be initialized on sender before sending weights") + if not self.synchronized_on_sender: + raise RuntimeError("Must be synchronized on sender before sending weights") - # For shared memory, we don't need the pipe in the receiver - # The transport is shared and workers see updates automatically + context = self.context - # Create receiver with the shared transport - receiver = WeightReceiver(self) - if context is not None: - receiver._context_ref = weakref.ref(context) - receiver._transport = self._shared_transport # Use shared transport - if model is not None: - receiver._register_model(model) - else: - # Register by model_id for later resolution - receiver._register_model(model_id) + # Let the scheme prepare the weights + torchrl_logger.debug("Preparing weights") + prepared_weights = self.prepare_weights( + weights=weights, + model_id=self._model_id, + strategy=self._strategy, + context=context, + ) - self._receiver = receiver - self._initialized_on_worker = True + transports = list(self._iterate_transports(worker_ids)) - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create shared memory transport and register pipe for lazy buffer distribution (legacy). + if not transports: + raise RuntimeError("No transports available.") - For lazy registration to work, we register each worker's pipe with the transport. - On first weight send, the transport will send buffer references via these pipes. + # Send to all workers first (non-blocking if transport supports it) + torchrl_logger.debug(f"Sending over transports {transports}") + for transport in transports: + if hasattr(transport, "send_weights_async"): + torchrl_logger.debug( + f"Sending {type(prepared_weights)=} through {type(transport)=} asynchronously." + ) + transport.send_weights_async(prepared_weights) + else: + # Fallback for transports that don't support async send + torchrl_logger.debug( + f"Sending {type(prepared_weights)=} through {type(transport)=} synchronously." + ) + transport.send_weights(prepared_weights) - Returns the shared transport instance that all workers will use. - Since this is shared memory, there's only one transport shared by all workers. - """ - # Register the pipe for lazy buffer distribution - if pipe_or_context is not None: - self._shared_transport.register_pipe(pipe_or_context) - return self._shared_transport + # Wait for all acknowledgments + torchrl_logger.debug("Waiting for acknowledgement") + for transport in transports: + if hasattr(transport, "wait_ack"): + transport.wait_ack() def prepare_weights( self, @@ -1732,447 +829,405 @@ def prepare_weights( strategy: WeightStrategy, context: Any = None, ) -> Any: - """Prepare weights for SharedMemWeightSyncScheme. + """Prepare weights for sending. - For SharedMemWeightSyncScheme, we prioritize using cached shared memory weights - from the context (collector) to avoid extracting fresh (non-shared) weights. + This method handles weight extraction, conversion, and any scheme-specific + preparation (e.g., cache lookups for SharedMemWeightSyncScheme). Args: - weights: Raw weights input - model_id: The model identifier + weights: Raw weights input (can be None, nn.Module, TensorDict, dict, str reference, etc.) + model_id: The model identifier (e.g., "policy") strategy: WeightStrategy for extracting/converting weights - context: Optional context (e.g., collector) for cache lookup + context: Optional context (e.g., collector) for model resolution Returns: - Shared memory weights ready to send + Prepared weights ready to send via transport """ - # If no weights provided, check for cached shared memory weights in collector + # Default implementation: extract from model or pass through if weights is None and context is not None: - if model_id == "policy" and hasattr(context, "_policy_weights_dict"): - policy_device = ( - context.policy_device - if not isinstance(context.policy_device, (list, tuple)) - else context.policy_device[0] - ) - cached_weights = context._policy_weights_dict.get(policy_device) - if cached_weights is not None: - return cached_weights - - # Fall back to default behavior - return super().prepare_weights(weights, model_id, strategy, context) + # Try to resolve and extract from model in context + try: + model = _resolve_model(context, model_id) + return strategy.extract_weights(model) + except (AttributeError, KeyError): + pass + # Try fallback policy + if model_id == "policy" and hasattr(context, "_fallback_policy"): + if context._fallback_policy is not None: + return strategy.extract_weights(context._fallback_policy) + return None + if isinstance(weights, nn.Module): + return strategy.extract_weights(weights) + elif isinstance(weights, str): + # String reference to model + if context is not None: + model = _resolve_model(context, weights) + return strategy.extract_weights(model) + raise ValueError( + f"Cannot resolve string reference '{weights}' without context" + ) + else: + # Already extracted weights (TensorDict, dict, etc.) + return weights -class NoWeightSyncScheme(WeightSyncScheme): - """No-op weight synchronization scheme. + # ======================================================================== + # Receiving Weights (Receiver Side) + # ======================================================================== - This scheme disables weight synchronization entirely. - """ + def receive(self, timeout: float | None = None) -> TensorDictBase | None: + """Check for and apply new weights (non-blocking). - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). + This method is called in the worker's main loop to check if + new weights have been sent. If weights are available, they + are applied to the registered model immediately, and the update + is cascaded to any sub-collectors via context.update_policy_weights_(). Args: - model_id: Identifier for the model being synchronized - context: Optional context object (not used) - **kwargs: Optional parameters (not used) - """ - # Create a no-op sender - sender = WeightSender(self) - sender._model_id = model_id - - self._sender = sender - self._initialized_on_sender = True + timeout: Maximum time to wait for weights (seconds). + None means no timeout (blocking). Some transports may not + support timeout and will raise ValueError if specified. - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). + Returns: + The received weights if available, None otherwise. - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (not used) - **kwargs: Optional parameters (not used) + Note: For SharedMemWeightSyncScheme, this always returns None + since workers automatically see updates via shared memory. """ - # Create a no-op receiver - receiver = WeightReceiver(self) - receiver._model_ref = model_id - - self._receiver = receiver - self._initialized_on_worker = True - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Returns None as no transport is needed (legacy).""" - # Return a dummy transport that does nothing - class NoOpTransport: - def send_weights(self, model_id: str, weights: Any) -> None: - pass - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - return None - - def check_connection(self) -> bool: - return True - - return NoOpTransport() - + if not self.initialized_on_receiver: + raise RuntimeError( + "Must be initialized on receiver before receiving weights" + ) + if not self.synchronized_on_receiver: + raise RuntimeError( + "Must be synchronized on receiver before receiving weights" + ) -class RayWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for Ray distributed computing. + # Determine which transport to use + if self._receiver_transport is not None: + transport = self._receiver_transport + elif self._shared_transport is not None: + # Use shared transport directly (e.g., SharedMemWeightSyncScheme) + transport = self._shared_transport + else: + return None - This scheme uses Ray's object store and remote calls to synchronize weights - across distributed workers (Ray actors). + # Try to receive weights - transport handles receiving and applying + torchrl_logger.debug(f"Calling receive_weights on transport {transport}") + result = transport.receive_weights( + timeout=timeout, + weights=self.weights, + model=self.model, + strategy=self._strategy, + ) + if result is None: + return None - Each remote collector gets its own transport, following the same pattern - as multiprocess collectors. - """ + weights = result + model_id = self._model_id or "policy" + torchrl_logger.debug(f"Received weights for {model_id=}") - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create Ray-based transport for a specific remote collector. + # Cascade weight update to sub-collectors if context supports it + if self.context is not None and hasattr(self.context, "update_policy_weights_"): + torchrl_logger.debug( + f"Cascading weight update to sub-collectors for {model_id=}" + ) + self.context.update_policy_weights_( + model_id=model_id, policy_or_weights=weights + ) - Args: - pipe_or_context: The Ray actor handle for the remote collector. + # Send acknowledgment if transport supports it + if hasattr(transport, "send_ack"): + torchrl_logger.debug(f"Sending acknowledgement on {model_id=}") + transport.send_ack("updated") - Returns: - RayTransport configured for this specific remote collector. - """ - return RayTransport(remote_collector=pipe_or_context) + return weights - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). + def apply_weights(self, weights: TensorDictBase, inplace: bool = True) -> None: + """Apply weights to the model. Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing remote_collectors - **kwargs: Alternative to context (remote_collectors, source_model, etc.) + weights: The weights to apply. + inplace: Whether to apply weights in place. Default is `True`. """ - # Extract parameters from context or kwargs - if context is not None: - remote_collectors = getattr(context, "remote_collectors", None) - num_workers = getattr(context, "num_workers", None) or getattr( - context, "num_collectors", None + if not self.initialized_on_receiver: + if self.initialized_on_sender: + raise RuntimeError("apply_weights() called on a sender side.") + raise RuntimeError( + "apply_weights() called before init_on_receiver has been called." ) - else: - remote_collectors = kwargs.get("remote_collectors") - num_workers = kwargs.get("num_workers") or kwargs.get("num_collectors") - if remote_collectors is None: - raise ValueError("remote_collectors must be provided via context or kwargs") - if num_workers is None: - num_workers = len(remote_collectors) if remote_collectors else 0 + if self._model_ref is None: + raise ValueError("No model registered") - # Create sender and register all workers (Ray actors) - sender = WeightSender(self) - sender._model_id = model_id + model = self.model + self._strategy.apply_weights(model, weights, inplace=inplace) - # Register each Ray actor - _register_worker will create the transport - for worker_idx, remote_collector in enumerate(remote_collectors): - sender._register_worker(worker_idx, remote_collector) + # Send acknowledgment if transport supports it + if self.receiver_transport is not None and hasattr( + self.receiver_transport, "send_ack" + ): + self.receiver_transport.send_ack("updated") - # Set context with weak reference to avoid circular refs - if context is not None: - sender._set_context(weakref.ref(context), model_id) + # ======================================================================== + # Synchronization + # ======================================================================== - # Store source model reference if provided for automatic weight extraction - source_model = kwargs.get("source_model") - if source_model is not None: - sender._source_model = source_model + @overload + def connect(self, *, worker_idx: int | None = None) -> None: + ... - self._sender = sender - self._initialized_on_sender = True + @overload + def connect(self, *, weights: Any | None = None) -> None: + ... - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, + def connect( + self, *, worker_idx: int | None = None, weights: Any | None = None ) -> None: - """Initialize on worker process (receiver side). + """Method to be called once the workers have started. - For Ray workers, weight updates are handled via remote method calls, - so this is typically a no-op. The receiver is created but doesn't - need special initialization. + Triggers a rendez-vous for the workers to receive their copy of the weights. - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (typically the remote collector) - **kwargs: Optional parameters (pipe, model, etc.) + Dispatches to _setup_connection_and_weights_on_sender_impl() or _setup_connection_and_weights_on_receiver_impl() + based on which initialization was performed. """ - # Create receiver - receiver = WeightReceiver(self) - - # Register model if provided - model = kwargs.get("model") or ( - getattr(context, "policy", None) if context else None - ) - if model is not None: - receiver._register_model(model) - - # Set context if provided - if context is not None: - receiver._set_context(weakref.ref(context)) - - self._receiver = receiver - self._initialized_on_worker = True + if self.synchronized_on_receiver or self.synchronized_on_sender: + raise RuntimeError("Cannot synchronize weights on sender twice.") + if self._initialized_on_sender: + torchrl_logger.debug("Synchronizing weights on sender") + if worker_idx is not None: + # Safety check, we can consider removing this in the future. + raise RuntimeError( + "Cannot specify worker_idx on sender side during synchronization." + ) + self.synchronized_on_sender = True + try: + self._setup_connection_and_weights_on_sender_impl(weights=weights) + except Exception: + self.synchronized_on_sender = False + raise + elif self._initialized_on_receiver: + torchrl_logger.debug(f"Synchronizing weights on receiver -- {worker_idx=}") + if weights is not None: + # safety check: weights are passed to sender, not receiver for initial sync + raise RuntimeError( + "Cannot specify weights on receiver side during synchronization." + ) + self.synchronized_on_receiver = True + try: + self._setup_connection_and_weights_on_receiver_impl( + worker_idx=worker_idx + ) + except Exception: + self.synchronized_on_receiver = False + raise + else: + raise RuntimeError( + "Neither init_on_sender nor init_on_receiver have been called." + ) + def _setup_connection_and_weights_on_sender_impl( + self, + *, + worker_idx: int | None = None, + weights: Any | None = None, + ) -> None: + """Synchronize weights on sender side. -class RayModuleTransformScheme(WeightSyncScheme): - """Weight synchronization for RayModuleTransform actors. + Default implementation uses transport's setup_connection_and_weights_on_sender(). + Subclasses may override for custom behavior. + """ + if self._shared_transport is not None: + # We only need to synchronize once + self.shared_transport.setup_connection_and_weights_on_sender() + return - This scheme is designed specifically for updating models hosted within - Ray actors, such as RayModuleTransform instances. It creates a transport - that directly calls the actor's weight update methods. + idx = -1 + for idx, transport in enumerate(self._iterate_transports()): + if worker_idx is not None and idx != worker_idx: + continue + transport.setup_connection_and_weights_on_sender() + if idx == -1: + raise RuntimeError("No transports available.") - Args: - strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). - Default is "tensordict". - """ + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None + ) -> None: + """Synchronize weights on receiver side. - def __init__(self, strategy: str = "tensordict"): - super().__init__(strategy) + Default implementation uses transport's setup_connection_and_weights_on_receiver(). + Subclasses may override for custom behavior. + """ + if self.receiver_transport is None: + return - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create RayActorTransport for the given actor. + # Use stored worker_idx if not provided + if worker_idx is None: + worker_idx = self._worker_idx - Args: - pipe_or_context: Either a Ray actor reference or a context object - from which to extract the actor reference. + # Call transport's synchronize method with all relevant kwargs + weights = self.receiver_transport.setup_connection_and_weights_on_receiver( + worker_idx=worker_idx, + weights=self.weights, + model=self.model, + strategy=self._strategy, + ) - Returns: - RayActorTransport configured with the actor reference. - """ - actor_ref = self._extract_actor_ref(pipe_or_context) - return RayActorTransport(actor_ref=actor_ref, update_method=self.strategy) + # Apply weights to model if received (SharedMemTransport case) + # For other transports (MPTransport, etc.), weights is None and synchronization + # happens later via receive(), so this is a no-op + if weights is not None: + model = self.model + self._strategy.apply_weights(model, weights, inplace=False) - def _extract_actor_ref(self, pipe_or_context: Any) -> Any: - """Extract the Ray actor reference from the context. + @property + def synchronized_on_sender(self): + return getattr(self, "_synchronized_on_sender", False) - Args: - pipe_or_context: Either a direct actor reference or an object - with an `_actor` attribute. + @synchronized_on_sender.setter + def synchronized_on_sender(self, value: bool): + self._synchronized_on_sender = value - Returns: - The Ray actor reference. - """ - if hasattr(pipe_or_context, "_actor"): - return pipe_or_context._actor - return pipe_or_context + @property + def synchronized_on_receiver(self): + return getattr(self, "_synchronized_on_receiver", False) - def create_sender(self) -> RayModuleTransformSender: - """Create a specialized sender for Ray actor communication.""" - return RayModuleTransformSender(self) + @synchronized_on_receiver.setter + def synchronized_on_receiver(self, value: bool): + self._synchronized_on_receiver = value - def create_receiver(self) -> RayModuleTransformReceiver: - """Create a specialized receiver for Ray actor communication.""" - return RayModuleTransformReceiver(self) + # ======================================================================== + # Background Receiver + # ======================================================================== - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). + def _start_background_receiver(self): + """Start daemon thread that monitors for weight update instructions. - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing actor references - **kwargs: Alternative to context (actors, actor_refs, source_model, etc.) + The background thread runs _background_receive_loop() which waits for + instructions via _wait_for_instruction() and calls receive() when + an instruction arrives. """ - # Extract actor references from context or kwargs - if context is not None: - # Could be actor_refs, actors, or remote_collectors - actor_refs = ( - getattr(context, "actor_refs", None) - or getattr(context, "actors", None) - or getattr(context, "remote_collectors", None) - ) - else: - actor_refs = ( - kwargs.get("actor_refs") - or kwargs.get("actors") - or kwargs.get("remote_collectors") + if not self.initialized_on_receiver: + raise RuntimeError( + "_start_background_receiver must be called on the receiver side." ) + self._stop_event = threading.Event() + self._background_thread = threading.Thread( + target=self._background_receive_loop, + daemon=True, + name=f"WeightReceiver-{self._worker_idx}", + ) + self._background_thread.start() + torchrl_logger.debug( + f"{type(self).__name__}: Started background receiver thread for worker {self._worker_idx}" + ) - if actor_refs is None: - raise ValueError( - "actor_refs (or actors) must be provided via context or kwargs" - ) + def _background_receive_loop(self): + """Background thread loop that waits for instructions and receives weights. - # Create specialized sender - sender = self.create_sender() - sender._model_id = model_id + Default implementation uses _wait_for_instruction() and receive(). + Subclasses may override for custom behavior. + """ + while not self._stop_event.is_set(): + try: + instruction = self._wait_for_instruction() + if instruction is None: + # Stop signal received + break + if instruction == "receive": + self.receive() + elif instruction == "stop": + break + else: + torchrl_logger.warning(f"Unknown instruction: {instruction}") + except Exception as e: + if not self._stop_event.is_set(): + torchrl_logger.warning(f"Background receiver error: {e}") - # Register all actors - _register_worker will create the transport - for worker_idx, actor_ref in enumerate(actor_refs): - sender._register_worker(worker_idx, actor_ref) + def _wait_for_instruction(self, timeout: float | None = None) -> str | None: + """Block until an instruction arrives from the sender. - # Set context with weak reference - if context is not None: - sender._set_context(weakref.ref(context), model_id) + This method should be overridden by subclasses to implement + scheme-specific instruction waiting (e.g., queue.get(), store polling). - # Store source model if provided - source_model = kwargs.get("source_model") - if source_model is not None: - sender._source_model = source_model + Args: + timeout: Maximum time to wait for instruction (seconds). + None means block indefinitely. - self._sender = sender - self._initialized_on_sender = True + Returns: + The instruction string (e.g., "receive", "stop"), or None if + stop event is set or timeout expires. + """ + raise NotImplementedError( + f"{type(self).__name__} must implement _wait_for_instruction()" + ) - def init_on_worker( + def _send_instruction( self, - model_id: str, - context: Any = None, - **kwargs, + instruction: str = "receive", + worker_ids: int | list[int] | None = None, ) -> None: - """Initialize on worker process (receiver side). + """Send instruction to receiver(s) to trigger weight reception. + + This method should be overridden by subclasses to implement + scheme-specific instruction sending (e.g., queue.put(), store.set()). Args: - model_id: Identifier for the model being synchronized - context: Optional context object (typically the actor itself) - **kwargs: Optional parameters (actor_ref, model, etc.) + instruction: The instruction to send (default: "receive"). + worker_ids: Which workers to send to (None = all workers). """ - # Create specialized receiver - receiver = self.create_receiver() - - # Extract actor reference if needed - actor_ref = kwargs.get("actor_ref") or context - if actor_ref is not None: - # Register the transport for this actor - transport = self.create_transport(actor_ref) - receiver._register_worker_transport(transport) - - # Register model if provided - model = kwargs.get("model") or ( - getattr(context, "_actor_module", None) or getattr(context, "module", None) - if context - else None + raise NotImplementedError( + f"{type(self).__name__} must implement _send_instruction()" ) - if model is not None: - receiver._register_model(model) - - # Set context if provided - if context is not None: - receiver._set_context(weakref.ref(context)) - - self._receiver = receiver - self._initialized_on_worker = True + def _send_ack(self, message: str = "updated") -> None: + """Send acknowledgment back to sender after receiving weights. -class RPCWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for torch.distributed.rpc. - - This scheme uses RPC calls to synchronize weights across distributed - workers. Each remote collector gets its own transport, following the - same pattern as multiprocess collectors. - """ - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create RPC-based transport for a specific remote collector. + Called by the background receiver after successfully applying weights. + Subclasses should override to implement scheme-specific acknowledgment. Args: - pipe_or_context: A tuple of (collector_info, collector_rref, collector_class) - for the remote collector. - - Returns: - RPCTransport configured for this specific remote collector. + message: The acknowledgment message (default: "updated"). """ - if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 3: - collector_info, collector_rref, collector_class = pipe_or_context - return RPCTransport( - collector_info=collector_info, - collector_rref=collector_rref, - collector_class=collector_class, - ) - # If just passed the info directly - return RPCTransport(collector_info=pipe_or_context) - - -class DistributedWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for torch.distributed. + # Default: use transport's send_ack if available + transport = self._receiver_transport or self._shared_transport + if transport is not None and hasattr(transport, "send_ack"): + transport.send_ack(message) - This scheme uses torch.distributed primitives (send/recv) to synchronize - weights across distributed workers. Each worker gets its own transport, - following the same pattern as multiprocess collectors. - - Args: - backend (str): The distributed backend ("gloo", "nccl", etc.) - sync (bool): Whether to use synchronous weight updates - """ - - def __init__(self, backend: str = "gloo", sync: bool = True): - super().__init__() - self.backend = backend - self.sync = sync + def _wait_for_ack( # noqa: B027 + self, + worker_ids: int | list[int] | None = None, + timeout: float | None = None, + ) -> None: + """Wait for acknowledgment from receiver(s). - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create distributed transport for a specific worker. + Called by send() in synchronous mode to block until receivers confirm. + Subclasses should override to implement scheme-specific waiting. Args: - pipe_or_context: A tuple of (store, rank) for the worker. - - Returns: - DistributedTransport configured for this specific worker. + worker_ids: Which workers to wait for (None = all workers). + timeout: Maximum time to wait (seconds). None means block indefinitely. """ - if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 2: - store, rank = pipe_or_context - return DistributedTransport(store=store, rank=rank, sync=self.sync) - # Fallback - shouldn't normally happen - return DistributedTransport() - - -# ============================================================================ -# Helper Functions -# ============================================================================ + # Default: no-op (subclasses implement scheme-specific waiting) + def __getstate__(self): + """Prepare the scheme for pickling by excluding non-serializable runtime state.""" + state = self.__dict__.copy() + # Remove non-serializable runtime state + state["_context_ref"] = None + state["_model_ref"] = None -def _resolve_model(context: Any, model_id: str) -> Any: - """Resolve model_id like 'policy' or 'env.value_net' to actual object. + state["_initialized_on_sender"] = False + state["_initialized_on_receiver"] = False - Also processes getitem notation like 'env.transform[0]' to actual object. + state["_synchronized_on_sender"] = False + state["_synchronized_on_receiver"] = False - Args: - context: The context object (collector or inner_collector). - model_id: A string address like "policy" or "env.value_net". + state["_background_thread"] = None + state["_stop_event"] = None - Returns: - The object at the specified address. + return state - Examples: - _resolve_model(collector, "policy") # -> collector.policy - _resolve_model(collector, "env.value_net") # -> collector.env.value_net - """ - parts = model_id.split(".") - obj = context - for i, part in enumerate(parts): - if "[" in part: - key, *indices = part.split("[") - indices = [int(index[:-1]) for index in indices] - try: - obj = getattr(obj, key) - except AttributeError: - raise AttributeError( - f"Attribute {key} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" - ) - for index in indices: - obj = obj[index] - else: - try: - obj = getattr(obj, part) - except AttributeError: - raise AttributeError( - f"Attribute {part} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" - ) - return obj + def __setstate__(self, state): + """Restore the scheme from pickling.""" + self.__dict__.update(state) diff --git a/tutorials/sphinx-tutorials/getting-started-3.py b/tutorials/sphinx-tutorials/getting-started-3.py index 7b6dd82e7b0..bc958476235 100644 --- a/tutorials/sphinx-tutorials/getting-started-3.py +++ b/tutorials/sphinx-tutorials/getting-started-3.py @@ -60,7 +60,7 @@ from torchrl.collectors import SyncDataCollector from torchrl.envs import GymEnv -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy torch.manual_seed(0) diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 0ece54926f1..5c103ca8271 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -637,7 +637,7 @@ def assert0(x): ToTensorImage, TransformedEnv, ) -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy env = TransformedEnv( GymEnv("CartPole-v1", from_pixels=True),