← Back to home

Inference Engines 2/N: Batching and PagedAttention

I recommend you start from the beginning, from part 0/N.

Review of Phase 1

We left off in a pretty good place. Our decode loop had grown up quite a bit! Instead of rereading the entire sequence from scratch every step, it now carries a KV cache forward, and thanks to this KV cache, our ms/tok stays flat regardless of how many tokens we generate.

|  new_tokens | tok/s | ms/tok | TTFT (ms) | peak MB |
|------------:|------:|-------:|----------:|--------:|
| 32          | 30.5  | 32.78  | 34        | 7741    |
| 128         | 30.7  | 32.57  | 34        | 7754    |
| 512         | 30.4  | 32.89  | 34        | 7843    |
| 1024        | 30.2  | 33.11  | 34        | 7882    |

Beautiful, constant, steady. The quadratic pain is gone. :)

But I left you with a cliffhanger that other day, didn't I?

Right now, we're serving a single prompt, from a single user, in one lonely decode loop. But real inference engines handle many sequences at once, and that changes a lot of things.

Suddenly, we worry for our tidy little KV Cache. The one that grows by one row each step, oh, that scrappy little grow-one-row-at-a-time cache, how will it survive the influx of user requests? How do you allocate memory for a cache when you don't know how long each sequence will be? What happens when one sequence finishes early and leaves a hole in the cache, a hole too small to use, yet too large to ignore?

Well, the cliffhanger has gone on long enough. Let's open the gates of Nantucket and let the users in. As for how our little KV cache will survive, well, we can worry about that later (and by later, I mean soon enough, within this same post).

New Benchmarking Questions

Before we write a single line of code, let's agree on what we're measuring, now that we're handling multiple sequences from multiple users.

In Phase 0 and 1, our metrics revolved around a single sequence, for a single user's experience. The two biggest numbers that we were preoccupied with were ms/tok for that user's sequence ("how fast are their tokens arriving?") and TTFT for that user ("how long did they wait for their first token?").

But if we're going to serve many sequences at once, we also need to think on the system-level, to make sure we're getting what we paid for in our hardware.

As much as we care that our users aren't sitting around feeling bored, we also care that our GPU isn't sitting around feeling bored. If we have 8 users each getting 30 tok/s, that looks fine on a per-user basis. But maybe we merely asked our GPU to serve one user at a time, when our GPU was very much capable of serving all 8 users at once in one batch, and so our GPU is only using 15% of its brainpower and is thus bored.

So, we introduce a new metric, total throughput. Total throughput, aka tok/s summed across all sequences, tells us how hard our hardware is working.

Total throughput: tokens per second summed across all sequences running in parallel. This is the number that tells us how efficiently we're using our hardware.

However, we can't optimize total throughput alone, at the expense of the user experience. Suppose we batch up 32 user sequences together and our total throughput jumps from 30 tok/s to 600 tok/s. The hardware utilization number looks great! But if each user is now waiting significantly longer than they were before, is that truly an improvement?

So, in addition to total throughput, we also track the average ms/tok across all sequences. The basic idea is that we hold onto ms/tok, our trusty proxy for individual user happiness from Phase 0 and 1, and average it across all sequences as a proxy for collective user happiness.

Average ms/tok across all sequences: so we can see if we're making individual sequences slower in exchange for higher total throughput.

You will find these new metrics, total throughput and average ms/tok, added to our benchmark harness.

Phase 2, Step 1: Naive Batching

Commit: 5053b2e

Okay, so we go from one prompt, by one hypothetical user...

def generate(
    lm: LoadedModel,
    prompt: str,                # <---- a prompt (singular!)
    max_new_tokens: int = 64
) -> tuple[str, list[float]]:

to a list of prompts, coming at us from many users at once.

def generate_batch(
    lm: LoadedModel,
    prompts: list[str],         # <----  a list of prompts (plural!)
    max_new_tokens: int = 64,
)

The simplest thing we can do is naive batching. The easiest way to get us some semblance of batching, naive batching, is to go from a singular tensor of some length...

tok(prompt, return_tensors="pt")

to a batch of tensors, padded to the length of the longest prompt in the list.

tok(prompts, return_tensors="pt", padding=True)

Concretely, what this looks like is, suppose we had a list of prompts like this:

Prompt 0 (4 tok):   ["The", "quick", "brown", "fox" ]
Prompt 1 (6 tok):   ["If",  "you", "had", "one", "wish", ","]
Prompt 2 (1 tok):   ["bork"]                         
Prompt 3 (7 tok):   ["Once", "upon", "a", "time", ",", "the", "dog"]  

We would pad the prompts like so:

Prompt 0 (4 tok):   [PAD, PAD, PAD, "The", "quick", "brown, "fox"   ]
Prompt 1 (6 tok):   [PAD, "If", "you", "had", "one", "wish", ","]
Prompt 2 (1 tok):   [PAD, PAD, PAD, PAD, PAD, PAD, "bork"]
Prompt 3 (7 tok):   ["Once", "upon", "a", "time", ",", "the", "dog"]  

Run it Yourself

Now, when it comes to benchmarking, it makes sense that we'd want to sweep across batch size, starting with a single prompt and seeing how many our GPU can work on at once before users start to notice.

Let's use the prompt, "The transformer architecture revolutionized NLP because", and see how many instances of this prompt the GPU can process at once in a batch.

SWEEP_PROMPT = "The transformer architecture revolutionized NLP because"
SWEEP_NEW_TOKENS = 128
SWEEP_BATCH_SIZES = [1, 2, 4, 8, 16]

If you may, please run the sweep:

modal run modal_run.py::sweep
| batch_size | total_tok/s | ms/tok (per seq) | TTFT (ms) | Peak MB |
|-----------:|------------:|-----------------:|----------:|--------:|
| 1          | 25.7        | 38.89            | 40.3      | 7754    |
| 2          | 51.2        | 39.08            | 41.1      | 7774    |
| 4          | 102.4       | 39.08            | 41.6      | 7843    |
| 8          | 206.1       | 38.81            | 41.9      | 7897    |
| 16         | 411.4       | 38.89            | 41.1      | 8050    |

Look at that! Our total throughput scales almost perfectly linearly with batch size; when we're processing 16x the sequences, we get 16x the amount of tokens per second, and yet our users are hardly waiting any longer (ms/tok barely budges).

So, we're getting 16 users' worth of output for roughly the same latency cost as serving a single user. This is the charm of batching; since the GPU does the same amount of work per step regardless, we might as well pack that work full and keep our GPU busy.

Not All Prompts Are Equal

But hold on, not so fast. Do you see a problem with the way we're measuring our batching implementation?

All the prompts in the batch are the same length! In particular, all prompts in the batch are the same guy, "The transformer architecture revolutionized NLP because"!

Of course our batching would look great when all sequences are roughly the same length.

What happens when they're not? Let's stress-test this with four prompts that will very obviously produce outputs of wildly different lengths, and cap the run at 200 tokens.

WASTE_DEMO_PROMPTS = [
    "What is the capital of France?",
    "Explain in detail how the attention mechanism works in a transformer.",
    "Is Python interpreted or compiled? Answer in one sentence.",
    "Write a short story about a robot who learns to paint.",
]
WASTE_DEMO_MAX_TOKENS = 200

Read through these prompts and see for yourself, how many words are you going to write for "name the capital of France"? How about "Write a short story..."?

Like you, the LLM isn't going to treat all four of these prompts equally.

Yet, our naive batching strategy has no idea about any of this; it just runs every sequence in the batch for max_new_tokens = 200 steps, no matter what.

To see just how bad this gets, we wrote a small waste demo that tracks, for each prompt in the batch, how many decode steps were needed versus how many were wasted spinning on padding.

In writing this waste demo, we discovered that Qwen, unlike our GPT-2... doesn't really believe in the stop token. Left to its own devices, it will just keep going on and on for forever, probably. So I had to add a small scaffolding layer in model.py to show it a worked example of what a stop token looks like and teach it to recognize one as a reason to quit.

Under my policy of no surprises, I figured I would point this out for the sake of the few lines in model.py that would otherwise look like magic six months from now.

If you may, please run the waste demo:

modal run modal_run.py::waste_demo

You should see the following table:

batch ran for 199 decode steps total

prompt                         steps needed    steps wasted   waste%
---------------------------------------------------------------------
What is the capital of France?      10             189        95.0%
Explain in detail how the a...     199             0          0.0%
Is Python interpreted or co...      25             174        87.4%
Write a short story about a...     199             0          0.0%

Well, that sucks! While prompts 0 and 2 finished in an instant, our GPU is stuck burning cycles on their PAD tokens, precious cycles that could've been spent processing the tokens for a new prompt and making a new user happy.

Phase 2, Step 2: Scheduler

Commit: 53defdc

Intuitively, when we finish serving a prompt, we want to remove it from the batch and bring in a new prompt that needs serving.

In order to do this, we need to know when each sequence is WAITING, being worked on (PREFILLING and DECODING), or DONE.

class Status(Enum):
    WAITING = auto()
    PREFILLING = auto()
    DECODING = auto()
    DONE = auto()

So let's create a Scheduler for managing the statuses of each sequence. In Step 3, we'll convert this Scheduler from Naive Batching to Continuous Batching by responding to changes in status.

Phase 2, Step 3: Continuous Batching

Commit: 77505a6

Conceptually, our solution to the waste problem is quite simple: don't wait for the whole batch to finish. The moment we're DONE with a sequence, we pull the next WAITING sequence off the queue and start it immediately.

This is called "continuous batching," as our scheduler runs continuously, swapping sequences in and out as they complete, rather than waiting for the full batch to complete before admitting new work.

However, there's a little nuance behind continuous batching that'll affect how we treat our KV Cache.

Recall our four prompts from the waste demo. Early in the run, all four are decoding; though, as we saw, they won't need anywhere near the same number of steps to finish. NOTE: Token counts shown for illustration only; the scheduler has no way to know these in advance.

[DECODING] "What is the capital of France?" (~10 tokens to go)
[DECODING] "Explain in detail how the attention mechanism works..." (~199 tokens to go)
[DECODING] "Is Python interpreted or compiled? Answer in one..." (~25 tokens to go)
[DECODING] "Write a short story about a robot who learns to paint." (~199 tokens to go)

Suppose "What is the capital of France" hits its stop token and we mark it DONE:

[DONE]      "What is the capital of France?"
[DECODING] "Explain in detail how the attention mechanism works..." (~189 tokens to go)
[DECODING] "Is Python interpreted or compiled? Answer in one..." (~15 tokens to go)
[DECODING] "Write a short story about a robot who learns to paint." (~189 tokens to go)

That slot is now open. In static batching, it would sit around twiddling its thumbs and burning PAD tokens until the rest of the batch finishes.

In continuous batching, we immediately pull the next waiting sequence off the queue, "Tell me a joke about cats.", and start it right away:

[PREFILLING]  "Tell me a joke about cats."                          (?? tokens to go)
[DECODING] "Explain in detail how the attention mechanism works..." (~189 tokens to go)
[DECODING] "Is Python interpreted or compiled? Answer in one..." (~15 tokens to go)
[DECODING] "Write a short story about a robot who learns to paint." (~189 tokens to go)

The new issue we have, see, is that this new prompt is at a different step than the older prompts in the batch.

In naive batching, every sequence in the batch is at the same decode step at all times. They prefilled together, they decode together, and because they're always on the same step, they can all feed into a single batched forward pass with a single shared KV cache, since the K and V tensors are just stacked along the batch dimension.

However, with continuous batching, we can no longer do a single forward pass on all of them at once with a shared KV Cache, because the cache length for sequence "Tell me a joke about cats" is 0, while the cache length for all the other sequences in the batch is 10.

The simplest solution we'll go with for this blog post is to give each sequence its own KV Cache, no sharing. So let's go ahead and create a distinctive KV Cache for each sequence here, thereby removing the shared KV Cache from our previous step:

class Sequence:
    seq_id: int
    prompt: str
    status: Status = Status.WAITING
    kv_cache: object = None             # <---- owned per sequence

Phase 2, Step 4: Comparing Naive with Continuous Batching

Commit: e7a7985

Run It Yourself: Latency Demo

If you may, please run our latency demo, which compares our naive and continuous batching implementations on a queue of prompts running in a batches of at most 4 at a time, where each batch has a mix of short prompts and long prompts drawn from, and caps the output length to 2048 tokens.

SHORT_PROMPTS = [
    "What is the capital of France?",
    "Is Python interpreted or compiled? Answer in one sentence.",
    "What is 2 + 2?",
    "Name one fruit.",
]
LONG_PROMPTS = [
    "Explain in detail how the attention mechanism works in a transformer.",
    "Write a short story about a robot who learns to paint.",
    "Describe the history of the Roman Empire in significant detail.",
    "What are the key differences between supervised and unsupervised learning?",
]
LATENCY_DEMO_MAX_RUNNING = 4
LATENCY_DEMO_MAX_TOKENS = 128

Upon running the latency demo...

modal run modal_run.py::latency_demo                       

You should see a table like this:

 seq  prompt                                   static ms     cont ms   steps   speedup
--------------------------------------------------------------------------------------
   0  What is the capital of France?              45288ms       1261ms       9     35.9×
   1  Explain in detail how the attenti...        45288ms     117809ms    1269      0.4×
   2  Is Python interpreted or compiled...        45288ms       3048ms      24     14.9×
   3  Write a short story about a robot...        45288ms      53120ms     443      0.9×
   4  What is 2 + 2?                             115074ms       4022ms      23     28.6×
   5  Describe the history of the Roman...       115074ms     128150ms    1592      0.9×
   6  Name one fruit.                            115074ms       4303ms       2     26.7×
   7  What are the key differences betw...       115074ms      98620ms     913      1.2×

Awesome, we do see an improvement! For prompts that expect short responses, like "What is the capital of France?" and "What is 2+2?" and "Name one fruit.", we do see that our users wait significantly less milliseconds in continuous batching than they do with naive (static) batching.

Phase 3: The Case for PagedAttention

Commit: b8eaf65

There are two problems that our Phase 2 scheduler suffers from, by virtue of giving each sequence the privilege to their very own personal KV cache.

First, memory waste. When we go from one KV shared cache to one-KV-cache-per-sequence, a perennial question we have to ask is this: how much memory do we provide for each KV cache? The issue with LLM inference is that, for any given prompt, whether it's "What is the capital of France?" or "Write me a dissertation...", we don't know how many tokens will end up in the LLM's response.

HuggingFace's KV cache implementations take two approaches to this problem.

modal run modal_run.py::wall
concur   static_kv  dynamic_kv   avg_tok   wasted  
--------------------------------------------------
     1       0.18G       0.19G    640.0t   68.8%  
     2       0.35G       0.25G    437.2t   78.6%  
     4       0.70G       0.62G    535.4t   73.9%  
     8       1.41G       1.24G    535.4t   73.9%  
    16       2.81G       2.47G    535.4t   73.9%  
    32       5.62G       4.95G    535.4t   73.9%  

HuggingFace's first approach is StaticCache, which reserves MAX_SEQ_LENGTH_DOOMSDAY_PREPARATION slots (variable has been renamed for narrative flair) for each sequence upfront. This is, of course, a worst-case reservation, because we have no idea how long the sequence will actually run. As seen in the table, this kind of doomsday reservation can be quite wasteful; if we reserve MAX_SEQ_LENGTH_DOOMSDAY_PREPARATION=2048 slots but we only use 535 of those slots (by virtue of only needing to generate 535 tokens), we waste ~74% of that memory on slots that go unused. At 32 concurrent sequences, that's 5.62GB reserved for KV caches!

To ameliorate this doomsday preparation guessing, HuggingFace's has a second approach, the DynamicCache, which grows the KV tensor one row at a time as each new token is generated. This is, in fact, the same grow-one-row-at-a-time behavior we first observed back in Inference Engines 1/N: KV Cache, and it's a bit better, as it doesn't waste memory reserving slots that go unused.

But, DynamicCache can be costly in its own way. Every decode step requires reallocating and copying the entire KV tensor to a new, slightly-larger buffer, as your KV tensors grow from [1, 8, 9, 128], [1, 8, 10, 128], [1, 8, 11, 128], etc. At long sequence lengths, this copying overhead adds up.

And what about fragmentation? As sequences grow long, those buffers get large, and finding a single free contiguous block of GPU memory big enough to hold the whole tensor becomes increasingly difficult. Your GPU might have plenty of free memory in aggregate, but if it's been carved up by many sequences growing and shrinking at different rates, that free memory is scattered across many small gaps, none of the gaps individually large enough to fit your tensor. So the memory you need might exist, just not all in one place.

Can we do better? Is there some way we can grow the KV cache incrementally, like in DynamicCache, without the issues of fragmentation?

Second, no batched forward pass. The latency for our current continuous batching implementation isn't too great, since each sequence runs its own forward pass. So at batch size 16, we're running 16 sequential forward passes per decode step instead of one big batched one.

The reason we lost our batched forward pass in Continuous Batching is because we gave each sequence its own KV cache, each of which can be of different lengths at any given time, and we lack a mechanism to reconcile these different lengths into a single call to forward pass.

There actually does exist a mechanism to batch multiple sequences at different KV lengths into a single forward pass. We'll discuss this a little more in Phase 3.

Enter: Paging

PagedAttention borrows its solution from how operating systems manage RAM, with a technique that you may have guessed is called paging.

The OS, like our inference engine, is serving many processes at once, each of which needs a block of memory that might grow over time. And like us, the OS doesn't know upfront how much memory each process will need.

If our OS were to literally hand out one giant contiguous block of RAM to each process, that would be disastrous. Imagine if every time a process needed more memory, our OS had to guess the worst-case size upfront and reserve it all immediately, leaving enormous swaths of RAM sitting unused (like StaticCache). Or, shudder, imagine the OS had to find a brand new contiguous block every time a process news more memory, copy it all over, and throw away the old block (like DynamicCache). The horror!

Instead, OS cuts our physical memory up into small fixed-size pages, and hands out the memory page-by-page. Each process starts with some amount of pages on hand. When a process needs more memory, the OS hands it a fresh new page.

The crucial part of all of this is that the process keeps all the pages it was given before, no reallocating, no copying, no throwing away the old pages, and this fresh new page does not have to be physically next to any of the pages given thus far. It can come from anywhere in RAM.

"But wait," you say, "I thought each process was fussy and really, really needed its memory to be contiguous, or they would complain loudly?"

Yes, they would. So we maintain the illusion that the memory is contiguous. :-) The illusion that each process individually sees is virtual memory. Beneath the surface lies physical memory, the actual memory, and the OS quietly manages the gap between the two so that no matter how messy physical memory gets, the process will always feel at home in its peaceful abode of virtual memory.

So a process can believe that as it walks an array from logical address 0x1FF0 to 0x1FF8 to 0x2000, it's frolicking through a tidy, contiguous stretch of memory. But behind the scenes, address 0x1FF0 might live on physical page 42, and address 0x1FF8 might also live on physical page 42, and all of a sudden, we cross a page boundary (!), and address 0x2000 might live on physical page 9.

The mechanism by which OS maintains this illusion is via a page table, one page table per process, responsible for mapping logical address X (virtual memory) to physical page Y (physical memory).

...and now, PagedAttention

PagedAttention does the same thing for KV caches. Just as the OS cuts its physical RAM into pages for handing out on demand, we cut the KV store up into fixed-size blocks (say, 16 tokens per block), and maintain a block table per sequence that maps "logical KV position X" to "physical block Y." When a sequence needs more KV memory, we simply hand it one more block.

First, on memory waste. This paging technique significantly reduces memory waste.

On the StaticCache side, instead of reserving MAX_SEQ_LENGTH_DOOMSDAY_PREPARATION slots per sequence upfront, we allocate one block at a time, only as tokens arrive. So, a sequence that finishes in 10 tokens uses one block and nothing more; a sequence that runs to 2048 tokens accumulates blocks gradually, never holding more blocks than it currently needs.

On the DynamicCache side, when a sequence needs more KV memory, we hand it a fresh block from wherever one happens to be free, no copying and reallocation needed, and of course, no fragmentation because we've pre-carved the pool into same-sized blocks from the start, so there's no such idea as "free memory that's too small to use," just blocks that are free and blocks that aren't.

Second, we can restore our batched forward pass if we so choose. And in regards to the second problem of no single batched forward pass, PagedAttention provides us with the scaffolding for such a solution. Since each sequence's KV data lives in the same fixed-size block format, all we need to do is teach our forward pass how to read the block table. We need to tell our forward pass, "Hey, instead of assuming all sequences live in one contiguous tensor, you should consult the block table. For sequence 3, token position 47 lives in physical block 12, and for this other sequence, its token position lives in this other physical block..." As long as the forward pass can do this lookup, it doesn't matter that sequences are at different steps.

Phase 3, Step 2: Block Allocator

Commit: 8ce893b

Our pool of blocks, at heart, is a single big tensor with this shape:

self.kv_store = torch.zeros(                                                                                                         
    n_blocks, n_layers, 2, n_kv_heads, block_size, head_dim,
    device=device, dtype=dtype,                                                                                                      
)

It accomodates for n_blocks physical blocks, each of which hold space for a block_size tokens' worth of K and V values across all layers.

We allocate this big tensor once at startup, and as sequence after sequence come in asking for memory, our BlockAllocator serves slices of this big tensor, aka "blocks," to our sequences on demand.

Tracking which blocks are free is almost embarrassingly simple, since we're blessed by the fact that every block is the same size, and so we don't need to futz around with finding a size match. We just maintain a double-ended queue like so:

self._free_list: deque[int] = deque(range(n_blocks))

When a sequence no longer needs a block, we free it to the back of this double-ended queue, and when another sequence comes in asking for a block, we remove a block from the front of this queue.

We also have a budget_n_blocks() helper that decides how large to make the pool at startup. It checks how much VRAM remains after model load, takes 80% of that, and divides by the bytes-per-block cost.

Run It Yourself: Alloc Demo

We wrote an alloc_demo that runs the DynamicScheduler to collect real token counts, then reports what block usage would have looked like under PagedAttention compared to static pre-allocation.

modal run modal_run.py::alloc_demo

You should see something like (table has been truncated):

pool capacity: 7730 blocks available after model load

 seq  prompt                         tokens   used   
-----------------------------------------------------
   0  
What is the capital of France?        29t       2b  
   1
Explain in detail how the attenti    1295t      81b
...

See how the LLM takes 2 blocks to compose a response for "What is the capital of France?", while it takes 81 for "Explain in detail.... Meanwhile, static pre-allocation would've handled both of them using 128 blocks each, regardless.

Phase 3, Step 3: Block Table

Commit: 08a0b77

The BlockTable is each sequence's personal map from logical KV position to physical block in the pool. It grows one block at a time.

@dataclass      
class BlockTable:
    block_size: int
    block_ids: list[int] = field(default_factory=list)
    n_filled: int = 0

When a sequence is about to write a token and the current block is full, maybe_extend() tells BlockAllocator to allocate a fresh block and appends this block to the table:

def maybe_extend(self, allocator) -> None:
    if self.n_filled % self.block_size == 0:
        self.block_ids.append(allocator.alloc())

See how, unlike HuggingFace's DynamicCache, there's no copying or reallocation. The sequence just gets handed a new block from wherever one happens to be free in the pool.

Finally, translating a logical token position back to its physical location is a two-step lookup. just a matter of getting the block and offset within the block:

def physical_pos(self, token_pos: int) -> tuple[int, int]:
    logical_block = token_pos // self.block_size
    offset = token_pos % self.block_size
    return self.block_ids[logical_block], offset

All in all, this is sort of like a mini page table you'd see in an OS. From each sequence's POV, we maintain the illusion that the sequence sees a clean, contiguous address space from token 0 onward, while our BlockTable diligently resolves that to wherever the data actually lives in the pool. And when a sequence finishes, we return every block it ever held back to the free list in one shot, immediately available for the next sequence waiting in the queue.

Run It Yourself: Table Demo

The table_demo runs a single sequence through prefill and decode, printing the block table state at each block boundary.

modal run modal_run.py::table_demo

You should see something like (table truncated):

prompt: "Explain how the attention mechanism works in a transformer model."
prompt tokens: 24  generated: 61  total: 85  block_size: 16

  [after token 16]  (16 tokens in KV cache)
   logical   physical          tokens      fill
         0          0            0–15    16/16

  [after prefill (24 prompt tokens)]  (24 tokens in KV cache)
   logical   physical          tokens      fill
         0          0            0–15    16/16
         1          1           16–23    8/16

  ...

You might notice that logical block 0 maps to physical block 0, logical 1 to physical 1, and so on. That's only because this is a single sequence with no competition for the pool. In a real batch, with many sequences racing to claim and return blocks, physical IDs would be scattered around the pool in no particular order.

Phase 3, Step 4: PagedCache

Commit: a8c259d

PagedCache is a drop-in replacement for DynamicCache. We call into the PagedCache the same way we use other caches from HuggingFace's API; what changes is how we store and return the K/V states.

Instead of DynamicCache's torch.cat-ing onto a growing tensor (which does a lot of copying and reallocating)...

self.keys = torch.cat([self.keys, key_states], dim=-2)
self.values = torch.cat([self.values, value_states], dim=-2)
return self.keys, self.values

...our PagedCache writes each token's K/V into its block slot.

block_idx = pos // self.allocator.block_size
offset    = pos % self.allocator.block_size
phys      = self.block_table.block_ids[block_idx]
self.allocator.kv_store[phys, layer_idx, 0, :, offset, :] = key_states[0, :, i, :]
self.allocator.kv_store[phys, layer_idx, 1, :, offset, :] = value_states[0, :, i, :]

Run It Yourself

Let's run the correctness checker

modal run modal_run.py::snapshot

This checks that for every prompt (4 short, 4 long), our engine that uses PagedCache produces bit-identical output to the engine that uses DynamicCache.

Let's also compare the latency, throughput, and memory usage of our inference engines under PagedCache vs HuggingFace's DynamicCache.

modal run modal_run.py::tput
batch_size  dyn_kv  paged_kv
-----------------------------
1           0.05G     0.04G
2           0.09G     0.07G
4           0.17G     0.15G
8           0.35G     0.30G
16          0.69G     0.60G
32          1.38G     1.20G

Phase 3, Step 5: The Gather Kernel

Commit: be40ea7

When you compare the latency of our inference engines under PagedCache vs the pre-existing DynamicCache, you'll see that our latency still lags behind, hovering around 31ms, while DynamicCache sits comfortably under 30ms.

batch_size   dyn_tok/s   dyn_ms/tok   paged_tok/s   paged_ms/tok   
-----------------------------------------------------------------
     1       33.0t/s     30.3ms       32.1t/s       31.2ms 
     2       36.8t/s     27.2ms       32.1t/s       31.1ms   
     4       34.7t/s     28.8ms       31.8t/s       31.5ms    
     8       37.5t/s     26.7ms       32.7t/s       30.6ms 
    16       37.4t/s     26.8ms       32.2t/s       31.0ms  
    32       37.1t/s     27.0ms       32.2t/s       31.0ms  

This is because both schedulers run N sequential forward passes per step, rather than a single batched forward pass as desired. So when we have a batch of, say, 16 sequences, we still do 16 individual forward passes like this:

for seq in self.running:
    if seq.status == Status.DECODING:
        self._decode_seq(seq)

On top of that, our PagedScheduler is slower than our DynamicScheduler because for each forward pass, we perform a _gather for that sequence's respective block table. This _gather is a horribly slow operation:

def _gather(...):
    ...
    for i, phys in enumerate(self.block_table.block_ids):
        if i < n_full:
            k_chunks.append(self.allocator.kv_store[phys, layer_idx, 0])
            v_chunks.append(self.allocator.kv_store[phys, layer_idx, 1])
        ...

    k_out = torch.cat(k_chunks, dim=1).unsqueeze(0)
    v_out = torch.cat(v_chunks, dim=1).unsqueeze(0) 

In regards to a single batched forward pass, you might wonder why we don't gather each sequence's K/V into a contiguous padded tensor, stack all N into a single batched cache, and do a single forward pass on this batch. But this is actually slower than N sequential calls in the common case due to introducing extra padding and copying overhead, hence why we left it as N sequential calls.

What, then, does vLLM do in production that makes PagedAttention so worth the trouble?

The canonical approach is to write our Python _gather() into a gather kernel, which, in one call, given N block tables, reads K/V for all N sequences directly from the shared pool without ever fussing with padded buffers.

This kernel is also what finally earns us back our single batched forward pass, since it takes care of the problem we had earlier of "how does the forward pass know where each sequence's K/V lives when the caches are at different lengths?"

Below is a brief sketch of how we might begin our gather operation as a Triton kernel (the full kernel would be something like ~50-100 lines):

phys_id   = block_ids[token_pos // block_size]
offset    = token_pos % block_size
out[head, token_pos, d] = kv_store[phys_id, layer_idx, kv, head, offset, d]
...

Often, you won't see this in production as a single kernel whose sole job is to do gather. Instead, this gather kernel may be fused with other computations, so you'll see something like a fused attention kernel that does attention, gather, and softmax in one.

This is what production inference engines like vLLM do, using kernels from libraries like FlashInfer. FlashInfer is a library of GPU kernels for common LLM inference operations, often fused together. Inference engines often build on top of such libraries rather than writing the kernels from scratch themselves.

Writing kernels is a discipline in and of itself, one that deserves its own posts rather than being a footnote at the end of this one. So we'll leave our Triton kernel sketch as exactly that, a sketch, and save the craft of kernel writing for a different day.

On the Name

If you've made it this far, now you know where the "Kernel" in Standard Kernel comes from. Writing kernels is actually a small part of what we do. The harder problem, often, is knowing where to write them, where a faster kernel would move the needle most. The kernel is something you reach for only after you've sorted out the bigger questions above it and designed the broader system of "What are we trying to compute and how do we orchestrate it?"

We came for batching, we stayed for paging, and we leave with a working PagedAttention implementation and an appreciation for where kernel writers fit in this bigger picture of ML Systems.