Inference Engines 0/N: Foundations
It's become important for me to understand LLM inference engines at work, especially since we're optimizing them, and so I decided that I'd walk through the process of writing and optimizing a minimal inference engine from scratch.
How to Read This Series
I'm dividing this series up into phases, one post per phase, where I start with a naive inference engine (that's this post), and with each phase, I introduce a common optimization.
Each phase will have its own branch in the tiny-decode Github repository. Within each phase, I divide my work into tiny, digestible steps, that I push as commits to the branch.
So, as you read through my post, you can go in and look at each individual commit, copy over my changes to your editor, run it, and voila, see the inference engine improve step by step.
(These posts take a rather long time to write. Basically, I try things in a private repo until I land on a version I'm happy with, then I think through how to break it up into digestible steps for a reader like you, and I write it a second time, this time in an order that deceptively looks like I knew exactly what decisions to make in exactly the right order, and that's the public copy you get.)
Call Me Ishmael
Call me Ishmael. Some years ago, never mind how long precisely, having little or no money in my purse, and nothing particular to interest me on shore, I thought I would sail about a little and see the watery part of the world. — Herman Melville, Moby Dick
Moby Dick's about a guy who spends an entire novel chasing a white whale. I didn't really get him then, I don't really get him now, but I think I'm starting to.
Some time ago, never mind how long precisely, I found myself unable to look away from a single line of code:
pythonout = lm.model(input_ids=input_ids, use_cache=False)
That one line, that there's our white whale.
Call it an obsession, call it a fixation, call me a madman.
Don't worry if it means nothing to you yet. We have ~6 phases ahead of us to swing at that whale, and today's phase 0 is just our scaffolding so that when we start swinging, we'll know if we hit anything. And the end of this post, this phase 0, you'll understand why every piece of scaffolding we build exists in service of making that one line faster.
Why do we have inference engines?
So, you've taught yourself deep learning with your trusted deep learning course, and you trained an LLM that works beautifully in a Jupyter notebook.
Now, getting that LLM in front of millions of users turns out to be a very different engineering problem. When you open up ChatGPT, or Claude, or Gemini, it's actually quite remarkable how the chatbot can hold a conversation that spans dozens of messages, stays on topic, and maintains a consistent personality throughout.
To understand why that's hard, it helps for us to put ourselves in that chatbot's shoes. Basically, imagine that you wake up in a strange, dark room with this curious-looking memo that it's never seen before in your life.
1 - Stranger on internet (user): Hey!
You read the memo from beginning to end and write a response on the memo.
1 - Stranger on internet (user): Hey!
2 - You (LLM): your glorious response
Upon sending this message, a mysterious force knocks you out, and not too long after, you wake up in a strange, dark room with this curious-looking memo that you've never seen before in your life.
1 - Stranger on internet (user): Hey!
2 - You (LLM): your glorious response
3 - Stranger on internet (user): their glorious response
You read the memo from beginning to end and write a response on the memo. And then you get knocked out again, and you wake up once again in this strange room with this curious memo, and on, and on, and on...
The important part is that every time the chatbot wakes up, it's reading the memo from beginning to end. So, on the first round, the memo is 1 message long. On the second round, it's 3 messages long. On the third round, it's 5 messages, then 7, then 9, then 11. After who knows how many rounds, it's reading 1+3+5+7+9+11+... messages from beginning to end.
If you think about it, that's an egregious amount of reading the chatbot is doing, all because it was born with the unfortunate handicap of not being able to remember things.
And when you think about how many users are talking to this chatbot every day, and how long of a conversation they're having, and consider your users' predilections for writing long novellas and engaging in lengthy therapy sessions, it's quite impressive how we can serve these chatbots so quickly, reliably, and affordably.
Phase 0, Step 1: Our LLM Wakes Up in a Dark Room
Commit: 0a50648
So, an inference engine begins with an LLM waking up in a dark room. This dark room can be your local laptop, or a rented GPU in the cloud, just some form of hardware that our LLM can run on.
For the purposes of this series, we'll be using Modal.
Why Modal? If you have a GPU lab at home, then yeah, by all means, go use it! But no two readers' machines are identical and can produce very different results.
We could use a cloud provider like AWS or GCP, but getting GPU quota approved can sometimes involve wading through customer support (speaking from personal experience), and it's easy to rack up charges from a forgotten running machine.
Modal sidesteps all of this, as it's serverless, so GPUs spin up on demand, and you're only billed for the seconds that your code runs. For a toy learning series like this, Modal is perfect!
After this spiel, I should say that, no, I'm not sponsored by Modal, but I wouldn't mind getting some free credits ;o.
So, head on over to modal.com and create a free account; it should take about two minutes. Once you're in, copy modal_run.py into a fresh folder on your machine.
And now copy generate.py into your editor, in the same folder as modal_run.py.
To quickly go over what it does, our generate.py loads GPT-2 from HuggingFace. We're using GPT-2 because it's small enough to iterate on quickly, while still being a real transformer with all the same structural properties as much larger models.
The part we'll be zooming in on is this generate function, which is the loop responsible for producing the LLM's output, one token at a time.
def generate(lm: LoadedModel, prompt: str, max_new_tokens: int = 64) -> str:
tok = lm.tokenizer
input_ids = tok(prompt, return_tensors="pt").input_ids.to(lm.device)
eos_id = tok.eos_token_id
for _ in range(max_new_tokens):
out = lm.model(input_ids=input_ids, use_cache=False) # ⭒⭒!!
next_id = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_id], dim=1)
if eos_id is not None and next_id.item() == eos_id:
break
return tok.decode(input_ids[0], skip_special_tokens=True)
Stop and Think
A few things worth orienting yourself to before we move on:
- You can think of each iteration as the LLM "waking up" in the dark room all over again, where
lm.model(input_ids=input_ids, ...)is the LLM reading the memo (input_ids) from beginning to end, andout.logits[:, -1, :].argmax(...)is it picking the single next token to compose. That token gets appended to the end of the memoinput_ids, and then the loop goes around again, with the memo one token longer. This is why generation gets slower the longer the conversation goes. eos_idis the stop token, a special token the LLM emits when it decides it's done. Without it, the loop would just run untilmax_new_tokensis exhausted.
⭒⭒!! The most expensive line, by far, is the lm.model(...) call. So when we talk about optimizing inference, we're really talking about making that line faster. The rest of this series follows from that fact.
There she blows, our white whale.
Run It Yourself
Now, let's see what happens when, in our Modal app, we hand our LLM a memo with the words, "The transformer architecture revolutionized NLP because".
@app.local_entrypoint()
def main(
prompt: str = "The transformer architecture revolutionized NLP because",
max_new_tokens: int = 64,
):
print(run_remote.remote(prompt=prompt, max_new_tokens=max_new_tokens))
Give it a run! (Modal might ask you to authenticate the first time around.)
modal run ./modal_run.py
After our LLM's laborious process of waking-up-and-reading-and-rereading-and-writing-to-the-memo, you should see our LLM write something:
The transformer architecture revolutionized NLP because it was able to provide a high-speed, low-cost,
and low-power power supply.
The NLP transformer architecture revolutionized NLP because it was able to provide a high-speed,
low-cost, and low-power power supply. The NLP transformer architecture was able to provide a
And just like that, we've written a baby LLM inference engine. Now we'll spend the rest of our days making it fast.
Phase 0, Step 2: Add Timing
Commit: 501dbca
We know the engine works, but we have no idea how fast it is. Before we can optimize anything, we need something to optimize against - in other words, a baseline of sorts against which to compare ourselves.
So, let's add per-token timing, like so:
@torch.inference_mode()
def generate(
lm: LoadedModel, prompt: str, max_new_tokens: int = 64
) -> tuple[str, list[float]]:
tok = lm.tokenizer
input_ids = tok(prompt, return_tensors="pt").input_ids.to(lm.device)
eos_id = tok.eos_token_id
per_token_seconds: list[float] = []
for _ in range(max_new_tokens):
torch.cuda.synchronize() # (・_・?)
t0 = time.perf_counter() # (・_・???)
out = lm.model(input_ids=input_ids, use_cache=False)
next_id = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_id], dim=1)
torch.cuda.synchronize() # (・_・?)
per_token_seconds.append(time.perf_counter() - t0) # (・_・???)
if eos_id is not None and next_id.item() == eos_id:
break
text = tok.decode(input_ids[0], skip_special_tokens=True)
return text, per_token_seconds
Since you might ask, it's worth noting the following:
(・_・?)
torch.cuda.synchronize()is necessary because GPU operations are asynchronous by default, meaningtime.perf_counter()would measure almost nothing without it.Without it,
time.perf_counter()would measure the time for the CPU to dispatch work to the GPU, rather than the time taken for the GPU to finish said work. It's kind of like timing how long it takes you to ask your AI coding assistant to do something ("do it make no mistakes ultrathink"), rather than timing how long it takes the AI coding assistant to... do it.The
synchronize()call tells the CPU to wait until the GPU is done, so we're measuring the right thing.
Stop and Think
(・_・???) You might wonder why we time each token individually rather than just timing the whole loop. Please think about this before reading on.
In the meantime, we'll take a quick break and do some relaxing runs, so you can let the question ferment in the back of your head.
Run It Yourself
When you run the timing...
modal run ./modal_run.py
...you should see a number at the bottom, something like this:
70.77 tok/s
Remember that number, it'll be important for comparing with Step 3 very soon.
Phase 0, Step 3: The Cold Start Problem
Commit: b6dd3ca
If you run the engine a few times, you'll notice that the numbers aren't stable.
The first run is almost always slower than the runs that follow. This is what we call the cold start problem; the first time the model runs, the GPU has to "warm up" of sorts, by loading weights into its fast cache, compiling some kernels, and do a handful of other one-time setup tasks. The subsequent runs skip all of that and get the privilege of running on "warm" hardware, so they're faster and more consistent.
So we do what any reasonable benchmark does, which is throw away the first run as a warmup, then take the mean across several runs.
The code itself is not particularly exciting; it's the standard warmup-and-average loop you'll find in most benchmarking harnesses. The more interesting question is what exactly we're averaging and why, which we'll talk about in Step 4.
Run It Yourself
Now when you rerun the timing...
modal run ./modal_run.py
You should see a series of numbers followed by a mean, like this:
warmup (discarded): [63.69] tok/s
per-run tok/s: [191.30, 189.16, 194.16]
mean: 191.54 tok/s
Do you see the cold-start problem in action? See how the throughput of the first run is dramatically lower than that of the subsequent runs, and that it should rightfully be discarded?
If you're getting strange results, i.e. your number from Step 2 is wildly different from the figure you see under
warmup (discarded), please contact me.
You might be taken aback by how dramatic the ratio is, that, oh, what's this, (191.54-63.69)/191.54 = ~66.7% of step 2's time was spent on cold-start overhead? For a small model like GPT-2 where each forward pass is only a couple milliseconds of real work, the cold-start overhead is comparatively large. Once we use a bigger model or each iteration runs longer, we can expect this warmup ratio to shrink.
Phase 0, Step 4: Prefill and Decode
Commit: 03141aa
Okay, now that we've let your brain relax a bit, do you have your best guess for this question?
(・_・???) You might wonder why we time each token individually rather than just timing the whole loop. Please think about this before reading on.
It's to lay the groundwork for the optimizations we'll have ahead of us. The basic motivation behind our optimizations is as follows:
The first token is special. Before the LLM can produce anything, it has to read the entire memo from beginning to end, and this is called the prefill stage. This first reading, this prefill stage, is very much unavoidable, and the prefill cost scales with how long the memo is.
However, as painful and unavoidable as the cost of the first token is, it doesn't have to be the same way for the next tokens. If we're smart and we come up with good optimizations, the tokens after that can be much cheaper.
So, we say that every token after that is the decode stage, which (if we've done our optimizations right) is comparatively cheap and roughly constant per token.
If we only timed the full loop, these two very different costs would be difficult to separate. By measuring per-token timings, we can measure prefill and decode separately: the timing of our first token is our time to first token (TTFT) (a measure of prefill), and the steady-state mean of the rest of our tokens is our decode speed (a measure of decode).
for _ in range(runs):
text, per_token_seconds = generate(lm, prompt, max_new_tokens=max_new_tokens)
throughputs.append(len(per_token_seconds) / sum(per_token_seconds))
ttfts_ms.append(per_token_seconds[0] * 1000.0) # prefill
if len(per_token_seconds) > 1:
steady_ms.append(statistics.mean(per_token_seconds[1:]) * 1000.0) # decode
Run It Yourself
Now when you rerun the benchmarks...
modal run ./modal_run.py
You should see the prefill and decode numbers.
tokens/sec: 189.86
TTFT: 5.1 ms (prefill)
steady ms/t: 5.27 ms (decode)
Do you see how TTFT and ms/token are basically equal, at 5.1 vs 5.27 ms? This is to be expected in this naive implementation of our inference engine. As we progress through this series, you should expect our decode metric, ms/token, to go down. :)
Intermission: Refactor
We want to add one more metric, peak memory, to our benchmarking harness, but our file modal_run.py has been getting a bit unwieldy with us adding more and more benchmarking logic to it. We can also expect generate.py to grow quite a bit, as we add more optimizations to it.
It's time we create two directories for each :
engine/will house our inference engine, which has our model loadermodel.pyand the autoregressive loopdecode.py(which will change throughout this series and split off into baby files of its own).harness/will house our benchmarking harness,bench.py, which contains the logic we added in steps 3 and 4 (and which shouldn't change much).
Phase 0, Step 5: Peak Memory
Commit: 5395784
Great, now we can add our metric to our benchmarking harness: peak memory.
@dataclass
class BenchResult:
model: str
runs: int
tokens_per_sec: float
ms_per_token: float
ttft_ms: float
peak_memory_mb: float # <- new metric!
We measure peak memory using two lines of code, where max_memory_allocated() tells us the most GPU memory we've ever held at any single moment during the benchmark.
torch.cuda.reset_peak_memory_stats() # before the runs
...
peak_mb = torch.cuda.max_memory_allocated() / (1024 * 1024)
Stop and Think
Why do we measure peak and not average? The reason for this is that out-of-memory issues (OOM, which I affectionately pronounce like gloom), are sort of a binary life-or-death situation. If you exceed your GPU's memory by even a couple of bytes for a couple of milliseconds, you get hit with the OOM, and the GPU kills your process. So it's pretty important that we keep our peak memory usage below the dangerous threshold that could cause an OOM.
Run It Yourself
Now when you rerun the benchmarks...
modal run ./modal_run.py
...you should see peak memory.
model=gpt2 runs=3
tokens/sec: 180.50
ms/token: 5.54 (steady-state, decode)
TTFT: 5.4 ms (prefill)
peak memory: 267 MB # <- new metric
I expect at least two of our planned phases to improve this metric significantly (PagedAttention, Quantization), and what's not gratifying about seeing yet another kind of number go down?
Phase 0, Step 6: Trust, but Verify
Commit: 992653f
We're almost ready to start optimizing.
One last thing: imagine that we code up our next phase (KV cache) with all its wonderful improvements, and our engine is 2x faster, and all the pretty numbers we introduced to you from above now look prettier.
Except, now, the new Phase 1 output reads...
"The transformer architecture revolutionized NLP because it provided a flexible, parallel approach..."
...instead of Phase 0's original...
"The transformer architecture revolutionized NLP because it was able to provide a high-speed, low-cost..."
Both are plausible English, and both sound like things that GPT-2 might say.
But... are we sure we haven't introduced a bug?
Q: Aren't LLMs nondeterministic? People on Twitter say that all the time!
A: They can be, but it depends on how you pick each token.
To go a little more into the weeds, the LLM doesn't output a single word, but rather a probability distribution across its entire vocabulary. For example, "The cat sat on the ___" might give "mat" a 40% chance, "floor" 25%, "couch" 15%, and so on.
In production, most LLMs sample from this distribution at random; they roll a weighted die so that "mat" comes up most often but not every time, there's a slight change of "floor" and "couch" too. So, two runs on the same input with the same weights can produce different outputs.
Meanwhile, our engine uses greedy decoding, which is a different form of token selection. Recall that
.argmax()you saw earlier, we take whichever token has the highest probability every time. So two runs on the same input with the same weights should produce the same output.
So, we add a file, snapshot.py, for our byte-for-byte correctness check:
- Run Phase 0. Hash the generated text. Save the hash to disk.
- Run your experimental new phase. Hash its generated text and compare this with the saved hash.
- Match -> yay! Mismatch -> bug.
Run It Yourself
First, save a the hash of Phase 0 to disk, by running with --write-snapshot. This will save the hash to snapshots/greedy.json, and you only need to do this once.
modal run modal_run.py --write-snapshot
From there, for each optimization that needs this sort of correctness check, run it with --check, and it'll tell you whether the hash matches or doesn't.
modal run modal_run.py --check
You should see something like
[snapshot] snapshot match (gpt2, n_tokens=64)
in the logs.
For fun, try finagling the output to something that's ostensibly not "The transformer architecture revolutionized NLP because it was able to provide a high-speed, low-cost...", like text="asdf". Rightfully, you should see something like
[snapshot] snapshot MISMATCH for (gpt2, n_tokens=64)
expected: f23961b8a2847fa78efa239821c614f9827f8640bb74f1769a6e1b603e50a4cb
actual: f0e4c2f76c58916ec258f246851bea091d14d4247a2fc3e18694461b1816e13b
expected text: 'The transformer architecture revolutionized NLP because it was able to provide a high-speed,
low-cost, and low-power pow'...
actual text: 'asdf'...
Stop and Think
When it comes to correctness checking, on what occasions do we care about a byte-for-byte match, and on what occasions do we not?
In short, we care about exact matches in math-preserving optimizations. When we say "math," you can think of the model as a mathematical function that takes an input, does some transformations to it, and produces an output.
This mathematical function changes when we change how the model does transformations.
For math-preserving operations, e.g. optimizations that have to do with how we move data around (like paging, batching, FlashAttention), the underlying mathematical function doesn't change. For these, the hashes must match, and if they don't, we know we have a bug.
However, we have a different story for math-altering optimizations. For example, Quantization replaces the weights of the model, thereby changing how the LLM does mathematical transformations, and hence the generated text will diverge from our original implementation in Phase 0.
For math-altering optimizations like these, we need to relax our correctness check, which we'll figure out when we get to it.
Goodbye and Our White Whale
And that's Phase 0, a working (if slow) inference engine, with a benchmarking harness and some semblance of a correctness check that we trust.
And we did all of that, all so we could finally get a clear fix on the object of our pursuit, our white whale if you may, that's this one line:
out = lm.model(input_ids=input_ids, use_cache=False)
And finally, (finally!) we're ready for Phase 1. In Phase 1, we'll take our first swing at our white whale, with the KV Cache.
Until next time, and see you there.
