Accelerating LLMs with GEMM: Part 2
Introduction to GEMM and Matrix Multiplication Basics
General Matrix–Matrix Multiplication (GEMM) refers to the operation of multiplying two matrices together (often adding a scaled third matrix as well). At its core, matrix multiplication produces an output matrix where each element is the dot product of a row from the first matrix with a column of the second. For example, multiplying a 4×3 matrix by a 3×4 matrix yields a 4×4 result. Each cell in the result is computed by summing elementwise products from one entire row (green) and one entire column (red). Matrix multiplication combines a row of matrix A (green) with a column of matrix B (red) to produce each element of result matrix C (blue). Each cell of C is a dot product of one row of A and one column of B.
In mathematical notation, if A is an m \times k matrix and B is a k \times n matrix, their product C is an m \times n matrix. The entry C_{ij} is given by:
C_{ij} = \sum_{t=1}^{k} A_{i,t} \times B_{t,j} \,.
This operation has a time complexity of O(m \times n \times k) (triple nested loops in a naïve implementation). BLAS (Basic Linear Algebra Subprograms) defines standard routines for such operations. GEMM is a Level 3 BLAS operation, meaning it involves matrix–matrix operations (Level 1 is vector ops, Level 2 is matrix–vector). In BLAS notation, a general matrix multiply often allows optional scaling and addition: one can compute C = α · A × B + β · C in a single GEMM call. For instance, the BLAS routine SGEMM performs single-precision (32-bit float) GEMM. In deep learning frameworks like PyTorch, a basic matrix multiplication C = A @ B (or torch.matmul(A, B)) under the hood will invoke an optimized GEMM (with α=1, β=0 by default).
Why is GEMM so important? Modern neural networks—especially large language models (LLMs)—spend most of their compute time performing large matrix multiplications. Any fully connected (dense) layer in a neural network is essentially a matrix multiply between input activations and weight parameters. As we’ll see, transformer-based LLMs consist of many sublayers (attention projections, feed-forward layers, etc.) that are implemented as GEMM operations. Because matrix multiplication is such a fundamental and heavy computation, it has been heavily optimized in both software (BLAS libraries, PyTorch’s torch.mm/torch.matmul) and hardware. Using efficient GEMM is critical for speed, and optimizing GEMM yields direct improvements in model training and inference performance. Next, we’ll explore why GPUs excel at GEMM and how LLMs leverage GEMM in their architecture.
Why GPUs Excel at GEMM Operations
Graphics Processing Units (GPUs) are specialized for high parallelism, which makes them ideal for matrix math. Unlike a CPU (which has a limited number of cores optimized for sequential serial performance), a GPU contains thousands of smaller cores (threads) that can perform arithmetic simultaneously. In a matrix multiplication, there are many independent multiply-add operations (each output element’s dot-product involves summing many independent multiplications). A GPU can assign these operations across many threads to compute results in parallel. Modern GPUs also employ a Single-Instruction Multiple-Thread (SIMT) model: groups of threads (warps) execute the same instruction on different data, exactly matching the pattern of applying the same multiply-add across different elements of a matrix.
Another reason GPUs accelerate GEMM is their memory architecture. They have high-bandwidth memory and on-chip shared memory caches. Optimized GPU kernels “tile” the matrix multiply, loading sub-blocks of the matrices into fast shared memory and registers to reuse data effectively. By processing the multiplication in chunks (blocks of rows and columns), GPUs minimize slow global memory accesses and maximize computations on fast local data. The net effect is that a well-implemented GEMM on a GPU keeps thousands of arithmetic units busy most of the time (high compute utilization), whereas a naive approach would be bottlenecked by memory or insufficient parallel work.
Tensor Cores and specialized hardware: Newer GPU architectures (NVIDIA Volta, Turing, Ampere, Hopper, etc.) include Tensor Core units specifically designed to multiply small matrices very efficiently in lower precision. For example, NVIDIA’s A100 GPU can reach 312 TFLOPS (tera-floating-point ops per second) for FP16 matrix multiply using Tensor Cores, compared to about 19.5 TFLOPS for standard FP32 operations . Tensor Cores perform multiple fused multiply-adds in one clock cycle, greatly boosting throughput for GEMM in FP16/BF16 or INT8. Even TensorFloat-32 (TF32), a 19-bit format on Ampere GPUs, allows FP32 matrix multiplies to run on Tensor Cores with minimal code change. TF32 uses an 8-bit exponent (like FP32) and 10-bit mantissa (like FP16) . By automatically executing FP32 ops in TF32 mode, an A100 can deliver up to 156 TFLOPS for what the user perceives as FP32 math . In practice, this means an Ampere GPU can be an order of magnitude faster at matrix multiplies than earlier generations or than a CPU, simply by virtue of these specialized units and massive parallelism.
Let’s illustrate the performance difference with a simple example (if you have a GPU available). The following code multiplies two large matrices on CPU vs GPU:
import torch, time
# Matrix dimensions
N = 2048
A = torch.randn(N, N) # CPU tensor (default float32)
B = torch.randn(N, N)
# CPU matrix multiplication timing
start = time.time()
C_cpu = A @ B # or torch.mm(A, B)
end = time.time()
print(f"CPU GEMM time: {end - start:.3f} s")
# GPU matrix multiplication timing (if CUDA is available)
if torch.cuda.is_available():
A_gpu = A.cuda()
B_gpu = B.cuda()
torch.cuda.synchronize() # ensure data is on GPU
start = time.time()
C_gpu = A_gpu @ B_gpu
torch.cuda.synchronize() # wait for GPU to finish
end = time.time()
print(f"GPU GEMM time: {end - start:.3f} s")
On a modern GPU, you would observe a dramatic speedup (often 10× or more) for this 2048×2048 matrix multiply compared to a CPU. This is why frameworks like PyTorch offload tensor operations to GPUs – especially GEMM – to achieve the performance needed for training large models.
GEMM in Transformer-Based LLM Architectures
Large language models (LLMs) such as Transformer architectures (e.g. GPT-3, BERT, Llama 2, etc.) rely heavily on matrix multiplications in their building blocks. Understanding where GEMM appears in these models will clarify why optimizing GEMM is so crucial:
-
Embedding layers: The first part of an LLM often maps token indices to embedding vectors. If implemented naïvely, this could be seen as multiplying a one-hot vector by an embedding matrix. In practice this is just a lookup, but conceptually it’s a matrix–vector product. The embedding matrix (vocab_size × hidden_dim) is often huge, and retrieving a batch of embeddings can leverage optimized GEMM routines or batched gathers (though usually done by dedicated kernel rather than GEMM for efficiency).
-
Self-Attention (Q, K, V projections): In a Transformer, each input token’s representation is linearly projected to form Query, Key, and Value vectors. These are implemented as three separate linear layers (weight matrices) applied to the input. If the model’s hidden size is d_model, each of Q, K, V is typically of size d_model as well. Computing Q = X · W_Q (where X is the batch of input features, shape [batch_size×seq_length, d_model] if flattened) is a matrix multiplication. The same goes for K and V. Often, implementations fuse these three projections for efficiency (e.g. using a single larger weight matrix or a batched GEMM), but either way the operations are GEMMs. After projection, Q, K, V might be reshaped to separate multiple attention heads, but that’s just a view of the same data.
-
Attention score computation (Q × K^T): To produce attention weights, we take the dot product of each Query with each Key. This is effectively a matrix multiply: Q (of shape [batch_size * num_heads, seq_len, head_dim]) times K^T (shape [batch_size * num_heads, head_dim, seq_len]) yields an attention score matrix of shape [batch_size * num_heads, seq_len, seq_len]. This is a GEMM operation for each attention head (often done in parallel as a batched or block-sparse GEMM). The result is then normalized (softmax) – which is not GEMM but a pointwise operation.
-
Attention output (softmax(QK^T) × V): After softmax, the weighted sum of Values is computed as the matrix multiplication of the softmax output and V (shape [batch_size * num_heads, seq_len, head_dim]). This produces the attended output for each head. That is another GEMM. In total, the attention mechanism involves multiple GEMMs: one for QK^T and one for applying those weights to V (plus the initial linear projections).
-
Feed-Forward Network (FFN): Each Transformer block has a two-layer MLP (often called the feed-forward network). The first layer takes the d_model features and projects to a larger intermediate size (e.g. 4×d_model), usually with a nonlinear activation (like GELU) after it. The second layer projects back to d_model. Both layers are standard fully-connected layers – i.e. matrix multiplications plus biases. For example, in Llama 2 7B (with hidden size around 4096), the FFN might expand to ~11000 then back to 4096, which are huge matrix multiplications.
-
Output layer: LLMs often have a final linear projection before producing logits over the vocabulary (or they reuse the embedding matrix for a softmax). This final step is again a matrix multiply of the last layer’s activations with a large weight matrix of size d_model × vocab_size. For generation, this is performed at each time step for inference (or as part of training loss computation).
As we can see, virtually every significant computation in a Transformer block is a GEMM. The layer normalization and softmax are minor by comparison (those are memory-bound, elementwise ops). Profiling studies confirm this: in a BERT-like Transformer, matrix multiplications dominate runtime. One study found that GEMM operations accounted for about 61% of the total execution time for a moderate sequence length, and even when sequence length increased (making attention more expensive), attention + GEMMs together made up ~89% of the time . In short, the bulk of a Transformer’s compute is spent inside GEMMs. This is why GPU-accelerated libraries and frameworks put so much effort into optimizing these matrix multiplies – any improvement here has an outsized impact on overall training or inference speed.
To illustrate how these look in code, consider a highly simplified Transformer attention snippet using PyTorch:
import torch
import torch.nn.functional as F
# Dummy dimensions
batch, seq, d_model, num_heads = 2, 4, 8, 2
head_dim = d_model // num_heads
# Random input sequence
x = torch.randn(batch, seq, d_model)
# Weights for Q, K, V projections (d_model -> d_model)
Wq = torch.randn(d_model, d_model)
Wk = torch.randn(d_model, d_model)
Wv = torch.randn(d_model, d_model)
# Linear projections (batch*seq, d_model) @ (d_model, d_model) -> (batch*seq, d_model)
q = x.reshape(-1, d_model) @ Wq
k = x.reshape(-1, d_model) @ Wk
v = x.reshape(-1, d_model) @ Wv
# Reshape into [batch, seq, heads, head_dim]
q = q.view(batch, seq, num_heads, head_dim)
k = k.view(batch, seq, num_heads, head_dim)
v = v.view(batch, seq, num_heads, head_dim)
# Compute attention scores for each head: (batch, heads, seq, seq)
scores = torch.einsum('bhqd, bhkd -> bhqk', q, k) # this is batch matmul Q * K^T
weights = F.softmax(scores / (head_dim ** 0.5), dim=-1) # softmax normalization
# Apply weights to V: (batch, heads, seq, head_dim)
attn_output = torch.einsum('bhqk, bhkd -> bhqd', weights, v) # another batch matmul
attn_output = attn_output.reshape(batch, seq, d_model) # merge heads
print("Attention output shape:", attn_output.shape)
In this toy example, we used Einstein summation (einsum) to express the batched matrix multiplications. Under the hood, PyTorch will route these to optimized batch GEMM kernels (einsum with those subscripts is essentially doing Q @ K^T and softmax @ V). The key takeaway is that each of these steps is a matrix multiplication. In a real model like Llama-2, these matrices are huge (for instance, if d_model=4096 and seq=1024, the QK^T multiplication is 1024×4096 times 4096×1024, and the two FFN layers are even larger multiplications).
Because so much of the model’s time is spent in these operations, efficient GEMM is synonymous with efficient LLM execution. Next, we’ll look at how we can speed up these GEMMs in PyTorch using various techniques: lower precision, batching, quantization, and kernel fusion.
Optimizing Matrix Multiplication in PyTorch
PyTorch uses highly optimized libraries (like NVIDIA’s cuBLAS or Intel MKL) for matrix operations by default. However, as users we have some control and techniques to further improve speed and memory usage for GEMM-heavy workloads like LLMs. Below are several key optimization strategies:
1. Mixed Precision (FP16 and BF16) Training
One of the most effective ways to accelerate GEMM is to use lower precision floats. By using 16-bit floating point instead of 32-bit, we halve memory usage and often get increased math throughput on hardware with tensor cores.
-
FP16 (half-precision) has 1 sign bit, 5 exponent bits, 10 fraction bits. It has a smaller range and precision than FP32. Using FP16 for matrix multiplication can speed up compute significantly on GPUs with tensor core support. For example, on Volta/Ampere GPUs, matrix multiplies in FP16 are executed on tensor cores which can be 8–16× faster than FP32 on standard cores. The downside is reduced precision and potential for overflow/underflow, which is why in training we often use loss scaling to avoid FP16 underflow.
-
BF16 (bfloat16) is an alternative 16-bit format with 1 sign, 8 exponent, 7 fraction bits. It has the same exponent range as FP32 (due to 8 exponent bits) but fewer mantissa bits. BF16 is popular for training because it retains dynamic range (so gradients don’t overflow easily) while still using 16 bits. New GPUs (Ampere and later) support BF16 arithmetic in tensor cores similarly to FP16. In practice, BF16 often achieves similar speed to FP16 but with simpler training (no need for loss scaling).
In PyTorch, using mixed precision is straightforward. You can cast models and data to half or bfloat16, or use the high-level torch.autocast context for automatic mixed precision (AMP). Here’s a simple demonstration of using FP16 in a GEMM:
import torch
# Create random matrices in full precision
A32 = torch.randn(1024, 1024, dtype=torch.float32, device='cuda')
B32 = torch.randn(1024, 1024, dtype=torch.float32, device='cuda')
# Multiply in FP32
C32 = A32 @ B32
# Convert to half-precision and multiply
A16 = A32.half()
B16 = B32.half()
C16 = A16 @ B16
print(C16.dtype) # torch.float16
# Compare results (convert C16 to float32 for fair comparison)
max_error = (C32 - C16.float()).abs().max().item()
print(f"Max difference between FP32 and FP16 result: {max_error}")
This code will output something like Max difference ...: 0.0012. The small numerical difference is due to precision loss in FP16. For inference, such a difference is usually negligible; for training, algorithms like loss scaling or switching to BF16 preserve model stability. When enabling mixed precision training (for example with torch.cuda.amp.autocast(dtype=torch.float16)), matrix multiplies and convolutions run in lower precision on tensor cores, while certain sensitive operations (like reductions, accumulations, and some normalization layers) can be kept in FP32 for accuracy. The result is often 2× or more throughput for the GEMM operations.
BF16 in PyTorch: To use BF16, you can cast tensors to torch.bfloat16 similarly. On CPUs, PyTorch supports BF16 arithmetic in recent versions (though it may be emulated or use AVX512 on capable hardware); on GPUs like A100/H100, BF16 is hardware accelerated. An example is A_bf16 = A32.to(torch.bfloat16), then C_bf16 = A_bf16 @ B_bf16. BF16 typically yields a tiny bit more precision than FP16 (because it has 3 more exponent bits instead of fraction bits). Many LLM training recipes now prefer BF16 as it avoids the fuss of loss scaling while still leveraging tensor core speed.
In summary, by using FP16 or BF16, we trade a tolerable amount of numerical precision for major gains in speed and memory usage. Most large model training today is done in mixed precision for this reason. (Note: When using mixed precision, ensure your GPU supports fast FP16/BF16; older GPUs without tensor cores won’t see as much benefit, and on those, 32-bit might even be faster due to better-supported vectorization.)
2. TensorFloat-32 (TF32) on NVIDIA Ampere+
NVIDIA introduced TF32 on the Ampere architecture as a way to accelerate FP32 GEMMs transparently. TF32 is a 19-bit format (10-bit mantissa, 8-bit exponent, 1 sign) used inside tensor cores . By default, in Ampere GPUs, torch.matmul on float32 tensors will actually use tensor cores with TF32 math (unless disabled) and accumulate in FP32. The idea is you get FP16-like speed with almost FP32-level range/accuracy. No code changes are needed in PyTorch – it’s enabled by setting a flag:
torch.backends.cuda.matmul.allow_tf32 = True # this is usually True by default in training
If you want full deterministic FP32 (for example, for certain numeric comparisons), you can turn TF32 off (allow_tf32 = False), but you’ll lose the speed benefit. In practice, TF32 incurs a tiny accuracy drop (similar scale as 16-bit precision error) but is generally safe for training. It’s one reason Ampere GPUs (A100, RTX 30-series) were so much faster for training than previous gen without requiring explicit FP16 usage – many ops got the boost automatically .
To summarize: TF32 allows float32 code to execute on tensor cores by rounding the mantissa to 10 bits. It provides up to 2×~3× speedups for GEMMs compared to classical FP32 (the A100’s tensor core FP32 TF32 performance is 156 TFLOPS vs 19.5 TFLOPS in normal FP32 ). If you’re using PyTorch on Ampere or newer, make sure TF32 is allowed (it is by default for CUDA >= 11). This way you get a free performance boost for GEMM-heavy operations even if you aren’t manually using half precision.
3. Batched GEMM and Parallelism
In many models, you might need to perform many smaller matrix multiplications. For example, computing the Q, K, V projections separately for each head, or multiple small attention matrices for many heads. Launching a separate GPU kernel for each small matrix multiply can be inefficient due to launch overhead and underutilization of GPU cores. Instead, batched GEMM APIs let you concatenate many matrices and multiply them in one go.
PyTorch provides torch.bmm (batch matrix-multiply) and in torch.matmul you can supply tensors with an extra batch dimension. The work will be done in a single kernel call (or a few calls) rather than looped in Python. This can significantly improve throughput when you have, say, hundreds of 64×64 matrices to multiply at once as opposed to one giant matrix.
Example of batched vs looped matrix multiplication:
import torch
import time
# Suppose we have to multiply 100 matrices of size 128x128 by 128x128
batch = 100
A = torch.randn(batch, 128, 128, device='cuda')
B = torch.randn(batch, 128, 128, device='cuda')
# Method 1: loop (inefficient on GPU due to 100 separate launches)
start = time.time()
outs = []
for i in range(batch):
outs.append(A[i] @ B[i])
torch.cuda.synchronize()
loop_time = time.time() - start
# Method 2: batched matmul in one call
start = time.time()
C = torch.bmm(A, B) # or A @ B, which also handles batch
torch.cuda.synchronize()
batch_time = time.time() - start
print(f"Loop time: {loop_time:.4f} s, Batched time: {batch_time:.4f} s")
print("Results equal:", torch.allclose(torch.stack(outs), C))
On a GPU, the batched version will be much faster than the looped version (the difference might be small on CPU, but on GPU avoiding kernel-launch overhead is key). This is because the batched operation leverages parallelism across the batch dimension internally. In our Transformer discussion, we noted that the attention computation for each head or each element in a batch can be thought of as independent matrix multiplies – these are perfect candidates for torch.bmm. In fact, PyTorch’s multi-head attention implementation uses batched GEMMs under the hood (often by combining the heads into one big batch or by merging the Q, K, V projections into one large weight matrix to apply a single GEMM).
Whenever you find yourself computing many independent GEMMs of the same size, try to use a batched operation instead of Python loops. This ensures that the GPU does the work with maximal concurrency. The code above demonstrates using torch.bmm, but note that torch.matmul is more general: it will automatically broadcast batch dimensions. For example, if A is of shape (batch, m, k) and B is (batch, k, n), then A @ B produces (batch, m, n) by performing batched matrix multiply. If B had shape (k, n) (no batch dim), PyTorch would broadcast it and still do a batched operation of batch multiplies. This high-level API is convenient and efficient.
4. Quantization (INT8) for Inference Efficiency
So far we focused on floating-point arithmetic, which is standard for training. But for inference, we can often go even lower in precision – down to 8-bit integers – using quantization. The idea is to represent weights (and possibly activations) with 8-bit integers and perform matrix multiply-accumulate in integer arithmetic. This can give another 2–4× speed boost and 4× memory reduction, at the cost of some accuracy loss (which can be mitigated with careful calibration or quantization-aware training).
Modern GPUs support INT8 GEMM on tensor cores as well. NVIDIA Turing (RTX 20-series) and newer (Ampere, Hopper) have INT8 tensor core instructions . For instance, an A100 can do mixed INT8/INT8 → INT32 accumulate at very high throughput (the A100’s INT8 tensor core rate is similar to its FP16 rate, in the hundreds of TFLOPS). To use these, you typically need to use special CUDA libraries or newer PyTorch features, since PyTorch’s high-level torch.matmul won’t automatically use int8 (as of PyTorch 2.x, you often need to go through FBGEMM on CPU or use external libraries on GPU).
Quantization in PyTorch: PyTorch offers a quantization toolkit (for CPU quantized inference) and there are libraries like NVIDIA’s Transformer Engine and bitsandbytes that provide GPU int8 support. For a simple example on CPU, PyTorch dynamic quantization can convert an nn.Linear module to use int8 weights:
import torch.nn as nn
import torch.quantization
# Define a simple linear layer
lin = nn.Linear(128, 64)
# Convert to dynamically quantized version (weights int8, activations float32 dynamically quantized)
lin_int8 = torch.quantization.quantize_dynamic(lin, {nn.Linear}, dtype=torch.qint8)
# Compare dtype of weight
print("Original weight dtype:", lin.weight.dtype)
print("Quantized weight dtype:", lin_int8._packed_params._packed_params[0].dtype)
Output might show the original weight dtype as torch.float32 and the quantized weight as torch.qint8 (quantized int8). The quantized module will internally convert inputs to int8 on the fly, perform GEMM in int8, then convert output back to float. This is great for saving memory and potentially using faster instructions on CPUs that support int8 (via AVX512 VNNI etc.). On GPU, one would use different approaches: for example, the bitsandbytes library provides Int8Params and Int8Linear that allow you to load a model’s weights in int8 and perform 8-bit matrix multiplies on GPUs, using CUDA’s efficient INT8 paths.
It’s worth noting that int8 matrix multiplication on GPU can be extremely fast, but only if done in the tensor cores (using specialized kernels). This requires memory alignment and using the right CUDA APIs (cuBLASLt or CUTLASS) . Projects like FasterTransformer and TensorRT do this to achieve real speedups. Pure Python libraries may sometimes see int8 slower than FP16 if not using tensor cores efficiently, due to overheads . However, when done right, int8 inference can significantly reduce latency and memory. For example, running INT8 GEMMs on Turing/Ampere tensor cores can more than double throughput . Also, lower precision like 4-bit is being explored (NVIDIA Hopper supports FP8 which is another direction). But int8 is a sweet spot for many practical LLM inference scenarios today.
In summary, quantization allows us to accelerate inference by using integer math. There is a trade-off with accuracy, so typically one would quantize a model after training and evaluate the accuracy drop. Techniques like post-training quantization (PTQ) with calibration or quantization-aware training (QAT) help minimize the loss in performance. Many open-source tools can quantize popular LLMs (e.g., you may have heard of LLM.int8() from Hugging Face, which uses bitsandbytes under the hood). If you deploy LLMs, it’s worth considering int8 quantization when you need faster/smaller models and can tolerate or mitigate the slight quality loss.
5. Kernel Fusion to Reduce Memory Bottlenecks
Apart from using lower precision, another powerful optimization is kernel fusion. This means combining multiple operations that would normally run separately into one GPU kernel launch. By fusing, we avoid intermediate memory writes/reads and can reuse data while it’s in registers. In GEMM-heavy workloads, often the GEMM itself is compute-bound, but surrounding operations (like adding a bias, activation functions, normalization) are memory-bound. Fusing them with the GEMM or with each other can speed up the overall operation and save memory.
Why fuse? Launching a GPU kernel has some overhead, and writing out intermediate results to GPU global memory only to read them back for the next operation wastes bandwidth and time. In an LLM, a common pattern is: Linear layer -> add bias -> apply activation (e.g. GELU) -> maybe another Linear. The add bias and GELU are cheap computation but each touches millions of elements in memory. If we do them in one fused kernel right as the GEMM produces outputs, we can do those additions/multiplications on the fly.
Example: fused bias add with GEMM. PyTorch’s torch.addmm is a simple example of fusion at the BLAS level. It computes C = beta*C + alpha*(A@B) in one call. If you have a bias term and you’re currently doing Y = A @ B; Y += bias, you can instead use torch.addmm(bias, A, B) to achieve the same result in one step. This doesn’t fuse arbitrary element-wise ops, but it fuses the addition of the bias vector to the result of matrix multiply. A quick demo:
M, N, K = 256, 256, 256
A = torch.randn(M, K, device='cuda')
B = torch.randn(K, N, device='cuda')
bias = torch.randn(M, N, device='cuda')
out1 = A @ B + bias # separate ops
out2 = torch.addmm(bias, A, B) # fused bias + matmul
print("Difference:", (out1 - out2).abs().max().item())
This will print Difference: 0.0 (the results are exactly the same). By using addmm, we let the library handle adding the bias inside the GEMM kernel (or at least without a full separate pass over out). In practice, this saves one kernel launch and one read/write of the entire output matrix.
On a larger scale, deep learning libraries and compilers go further: they can fuse normalization and activation operations. For instance, NVIDIA’s Transformer kernels fuse layer norm with residual bias adds, etc., and LinkedIn’s Liger (discussed below) fuses even more. According to one tech report, fusing bias and layernorm after attention yielded a 3.2% speed boost for a single Transformer layer, and fusing bias+activation after GEMM improved that sub-layer by 61% . Those numbers illustrate that while each fusion might save only a few microseconds, in a giant model with hundreds of layers those savings multiply, and you also save memory usage (less storage for intermediate activations).
PyTorch 2.x and TorchScript: PyTorch has introduced features like torch.compile (a JIT compiler) that can automatically fuse operations across op boundaries using graph-level optimization. For example, torch.compile can detect patterns like linear->activation and generate a fused kernel (often using Triton, which we discuss next). In earlier versions, using torch.jit.trace/Script could also fuse elementwise ops. When writing custom CUDA kernels, one can manually fuse as needed, but it’s complex and beyond beginner scope. Tools like Triton simplify this.
To conclude this section: fuse what you can. If you see a sequence of operations in your model that always occur one after another (and do not have data dependencies that prevent fusion), consider whether PyTorch provides a fused version or whether a compiler can merge them. Examples include nn.Linear (which under the hood uses addmm to fuse bias), F.relu or F.gelu which can sometimes be fused by compilers into surrounding ops, and custom fused kernels for things like attention (e.g. FlashAttention fuses several steps of attention into one highly-optimized kernel). By reducing kernel launches and memory traffic, you make better use of the GPU’s capabilities.
With these optimizations (mixed precision, batching, quantization, fusion) in mind, we already have a toolkit to speed up LLM computations significantly on existing hardware. Next, we’ll dive a bit deeper into writing custom GPU kernels, using Triton, and how the open-source Liger project leverages that to push LLM training efficiency further.
Writing Custom GEMM Kernels with Triton
Despite the high performance of libraries like cuBLAS, there are cases where you may want to write your own GPU kernel for GEMM or a GEMM fused with other operations. Perhaps you want to fuse an unusual combination of ops that framework libraries don’t cover, or experiment with novel data layouts or precisions. Traditionally, this meant writing CUDA C++ code – which is complex and hardware-specific. Triton is a framework that makes this much easier by allowing you to write GPU kernels in Python with a JIT compiler that optimizes and generates PTX (or AMD GCN) under the hood.
What is Triton? Triton (open-sourced by OpenAI) is a Python library and compiler for writing custom GPU kernels. You write a Triton kernel as a Python function, using an API that includes pointers and vectorized operations, and decorate it with @triton.jit. Triton handles splitting the work among threads, optimizing memory accesses, and can even auto-tune tile sizes. The advantage is you get performance approaching that of hand-written CUDA, but with far less effort and without needing to depend on vendor-specific libraries. You can also easily integrate Triton kernels into PyTorch as custom ops.
Why not just trust cuBLAS? For standard GEMM, vendor libraries are very fast. However, they are black boxes – you cannot easily change what they do. If you need a slightly different operation (say GEMM + nonlinearity in one pass), the library won’t provide that. As the Triton tutorial notes, these proprietary libraries cannot be customized to specific needs of modern DL workloads (they’re general-purpose) . By writing a custom kernel, you can, for example, fuse an activation or implement a lower precision computation that isn’t supported out-of-the-box. Triton gives you a way to do this in a productive manner.
Tiling and parallelism in Triton: When writing a GEMM kernel, a common strategy is block tiling. We split the output matrix C into small tiles (e.g. 128×128 or 64×64), and assign each tile to a GPU thread block (Triton calls it a program instance). That program will load the corresponding sub-blocks of A and B, multiply them, and accumulate the partial result. It may need to iterate if the tile is larger than what fits in registers/shared memory. Pseudocode for a tiled GEMM might look like:
# Pseudocode for blocked GEMM
for m in range(0, M, BLOCK_M):
for n in range(0, N, BLOCK_N):
acc = zeros((BLOCK_M, BLOCK_N))
for k in range(0, K, BLOCK_K):
a_block = A[m : m+BLOCK_M, k : k+BLOCK_K]
b_block = B[k : k+BLOCK_K, n : n+BLOCK_N]
acc += a_block @ b_block # small matrix multiply or elementwise multiply-add
C[m : m+BLOCK_M, n : n+BLOCK_N] = acc
Each iteration of the outer two loops (m,n) can be done in parallel by different GPU thread groups. The inner k-loop is computed by each group to accumulate the full dot product for that tile. Triton allows you to implement exactly this logic. You would use tl.arange to create index ranges for the tile, use tl.load to get the A and B sub-blocks, and perform the multiplication and sum, then tl.store to write back the result. Triton kernels often use 2D tile indices and a loop for the k-dimension. The Triton tutorial code for matmul uses these techniques and achieves performance on par with cuBLAS by carefully choosing tile sizes and using tl.zeros for the accumulator in higher precision (e.g. accumulate in FP32 for an FP16 GEMM to improve accuracy).
To give a flavor of Triton code, here’s a simplified example of a Triton kernel that computes C = A + B (elementwise addition) just to show the structure and launching mechanism:
import triton
import triton.language as tl
@triton.jit
def vector_add_kernel(A_ptr, B_ptr, C_ptr, N: tl.constexpr):
# Each kernel instance will handle a chunk of 128 elements
pid = tl.program_id(axis=0)
offset = pid * 128 # start index for this program
# Define an index range [offset, offset+128)
idx = offset + tl.arange(0, 128)
# Create a mask for bounds checking
mask = idx < N
# Load A and B (if out of bounds, use 0)
a = tl.load(A_ptr + idx, mask=mask, other=0.0)
b = tl.load(B_ptr + idx, mask=mask, other=0.0)
# Compute and store result
c = a + b
tl.store(C_ptr + idx, c, mask=mask)
# Prepare input data
N = 1000
A = torch.randn(N, dtype=torch.float32, device='cuda')
B = torch.randn(N, dtype=torch.float32, device='cuda')
C = torch.empty(N, dtype=torch.float32, device='cuda')
# Launch the Triton kernel with enough blocks to cover N elements
grid_size = (triton.cdiv(N, 128),) # number of program instances (each covers 128)
vector_add_kernel[grid_size](A, B, C, N)
torch.cuda.synchronize()
print("Max error:", (C - (A+B)).abs().max().item())
This Triton kernel divides the work such that each “program” handles 128 elements of the vectors. We calculate the index range for each program using tl.program_id(axis=0) (which gives the block index along the 0th dimension of the grid). We then load chunks of A and B, perform addition, and store into C. The mask ensures we don’t go out of bounds when N isn’t a multiple of 128. After launching, we synchronize and verify that C matches A+B. The result should be correct with max error 0, and indeed we have effectively written and run a custom GPU kernel without a single line of CUDA C++.
Now, extrapolate this idea to a matrix multiply: instead of adding two vectors, the kernel would multiply sub-matrices. Triton provides the tools, but you have to manage the indices for 2D tiles and the loop over K. The Triton docs walk through building a high-performance GEMM kernel step by step, including how to handle memory hierarchies and parallelization. Impressively, with ~50 lines of Triton code you can implement an FP16 GEMM that matches cuBLAS’s speed . The key is that Triton allows fine control: you can decide to, say, fuse an activation by simply adding that line in the kernel. This level of customization is powerful for research.
Automatic tuning: Triton also has an autotune decorator to try different tile sizes (BLOCK_M, BLOCK_N, BLOCK_K) and find the fastest on a given GPU. This is important because the optimal tile shape might differ on, say, A100 vs RTX 3090 due to memory and SM differences. The Triton compiler handles low-level optimizations (like using fast math instructions, vectorization, avoiding bank conflicts in shared memory) so you can focus on the high-level algorithm.
In practice, you might not write your own GEMM from scratch (since libraries do it well), but you might use Triton to fuse a custom elementwise function into a GEMM or implement a niche operation (like a custom attention pattern or an optimized normalization) that isn’t available in PyTorch. For example, one could implement a fused attention kernel that performs the softmax and the matrix multiply in one go (some projects have done similar things with Triton). The main takeaway is: Triton empowers developers to write GPU kernels tailored to their model’s needs, without needing to wait for NVIDIA or others to provide a specialized library. This opens up opportunities for optimization that were previously only in reach for CUDA experts.
Liger: Optimized Triton Kernels for LLM Training
While writing custom kernels is great, not everyone has the time to hand-tune every part of an LLM. This is where projects like Liger come in. Liger Kernel (by LinkedIn engineers) is an open-source collection of Triton-based kernels specifically optimized for LLM training and inference. It basically provides drop-in replacements for certain PyTorch/HuggingFace modules that achieve better performance and memory usage through kernel fusion and other tricks.
Key points about Liger:
-
It targets inefficiencies in the training pipeline by fusing operations and reducing memory traffic. The authors report on average a 20% increase in training throughput and a 60% reduction in GPU memory usage for popular LLMs compared to standard Hugging Face implementations . That’s a significant speed-up, which can translate to many hours saved in training or allow larger batch sizes or sequence lengths within the same hardware memory limits.
-
Liger is implemented with Triton, which means these kernels are vendor-agnostic (they work on NVIDIA and AMD GPUs via ROCm, as long as Triton supports them). It also means one can read the Python code for the kernels, which is more approachable than reading PTX or assembly.
-
It provides optimized versions of components like RMSNorm (Root Mean Square LayerNorm used in some LLMs), RoPE (rotary positional embeddings), SwiGLU activation (used in PaLM, etc.), cross-entropy loss, and even fused modules like an optimized FusedLinearCrossEntropy. By fusing these or implementing them with lower-level efficiency, Liger saves memory and time. For example, instead of doing a separate layernorm kernel launch for RMSNorm, they integrate it with other ops.
-
Liger is easy to use: it’s designed to be a drop-in. As the GitHub README highlights, you can apply it with “one line of code” by monkey-patching a HuggingFace transformer model . It has an AutoLigerKernelForCausalLM wrapper that will automatically swap out certain components of a loaded model with Liger’s optimized ones.
Let’s see a quick example of how one might use Liger in code:
!pip install liger-kernel # install Liger (assuming suitable environment)
from liger_kernel.transformers import AutoLigerKernelForCausalLM
# Load a pre-trained model with Liger's wrapper (if supported)
model_name = "huggingface/llama-2-7b" # example model name (ensure you have access or use a smaller one)
model = AutoLigerKernelForCausalLM.from_pretrained(model_name).cuda()
With these two lines, model is now a Llama-2 7B model that internally uses Liger’s Triton kernels for supported operations. If the model type is supported (Liger supports common architectures like OPT, GPT-NeoX, LLaMA family, etc.), it patches the forward functions of the model’s submodules to use Triton kernels. From the outside, you use model as usual, but it will consume less memory and run faster. Novice users can get improvements with minimal effort , and advanced users can even mix and match Liger modules manually (the library exposes individual Triton ops if needed).
What kind of improvements might you see? The Liger paper mentions an example of multi-GPU (data parallel) training where Liger gave 20% higher throughput without any convergence issues . The memory savings (60% in some cases) come from not materializing as many intermediate tensors. For instance, Liger might fuse an operation that normally would produce an intermediate of size (batch, seq, hidden) and thus they avoid storing that. This is very useful for long sequence training where memory is at a premium.
Integration with PyTorch 2.0: Liger can be combined with torch.compile for even more effect. There was a joint blog post on how torch.compile and Liger together can maximize performance (Torch Compile can orchestrate the whole graph, and use Liger’s kernels where applicable, plus do its own optimizations). The ecosystem is moving toward more of these “plugin” libraries where Triton kernels augment the default ones.
Finally, Liger is continuously evolving – the maintainers added support for new hardware (even AMD GPUs) and new model types. It’s a good example of how open-source contributions can push LLM efficiency forward. For a user building an LLM from scratch or fine-tuning one, it’s worth keeping an eye on such libraries; you might get significant speedups with a simple pip install. Liger shows that by focusing on the “last mile” of kernel optimization (where the general frameworks may not yet be optimal), you can unlock performance gains that hardware alone doesn’t give you .
Benchmarking Tips and Conclusion
We have covered a lot of ground: from the basics of matrix multiplication and GEMM, to why GPUs and lower precision make it fast, to how LLMs use GEMM everywhere, and finally to advanced optimization via Triton and Liger. When applying these concepts, it’s important to benchmark and ensure that changes are actually improving performance:
-
Use proper timing: When measuring GPU code, remember to synchronize before and after the section of interest. For example, use torch.cuda.synchronize() around time.time() calls, or better, use torch.cuda.Event for fine-grained measurements. Python’s %%timeit or timeit may not automatically account for asynchronous GPU calls. Always warm up the GPU with a few runs before timing (to let any caching or autotuning stabilize).
-
Profile memory: Reducing precision or fusing ops should reduce memory usage. You can check memory with torch.cuda.memory_allocated() or torch.cuda.max_memory_allocated() to compare before/after. Ensure that the 60% memory reduction from something like Liger is actually realized in your use case.
-
Compare throughput (tokens/s or samples/s): For end-to-end model performance, rather than just microseconds of a single GEMM, measure how many samples per second (for training) or tokens per second (for generation) you achieve. That is the true metric that matters for training speed or inference latency. Sometimes an optimization of one part might not show up if another part becomes the bottleneck, so profiling can help find the new bottleneck.
-
Quality/accuracy checks: When using lower precision or quantization, always validate that the model’s output (loss or accuracy) remains acceptable. PyTorch’s torch.allclose or checking max error (as we did in small demos) is useful for unit tests. For quantized models, evaluate on a validation set to ensure accuracy drop is within tolerances.
Below is an addendum showing why and how torch.compile can help further accelerate LLMs (in addition to the GEMM-focused optimizations). We’ll keep the language simple and practical.
8. Further Gains Using
torch.compile
With PyTorch 2.0, there’s a powerful new mechanism to accelerate your workloads: torch.compile. This feature (also known as TorchDynamo + TorchInductor under the hood) captures your model’s Python-level graph, transforms or fuses ops, and then generates optimized kernels. In the context of large language models (LLMs), we have already discussed how important GEMM operations are, and how libraries like cuBLAS, Triton, and Liger can make these GEMMs blazing fast.
But what about the rest of the model? LLMs also contain various elementwise and reduction operations (softmax, layer norms, etc.), plus other overhead from Python loops and submodules that can hamper performance if not properly optimized. That’s where torch.compile shines:
-
Graph Capture & Fusion
By dynamically tracing your model, torch.compile sees the entire sequence of PyTorch ops in forward/backward passes. It fuses smaller ops into fewer GPU kernels (similar to custom Triton kernels we saw in Liger), drastically cutting kernel launch overhead, memory reads/writes, and Python overhead.
-
Automatic Integration with Triton
Under the hood, the TorchInductor backend for torch.compile uses Triton. So if you already appreciate the benefits of custom Triton kernels, you’ll be glad to know that torch.compile can generate them automatically for you. You don’t need to write your own kernel code.
-
Even More GEMM Tuning
While vendor libraries like cuBLAS remain best in class for standard GEMM patterns, torch.compile can help fuse surrounding ops: for instance, the typical pattern y = activation(x @ W + b) could be generated as one fused kernel. Some of these sub-ops also contain small matrix multiplies or broadcast expansions that can be better batched or merged.
This means that even if your main GEMMs are already cuBLAS-level fast, torch.compile may reduce overhead on other parts of your network.
-
Less Python Overhead
An LLM typically involves many repeated blocks (multi-head attention layers, feed-forward layers, etc.). Each forward pass in eager mode calls the same submodules in a Python loop. torch.compile reduces this overhead because it partially compiles the control flow into a lower-level IR (FX Graph). So you run fewer Python-level instructions each pass.
-
Optional Additional Optimizations
-
Modes like reduce-overhead and max-autotune: You can specify mode="reduce-overhead" to reduce repeated kernel launches or overhead for each call.
-
Integration with CUDA graphs: If your batch shape is consistent, PyTorch can record a CUDA graph once, then repeatedly launch it with minimal overhead.
-
Below, we’ll outline how you can integrate torch.compile into your LLM workflow:
import torch
# Suppose you have a standard LLM model or a custom LLM architecture in PyTorch
model = MyLLM() # e.g. a Hugging Face transformer, or your own Transformer class
model.cuda()
model.train()
# Enable half precision or BF16 for big speed-ups on GPU
# (assuming Ampere or newer GPU with Tensor Cores)
model = model.half()
# Then compile the model
model = torch.compile(model, mode="max-autotune")
Now, every time you do model(input) for training or inference, PyTorch will capture the forward pass, optimize the sub-graphs (including or around your GEMMs), and produce a more efficient set of kernels. If your LLM has consistent input sizes (batch dimension, sequence length, etc.), or if you run many steps with the same shapes, the compilation overhead is amortized quickly over multiple iterations. You’ll likely see performance improvements on top of the gains from half precision or int8 GEMMs.
Example: Combining
torch.compile
with Liger/Triton
If you’re already using Liger or custom Triton kernels for large chunks of your model (like fused RMSNorm or fused MLP), torch.compile can still help by:
-
Fusing other ops not covered by Liger.
-
Potentially integrating “Liger kernels” seamlessly in the compiled graph. Some parts of your model that Liger didn’t optimize might still get fused or auto-tuned by TorchInductor.
Things to Note
-
First pass overhead: The very first time you run model(input) after torch.compile, PyTorch will trace and compile the model. This can take a noticeable amount of time (seconds to minutes, depending on model complexity). But once compiled, subsequent iterations are much faster. For inference servers or training loops that run many batches, this overhead is quickly paid back.
-
Graph breaks: If your code has data-dependent Python logic (like random branching or certain unsupported ops), torch.compile may break the graph into pieces. Each piece is compiled separately. The more breaks, the less speedup. You can use torch._dynamo.explain(model)(input) to see where breaks occur and refactor code if needed.
-
Compatibility: torch.compile typically works best with PyTorch 2.0+ on newer GPUs (Ampere or Hopper with half or BF16). For older GPUs, the speedups might be smaller. Also, check that your environment has the right PyTorch version with Triton. If you see no speed gain, profile for graph breaks or see if your model is purely memory-bound (very large embeddings or multi-gpu communication overhead).
-
Use Cases:
-
Training big LLMs from scratch or fine-tuning with standard frameworks: torch.compile can reduce overhead in the forward/backward pass, fusing activations, normalization, etc.
-
Inference or serving with consistent shapes: You can reduce latency or boost throughput by letting torch.compile fuse the model’s ops into fewer, bigger kernels.
-
Summary
Even if you’ve done all the usual GEMM optimizations (FP16, BF16, int8, batched ops, vendor libraries, or Triton custom kernels), you may still gain an additional 5–30% speed boost (numbers vary) by letting torch.compile fuse extra ops, remove Python overhead, and produce a more streamlined execution graph.
For large-scale training, every percentage counts. For latency-sensitive inference, a further ~10% reduction might let you serve more requests per second. torch.compile is thus an excellent final step in your optimization pipeline after you’ve tackled the major “low-hanging fruit” (like ensuring your GEMMs are as efficient as possible).
TL;DR: After using advanced GEMM optimizations, don’t skip torch.compile. It can automatically fuse the rest of your LLM’s code path and can yield additional speed-ups with only a one-line change—giving you a more efficient end-to-end execution.
In conclusion, GEMM is the engine that powers LLMs. Mastering how to accelerate GEMM means you can train and inference large models more efficiently. We started from the simple idea of multiplying two matrices and built up to the complex reality of an LLM, and saw that at each step, improvements like using half precision or fusing operations translate to real-world gains. Thanks to tools like PyTorch’s AMP and projects like Triton and Liger, these optimizations are becoming more accessible. A developer with basic Python knowledge can now leverage GPU tensor cores and custom kernels without writing low-level code.
By applying the techniques covered – from mixed precision to custom fused kernels – you can achieve significant speedups: it’s not uncommon to see 2–3× faster training times and models that would barely fit in memory now training with room to spare. As LLMs continue to grow, efficient GEMM will only become more important (hardware advances like new tensor core precisions and software advances like smarter compilers will keep this field evolving). With this tutorial, you should have a solid foundation to understand and optimize the matrix multiplications that lie at the heart of modern AI models, enabling you to train the next billion-parameter model just a bit faster than before!
Sources: Optimizations and concepts were drawn from NVIDIA documentation and hardware specs for tensor core capabilities , from research on transformer performance breakdowns , and from the Liger Kernel report which demonstrated the benefits of kernel fusion in LLM training . The Triton tutorial was referenced for insights into custom kernel implementation . These illustrate the significant gains possible by focusing on GEMM efficiency in deep learning workloads.