Embeddings to Attention - Relating Tokens to Each Other
Deep dive into attention mechanisms: why transformers replaced RNNs, scaled dot-product attention, multi-head attention, and how context length affects performance
Concepts Covered in This Article
Building On Previous Knowledge
In the previous progression, you learned how tokens become embeddings—vectors that capture meaning. Each token has its own embedding vector.
But there’s a problem: embeddings are independent. The vector for “bank” doesn’t know whether it’s about finance or rivers until it sees the surrounding words.
Attention solves this by letting each token’s representation incorporate information from other tokens. After attention, “bank” in “river bank” has a different representation than “bank” in “savings bank”—because it attended to different context.
What Goes Wrong Without This:
Symptom: Your model truncates long documents and misses important information.
Cause: You treated context as infinite. Attention is O(n²) in memory.
128K context doesn't mean you can use 128K without consequences.
Symptom: Model gives inconsistent answers to the same question.
Cause: In long contexts, attention can miss relevant information.
"Lost in the middle" - models attend more to beginning and end.
Symptom: Reasoning fails on complex multi-step problems.
Cause: Attention struggles to carry information across many hops.
Each hop through attention layers is lossy.
Why Attention Matters
Before attention, sequence models used recurrence (RNNs, LSTMs):
Process sequentially:
token_1 → state_1 → token_2 → state_2 → ... → token_n → state_n
Problems:
1. Can't parallelize (each step depends on previous)
2. Long-range dependencies are hard (gradient vanishing)
3. Information bottleneck (fixed-size state)
A 1000-word document must compress through a single state vector.
Attention allows direct connections:
Every token can directly access every other token:
token_1 ←→ token_2 ←→ token_3 ←→ ... ←→ token_n
↑ ↑ ↑ ↑
└───────────┴───────────┴───────────────┘
All pairwise connections
Benefits:
1. Fully parallelizable (all attention computed at once)
2. Direct long-range access (no bottleneck)
3. Dynamic weighting (attend more to relevant tokens)
This is why Transformers replaced RNNs everywhere.
The Core Idea: Weighted Mixing
Attention is surprisingly simple at its core:
Input: Sequence of token embeddings [v1, v2, v3, v4]
For each token, compute a new representation by
MIXING all tokens weighted by relevance:
new_v2 = 0.1*v1 + 0.6*v2 + 0.2*v3 + 0.1*v4
↑ ↑ ↑ ↑
weights sum to 1.0 (softmax)
The weights (attention scores) determine how much
each token contributes to the new representation.
For the sentence “The cat sat on the mat”:
When processing "sat":
- High attention to "cat" (subject of sat)
- Medium attention to "mat" (related to sitting)
- Low attention to "the" (less informative)
Result: "sat" embedding now contains information
about WHAT sat (cat) and WHERE (mat).
Query, Key, Value
The Q, K, V framework formalizes how attention scores are computed:
+------------------------------------------------------------------+
| INTUITION: Library Metaphor |
+------------------------------------------------------------------+
| |
| Query (Q): What am I looking for? |
| "I need books about machine learning" |
| |
| Key (K): What does each item contain? |
| Book 1: "Introduction to AI" |
| Book 2: "Cooking recipes" |
| Book 3: "Deep Learning fundamentals" |
| |
| Value (V): The actual content to retrieve |
| The book's actual contents |
| |
| Match Query against Keys → Weight Values by match quality |
| |
+------------------------------------------------------------------+
In practice, Q, K, V are linear projections of the input embeddings:
Input embedding: x (dimension d_model)
Q = x @ W_Q # project to query space
K = x @ W_K # project to key space
V = x @ W_V # project to value space
Where W_Q, W_K, W_V are learned weight matrices.
Each token gets its own Q, K, V vectors.
Token i's query asks: "What should I attend to?"
Token j's key advertises: "Here's what I contain"
Token j's value provides: "Here's my information if you want it"
Scaled Dot-Product Attention
The standard attention formula:
Attention(Q, K, V) = softmax(QK^T / √d_k) V
Let's break this down:
Step 1: Compute Attention Scores
scores = Q @ K^T
For a sequence of n tokens, each with d_k dimensional Q and K:
Q: (n, d_k)
K: (n, d_k)
K^T: (d_k, n)
Q @ K^T: (n, n) ← attention scores matrix
scores[i][j] = how much token i should attend to token j
Step 2: Scale
scaled_scores = scores / √d_k
Why scale?
Dot products grow with dimension size.
Large dot products → softmax becomes very peaked
→ gradients vanish (all weight on one token)
√d_k keeps variance stable regardless of dimension.
Step 3: Softmax
attention_weights = softmax(scaled_scores)
Softmax converts scores to probabilities:
- All values between 0 and 1
- Each row sums to 1.0
- High scores → high weights, low scores → near zero
Example row: [2.1, 0.5, -1.0, 0.8]
After softmax: [0.65, 0.13, 0.03, 0.19]
↑
Token with score 2.1 gets most attention
Step 4: Weighted Sum
output = attention_weights @ V
Each output vector is a weighted combination of all value vectors:
output_i = Σ (attention_weight[i][j] * V[j])
This is where information actually flows between tokens.
Complete Picture
+------------------------------------------------------------------+
| SCALED DOT-PRODUCT ATTENTION |
+------------------------------------------------------------------+
| |
| Q (n×d_k) K (n×d_k) |
| │ │ |
| │ ┌─────────┘ |
| │ │ (transpose) |
| ▼ ▼ |
| ┌────────────┐ |
| │ MatMul │ Q @ K^T = (n×n) attention scores |
| └─────┬──────┘ |
| │ |
| ▼ |
| ┌────────────┐ |
| │ Scale │ divide by √d_k |
| └─────┬──────┘ |
| │ |
| ▼ |
| ┌────────────┐ |
| │ Softmax │ convert to probabilities (each row) |
| └─────┬──────┘ |
| │ |
| │ V (n×d_v) |
| │ │ |
| ▼ ▼ |
| ┌────────────────────┐ |
| │ MatMul │ weights @ V = output (n×d_v) |
| └─────────┬──────────┘ |
| │ |
| ▼ |
| Output (n×d_v) |
| |
+------------------------------------------------------------------+
Multi-Head Attention
One attention pattern isn’t enough. Different relationships need different attention:
"The animal didn't cross the street because it was too tired."
Different questions need different attention patterns:
- Q: What is "it"? → attend "it" to "animal" (coreference)
- Q: What action? → attend verbs to subjects
- Q: What's the reason? → attend "tired" to "didn't cross"
Solution: Multiple attention "heads", each learning different patterns.
Multi-head attention runs h parallel attention operations:
+------------------------------------------------------------------+
| MULTI-HEAD ATTENTION |
+------------------------------------------------------------------+
| |
| Input X |
| │ |
| ├───────────────┬───────────────┬─────────────────┐ |
| │ │ │ │ |
| ┌──▼──┐ ┌──▼──┐ ┌──▼──┐ ┌──▼──┐ |
| │Head1│ │Head2│ │Head3│ ... │Head_h│ |
| │ QKV │ │ QKV │ │ QKV │ │ QKV │ |
| └──┬──┘ └──┬──┘ └──┬──┘ └──┬──┘ |
| │ │ │ │ |
| (n,d_v/h) (n,d_v/h) (n,d_v/h) (n,d_v/h) |
| │ │ │ │ |
| └───────────────┴───────────────┴─────────────────┘ |
| │ |
| ┌─────▼─────┐ |
| │ Concat │ Combine all heads |
| └─────┬─────┘ |
| │ |
| ┌─────▼─────┐ |
| │ W_O │ Project back to d_model |
| └─────┬─────┘ |
| │ |
| Output (n, d_model) |
| |
+------------------------------------------------------------------+
Typical configurations:
+------------------+---------------+---------------+
| Model | d_model | Heads (h) |
+------------------+---------------+---------------+
| BERT-base | 768 | 12 |
| GPT-2 | 768 | 12 |
| GPT-3 (175B) | 12288 | 96 |
| LLaMA 7B | 4096 | 32 |
+------------------+---------------+---------------+
Each head has d_k = d_model / h dimensions.
More heads = more diverse attention patterns.
Context Window and Attention
The context window limit exists because attention is O(n²):
For sequence length n:
- Attention matrix: n × n
- Memory: O(n²)
- Compute: O(n²)
+------------------+---------------+---------------+
| Context Length | Attention | Memory |
+------------------+---------------+---------------+
| 1K tokens | 1M entries | ~4 MB |
| 4K tokens | 16M entries | ~64 MB |
| 32K tokens | 1B entries | ~4 GB |
| 128K tokens | 16B entries | ~64 GB |
+------------------+---------------+---------------+
This is why long-context models are expensive.
128K context doesn't mean free 128K—it means 128K² cost.
Techniques for Longer Context
1. Sparse Attention
Instead of n² full attention, attend to subset:
- Local attention: only nearby tokens
- Strided attention: every k-th token
- Random attention: sample positions
BigBird, Longformer use O(n) attention patterns.
Trade: some information paths are blocked.
2. Flash Attention
Not mathematically different—same result.
But implements attention in a memory-efficient way:
- Never materializes full n×n matrix
- Computes in tiles that fit in GPU SRAM
- 2-4x faster, same memory as single forward pass
This is why modern context windows keep growing.
3. Sliding Window / RoPE
Combine:
- Rotary Position Embeddings (RoPE) for relative positions
- Sliding window for bounded attention
- Global tokens that always attend everywhere
LLaMA, Mistral use these patterns.
The “Lost in the Middle” Problem
Even with long context, attention has limitations:
+------------------------------------------------------------------+
| Position in context vs attention received |
| |
| Attention │ |
| Score │ ████ |
| │ ████ ████ |
| │ ████████ ████████████ |
| │ ████████████████ ████████████████████████ |
| └────────────────────────────────────────────── |
| Beginning Middle End |
| |
| Beginning and end get more attention. |
| Middle content can be "lost." |
+------------------------------------------------------------------+
Practical impact:
- Put critical information at beginning or end of prompts
- Don't bury important context in the middle of long documents
- Test your application with information at different positions
Attention Visualization
What attention patterns look like:
"The cat sat on the mat"
Attention weights (simplified, one head):
The cat sat on the mat
The [0.3 0.2 0.1 0.1 0.2 0.1]
cat [0.2 0.4 0.2 0.0 0.1 0.1]
sat [0.1 0.5 0.2 0.1 0.0 0.1] ← "sat" attends heavily to "cat"
on [0.1 0.1 0.3 0.2 0.1 0.2]
the [0.1 0.1 0.1 0.2 0.3 0.2]
mat [0.1 0.1 0.2 0.2 0.2 0.2]
Different heads learn different patterns:
- Head 1: Subject-verb relationships
- Head 2: Positional (nearby tokens)
- Head 3: Syntactic structure
Code Example
Minimal implementation of scaled dot-product attention:
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(
Q: torch.Tensor, # (batch, n, d_k)
K: torch.Tensor, # (batch, n, d_k)
V: torch.Tensor, # (batch, n, d_v)
mask: torch.Tensor = None, # optional mask
) -> torch.Tensor:
"""
Compute scaled dot-product attention.
Returns:
Output tensor of shape (batch, n, d_v)
"""
d_k = Q.size(-1)
# Step 1: Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) # (batch, n, n)
# Step 2: Scale
scores = scores / (d_k ** 0.5)
# Optional: Apply mask (for causal/padding)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 3: Softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1) # (batch, n, n)
# Step 4: Weighted sum of values
output = torch.matmul(attention_weights, V) # (batch, n, d_v)
return output
# Example usage
batch_size, seq_len, d_model = 2, 10, 64
# Random Q, K, V (in practice, these come from linear projections)
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)
output = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}") # (2, 10, 64)
Key Takeaways
1. Attention lets tokens incorporate information from all other tokens
2. Q, K, V are projections that define what to attend to and what to retrieve
3. Scaled dot-product attention: softmax(QK^T / √d_k) @ V
4. Multi-head attention runs h parallel attention operations
- Each head can learn different relationship patterns
5. Context window limits exist because attention is O(n²)
- 128K context = 128K² computation
6. "Lost in the middle" is real
- Critical information should be at beginning or end
Verify Your Understanding
Before proceeding, you should be able to:
Draw the attention formula and explain each component — What does the softmax do? Why scale by √d_k? What does multiplying by V accomplish?
Explain why multi-head attention is better than single-head — Give a concrete example of different “types” of relationships different heads might learn.
Your LLM has 128K context but struggles to answer questions about content in the middle. What’s happening? How would you restructure your prompt?
Calculate the memory required for full attention with 32K tokens at float16 precision. How does this change with 64K tokens?
What’s Next
After this, you can:
- Continue → Attention → Generation — how models produce text token by token
- Go deeper → Explore transformer architectures, pre-training objectives