PIM attention offload
What this demonstrates: moving attention compute off the GPU and onto a PIM (processing-in-memory) device, while keeping the rest of the layer on GPU.
Attention's memory-bandwidth bottleneck makes it a natural fit for
PIM, which puts compute units inside DRAM. LLMServingSim models
attention as a separate device-bound block; turning on
--enable-attn-offloading swaps the GPU attention kernel for a PIM
attention kernel inside the trace.
This is compute disaggregation: attention runs on PIM, the rest of the layer runs on the NPU, and the simulator coordinates the hand-off.
Prerequisites
- Simulator container set up
- Bundled RTXPRO6000 profile for
meta-llama/Llama-3.1-8B - A PIM device config from
configs/pim/(e.g.,DDR4_8GB_3200_pim): these are DRAMSim3 INI files describing the PIM substrate.
Cluster config
configs/cluster/single_node_pim_instance.json: note the
pim_config field on the node's cpu_mem:
{
"num_nodes": 1,
"link_bw": 16,
"link_latency": 20000,
"nodes": [
{
"num_instances": 1,
"cpu_mem": {
"mem_size": 512,
"mem_bw": 256,
"mem_latency": 0,
"pim_config": "DDR4_8GB_3200_pim"
},
"instances": [
{
"model_name": "meta-llama/Llama-3.1-8B",
"hardware": "RTXPRO6000",
"npu_mem": {"mem_size": 96, "mem_bw": 1597, "mem_latency": 0},
"pd_type": null,
"tp_size": 1
}
],
"power": { "...": "see provided config for full power model" }
}
]
}
The PIM hookup:
cpu_mem.pim_config: "DDR4_8GB_3200_pim": points atconfigs/pim/DDR4_8GB_3200_pim/(DRAMSim3 INI files describing PIM-side compute and memory).- The GPU instance is otherwise unchanged. PIM attention is selected at runtime via the CLI flag, not the config.
Run
python -m serving \
--cluster-config 'configs/cluster/single_node_pim_instance.json' \
--dtype float16 --block-size 16 \
--enable-attn-offloading \
--dataset 'workloads/example_trace.jsonl' \
--output 'outputs/pim_offload_run.csv' \
--log-level WARNING
--enable-attn-offloading is the switch that swaps the trace's
attention kernel from the NPU profile to the PIM profile. The rest
(qkv_proj, o_proj, mlp) still runs on the GPU.
Expected output
[INFO] step=10 batch=8 prompt_t=1.1k tok/s decode_t=520 tok/s
npu_mem=63.4 GB pim_busy=72%
[INFO] step=11 batch=8 prompt_t=1.1k tok/s decode_t=540 tok/s
npu_mem=63.4 GB pim_busy=78%
pim_busy reports the PIM device's utilization, when this
plateaus near 100% you've found the PIM bottleneck for the workload.
What's interesting
- Decode TPOT often improves because attention on long KV caches is memory-bandwidth-bound. PIM has very different bandwidth characteristics than GPU HBM, and on decode-heavy workloads with long contexts the PIM path can win even with slower per-op throughput.
- Prefill TTFT can regress because attention during prefill is more compute-bound, PIM's narrower compute per channel doesn't help there. Pair with sub-batch interleaving (see Advanced) to overlap prefill compute on GPU with decode attention on PIM.
- NPU memory drops: KV cache lives in PIM memory now, freeing ~10–30 GB of NPU memory for weights or larger batches.
Related examples
- Sub-batch interleaving - the natural follow-on. Overlap GPU and PIM work to recover the prefill regression.
- Prefill/decode split: alternative way to specialize on decode-heavy workloads, but with whole-instance granularity instead of per-layer.
Where to learn more
- The PIM device model lives in
serving/core/pim_model.py; the trace generator emitsPIM {channel}/PIM ENDmarkers around the offloaded attention block (see Trace file format). - DRAMSim3 INI files in
configs/pim/<name>/configure the PIM substrate. Add a new substrate by dropping a new directory there and pointingpim_configat it.