Skip to main content

PIM attention offload

What this demonstrates: moving attention compute off the GPU and onto a PIM (processing-in-memory) device, while keeping the rest of the layer on GPU.

Attention's memory-bandwidth bottleneck makes it a natural fit for PIM, which puts compute units inside DRAM. LLMServingSim models attention as a separate device-bound block; turning on --enable-attn-offloading swaps the GPU attention kernel for a PIM attention kernel inside the trace.

This is compute disaggregation: attention runs on PIM, the rest of the layer runs on the NPU, and the simulator coordinates the hand-off.

Prerequisites

  • Simulator container set up
  • Bundled RTXPRO6000 profile for meta-llama/Llama-3.1-8B
  • A PIM device config from configs/pim/ (e.g., DDR4_8GB_3200_pim): these are DRAMSim3 INI files describing the PIM substrate.

Cluster config

configs/cluster/single_node_pim_instance.json: note the pim_config field on the node's cpu_mem:

configs/cluster/single_node_pim_instance.json (excerpt)
{
"num_nodes": 1,
"link_bw": 16,
"link_latency": 20000,
"nodes": [
{
"num_instances": 1,
"cpu_mem": {
"mem_size": 512,
"mem_bw": 256,
"mem_latency": 0,
"pim_config": "DDR4_8GB_3200_pim"
},
"instances": [
{
"model_name": "meta-llama/Llama-3.1-8B",
"hardware": "RTXPRO6000",
"npu_mem": {"mem_size": 96, "mem_bw": 1597, "mem_latency": 0},
"pd_type": null,
"tp_size": 1
}
],
"power": { "...": "see provided config for full power model" }
}
]
}

The PIM hookup:

  • cpu_mem.pim_config: "DDR4_8GB_3200_pim": points at configs/pim/DDR4_8GB_3200_pim/ (DRAMSim3 INI files describing PIM-side compute and memory).
  • The GPU instance is otherwise unchanged. PIM attention is selected at runtime via the CLI flag, not the config.

Run

python -m serving \
--cluster-config 'configs/cluster/single_node_pim_instance.json' \
--dtype float16 --block-size 16 \
--enable-attn-offloading \
--dataset 'workloads/example_trace.jsonl' \
--output 'outputs/pim_offload_run.csv' \
--log-level WARNING

--enable-attn-offloading is the switch that swaps the trace's attention kernel from the NPU profile to the PIM profile. The rest (qkv_proj, o_proj, mlp) still runs on the GPU.

Expected output

[INFO] step=10 batch=8 prompt_t=1.1k tok/s decode_t=520 tok/s
npu_mem=63.4 GB pim_busy=72%
[INFO] step=11 batch=8 prompt_t=1.1k tok/s decode_t=540 tok/s
npu_mem=63.4 GB pim_busy=78%

pim_busy reports the PIM device's utilization, when this plateaus near 100% you've found the PIM bottleneck for the workload.

What's interesting

  • Decode TPOT often improves because attention on long KV caches is memory-bandwidth-bound. PIM has very different bandwidth characteristics than GPU HBM, and on decode-heavy workloads with long contexts the PIM path can win even with slower per-op throughput.
  • Prefill TTFT can regress because attention during prefill is more compute-bound, PIM's narrower compute per channel doesn't help there. Pair with sub-batch interleaving (see Advanced) to overlap prefill compute on GPU with decode attention on PIM.
  • NPU memory drops: KV cache lives in PIM memory now, freeing ~10–30 GB of NPU memory for weights or larger batches.
  • Sub-batch interleaving - the natural follow-on. Overlap GPU and PIM work to recover the prefill regression.
  • Prefill/decode split: alternative way to specialize on decode-heavy workloads, but with whole-instance granularity instead of per-layer.

Where to learn more

  • The PIM device model lives in serving/core/pim_model.py; the trace generator emits PIM {channel} / PIM END markers around the offloaded attention block (see Trace file format).
  • DRAMSim3 INI files in configs/pim/<name>/ configure the PIM substrate. Add a new substrate by dropping a new directory there and pointing pim_config at it.