How do the FSDP actor and the vLLM rollout share the same GPUs in a colocated verl GRPO run?

Configuration setup

4× GPU single node Qwen3-8B (~8.98B) actor FSDP2 (fully_shard) colocated / hybrid engine vLLM TP=1 (4 DP replicas)
actor param_offload=False actor optimizer_offload=False ref param_offload=True gpu_memory_utilization=0.25 free_cache_engine=True vLLM ≥0.8.5 → sleep_level=2

TL;DR the actor (FSDP) and the rollout (vLLM) live on the same GPUs. They cannot both own their full footprint at once, so verl time-shares the GPU: vLLM is put to sleep (its weights + KV pool freed back to the OS via vLLM's CuMemAllocator) while the actor trains, and woken while it generates. With actor offload off, the FSDP actor + optimizer stay resident the whole time, so vLLM only gets a fraction of each GPU, and its memory is the part that churns.

How to read the bars: each section below shows one GPU (rank 0, 80 GB) at that phase. Heights are illustrative, scaled to tell the story, not measured.

FSDP actor params (sharded ¼) Optimizer state (fp32 master + Adam m,v)
Gradients + activations Ref policy (offloaded → CPU)
vLLM weights (full model, TP=1) vLLM KV-cache pool
Free / reserved Weight all-gather (transient, during sync)

1 · Initialization

GPU rank 0 · 80 GB

Step-1: Actor loading. FSDP2 builds a sharded model

Each of the N GPUs permanently holds only 1/N of every parameter, without any rank ever materializing the whole model on its GPU.

  1. Build the model. Each rank picks its init context locally (no cross-rank comms):
    • rank 0 → cpu_init_weights → builds on CPU and loads the real checkpoint weights.
    • ranks 1…N-1 → init_empty_weights (accelerate) → builds on the meta device with zero-byte tensors, only registering the shape and dtype.
    verl/utils/fsdp_utils.py · get_init_weight_context_manager()
    verl/workers/engine/fsdp/transformer_impl.py · FSDPEngine._build_module()
  2. Shard into DTensors, then broadcast the weights in. Two separate things: sharding the empty slots, then streaming the weights into them.
    Shard. fully_shard wraps each decoder layer (still on CPU) so every parameter becomes a DTensor with placement Shard(0), one logical tensor split along dim 0 across the GPU mesh, each rank owning a 1/N slice (.shape stays global, .to_local() is the local slice).
    Put the slots on GPU. The sharded module moves to GPU: rank 0 via .to(device) (its real 1/N shard), ranks 1…N-1 via .to_empty(device) (an empty 1/N slot). Note the broadcast source is a separate full CPU snapshot (full_state), held only on rank 0.
    Broadcast, one parameter at a time. rank 0 lifts a full parameter from its CPU snapshot up to a GPU buffer and broadcasts it. Every rank (rank 0 included) transiently allocates a full-size GPU buffer to receive it, then copies out only its own 1/N slice into its shard slot.
    Release. The transient full-param GPU buffer is freed on all ranks after each parameter; rank 0's CPU snapshot is dropped only once the whole loop finishes.
    verl/utils/fsdp_utils.py · apply_fsdp2() (→ fully_shard()), fsdp2_load_full_state_dict() (→ _broadcast_state_dict())
    verl/workers/engine/fsdp/transformer_impl.py:402–419
    mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
    full_state = module.state_dict()                          # snapshot (rank 0 = real weights)
    apply_fsdp2(module, fsdp_kwargs, self.engine_config)        # shard each layer → Shard(0) DTensors
    fsdp2_load_full_state_dict(module, full_state, fsdp_mesh, offload_policy)  # broadcast → fill shards
  3. Build the optimizer, after the model is loaded. AdamW is constructed over module.parameters(), already the sharded DTensors, so its state is naturally 1/N-sized. The m/v buffers (≈ 2× fp32 of the shard) are allocated lazily on the first step(); they appear during the first actor update, not at init.
    verl/workers/engine/fsdp/transformer_impl.py · FSDPEngine._build_optimizer()build_optimizer()
    ordering in _build_model_optimizer()

(2) vLLM loading (Simplified TP=1)

With tensor-parallel size 1, vLLM does not shard. Each GPU gets its own complete copy of the model.

  1. Spin up the replicas. verl works out how many replicas to run (when TP=1 with 4 GPUs, it will create 4 replicas); Ray launches one worker process per replica and pins it to a GPU; verl's rollout code, running inside each, then constructs the vLLM engine, each holding the full model.
    verl/workers/rollout/vllm_rollout/vllm_rollout.py
    verl/workers/rollout/replica.py
  2. Allocate weights as “dummy”. Default load_format="dummy": each replica allocates full-model weight buffers and fills them with random values (no disk read).
    verl/workers/config/rollout.py · load_format = "dummy"
    vllm/model_executor/model_loader/weight_utils.py · initialize_dummy_weights()
  3. Profile, then reserve the KV-cache pool. vLLM runs a profiling forward: one dummy forward at the worst-case (max) batch, used only to measure peak GPU memory (weights + activations). It then takes whatever is left within the gpu_memory_utilization budget and carves it into fixed-size blocks (each holding ~16 tokens of KV, all layers), which it physically reserves up front as the empty KV pool (prints #GPU blocks: N). Space is committed at init; the actual K/V values are only written during generation. Engine uses enable_sleep_mode=True so the pool can be released on sleep.
    verl/workers/rollout/vllm_rollout/vllm_rollout.py · engine init (gpu_memory_utilization, enable_sleep_mode)
    KV sizing is vLLM-internal (determine_num_available_blocks)
  4. Real weights arrive by sync. On the first rollout, the actor all-gathers its shards to full tensors and streams them over CUDA IPC into each replica. At TP=1 it's a straight whole-tensor copy, no resharding needed. See section 2, Wake vLLM and sync weights.
    actor side verl/workers/engine/fsdp/transformer_impl.py · get_per_tensor_param()
    rollout side verl/workers/rollout/vllm_rollout/vllm_rollout.py · update_weights()update_weights_from_ipc()

2 · Wake vLLM + sync weights

GPU rank 0 · 80 GB

A step's rollout begins by waking vLLM's freed memory and pushing the actor's fresh weights into it (gated on free_cache_engine=True).

  1. Wake the weight buffers.
    [Background Knowledge] Allocation model. Each vLLM tensor reserves a fixed virtual address (via CUDA's VMM API) with physical GPU pages mapped behind it. The torch tensors and model structure only ever reference the virtual address.
    [Background Knowledge] What sleep(level=2) did. After vLLM rollout and before the actor starts training, vLLM called sleep(level=2). It unmapped and freed the physical pages (handing that GPU memory back so the actor could train), but kept the virtual reservations and the tensor objects. So the model structure survived; it just had no bytes behind it.
    What wake_up does now. It maps fresh physical pages back onto those same virtual addresses. The weight tensors instantly become valid again (same shapes, same addresses), but the new pages hold undefined leftover bytes (no re-init, not the old weights; possibly the actor's stale grad/activation data).
    So the weight tensors now exist and are addressable but contain garbage.
    verl/workers/rollout/vllm_rollout/vllm_rollout.py · resume()wake_up() · release()sleep()
  2. Gather the actor's fresh weights one parameter at a time. get_per_tensor_param processes the parameters one at a time through a "lazy generator". For each parameter, param.full_tensor() calls all_gather. It combines the shards from all ranks so that every rank (not just rank 0) ends up with the full tensor, which is then cast to bf16. Only one parameter is live at a time (the pink block in the GPU memory diagram above). The whole model is never unsharded at once, so the transient stays small.
    verl/workers/engine/fsdp/transformer_impl.py · get_per_tensor_param()DTensor.full_tensor()
    what .full_tensor() runs under the hood
    param.full_tensor()                       ← what you call (DTensor convenience method)
      └─ redistribute Shard(0) → Replicate        ← DTensor: every rank needs all shards
           └─ all_gather(...)                  ← the actual collective function it invokes
                └─ ncclAllGather                 ← low-level NCCL/GPU primitive that moves the bytes
  3. Stream into vLLM over CUDA IPC. The full weights now have to cross from the actor process into its colocated vLLM process: same GPU, but not the same address space.
    Why a special transfer is needed. The FSDP actor and the vLLM engine run in separate OS processes (separate Ray workers / the vLLM engine subprocess), even though they sit on the same physical GPU. A CUDA tensor in process A is invisible to process B (each has its own CUDA context and address space), so the freshly all-gathered weight can't be handed to vLLM's buffer directly.
    Why not just copy it. The naive route is actor GPU → CPU → vLLM GPU, and that host round-trip is slow. CUDA IPC (Inter-Process Communication) avoids it: one process exports a handle to its GPU memory, the other maps and reads it: a direct GPU→GPU copy, no host hop.
    update_weights() is the orchestrator (rollout side). The entry point verl calls with the per-tensor generator from step 2. It pulls tensors from the generator, batched by update_weights_bucket_megabytes (e.g. WEIGHT_BUCKET_MB=512), and dispatches the transfer to the vLLM worker via RPC, passing a use_shm flag (IPC vs. shared-memory fallback).
    update_weights_from_ipc() is the mover (in each vLLM worker). Imports the actor's CUDA IPC handle, maps the same GPU memory into its own process, and copies it GPU→GPU into the model's weight buffers, overwriting the dummy/garbage values with the current policy.
    Each bucket's transient (~0.5 GB, defined by WEIGHT_BUCKET_MB) is freed before the next, so the sync never costs the full model.
    verl/workers/rollout/vllm_rollout/vllm_rollout.py · update_weights()update_weights_from_ipc()
    how update_weights() dispatches the transfer
    update_weights(generator)                ← rollout side: batch by WEIGHT_BUCKET_MB, dispatch
       └─ update_weights_from_ipc(use_shm)   ← in each vLLM worker:
            ├─ import CUDA IPC handle         ← map the actor's GPU memory into this process
            └─ load_weights()                ← copy GPU→GPU into vLLM's weight buffers
       (fallback: shared memory if IPC unsupported → use_shm)
  4. Wake the KV pool. resume(["kv_cache"])wake_up(tags=["kv_cache"]) re-allocates the (empty) KV-cache pool.
    verl/workers/rollout/vllm_rollout/vllm_rollout.py · resume()

3 · Generate (rollout)

GPU rank 0 · 80 GB
  1. How the rollout is batched.
    How many sequences a step generates. data.train_batch_size (TRAIN_BATCH_SIZE=8) is the number of prompts per training step; actor_rollout_ref.rollout.n (ROLLOUT_N=8) is the number of responses per prompt (the GRPO group). So vLLM generates TRAIN_BATCH_SIZE × ROLLOUT_N = 64 sequences this step.
    Data parallel, across replicas. The 8 prompts are split across the 4 DP vLLM replicas, so ~2 prompts each, ×8 responses ≈ 16 sequences per replica. Each replica generates its share independently, and the results are gathered afterward.
    Continuous batching, within a replica. vLLM does not run a fixed batch. Its scheduler runs as many sequences concurrently as fit, bounded by max_num_seqs, max_num_batched_tokens, and the KV-pool size (set by gpu_memory_utilization and MAX_MODEL_LEN). As sequences finish, queued ones are admitted, so the running batch flexes over time.
    Per prompt: parallel. The ROLLOUT_N responses for one prompt come from SamplingParams(n=ROLLOUT_N).
    verl/workers/rollout/vllm_rollout/vllm_rollout.py · generate_sequences() · vLLM SamplingParams(n=...), max_num_seqs, max_num_batched_tokens

4 · Sleep vLLM

GPU rank 0 · 80 GB
verl/workers/rollout/vllm_rollout/vllm_rollout.py · release() → vLLM sleep()
how release() puts vLLM to sleep
release()                    ← verl trigger (gated on free_cache_engine)
   └─ sleep(level)             ← vLLM engine: free its GPU memory
        └─ CuMemAllocator      ← unmaps/frees the physical pages (keeps the virtual addresses)

level 1:  offload weights to CPU, drop KV   (weights restored from CPU on wake)
level 2:  discard weights AND KV            (weights re-synced from the actor on wake)  ← this run, vLLM ≥ 0.8.5

5 · Actor update (train)

GPU rank 0 · 80 GB
  1. Ref log-probs come from a CPU-offloaded forward. The ref policy is forward-only (no gradients, no optimizer). Under FSDP2 it gets a CPUOffloadPolicy(pin_memory=True), so its sharded params stay pinned on CPU.
    When it runs. For the KL term, verl runs one no_grad forward of the ref to get the ref log-probs of the rollout responses.
    How it moves. FSDP2 gathers each layer's shards CPU→GPU (bf16) just in time, computes, then frees them. It is per-layer, so only about one layer is on the GPU at a time, not the whole ~9 GB model. After the forward, nothing of the ref stays resident.
    verl/workers/engine/fsdp/transformer_impl.py · CPUOffloadPolicy (set for forward_only) · ref log-probs in RayPPOTrainer.fit()compute_ref_log_prob()
  2. One FSDP2 forward + backward + step. The actor update is a sharded training loop. Per FSDP unit (one decoder layer):
    Forward (layers in order). all-gather the layer's shards into a full bf16 layer, compute, then reshard (free the full copy, keep the 1/N shard) because reshard_after_forward=True. With enable_gradient_checkpointing=True, inner activations are not saved, only the layer input.
    Backward (layers in reverse). all-gather the shards again, recompute the layer's activations (checkpointing), compute gradients, then reduce-scatter the gradients so each rank keeps only its 1/N grad shard; reshard the params.
    Optimizer step. Each rank updates only its own 1/N master shard (fp32) with its 1/N grad shard and its 1/N Adam m/v. This is purely local, with no communication.
    Memory. Resident: the sharded master + Adam state. Transient: the per-layer full gather (about one layer), the gradients, and the (checkpointed) activations. Actor offload is off, so the actor shards never leave the GPU (_context_switch is a no-op); vLLM being asleep is what makes room.
    verl/utils/fsdp_utils.py · apply_fsdp2() / fully_shard() · verl/workers/engine/base.py · _context_switch() · actor update in RayPPOTrainer.fit()update_actor()

Standard colocated verl GRPO (FSDP2 actor + vLLM hybrid engine). Code refs are verl HEAD in .venv-verl. Bar sizes are illustrative. Open this file in any browser; no network needed.