Tuesday, April 22, 2025

 

Accelerating LLMs with GEMM: Part 2


Introduction to GEMM and Matrix Multiplication Basics


General Matrix–Matrix Multiplication (GEMM) refers to the operation of multiplying two matrices together (often adding a scaled third matrix as well). At its core, matrix multiplication produces an output matrix where each element is the dot product of a row from the first matrix with a column of the second. For example, multiplying a 4×3 matrix by a 3×4 matrix yields a 4×4 result. Each cell in the result is computed by summing elementwise products from one entire row (green) and one entire column (red). Matrix multiplication combines a row of matrix A (green) with a column of matrix B (red) to produce each element of result matrix C (blue). Each cell of C is a dot product of one row of A and one column of B.


In mathematical notation, if A is an m \times k matrix and B is a k \times n matrix, their product C is an m \times n matrix. The entry C_{ij} is given by:


C_{ij} = \sum_{t=1}^{k} A_{i,t} \times B_{t,j} \,.


This operation has a time complexity of O(m \times n \times k) (triple nested loops in a naïve implementation). BLAS (Basic Linear Algebra Subprograms) defines standard routines for such operations. GEMM is a Level 3 BLAS operation, meaning it involves matrix–matrix operations (Level 1 is vector ops, Level 2 is matrix–vector). In BLAS notation, a general matrix multiply often allows optional scaling and addition: one can compute C = α · A × B + β · C in a single GEMM call. For instance, the BLAS routine SGEMM performs single-precision (32-bit float) GEMM. In deep learning frameworks like PyTorch, a basic matrix multiplication C = A @ B (or torch.matmul(A, B)) under the hood will invoke an optimized GEMM (with α=1, β=0 by default).


Why is GEMM so important? Modern neural networks—especially large language models (LLMs)—spend most of their compute time performing large matrix multiplications. Any fully connected (dense) layer in a neural network is essentially a matrix multiply between input activations and weight parameters. As we’ll see, transformer-based LLMs consist of many sublayers (attention projections, feed-forward layers, etc.) that are implemented as GEMM operations. Because matrix multiplication is such a fundamental and heavy computation, it has been heavily optimized in both software (BLAS libraries, PyTorch’s torch.mm/torch.matmul) and hardware. Using efficient GEMM is critical for speed, and optimizing GEMM yields direct improvements in model training and inference performance. Next, we’ll explore why GPUs excel at GEMM and how LLMs leverage GEMM in their architecture.


Why GPUs Excel at GEMM Operations


Graphics Processing Units (GPUs) are specialized for high parallelism, which makes them ideal for matrix math. Unlike a CPU (which has a limited number of cores optimized for sequential serial performance), a GPU contains thousands of smaller cores (threads) that can perform arithmetic simultaneously. In a matrix multiplication, there are many independent multiply-add operations (each output element’s dot-product involves summing many independent multiplications). A GPU can assign these operations across many threads to compute results in parallel. Modern GPUs also employ a Single-Instruction Multiple-Thread (SIMT) model: groups of threads (warps) execute the same instruction on different data, exactly matching the pattern of applying the same multiply-add across different elements of a matrix.


Another reason GPUs accelerate GEMM is their memory architecture. They have high-bandwidth memory and on-chip shared memory caches. Optimized GPU kernels “tile” the matrix multiply, loading sub-blocks of the matrices into fast shared memory and registers to reuse data effectively. By processing the multiplication in chunks (blocks of rows and columns), GPUs minimize slow global memory accesses and maximize computations on fast local data. The net effect is that a well-implemented GEMM on a GPU keeps thousands of arithmetic units busy most of the time (high compute utilization), whereas a naive approach would be bottlenecked by memory or insufficient parallel work.


Tensor Cores and specialized hardware: Newer GPU architectures (NVIDIA Volta, Turing, Ampere, Hopper, etc.) include Tensor Core units specifically designed to multiply small matrices very efficiently in lower precision. For example, NVIDIA’s A100 GPU can reach 312 TFLOPS (tera-floating-point ops per second) for FP16 matrix multiply using Tensor Cores, compared to about 19.5 TFLOPS for standard FP32 operations . Tensor Cores perform multiple fused multiply-adds in one clock cycle, greatly boosting throughput for GEMM in FP16/BF16 or INT8. Even TensorFloat-32 (TF32), a 19-bit format on Ampere GPUs, allows FP32 matrix multiplies to run on Tensor Cores with minimal code change. TF32 uses an 8-bit exponent (like FP32) and 10-bit mantissa (like FP16) . By automatically executing FP32 ops in TF32 mode, an A100 can deliver up to 156 TFLOPS for what the user perceives as FP32 math . In practice, this means an Ampere GPU can be an order of magnitude faster at matrix multiplies than earlier generations or than a CPU, simply by virtue of these specialized units and massive parallelism.


Let’s illustrate the performance difference with a simple example (if you have a GPU available). The following code multiplies two large matrices on CPU vs GPU:

import torch, time

# Matrix dimensions
N = 2048
A = torch.randn(N, N)  # CPU tensor (default float32)
B = torch.randn(N, N)

# CPU matrix multiplication timing
start = time.time()
C_cpu = A @ B  # or torch.mm(A, B)
end = time.time()
print(f"CPU GEMM time: {end - start:.3f} s")

# GPU matrix multiplication timing (if CUDA is available)
if torch.cuda.is_available():
    A_gpu = A.cuda()
    B_gpu = B.cuda()
    torch.cuda.synchronize()           # ensure data is on GPU
    start = time.time()
    C_gpu = A_gpu @ B_gpu
    torch.cuda.synchronize()           # wait for GPU to finish
    end = time.time()
    print(f"GPU GEMM time: {end - start:.3f} s")

On a modern GPU, you would observe a dramatic speedup (often 10× or more) for this 2048×2048 matrix multiply compared to a CPU. This is why frameworks like PyTorch offload tensor operations to GPUs – especially GEMM – to achieve the performance needed for training large models.


GEMM in Transformer-Based LLM Architectures


Large language models (LLMs) such as Transformer architectures (e.g. GPT-3, BERT, Llama 2, etc.) rely heavily on matrix multiplications in their building blocks. Understanding where GEMM appears in these models will clarify why optimizing GEMM is so crucial:

  • Embedding layers: The first part of an LLM often maps token indices to embedding vectors. If implemented naïvely, this could be seen as multiplying a one-hot vector by an embedding matrix. In practice this is just a lookup, but conceptually it’s a matrix–vector product. The embedding matrix (vocab_size × hidden_dim) is often huge, and retrieving a batch of embeddings can leverage optimized GEMM routines or batched gathers (though usually done by dedicated kernel rather than GEMM for efficiency).

  • Self-Attention (Q, K, V projections): In a Transformer, each input token’s representation is linearly projected to form Query, Key, and Value vectors. These are implemented as three separate linear layers (weight matrices) applied to the input. If the model’s hidden size is d_model, each of Q, K, V is typically of size d_model as well. Computing Q = X · W_Q (where X is the batch of input features, shape [batch_size×seq_length, d_model] if flattened) is a matrix multiplication. The same goes for K and V. Often, implementations fuse these three projections for efficiency (e.g. using a single larger weight matrix or a batched GEMM), but either way the operations are GEMMs. After projection, Q, K, V might be reshaped to separate multiple attention heads, but that’s just a view of the same data.

  • Attention score computation (Q × K^T): To produce attention weights, we take the dot product of each Query with each Key. This is effectively a matrix multiply: Q (of shape [batch_size * num_heads, seq_len, head_dim]) times K^T (shape [batch_size * num_heads, head_dim, seq_len]) yields an attention score matrix of shape [batch_size * num_heads, seq_len, seq_len]. This is a GEMM operation for each attention head (often done in parallel as a batched or block-sparse GEMM). The result is then normalized (softmax) – which is not GEMM but a pointwise operation.

  • Attention output (softmax(QK^T) × V): After softmax, the weighted sum of Values is computed as the matrix multiplication of the softmax output and V (shape [batch_size * num_heads, seq_len, head_dim]). This produces the attended output for each head. That is another GEMM. In total, the attention mechanism involves multiple GEMMs: one for QK^T and one for applying those weights to V (plus the initial linear projections).

  • Feed-Forward Network (FFN): Each Transformer block has a two-layer MLP (often called the feed-forward network). The first layer takes the d_model features and projects to a larger intermediate size (e.g. 4×d_model), usually with a nonlinear activation (like GELU) after it. The second layer projects back to d_model. Both layers are standard fully-connected layers – i.e. matrix multiplications plus biases. For example, in Llama 2 7B (with hidden size around 4096), the FFN might expand to ~11000 then back to 4096, which are huge matrix multiplications.

  • Output layer: LLMs often have a final linear projection before producing logits over the vocabulary (or they reuse the embedding matrix for a softmax). This final step is again a matrix multiply of the last layer’s activations with a large weight matrix of size d_model × vocab_size. For generation, this is performed at each time step for inference (or as part of training loss computation).


As we can see, virtually every significant computation in a Transformer block is a GEMM. The layer normalization and softmax are minor by comparison (those are memory-bound, elementwise ops). Profiling studies confirm this: in a BERT-like Transformer, matrix multiplications dominate runtime. One study found that GEMM operations accounted for about 61% of the total execution time for a moderate sequence length, and even when sequence length increased (making attention more expensive), attention + GEMMs together made up ~89% of the time . In short, the bulk of a Transformer’s compute is spent inside GEMMs. This is why GPU-accelerated libraries and frameworks put so much effort into optimizing these matrix multiplies – any improvement here has an outsized impact on overall training or inference speed.


To illustrate how these look in code, consider a highly simplified Transformer attention snippet using PyTorch:

import torch
import torch.nn.functional as F

# Dummy dimensions
batch, seq, d_model, num_heads = 2, 4, 8, 2
head_dim = d_model // num_heads

# Random input sequence
x = torch.randn(batch, seq, d_model)

# Weights for Q, K, V projections (d_model -> d_model)
Wq = torch.randn(d_model, d_model)
Wk = torch.randn(d_model, d_model)
Wv = torch.randn(d_model, d_model)

# Linear projections (batch*seq, d_model) @ (d_model, d_model) -> (batch*seq, d_model)
q = x.reshape(-1, d_model) @ Wq  
k = x.reshape(-1, d_model) @ Wk  
v = x.reshape(-1, d_model) @ Wv

# Reshape into [batch, seq, heads, head_dim]
q = q.view(batch, seq, num_heads, head_dim)
k = k.view(batch, seq, num_heads, head_dim)
v = v.view(batch, seq, num_heads, head_dim)

# Compute attention scores for each head: (batch, heads, seq, seq)
scores = torch.einsum('bhqd, bhkd -> bhqk', q, k)  # this is batch matmul Q * K^T
weights = F.softmax(scores / (head_dim ** 0.5), dim=-1)  # softmax normalization
# Apply weights to V: (batch, heads, seq, head_dim)
attn_output = torch.einsum('bhqk, bhkd -> bhqd', weights, v)  # another batch matmul
attn_output = attn_output.reshape(batch, seq, d_model)  # merge heads

print("Attention output shape:", attn_output.shape)

In this toy example, we used Einstein summation (einsum) to express the batched matrix multiplications. Under the hood, PyTorch will route these to optimized batch GEMM kernels (einsum with those subscripts is essentially doing Q @ K^T and softmax @ V). The key takeaway is that each of these steps is a matrix multiplication. In a real model like Llama-2, these matrices are huge (for instance, if d_model=4096 and seq=1024, the QK^T multiplication is 1024×4096 times 4096×1024, and the two FFN layers are even larger multiplications).


Because so much of the model’s time is spent in these operations, efficient GEMM is synonymous with efficient LLM execution. Next, we’ll look at how we can speed up these GEMMs in PyTorch using various techniques: lower precision, batching, quantization, and kernel fusion.


Optimizing Matrix Multiplication in PyTorch


PyTorch uses highly optimized libraries (like NVIDIA’s cuBLAS or Intel MKL) for matrix operations by default. However, as users we have some control and techniques to further improve speed and memory usage for GEMM-heavy workloads like LLMs. Below are several key optimization strategies:


1. Mixed Precision (FP16 and BF16) Training


One of the most effective ways to accelerate GEMM is to use lower precision floats. By using 16-bit floating point instead of 32-bit, we halve memory usage and often get increased math throughput on hardware with tensor cores.

  • FP16 (half-precision) has 1 sign bit, 5 exponent bits, 10 fraction bits. It has a smaller range and precision than FP32. Using FP16 for matrix multiplication can speed up compute significantly on GPUs with tensor core support. For example, on Volta/Ampere GPUs, matrix multiplies in FP16 are executed on tensor cores which can be 8–16× faster than FP32 on standard cores. The downside is reduced precision and potential for overflow/underflow, which is why in training we often use loss scaling to avoid FP16 underflow.

  • BF16 (bfloat16) is an alternative 16-bit format with 1 sign, 8 exponent, 7 fraction bits. It has the same exponent range as FP32 (due to 8 exponent bits) but fewer mantissa bits. BF16 is popular for training because it retains dynamic range (so gradients don’t overflow easily) while still using 16 bits. New GPUs (Ampere and later) support BF16 arithmetic in tensor cores similarly to FP16. In practice, BF16 often achieves similar speed to FP16 but with simpler training (no need for loss scaling).


In PyTorch, using mixed precision is straightforward. You can cast models and data to half or bfloat16, or use the high-level torch.autocast context for automatic mixed precision (AMP). Here’s a simple demonstration of using FP16 in a GEMM:

import torch
# Create random matrices in full precision
A32 = torch.randn(1024, 1024, dtype=torch.float32, device='cuda')
B32 = torch.randn(1024, 1024, dtype=torch.float32, device='cuda')
# Multiply in FP32
C32 = A32 @ B32

# Convert to half-precision and multiply
A16 = A32.half()
B16 = B32.half()
C16 = A16 @ B16

print(C16.dtype)  # torch.float16
# Compare results (convert C16 to float32 for fair comparison)
max_error = (C32 - C16.float()).abs().max().item()
print(f"Max difference between FP32 and FP16 result: {max_error}")

This code will output something like Max difference ...: 0.0012. The small numerical difference is due to precision loss in FP16. For inference, such a difference is usually negligible; for training, algorithms like loss scaling or switching to BF16 preserve model stability. When enabling mixed precision training (for example with torch.cuda.amp.autocast(dtype=torch.float16)), matrix multiplies and convolutions run in lower precision on tensor cores, while certain sensitive operations (like reductions, accumulations, and some normalization layers) can be kept in FP32 for accuracy. The result is often 2× or more throughput for the GEMM operations.


BF16 in PyTorch: To use BF16, you can cast tensors to torch.bfloat16 similarly. On CPUs, PyTorch supports BF16 arithmetic in recent versions (though it may be emulated or use AVX512 on capable hardware); on GPUs like A100/H100, BF16 is hardware accelerated. An example is A_bf16 = A32.to(torch.bfloat16), then C_bf16 = A_bf16 @ B_bf16. BF16 typically yields a tiny bit more precision than FP16 (because it has 3 more exponent bits instead of fraction bits). Many LLM training recipes now prefer BF16 as it avoids the fuss of loss scaling while still leveraging tensor core speed.


In summary, by using FP16 or BF16, we trade a tolerable amount of numerical precision for major gains in speed and memory usage. Most large model training today is done in mixed precision for this reason. (Note: When using mixed precision, ensure your GPU supports fast FP16/BF16; older GPUs without tensor cores won’t see as much benefit, and on those, 32-bit might even be faster due to better-supported vectorization.)


2. TensorFloat-32 (TF32) on NVIDIA Ampere+


NVIDIA introduced TF32 on the Ampere architecture as a way to accelerate FP32 GEMMs transparently. TF32 is a 19-bit format (10-bit mantissa, 8-bit exponent, 1 sign) used inside tensor cores . By default, in Ampere GPUs, torch.matmul on float32 tensors will actually use tensor cores with TF32 math (unless disabled) and accumulate in FP32. The idea is you get FP16-like speed with almost FP32-level range/accuracy. No code changes are needed in PyTorch – it’s enabled by setting a flag:

torch.backends.cuda.matmul.allow_tf32 = True  # this is usually True by default in training

If you want full deterministic FP32 (for example, for certain numeric comparisons), you can turn TF32 off (allow_tf32 = False), but you’ll lose the speed benefit. In practice, TF32 incurs a tiny accuracy drop (similar scale as 16-bit precision error) but is generally safe for training. It’s one reason Ampere GPUs (A100, RTX 30-series) were so much faster for training than previous gen without requiring explicit FP16 usage – many ops got the boost automatically .


To summarize: TF32 allows float32 code to execute on tensor cores by rounding the mantissa to 10 bits. It provides up to 2×~3× speedups for GEMMs compared to classical FP32 (the A100’s tensor core FP32 TF32 performance is 156 TFLOPS vs 19.5 TFLOPS in normal FP32 ). If you’re using PyTorch on Ampere or newer, make sure TF32 is allowed (it is by default for CUDA >= 11). This way you get a free performance boost for GEMM-heavy operations even if you aren’t manually using half precision.


3. Batched GEMM and Parallelism


In many models, you might need to perform many smaller matrix multiplications. For example, computing the Q, K, V projections separately for each head, or multiple small attention matrices for many heads. Launching a separate GPU kernel for each small matrix multiply can be inefficient due to launch overhead and underutilization of GPU cores. Instead, batched GEMM APIs let you concatenate many matrices and multiply them in one go.


PyTorch provides torch.bmm (batch matrix-multiply) and in torch.matmul you can supply tensors with an extra batch dimension. The work will be done in a single kernel call (or a few calls) rather than looped in Python. This can significantly improve throughput when you have, say, hundreds of 64×64 matrices to multiply at once as opposed to one giant matrix.


Example of batched vs looped matrix multiplication:

import torch
import time

# Suppose we have to multiply 100 matrices of size 128x128 by 128x128
batch = 100
A = torch.randn(batch, 128, 128, device='cuda')
B = torch.randn(batch, 128, 128, device='cuda')

# Method 1: loop (inefficient on GPU due to 100 separate launches)
start = time.time()
outs = []
for i in range(batch):
    outs.append(A[i] @ B[i])
torch.cuda.synchronize()
loop_time = time.time() - start

# Method 2: batched matmul in one call
start = time.time()
C = torch.bmm(A, B)  # or A @ B, which also handles batch
torch.cuda.synchronize()
batch_time = time.time() - start

print(f"Loop time: {loop_time:.4f} s, Batched time: {batch_time:.4f} s")
print("Results equal:", torch.allclose(torch.stack(outs), C))

On a GPU, the batched version will be much faster than the looped version (the difference might be small on CPU, but on GPU avoiding kernel-launch overhead is key). This is because the batched operation leverages parallelism across the batch dimension internally. In our Transformer discussion, we noted that the attention computation for each head or each element in a batch can be thought of as independent matrix multiplies – these are perfect candidates for torch.bmm. In fact, PyTorch’s multi-head attention implementation uses batched GEMMs under the hood (often by combining the heads into one big batch or by merging the Q, K, V projections into one large weight matrix to apply a single GEMM).


Whenever you find yourself computing many independent GEMMs of the same size, try to use a batched operation instead of Python loops. This ensures that the GPU does the work with maximal concurrency. The code above demonstrates using torch.bmm, but note that torch.matmul is more general: it will automatically broadcast batch dimensions. For example, if A is of shape (batch, m, k) and B is (batch, k, n), then A @ B produces (batch, m, n) by performing batched matrix multiply. If B had shape (k, n) (no batch dim), PyTorch would broadcast it and still do a batched operation of batch multiplies. This high-level API is convenient and efficient.


4. Quantization (INT8) for Inference Efficiency


So far we focused on floating-point arithmetic, which is standard for training. But for inference, we can often go even lower in precision – down to 8-bit integers – using quantization. The idea is to represent weights (and possibly activations) with 8-bit integers and perform matrix multiply-accumulate in integer arithmetic. This can give another 2–4× speed boost and 4× memory reduction, at the cost of some accuracy loss (which can be mitigated with careful calibration or quantization-aware training).


Modern GPUs support INT8 GEMM on tensor cores as well. NVIDIA Turing (RTX 20-series) and newer (Ampere, Hopper) have INT8 tensor core instructions . For instance, an A100 can do mixed INT8/INT8 → INT32 accumulate at very high throughput (the A100’s INT8 tensor core rate is similar to its FP16 rate, in the hundreds of TFLOPS). To use these, you typically need to use special CUDA libraries or newer PyTorch features, since PyTorch’s high-level torch.matmul won’t automatically use int8 (as of PyTorch 2.x, you often need to go through FBGEMM on CPU or use external libraries on GPU).


Quantization in PyTorch: PyTorch offers a quantization toolkit (for CPU quantized inference) and there are libraries like NVIDIA’s Transformer Engine and bitsandbytes that provide GPU int8 support. For a simple example on CPU, PyTorch dynamic quantization can convert an nn.Linear module to use int8 weights:

import torch.nn as nn
import torch.quantization

# Define a simple linear layer
lin = nn.Linear(128, 64)
# Convert to dynamically quantized version (weights int8, activations float32 dynamically quantized)
lin_int8 = torch.quantization.quantize_dynamic(lin, {nn.Linear}, dtype=torch.qint8)

# Compare dtype of weight
print("Original weight dtype:", lin.weight.dtype)
print("Quantized weight dtype:", lin_int8._packed_params._packed_params[0].dtype)

Output might show the original weight dtype as torch.float32 and the quantized weight as torch.qint8 (quantized int8). The quantized module will internally convert inputs to int8 on the fly, perform GEMM in int8, then convert output back to float. This is great for saving memory and potentially using faster instructions on CPUs that support int8 (via AVX512 VNNI etc.). On GPU, one would use different approaches: for example, the bitsandbytes library provides Int8Params and Int8Linear that allow you to load a model’s weights in int8 and perform 8-bit matrix multiplies on GPUs, using CUDA’s efficient INT8 paths.


It’s worth noting that int8 matrix multiplication on GPU can be extremely fast, but only if done in the tensor cores (using specialized kernels). This requires memory alignment and using the right CUDA APIs (cuBLASLt or CUTLASS) . Projects like FasterTransformer and TensorRT do this to achieve real speedups. Pure Python libraries may sometimes see int8 slower than FP16 if not using tensor cores efficiently, due to overheads . However, when done right, int8 inference can significantly reduce latency and memory. For example, running INT8 GEMMs on Turing/Ampere tensor cores can more than double throughput . Also, lower precision like 4-bit is being explored (NVIDIA Hopper supports FP8 which is another direction). But int8 is a sweet spot for many practical LLM inference scenarios today.


In summary, quantization allows us to accelerate inference by using integer math. There is a trade-off with accuracy, so typically one would quantize a model after training and evaluate the accuracy drop. Techniques like post-training quantization (PTQ) with calibration or quantization-aware training (QAT) help minimize the loss in performance. Many open-source tools can quantize popular LLMs (e.g., you may have heard of LLM.int8() from Hugging Face, which uses bitsandbytes under the hood). If you deploy LLMs, it’s worth considering int8 quantization when you need faster/smaller models and can tolerate or mitigate the slight quality loss.


5. Kernel Fusion to Reduce Memory Bottlenecks


Apart from using lower precision, another powerful optimization is kernel fusion. This means combining multiple operations that would normally run separately into one GPU kernel launch. By fusing, we avoid intermediate memory writes/reads and can reuse data while it’s in registers. In GEMM-heavy workloads, often the GEMM itself is compute-bound, but surrounding operations (like adding a bias, activation functions, normalization) are memory-bound. Fusing them with the GEMM or with each other can speed up the overall operation and save memory.


Why fuse? Launching a GPU kernel has some overhead, and writing out intermediate results to GPU global memory only to read them back for the next operation wastes bandwidth and time. In an LLM, a common pattern is: Linear layer -> add bias -> apply activation (e.g. GELU) -> maybe another Linear. The add bias and GELU are cheap computation but each touches millions of elements in memory. If we do them in one fused kernel right as the GEMM produces outputs, we can do those additions/multiplications on the fly.


Example: fused bias add with GEMM. PyTorch’s torch.addmm is a simple example of fusion at the BLAS level. It computes C = beta*C + alpha*(A@B) in one call. If you have a bias term and you’re currently doing Y = A @ B; Y += bias, you can instead use torch.addmm(bias, A, B) to achieve the same result in one step. This doesn’t fuse arbitrary element-wise ops, but it fuses the addition of the bias vector to the result of matrix multiply. A quick demo:

M, N, K = 256, 256, 256
A = torch.randn(M, K, device='cuda')
B = torch.randn(K, N, device='cuda')
bias = torch.randn(M, N, device='cuda')

out1 = A @ B + bias  # separate ops
out2 = torch.addmm(bias, A, B)  # fused bias + matmul
print("Difference:", (out1 - out2).abs().max().item())

This will print Difference: 0.0 (the results are exactly the same). By using addmm, we let the library handle adding the bias inside the GEMM kernel (or at least without a full separate pass over out). In practice, this saves one kernel launch and one read/write of the entire output matrix.


On a larger scale, deep learning libraries and compilers go further: they can fuse normalization and activation operations. For instance, NVIDIA’s Transformer kernels fuse layer norm with residual bias adds, etc., and LinkedIn’s Liger (discussed below) fuses even more. According to one tech report, fusing bias and layernorm after attention yielded a 3.2% speed boost for a single Transformer layer, and fusing bias+activation after GEMM improved that sub-layer by 61% . Those numbers illustrate that while each fusion might save only a few microseconds, in a giant model with hundreds of layers those savings multiply, and you also save memory usage (less storage for intermediate activations).


PyTorch 2.x and TorchScript: PyTorch has introduced features like torch.compile (a JIT compiler) that can automatically fuse operations across op boundaries using graph-level optimization. For example, torch.compile can detect patterns like linear->activation and generate a fused kernel (often using Triton, which we discuss next). In earlier versions, using torch.jit.trace/Script could also fuse elementwise ops. When writing custom CUDA kernels, one can manually fuse as needed, but it’s complex and beyond beginner scope. Tools like Triton simplify this.


To conclude this section: fuse what you can. If you see a sequence of operations in your model that always occur one after another (and do not have data dependencies that prevent fusion), consider whether PyTorch provides a fused version or whether a compiler can merge them. Examples include nn.Linear (which under the hood uses addmm to fuse bias), F.relu or F.gelu which can sometimes be fused by compilers into surrounding ops, and custom fused kernels for things like attention (e.g. FlashAttention fuses several steps of attention into one highly-optimized kernel). By reducing kernel launches and memory traffic, you make better use of the GPU’s capabilities.


With these optimizations (mixed precision, batching, quantization, fusion) in mind, we already have a toolkit to speed up LLM computations significantly on existing hardware. Next, we’ll dive a bit deeper into writing custom GPU kernels, using Triton, and how the open-source Liger project leverages that to push LLM training efficiency further.


Writing Custom GEMM Kernels with Triton


Despite the high performance of libraries like cuBLAS, there are cases where you may want to write your own GPU kernel for GEMM or a GEMM fused with other operations. Perhaps you want to fuse an unusual combination of ops that framework libraries don’t cover, or experiment with novel data layouts or precisions. Traditionally, this meant writing CUDA C++ code – which is complex and hardware-specific. Triton is a framework that makes this much easier by allowing you to write GPU kernels in Python with a JIT compiler that optimizes and generates PTX (or AMD GCN) under the hood.


What is Triton? Triton (open-sourced by OpenAI) is a Python library and compiler for writing custom GPU kernels. You write a Triton kernel as a Python function, using an API that includes pointers and vectorized operations, and decorate it with @triton.jit. Triton handles splitting the work among threads, optimizing memory accesses, and can even auto-tune tile sizes. The advantage is you get performance approaching that of hand-written CUDA, but with far less effort and without needing to depend on vendor-specific libraries. You can also easily integrate Triton kernels into PyTorch as custom ops.


Why not just trust cuBLAS? For standard GEMM, vendor libraries are very fast. However, they are black boxes – you cannot easily change what they do. If you need a slightly different operation (say GEMM + nonlinearity in one pass), the library won’t provide that. As the Triton tutorial notes, these proprietary libraries cannot be customized to specific needs of modern DL workloads (they’re general-purpose) . By writing a custom kernel, you can, for example, fuse an activation or implement a lower precision computation that isn’t supported out-of-the-box. Triton gives you a way to do this in a productive manner.


Tiling and parallelism in Triton: When writing a GEMM kernel, a common strategy is block tiling. We split the output matrix C into small tiles (e.g. 128×128 or 64×64), and assign each tile to a GPU thread block (Triton calls it a program instance). That program will load the corresponding sub-blocks of A and B, multiply them, and accumulate the partial result. It may need to iterate if the tile is larger than what fits in registers/shared memory. Pseudocode for a tiled GEMM might look like:

# Pseudocode for blocked GEMM
for m in range(0, M, BLOCK_M):
    for n in range(0, N, BLOCK_N):
        acc = zeros((BLOCK_M, BLOCK_N))
        for k in range(0, K, BLOCK_K):
            a_block = A[m : m+BLOCK_M, k : k+BLOCK_K]
            b_block = B[k : k+BLOCK_K, n : n+BLOCK_N]
            acc += a_block @ b_block  # small matrix multiply or elementwise multiply-add
        C[m : m+BLOCK_M, n : n+BLOCK_N] = acc

Each iteration of the outer two loops (m,n) can be done in parallel by different GPU thread groups. The inner k-loop is computed by each group to accumulate the full dot product for that tile. Triton allows you to implement exactly this logic. You would use tl.arange to create index ranges for the tile, use tl.load to get the A and B sub-blocks, and perform the multiplication and sum, then tl.store to write back the result. Triton kernels often use 2D tile indices and a loop for the k-dimension. The Triton tutorial code for matmul uses these techniques and achieves performance on par with cuBLAS by carefully choosing tile sizes and using tl.zeros for the accumulator in higher precision (e.g. accumulate in FP32 for an FP16 GEMM to improve accuracy).


To give a flavor of Triton code, here’s a simplified example of a Triton kernel that computes C = A + B (elementwise addition) just to show the structure and launching mechanism:

import triton
import triton.language as tl

@triton.jit
def vector_add_kernel(A_ptr, B_ptr, C_ptr, N: tl.constexpr):
    # Each kernel instance will handle a chunk of 128 elements
    pid = tl.program_id(axis=0)
    offset = pid * 128  # start index for this program
    # Define an index range [offset, offset+128)
    idx = offset + tl.arange(0, 128)
    # Create a mask for bounds checking
    mask = idx < N
    # Load A and B (if out of bounds, use 0)
    a = tl.load(A_ptr + idx, mask=mask, other=0.0)
    b = tl.load(B_ptr + idx, mask=mask, other=0.0)
    # Compute and store result
    c = a + b
    tl.store(C_ptr + idx, c, mask=mask)

# Prepare input data
N = 1000
A = torch.randn(N, dtype=torch.float32, device='cuda')
B = torch.randn(N, dtype=torch.float32, device='cuda')
C = torch.empty(N, dtype=torch.float32, device='cuda')

# Launch the Triton kernel with enough blocks to cover N elements
grid_size = (triton.cdiv(N, 128),)  # number of program instances (each covers 128)
vector_add_kernel[grid_size](A, B, C, N)
torch.cuda.synchronize()
print("Max error:", (C - (A+B)).abs().max().item())

This Triton kernel divides the work such that each “program” handles 128 elements of the vectors. We calculate the index range for each program using tl.program_id(axis=0) (which gives the block index along the 0th dimension of the grid). We then load chunks of A and B, perform addition, and store into C. The mask ensures we don’t go out of bounds when N isn’t a multiple of 128. After launching, we synchronize and verify that C matches A+B. The result should be correct with max error 0, and indeed we have effectively written and run a custom GPU kernel without a single line of CUDA C++.


Now, extrapolate this idea to a matrix multiply: instead of adding two vectors, the kernel would multiply sub-matrices. Triton provides the tools, but you have to manage the indices for 2D tiles and the loop over K. The Triton docs walk through building a high-performance GEMM kernel step by step, including how to handle memory hierarchies and parallelization. Impressively, with ~50 lines of Triton code you can implement an FP16 GEMM that matches cuBLAS’s speed . The key is that Triton allows fine control: you can decide to, say, fuse an activation by simply adding that line in the kernel. This level of customization is powerful for research.


Automatic tuning: Triton also has an autotune decorator to try different tile sizes (BLOCK_M, BLOCK_N, BLOCK_K) and find the fastest on a given GPU. This is important because the optimal tile shape might differ on, say, A100 vs RTX 3090 due to memory and SM differences. The Triton compiler handles low-level optimizations (like using fast math instructions, vectorization, avoiding bank conflicts in shared memory) so you can focus on the high-level algorithm.


In practice, you might not write your own GEMM from scratch (since libraries do it well), but you might use Triton to fuse a custom elementwise function into a GEMM or implement a niche operation (like a custom attention pattern or an optimized normalization) that isn’t available in PyTorch. For example, one could implement a fused attention kernel that performs the softmax and the matrix multiply in one go (some projects have done similar things with Triton). The main takeaway is: Triton empowers developers to write GPU kernels tailored to their model’s needs, without needing to wait for NVIDIA or others to provide a specialized library. This opens up opportunities for optimization that were previously only in reach for CUDA experts.


Liger: Optimized Triton Kernels for LLM Training


While writing custom kernels is great, not everyone has the time to hand-tune every part of an LLM. This is where projects like Liger come in. Liger Kernel (by LinkedIn engineers) is an open-source collection of Triton-based kernels specifically optimized for LLM training and inference. It basically provides drop-in replacements for certain PyTorch/HuggingFace modules that achieve better performance and memory usage through kernel fusion and other tricks.


Key points about Liger:

  • It targets inefficiencies in the training pipeline by fusing operations and reducing memory traffic. The authors report on average a 20% increase in training throughput and a 60% reduction in GPU memory usage for popular LLMs compared to standard Hugging Face implementations . That’s a significant speed-up, which can translate to many hours saved in training or allow larger batch sizes or sequence lengths within the same hardware memory limits.

  • Liger is implemented with Triton, which means these kernels are vendor-agnostic (they work on NVIDIA and AMD GPUs via ROCm, as long as Triton supports them). It also means one can read the Python code for the kernels, which is more approachable than reading PTX or assembly.

  • It provides optimized versions of components like RMSNorm (Root Mean Square LayerNorm used in some LLMs), RoPE (rotary positional embeddings), SwiGLU activation (used in PaLM, etc.), cross-entropy loss, and even fused modules like an optimized FusedLinearCrossEntropy. By fusing these or implementing them with lower-level efficiency, Liger saves memory and time. For example, instead of doing a separate layernorm kernel launch for RMSNorm, they integrate it with other ops.

  • Liger is easy to use: it’s designed to be a drop-in. As the GitHub README highlights, you can apply it with “one line of code” by monkey-patching a HuggingFace transformer model . It has an AutoLigerKernelForCausalLM wrapper that will automatically swap out certain components of a loaded model with Liger’s optimized ones.


Let’s see a quick example of how one might use Liger in code:

!pip install liger-kernel  # install Liger (assuming suitable environment)
from liger_kernel.transformers import AutoLigerKernelForCausalLM

# Load a pre-trained model with Liger's wrapper (if supported)
model_name = "huggingface/llama-2-7b"  # example model name (ensure you have access or use a smaller one)
model = AutoLigerKernelForCausalLM.from_pretrained(model_name).cuda()

With these two lines, model is now a Llama-2 7B model that internally uses Liger’s Triton kernels for supported operations. If the model type is supported (Liger supports common architectures like OPT, GPT-NeoX, LLaMA family, etc.), it patches the forward functions of the model’s submodules to use Triton kernels. From the outside, you use model as usual, but it will consume less memory and run faster. Novice users can get improvements with minimal effort , and advanced users can even mix and match Liger modules manually (the library exposes individual Triton ops if needed).


What kind of improvements might you see? The Liger paper mentions an example of multi-GPU (data parallel) training where Liger gave 20% higher throughput without any convergence issues . The memory savings (60% in some cases) come from not materializing as many intermediate tensors. For instance, Liger might fuse an operation that normally would produce an intermediate of size (batch, seq, hidden) and thus they avoid storing that. This is very useful for long sequence training where memory is at a premium.


Integration with PyTorch 2.0: Liger can be combined with torch.compile for even more effect. There was a joint blog post on how torch.compile and Liger together can maximize performance (Torch Compile can orchestrate the whole graph, and use Liger’s kernels where applicable, plus do its own optimizations). The ecosystem is moving toward more of these “plugin” libraries where Triton kernels augment the default ones.


Finally, Liger is continuously evolving – the maintainers added support for new hardware (even AMD GPUs) and new model types. It’s a good example of how open-source contributions can push LLM efficiency forward. For a user building an LLM from scratch or fine-tuning one, it’s worth keeping an eye on such libraries; you might get significant speedups with a simple pip install. Liger shows that by focusing on the “last mile” of kernel optimization (where the general frameworks may not yet be optimal), you can unlock performance gains that hardware alone doesn’t give you .


Benchmarking Tips and Conclusion


We have covered a lot of ground: from the basics of matrix multiplication and GEMM, to why GPUs and lower precision make it fast, to how LLMs use GEMM everywhere, and finally to advanced optimization via Triton and Liger. When applying these concepts, it’s important to benchmark and ensure that changes are actually improving performance:

  • Use proper timing: When measuring GPU code, remember to synchronize before and after the section of interest. For example, use torch.cuda.synchronize() around time.time() calls, or better, use torch.cuda.Event for fine-grained measurements. Python’s %%timeit or timeit may not automatically account for asynchronous GPU calls. Always warm up the GPU with a few runs before timing (to let any caching or autotuning stabilize).

  • Profile memory: Reducing precision or fusing ops should reduce memory usage. You can check memory with torch.cuda.memory_allocated() or torch.cuda.max_memory_allocated() to compare before/after. Ensure that the 60% memory reduction from something like Liger is actually realized in your use case.

  • Compare throughput (tokens/s or samples/s): For end-to-end model performance, rather than just microseconds of a single GEMM, measure how many samples per second (for training) or tokens per second (for generation) you achieve. That is the true metric that matters for training speed or inference latency. Sometimes an optimization of one part might not show up if another part becomes the bottleneck, so profiling can help find the new bottleneck.

  • Quality/accuracy checks: When using lower precision or quantization, always validate that the model’s output (loss or accuracy) remains acceptable. PyTorch’s torch.allclose or checking max error (as we did in small demos) is useful for unit tests. For quantized models, evaluate on a validation set to ensure accuracy drop is within tolerances.


 

Below is an addendum showing why and how torch.compile can help further accelerate LLMs (in addition to the GEMM-focused optimizations). We’ll keep the language simple and practical.


8. Further Gains Using 

torch.compile


With PyTorch 2.0, there’s a powerful new mechanism to accelerate your workloads: torch.compile. This feature (also known as TorchDynamo + TorchInductor under the hood) captures your model’s Python-level graph, transforms or fuses ops, and then generates optimized kernels. In the context of large language models (LLMs), we have already discussed how important GEMM operations are, and how libraries like cuBLAS, Triton, and Liger can make these GEMMs blazing fast.


But what about the rest of the model? LLMs also contain various elementwise and reduction operations (softmax, layer norms, etc.), plus other overhead from Python loops and submodules that can hamper performance if not properly optimized. That’s where torch.compile shines:

  1. Graph Capture & Fusion

    By dynamically tracing your model, torch.compile sees the entire sequence of PyTorch ops in forward/backward passes. It fuses smaller ops into fewer GPU kernels (similar to custom Triton kernels we saw in Liger), drastically cutting kernel launch overhead, memory reads/writes, and Python overhead.

  2. Automatic Integration with Triton

    Under the hood, the TorchInductor backend for torch.compile uses Triton. So if you already appreciate the benefits of custom Triton kernels, you’ll be glad to know that torch.compile can generate them automatically for you. You don’t need to write your own kernel code.

  3. Even More GEMM Tuning

    While vendor libraries like cuBLAS remain best in class for standard GEMM patterns, torch.compile can help fuse surrounding ops: for instance, the typical pattern y = activation(x @ W + b) could be generated as one fused kernel. Some of these sub-ops also contain small matrix multiplies or broadcast expansions that can be better batched or merged.

    This means that even if your main GEMMs are already cuBLAS-level fast, torch.compile may reduce overhead on other parts of your network.

  4. Less Python Overhead

    An LLM typically involves many repeated blocks (multi-head attention layers, feed-forward layers, etc.). Each forward pass in eager mode calls the same submodules in a Python loop. torch.compile reduces this overhead because it partially compiles the control flow into a lower-level IR (FX Graph). So you run fewer Python-level instructions each pass.

  5. Optional Additional Optimizations

    • Modes like reduce-overhead and max-autotune: You can specify mode="reduce-overhead" to reduce repeated kernel launches or overhead for each call.

    • Integration with CUDA graphs: If your batch shape is consistent, PyTorch can record a CUDA graph once, then repeatedly launch it with minimal overhead.


Below, we’ll outline how you can integrate torch.compile into your LLM workflow:

import torch

# Suppose you have a standard LLM model or a custom LLM architecture in PyTorch
model = MyLLM()  # e.g. a Hugging Face transformer, or your own Transformer class
model.cuda()
model.train()

# Enable half precision or BF16 for big speed-ups on GPU
# (assuming Ampere or newer GPU with Tensor Cores)
model = model.half()

# Then compile the model
model = torch.compile(model, mode="max-autotune")

Now, every time you do model(input) for training or inference, PyTorch will capture the forward pass, optimize the sub-graphs (including or around your GEMMs), and produce a more efficient set of kernels. If your LLM has consistent input sizes (batch dimension, sequence length, etc.), or if you run many steps with the same shapes, the compilation overhead is amortized quickly over multiple iterations. You’ll likely see performance improvements on top of the gains from half precision or int8 GEMMs.


Example: Combining 

torch.compile

 with Liger/Triton


If you’re already using Liger or custom Triton kernels for large chunks of your model (like fused RMSNorm or fused MLP), torch.compile can still help by:

  • Fusing other ops not covered by Liger.

  • Potentially integrating “Liger kernels” seamlessly in the compiled graph. Some parts of your model that Liger didn’t optimize might still get fused or auto-tuned by TorchInductor.


Things to Note

  1. First pass overhead: The very first time you run model(input) after torch.compile, PyTorch will trace and compile the model. This can take a noticeable amount of time (seconds to minutes, depending on model complexity). But once compiled, subsequent iterations are much faster. For inference servers or training loops that run many batches, this overhead is quickly paid back.

  2. Graph breaks: If your code has data-dependent Python logic (like random branching or certain unsupported ops), torch.compile may break the graph into pieces. Each piece is compiled separately. The more breaks, the less speedup. You can use torch._dynamo.explain(model)(input) to see where breaks occur and refactor code if needed.

  3. Compatibility: torch.compile typically works best with PyTorch 2.0+ on newer GPUs (Ampere or Hopper with half or BF16). For older GPUs, the speedups might be smaller. Also, check that your environment has the right PyTorch version with Triton. If you see no speed gain, profile for graph breaks or see if your model is purely memory-bound (very large embeddings or multi-gpu communication overhead).

  4. Use Cases:

    • Training big LLMs from scratch or fine-tuning with standard frameworks: torch.compile can reduce overhead in the forward/backward pass, fusing activations, normalization, etc.

    • Inference or serving with consistent shapes: You can reduce latency or boost throughput by letting torch.compile fuse the model’s ops into fewer, bigger kernels.


Summary


Even if you’ve done all the usual GEMM optimizations (FP16, BF16, int8, batched ops, vendor libraries, or Triton custom kernels), you may still gain an additional 5–30% speed boost (numbers vary) by letting torch.compile fuse extra ops, remove Python overhead, and produce a more streamlined execution graph.


For large-scale training, every percentage counts. For latency-sensitive inference, a further ~10% reduction might let you serve more requests per second. torch.compile is thus an excellent final step in your optimization pipeline after you’ve tackled the major “low-hanging fruit” (like ensuring your GEMMs are as efficient as possible).


TL;DR: After using advanced GEMM optimizations, don’t skip torch.compile. It can automatically fuse the rest of your LLM’s code path and can yield additional speed-ups with only a one-line change—giving you a more efficient end-to-end execution.

 

 

 

In conclusion, GEMM is the engine that powers LLMs. Mastering how to accelerate GEMM means you can train and inference large models more efficiently. We started from the simple idea of multiplying two matrices and built up to the complex reality of an LLM, and saw that at each step, improvements like using half precision or fusing operations translate to real-world gains. Thanks to tools like PyTorch’s AMP and projects like Triton and Liger, these optimizations are becoming more accessible. A developer with basic Python knowledge can now leverage GPU tensor cores and custom kernels without writing low-level code.


By applying the techniques covered – from mixed precision to custom fused kernels – you can achieve significant speedups: it’s not uncommon to see 2–3× faster training times and models that would barely fit in memory now training with room to spare. As LLMs continue to grow, efficient GEMM will only become more important (hardware advances like new tensor core precisions and software advances like smarter compilers will keep this field evolving). With this tutorial, you should have a solid foundation to understand and optimize the matrix multiplications that lie at the heart of modern AI models, enabling you to train the next billion-parameter model just a bit faster than before!


Sources: Optimizations and concepts were drawn from NVIDIA documentation and hardware specs for tensor core capabilities , from research on transformer performance breakdowns , and from the Liger Kernel report which demonstrated the benefits of kernel fusion in LLM training . The Triton tutorial was referenced for insights into custom kernel implementation . These illustrate the significant gains possible by focusing on GEMM efficiency in deep learning workloads.

 

Accelerating Large Language Models with GEMM Optimization in PyTorch

1. Introduction: The Need for Speed in LLMs

Large Language Models (LLMs) have revolutionized natural language processing, powering applications from sophisticated chatbots to advanced content generation. However, their immense size and computational complexity pose significant challenges, particularly regarding inference and training speed. Achieving acceptable performance often requires specialized hardware like GPUs and sophisticated optimization techniques.

At the heart of LLM computation lies a fundamental operation: General Matrix-Matrix Multiplication (GEMM). Understanding and optimizing GEMM is paramount to accelerating LLMs. This report serves as a comprehensive tutorial for developers seeking to enhance LLM performance within the PyTorch framework. Starting with the basics of GEMM and its relevance to LLMs, we will progressively explore optimization strategies, from simple data type adjustments to advanced techniques involving specialized libraries like bitsandbytes, fused operations via scaled_dot_product_attention, and custom kernel development with Triton, including pre-built solutions like Liger. The goal is to equip developers, regardless of their initial expertise level, with the knowledge and practical code examples needed to significantly speed up their LLM workloads.

2. What is GEMM? The Computational Core of LLMs

GEMM stands for General Matrix-Matrix Multiplication. It's a fundamental operation in linear algebra and forms the computational backbone of deep learning, especially for models like LLMs.

The standard GEMM operation is defined by the following equation, often expressed using Basic Linear Algebra Subprograms (BLAS) notation:

Where:

  • is an matrix.
  • is a matrix.
  • is an matrix (the output matrix, which can also be an input for accumulation).
  • and are scalar values.

In the context of deep learning and LLMs, is typically 1, and is often 0 (for a simple multiplication ) or 1 (when adding a bias or residual connection, though this is often handled as a separate operation or fused). The core computation involves multiplying matrix by matrix to produce matrix .

Nearly every computationally intensive part of a standard Transformer-based LLM relies heavily on GEMM operations:

  1. Linear Layers (Fully Connected Layers): These layers, defined in PyTorch as torch.nn.Linear, perform the operation , where is the input, is the weight matrix, and is the bias. The core is a GEMM.
  2. Attention Mechanism Projections: In multi-head self-attention, the input activations () are projected into Query (), Key (), and Value () matrices using separate weight matrices (). Each projection (, , ) is a GEMM [User Material 1]. The output projection, which combines the attention outputs, is also a GEMM [User Material 1].
  3. Feed-Forward Networks (MLP Blocks): Transformer blocks typically contain a position-wise feed-forward network, usually consisting of two linear layers with an activation function in between (e.g., MLP-up, MLP-gate, MLP-down projections). Each of these linear transformations is a GEMM [User Material 1].

Because these operations are repeated across numerous layers and attention heads within an LLM, the model's overall performance becomes critically dependent on the efficiency of the underlying GEMM computations.

3. Why GPUs Excel at GEMM

Graphics Processing Units (GPUs) are the hardware of choice for training and deploying LLMs primarily because their architecture is exceptionally well-suited for the massive parallelism inherent in GEMM operations. Several factors contribute to this efficiency [User Material 2]:

  1. High Arithmetic Intensity: Arithmetic intensity is the ratio of floating-point operations (FLOPs) to bytes of data moved from memory. GEMM exhibits very high arithmetic intensity. Consider the calculation . To compute a single element , we perform multiplications and additions (~ FLOPs). However, elements (from row of ) and (from column of ) can be reused across computations for other elements in the same output row or column. By loading blocks (tiles) of and into fast on-chip memory (registers and shared memory), GPUs can perform many computations ( multiply-accumulates per element pair) before needing to fetch new data from slower main memory (HBM). This reuse minimizes the memory bandwidth bottleneck.
  2. Regular Memory Access Patterns: GEMM involves accessing matrix elements in predictable patterns (rows of , columns of ). GPUs can efficiently prefetch and stream these blocks of data ("block-wise" access) into their processing units, ensuring that the compute cores are constantly supplied with data. This contrasts with irregular memory access patterns found in other computations (like sparse operations or graph traversals) which can stall the GPU.
  3. Amenability to Tiling and Parallelism: The GEMM operation can be naturally decomposed into smaller, independent matrix multiplications on sub-blocks (tiles) of the original matrices. This "tiling" strategy allows the overall computation to be split across the thousands of lightweight threads available on a modern GPU. Each thread block can compute a tile of the output matrix independently, leading to massive parallelism.
  4. Vectorization (SIMT Architecture): GPUs employ a Single Instruction, Multiple Threads (SIMT) execution model. A single instruction can operate on multiple data elements simultaneously (vector processing). Modern GPUs have specialized units (like NVIDIA's Tensor Cores) that can perform small matrix multiplications (e.g., ) on lower-precision data types (FP16, BF16, INT8) in a single clock cycle [User Material 2]. This vector-friendly nature allows one instruction to perform the work of many, significantly boosting throughput.

Recognizing the importance of GEMM, hardware vendors provide highly optimized libraries and hardware units specifically for this task:

  • NVIDIA: Offers cuBLAS/cuBLASLt (highly tuned GEMM library), Tensor Cores (hardware units for accelerated mixed-precision GEMM), and open-source templates like CUTLASS and Triton for building custom GEMM kernels.2
  • AMD: Provides rocBLAS (similar library for AMD GPUs) [User Material 2].
  • Intel: Offers oneDNN (formerly MKL-DNN) for optimized primitives on Intel hardware [User Material 2].

These specialized tools ensure that GEMM operations run close to the theoretical peak performance of the hardware.

4. LLMs as "GEMM Farms": The Scale of the Challenge

The sheer number of GEMM operations required during a single forward pass of an LLM highlights why optimizing them is crucial. Let's consider a moderately sized model like Llama 2 7B as an example [User Material 3]. A typical Transformer block in such a model involves several GEMM-based sub-modules:

  • Self-Attention Projections (Q, K, V): 3 GEMMs to project the input (shape ) into (shape , where is the head dimension and often incorporates batch size, number of heads, and sequence length). Matrix shapes: .
  • Self-Attention Output Projection: 1 GEMM to project the combined attention outputs back to the model dimension. Matrix shape: .
  • MLP Feed-Forward Network: Typically involves 2 or 3 GEMMs. A common structure uses an up-projection and a gate projection (often to dimensions), followed by an element-wise operation, and then a down-projection back to dimensions.
    • MLP-up:
    • MLP-gate: (Parallel to MLP-up)
    • MLP-down:

This totals approximately 7 GEMM operations per Transformer block. For a model like Llama 2 7B with 32 blocks, a single forward pass requires roughly major GEMM calls.

This massive number means that even small percentage improvements in the performance of individual GEMM operations accumulate significantly, leading to noticeable reductions in overall wall-clock time for both training and inference [User Material 3]. LLMs effectively function as "GEMM farms," constantly executing these matrix multiplications. Therefore, ensuring these operations map efficiently onto the underlying hardware is the primary path to LLM acceleration.

5. Level 1 Optimizations: The Low-Hanging Fruit in PyTorch

Fortunately, significant performance gains can often be achieved with relatively simple changes within PyTorch, leveraging the highly optimized libraries and hardware features discussed earlier. These techniques form the foundation of LLM optimization.

Technique 1: Leverage Lower Precision (FP16, BF16, TF32)

Modern GPUs achieve much higher throughput when using lower-precision data types compared to standard 32-bit floating-point (FP32).

  • FP16 (Half Precision): Uses 16 bits (1 sign, 5 exponent, 10 mantissa). Offers significant speedups (2-8x on Tensor Cores) and memory savings (50% reduction) compared to FP32 [User Material 4]. However, its limited range can sometimes lead to numerical instability (overflow/underflow) during training, often requiring techniques like gradient scaling.
  • BF16 (Brain Floating Point): Also uses 16 bits (1 sign, 8 exponent, 7 mantissa). It maintains the same dynamic range as FP32 (due to the 8-bit exponent) but has lower precision than FP16. It is generally more stable for training than FP16, especially on hardware that supports it natively (e.g., NVIDIA Ampere and newer, Google TPUs). Offers similar speed and memory benefits as FP16 [User Material 4].
  • TF32 (TensorFloat-32): An NVIDIA-specific format available on Ampere and later GPUs.3 It uses 19 bits internally for computation (1 sign, 8 exponent, 10 mantissa) while still storing data in 32 bits. It provides a speedup over FP32 by utilizing Tensor Cores for matrix multiplication, while retaining nearly the same numerical range and precision as FP32 [User Material 4]. It's often enabled by default for convolutions but needs explicit enabling for matrix multiplications in PyTorch.

Practical Implementation in PyTorch:

You can enable lower precision in several ways:

  1. Set Default Dtype: Change the global default floating-point type. New tensors will be created with this type.
  2. Explicit Casting: Cast specific models or tensors using .half() (for FP16) or .to(torch.bfloat16).
  3. Enable TF32: Use torch.backends.cuda.matmul.allow_tf32 = True to allow PyTorch's cuBLAS integration to use TF32 hardware paths for FP32 matrix multiplications.
Python
import torch

# Check CUDA availability
if not torch.cuda.is_available():
    print("CUDA not available. Skipping dtype examples.")
else:
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
    # --- Option 1: Set Default Dtype (Affects new tensors) ---
    # Use BF16 if available (Ampere+ GPU), otherwise FP16
    if torch.cuda.is_bf16_supported():
        print("\nSetting default dtype to BF16")
        original_dtype = torch.get_default_dtype()
        torch.set_default_dtype(torch.bfloat16)
        current_default = torch.get_default_dtype()
    else:
        print("\nSetting default dtype to FP16")
        original_dtype = torch.get_default_dtype()
        torch.set_default_dtype(torch.float16)
        current_default = torch.get_default_dtype()

    print(f"Current default dtype: {current_default}")
    # Example model/tensors will now default to BF16/FP16
    # class MyModel(torch.nn.Module):
    #     def __init__(self):
    #         super().__init__()
    #         self.linear = torch.nn.Linear(10, 10) # Created with default dtype
    # model = MyModel().cuda()
    # input_tensor = torch.randn(5, 10, device='cuda') # Created with default dtype
    # print(f"Model parameter dtype: {next(model.parameters()).dtype}")
    # print(f"Input tensor dtype: {input_tensor.dtype}")

    # Reset default dtype
    torch.set_default_dtype(original_dtype)
    print(f"Reset default dtype to: {torch.get_default_dtype()}")

    # --- Option 2: Explicitly Cast Tensors/Models ---
    print("\nExplicit Casting Example:")
    model_fp32 = torch.nn.Linear(128, 128).cuda() # Starts as FP32
    input_fp32 = torch.randn(16, 128, device='cuda', dtype=torch.float32)
    print(f"Original model dtype: {next(model_fp32.parameters()).dtype}")
    print(f"Original input dtype: {input_fp32.dtype}")

    # Cast model and input to FP16
    model_fp16 = model_fp32.half() #.half() casts to FP16
    input_fp16 = input_fp32.half() # Cast input too
    print(f"Casted model dtype (FP16): {next(model_fp16.parameters()).dtype}")
    print(f"Casted input dtype (FP16): {input_fp16.dtype}")
    # output_fp16 = model_fp16(input_fp16) # Now runs in FP16

    # Cast model and input to BF16 (if supported)
    if torch.cuda.is_bf16_supported():
        model_bf16 = model_fp32.to(torch.bfloat16)
        input_bf16 = input_fp32.to(torch.bfloat16)
        print(f"Casted model dtype (BF16): {next(model_bf16.parameters()).dtype}")
        print(f"Casted input dtype (BF16): {input_bf16.dtype}")
        # output_bf16 = model_bf16(input_bf16) # Now runs in BF16
    else:
        print("BF16 not supported on this GPU.")

    # --- Option 3: Enable TF32 for Matmuls (if using FP32 inputs) ---
    print("\nEnabling TF32 for FP32 Matmuls:")
    # Check if TF32 is available (Compute Capability >= 8.0) and enable it
    if torch.cuda.get_device_capability() >= 8:
        print("GPU supports TF32. Enabling TF32 for matmul.")
        torch.backends.cuda.matmul.allow_tf32 = True
        # Note: TF32 for convolutions is usually enabled by default via cudnn
        # torch.backends.cudnn.allow_tf32 = True
        print(f"torch.backends.cuda.matmul.allow_tf32 = {torch.backends.cuda.matmul.allow_tf32}")
    else:
        print("GPU does not support TF32 (requires Ampere architecture or newer).")
        torch.backends.cuda.matmul.allow_tf32 = False # Ensure it's off

    # Example FP32 matmul potentially using TF32 now
    if torch.backends.cuda.matmul.allow_tf32:
        a_fp32 = torch.randn(512, 512, device='cuda', dtype=torch.float32)
        b_fp32 = torch.randn(512, 512, device='cuda', dtype=torch.float32)
        # This matmul might use TF32 hardware if enabled and available
        # The actual kernel dispatch depends on cuBLAS heuristics
        c_maybe_tf32 = a_fp32 @ b_fp32
        print("Performed FP32 matmul (potentially using TF32). Output dtype:", c_maybe_tf32.dtype)

    # Disable TF32 if needed for specific comparisons or operations
    torch.backends.cuda.matmul.allow_tf32 = False
    print(f"Disabled TF32 for matmul. Current setting: {torch.backends.cuda.matmul.allow_tf32}")

 
 

Technique 2: Aligning with Hardware (Dimension Multiples)

The performance of GEMM operations, especially when using Tensor Cores, is sensitive to the dimensions of the input matrices (). Tensor Cores and the underlying libraries often process data in fixed-size blocks or tiles (e.g., , , ) [User Material 4]. If the matrix dimensions are not multiples of these tile sizes (commonly 8 or 16 for modern NVIDIA hardware), the hardware execution units (lanes) might be partially idle during computation, leading to wasted cycles and reduced efficiency. Padding is often performed internally by the library, but this adds overhead.

Practical Advice:

During model architecture design or selection, favor dimensions that align well with the hardware. Specifically, choose hidden sizes (d_model), MLP intermediate dimensions, attention head dimensions, and vocabulary sizes that are multiples of 8, or ideally 16. Even better alignment can sometimes be achieved with multiples of 64 or 128, which can also align better with memory transaction sizes [User Material 4].

Python
# Conceptual Example: Choosing Model Dimensions

# Less Optimal (Not multiple of 8/16)
hidden_size_bad = 1000
mlp_intermediate_bad = 3000
vocab_size_unpadded = 32007

print(f"Less Optimal hidden size: {hidden_size_bad}")
print(f"Less Optimal MLP intermediate size: {mlp_intermediate_bad}")
print(f"Less Optimal vocab size: {vocab_size_unpadded}")

# More Optimal (Multiples of 8, 16, 64, or 128)
hidden_size_good = 1024 # Multiple of 128
mlp_intermediate_good = 4096 # Multiple of 128

# Pad vocabulary size to a multiple of 64 or 128 for efficiency
padding_multiple = 64
padded_vocab_size = ((vocab_size_unpadded + padding_multiple - 1) // padding_multiple) * padding_multiple

print(f"\nMore Optimal hidden size: {hidden_size_good} (Multiple of 128)")
print(f"More Optimal MLP intermediate size: {mlp_intermediate_good} (Multiple of 128)")
print(f"Padded vocab size (Multiple of {padding_multiple}): {padded_vocab_size}")

# Rationale: Choosing dimensions like 1024 or 4096 ensures that when these dimensions
# are involved in GEMMs (as M, N, or K), they divide evenly by common tile sizes
# like 8 or 16, maximizing Tensor Core utilization. Padding the vocabulary size
# helps ensure the embedding lookup table and the final output projection layer
# (which often involves a GEMM with the vocabulary size) are well-aligned.

[User Material 4]

Technique 3: Automatic Acceleration with torch.compile

PyTorch 2.0 introduced torch.compile, a powerful feature that acts as a just-in-time (JIT) compiler for PyTorch models.4 By simply wrapping your model with torch.compile(), you can often achieve significant speedups with minimal code changes.

How it Works: Under the hood, torch.compile uses several components 5:

  1. TorchDynamo: Captures the PyTorch computational graph from your Python code safely and with low overhead. It identifies parts of the graph that can be compiled.
  2. AOTAutograd: Traces the autograd engine to capture the backward pass ahead-of-time, allowing the compiler to optimize both forward and backward computations.
  3. PrimTorch: Canonicalizes the ~2000+ PyTorch operators into a smaller set of ~250 primitive operators, providing a stable target for backends.
  4. TorchInductor: The default compiler backend. It takes the graph of primitive operators and generates optimized low-level code. For GPUs, TorchInductor heavily relies on OpenAI Triton 5 to generate fast custom kernels, often fusing operations together automatically.5

This means torch.compile can automatically perform many optimizations, including generating efficient, potentially fused GEMM kernels tailored to your model and hardware, without requiring manual kernel writing.

Usage:

The basic usage is straightforward. torch.compile offers different modes that trade off compilation time for runtime performance 5:

  • mode="default": Good balance, optimized for large models, low compile overhead.
  • mode="reduce-overhead": Faster compilation, uses more memory, good for smaller models where framework overhead dominates.
  • mode="max-autotune": Longest compilation time, potentially fastest runtime by extensively searching for optimal kernel configurations.
Python
import torch
import time
import copy
import torch.nn as nn

# Define a simple model (e.g., a couple of linear layers mimicking an MLP block)
class SimpleMLP(nn.Module):
    def __init__(self, dim=1024):
        super().__init__()
        # Ensure dimensions are multiples of 8/16 for good baseline performance
        assert dim % 16 == 0, "Dimension should be multiple of 16"
        intermediate_dim = 4 * dim
        assert intermediate_dim % 16 == 0, "Intermediate dimension should be multiple of 16"

        self.linear1 = nn.Linear(dim, intermediate_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(intermediate_dim, dim)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

# Ensure we are on a GPU
if not torch.cuda.is_available():
    print("CUDA not available, skipping torch.compile example.")
elif torch.__version__ < "2.0":
    print("torch.compile requires PyTorch 2.0 or later. Skipping example.")
else:
    device = 'cuda'
    dim_size = 2048 # Use a multiple of 16
    batch_size = 64
    # For demonstration, treat sequence as part of batch dim for a simple 2D input
    # In a real LLM, input might be (batch_size * seq_len, dim_size)
    num_tokens = batch_size * 128

    # Create model instance and move to GPU
    model = SimpleMLP(dim=dim_size).to(device)
    # Use appropriate dtype (e.g., float16 or bfloat16 if supported)
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    model = model.to(dtype)
    input_tensor = torch.randn(num_tokens, dim_size, device=device, dtype=dtype)

    print(f"PyTorch Version: {torch.__version__}")
    print(f"Model and tensor created on GPU with dtype: {dtype}.")
    print(f"Input shape: {input_tensor.shape}")

    # --- Baseline Eager Execution ---
    print("\nRunning Eager Mode...")
    # Warmup runs
    for _ in range(5):
        _ = model(input_tensor)
    torch.cuda.synchronize() # Wait for GPU work to finish
    start_time = time.time()
    num_runs = 20
    for _ in range(num_runs):
        _ = model(input_tensor)
    torch.cuda.synchronize()
    eager_time = time.time() - start_time
    print(f"Eager execution time ({num_runs} runs): {eager_time:.4f} seconds")

    # --- Using torch.compile ---
    print("\nCompiling model with torch.compile (mode='default')...")
    # Make a copy to compile, leave original intact if needed
    # Note: Compilation happens on the first call(s) to the compiled model
    compiled_model = torch.compile(copy.deepcopy(model)) # Default mode
    # Other modes:
    # compiled_model = torch.compile(model, mode="reduce-overhead")
    # compiled_model = torch.compile(model, mode="max-autotune")

    print("Running Compiled Mode...")
    # Warmup runs for compiled model (includes compilation time on first run)
    for _ in range(5):
        _ = compiled_model(input_tensor)
    torch.cuda.synchronize()
    start_time = time.time()
    for _ in range(num_runs):
        _ = compiled_model(input_tensor)
    torch.cuda.synchronize()
    compiled_time = time.time() - start_time
    print(f"Compiled execution time ({num_runs} runs): {compiled_time:.4f} seconds")

    # Compare results (optional, should be numerically close)
    # Use no_grad for fair comparison if model includes dropout/batchnorm
    with torch.no_grad():
        model.eval() # Set model to evaluation mode
        compiled_model.eval()
        eager_output = model(input_tensor)
        compiled_output = compiled_model(input_tensor)
        # Use appropriate tolerance for the chosen dtype
        atol = 1e-2 if dtype == torch.float16 else 1e-1
        rtol = 1e-2 if dtype == torch.float16 else 1e-1
        if torch.allclose(eager_output, compiled_output, atol=atol, rtol=rtol):
            print("\n✅ Outputs match between eager and compiled model (within tolerance).")
        else:
            print("\n❌ Outputs differ significantly between eager and compiled model.")
            print(f"   Max absolute difference: {torch.abs(eager_output - compiled_output).max().item()}")
            print(f"   Mean absolute difference: {torch.abs(eager_output - compiled_output).mean().item()}")


    if compiled_time > 0:
        print(f"\nSpeedup factor (Eager / Compiled): {eager_time / compiled_time:.2f}x")
    else:
        print("\nCompiled time was zero or negative, cannot calculate speedup.")

4

Using lower precision, ensuring dimension alignment, and leveraging torch.compile represent the first line of defense in optimizing LLM performance. They are relatively easy to implement and often yield substantial gains by better utilizing the underlying hardware and compiler technologies.

6. Level 2 Optimizations: Getting More Advanced

Beyond the basic techniques, further performance can be unlocked by employing strategies that require more specific library usage or a deeper understanding of the computational patterns within LLMs. These include batching small operations, quantizing models, and fusing sequences of operations.

Technique 4: Batching Small Matmuls (torch.bmm)

LLMs often involve operations that result in many small, independent matrix multiplications. For example, in multi-head attention, computations might be performed independently for each head before results are combined. Executing these small GEMMs one by one using standard matrix multiplication (@ or torch.matmul) within a Python loop is highly inefficient. Each call incurs kernel launch latency – the overhead associated with telling the GPU to start a new computation. This latency can dominate the actual computation time for small matrices [User Material 4].

GPUs thrive on parallelism and perform best when given large, contiguous chunks of work. PyTorch provides torch.bmm (Batch Matrix Multiply) specifically for this scenario. It takes batches of matrices as input (shapes (B, m, k) and (B, k, n), where B is the batch dimension) and performs B independent matrix multiplications ( times ) using a single, optimized kernel launch [User Material 4, User Material 5.2]. This significantly reduces the kernel launch overhead and improves GPU utilization.

Reproducible Code:

Python
import torch
import time

# Setup: Batch of small matrices
batch_size = 256 # Larger batch size makes the overhead difference more apparent
m, k, n = 32, 64, 48 # Small dimensions, but keep them multiples of 8

# Ensure CUDA is available
if not torch.cuda.is_available():
    print("CUDA not available. Skipping torch.bmm example.")
else:
    device = 'cuda'
    dtype = torch.float16 # Use half precision for speed

    # Create batches of matrices on GPU
    A_batch = torch.randn(batch_size, m, k, device=device, dtype=dtype)
    B_batch = torch.randn(batch_size, k, n, device=device, dtype=dtype)
    # Pre-allocate output tensors
    C_loop = torch.empty(batch_size, m, n, device=device, dtype=dtype)
    # C_bmm will be created by the bmm operation

    print(f"Matrices: Batch={batch_size}, M={m}, K={k}, N={n}, dtype={dtype}")
    print(f"Input A shape: {A_batch.shape}")
    print(f"Input B shape: {B_batch.shape}")

    # --- Method 1: Python Loop (Inefficient) ---
    print("\nRunning with Python loop...")
    # Warmup - run a few iterations first
    for i in range(min(batch_size, 10)):
         _ = A_batch[i] @ B_batch[i]
    torch.cuda.synchronize() # Wait for warmup to finish

    start_time = time.time()
    num_runs = 10 # Run the loop multiple times for averaging if needed
    for _ in range(num_runs):
        for i in range(batch_size):
            C_loop[i] = A_batch[i] @ B_batch[i] # Standard matmul inside loop
    torch.cuda.synchronize() # Wait for all computations to finish
    loop_time = (time.time() - start_time) / num_runs
    print(f"Python loop average time: {loop_time:.6f} seconds")

    # --- Method 2: Batched Matrix Multiply (torch.bmm) ---
    print("\nRunning with torch.bmm...")
    # Warmup
    _ = torch.bmm(A_batch, B_batch)
    torch.cuda.synchronize()

    start_time = time.time()
    for _ in range(num_runs):
        C_bmm = torch.bmm(A_batch, B_batch) # Single kernel launch for the whole batch
    torch.cuda.synchronize()
    bmm_time = (time.time() - start_time) / num_runs
    print(f"torch.bmm average time: {bmm_time:.6f} seconds")

    # Check results for correctness
    if torch.allclose(C_loop, C_bmm, atol=1e-3): # Adjust tolerance for FP16
         print("\n✅ Outputs match between loop and bmm.")
    else:
         print("\n❌ Outputs differ between loop and bmm.")
         print(f"   Max difference: {torch.abs(C_loop - C_bmm).max().item()}")

    if bmm_time > 0:
        print(f"\nSpeedup factor (loop time / bmm time): {loop_time / bmm_time:.2f}x")
    else:
        print("\nCannot calculate speedup factor.")

[User Material 4, User Material 5.2]

Whenever you encounter a situation requiring identical matrix multiplications across a batch or group dimension, consider using torch.bmm or related functions like torch.einsum (which can express batch matrix multiplication and more complex tensor contractions) to consolidate the work into fewer, larger kernel launches.

Technique 5: Quantization (INT8/INT4) with bitsandbytes

Quantization is a powerful technique, especially for optimizing LLM inference. It involves converting model parameters (weights) and sometimes activations from higher-precision floating-point formats (like FP32 or FP16) to lower-precision integer formats, typically 8-bit integers (INT8) or even 4-bit integers (INT4).1

Benefits:

  • Reduced Memory Footprint: INT8 weights require 4x less memory than FP32 and 2x less than FP16. INT4 offers even greater savings (8x vs FP32, 4x vs FP16). This allows larger models to fit into limited GPU memory.8
  • Reduced Memory Bandwidth: Moving less data from HBM to the compute units alleviates memory bandwidth bottlenecks, which is often the limiting factor in LLM inference [User Material 4].
  • Faster Computation: Many GPUs have specialized hardware instructions for performing integer arithmetic faster than floating-point operations. INT8 computations can leverage Tensor Cores for significant speedups [User Material 4].

Trade-offs:

  • Potential Accuracy Loss: Representing values with fewer bits inevitably introduces approximation errors (quantization noise). While techniques are designed to minimize this, there can be a drop in model accuracy, which needs careful evaluation.8
  • Complexity: Implementing quantization correctly often involves calibration (determining the optimal mapping from float to int ranges) or Quantization-Aware Training (QAT), although libraries like bitsandbytes simplify Post-Training Quantization (PTQ) significantly.8

The bitsandbytes Library:

bitsandbytes is a popular open-source library that makes advanced quantization techniques accessible within the PyTorch ecosystem.8 It's particularly well-known for:

  • LLM.int8(): An 8-bit quantization scheme primarily for inference. It typically involves quantizing nn.Linear layer weights to INT8 while keeping activations in FP16. It uses mixed-precision decomposition during the GEMM computation to maintain accuracy, especially handling outlier activation values that are crucial for LLM performance.11
  • 4-bit Quantization (NF4, FP4): bitsandbytes also supports 4-bit quantization, notably the NF4 (NormalFloat 4-bit) data type, which was introduced with the QLoRA method for efficient fine-tuning.1
  • Ease of Use: It integrates smoothly with Hugging Face's transformers library, often enabling 8-bit or 4-bit loading with simple flags like load_in_8bit=True or load_in_4bit=True passed via a BitsAndBytesConfig.8 It also provides standalone quantized layer implementations like bnb.nn.Linear8bitLt and bnb.nn.Linear4bit.1

Reproducible Code (Using Linear8bitLt directly):

This example shows how to replace a standard nn.Linear layer with the bitsandbytes 8-bit version for inference.

Python
import torch
import torch.nn as nn
import os
import importlib

# Check if bitsandbytes is installed and seems compatible
try:
    import bitsandbytes as bnb
    from bitsandbytes.nn import Linear8bitLt
    # Perform a basic check that might fail if CUDA setup is incorrect
    if torch.cuda.is_available():
        p = torch.nn.Parameter(torch.rand(10, 10).cuda())
        adam = bnb.optim.Adam8bit([p]) # Try creating an 8-bit optimizer
        print("bitsandbytes imported and basic check passed.")
        bnb_available = True
    else:
        print("bitsandbytes imported, but CUDA not available.")
        bnb_available = False
except ImportError:
    print("bitsandbytes not installed.")
    bnb_available = False
except Exception as e:
    print(f"bitsandbytes might not be compatible or installed correctly: {e}")
    bnb_available = False

if bnb_available and torch.cuda.is_available():
    print("\nRunning bitsandbytes Linear8bitLt Example...")
    # Define model dimensions (use multiples of 8/16)
    in_features = 256
    out_features = 512
    batch_size = 32
    num_tokens = batch_size * 16 # Example token count

    # 1. Create or load a standard FP16 model/layer
    # In a real scenario, you'd load pretrained weights
    fp16_linear_layer = nn.Linear(in_features, out_features, bias=False).cuda().half()
    fp16_input = torch.randn(num_tokens, in_features, device='cuda', dtype=torch.float16)
    print(f"Input shape: {fp16_input.shape}")
    print(f"Original FP16 Layer: {fp16_linear_layer}")

    # 2. Create the equivalent layer using bitsandbytes Linear8bitLt for INT8 inference
    # For the common LLM.int8() inference pattern (INT8 weights, FP16 activations):
    # - Set has_fp16_weights=False: This tells the layer to store weights internally as INT8.
    # - threshold=6.0: This parameter is specific to the LLM.int8() mixed-precision
    #   decomposition technique to handle activation outliers. It determines the threshold
    #   above which activation values are treated separately in FP16.
    int8_linear_layer = Linear8bitLt(
        in_features,
        out_features,
        bias=False, # Match the original layer's bias setting
        has_fp16_weights=False, # Critical for INT8 weight storage
        threshold=6.0 # Default threshold for LLM.int8()
    ).cuda() # Move layer to GPU *before* loading state_dict or using it

    print(f"Bitsandbytes INT8 Layer: {int8_linear_layer}")

    # 3. Load the state dict from the FP16 layer into the INT8 layer
    # bitsandbytes automatically handles the quantization of the weights
    # when the state dict is loaded into the Linear8bitLt layer moved to CUDA.
    int8_linear_layer.load_state_dict(fp16_linear_layer.state_dict())
    print("Loaded FP16 weights into INT8 layer (quantization applied).")

    # Verify weight storage (internal attribute, may change)
    if hasattr(int8_linear_layer, 'weight') and int8_linear_layer.weight.dtype == torch.int8:
         print(f"Verified internal weight dtype is: {int8_linear_layer.weight.dtype}")
    else:
         print("Could not verify internal weight dtype (might be stored differently).")


    # 4. Perform inference using the INT8 layer
    # Input activations should typically remain in FP16 for LLM.int8()
    start_time = time.time()
    int8_output = int8_linear_layer(fp16_input)
    torch.cuda.synchronize()
    int8_time = time.time() - start_time

    print(f"\nOutput shape (INT8 Layer): {int8_output.shape}")
    # The output is typically returned in FP16 after the mixed-precision computation
    print(f"Output dtype (INT8 Layer): {int8_output.dtype}")
    print(f"Inference time (INT8 Layer): {int8_time:.6f} seconds")

    # Optional: Compare output and speed with the original FP16 layer
    start_time = time.time()
    fp16_output = fp16_linear_layer(fp16_input)
    torch.cuda.synchronize()
    fp16_time = time.time() - start_time
    print(f"Inference time (FP16 Layer): {fp16_time:.6f} seconds")

    diff = torch.abs(fp16_output - int8_output).mean()
    print(f"Mean absolute difference between FP16 and INT8 outputs: {diff.item():.4f}")
    if fp16_time > 0 :
        print(f"Speedup factor (INT8 vs FP16): {fp16_time / int8_time:.2f}x")


    # --- Using Hugging Face Integration (Conceptual Code) ---
    print("\n--- Conceptual Hugging Face Integration ---")
    print("# This requires transformers, accelerate, and bitsandbytes installed.")
    print("# Example:")
    print("# from transformers import AutoModelForCausalLM, BitsAndBytesConfig")
    print("# ")
    print("# model_id = 'meta-llama/Llama-2-7b-hf' # Or any compatible model")
    print("# ")
    print("# # Configure 8-bit quantization")
    print("# quantization_config = BitsAndBytesConfig(")
    print("#     load_in_8bit=True,")
    print("#     # Optional: Specify compute dtype if needed")
    print("#     # bnb_4bit_compute_dtype=torch.bfloat16 ")
    print("# )")
    print("# ")
    print("# # Load the model with quantization enabled")
    print("# model = AutoModelForCausalLM.from_pretrained(")
    print("#     model_id,")
    print("#     quantization_config=quantization_config,")
    print("#     device_map='auto' # Handles placing layers on devices")
    print("# )")
    print("# ")
    print("# print(f'Model {model_id} loaded with 8-bit quantization via Hugging Face!')")
    print("# print(model)")

else:
    print("Skipping bitsandbytes Linear8bitLt example due to unavailability.")

1

Quantization, particularly using libraries like bitsandbytes, offers a compelling way to reduce the memory demands and potentially accelerate the inference of large models, making them more deployable in resource-constrained environments.

Technique 6: Fused Operations (FlashAttention & xFormers via SDPA)

Kernel fusion is a powerful optimization technique where multiple distinct computational steps are combined into a single GPU kernel [User Material 4].

Why Fuse?

Modern GPUs have vastly more computational power (FLOPs) than memory bandwidth. Many operations in neural networks are "memory-bound," meaning the time taken is limited by how fast data can be read from and written to the GPU's main memory (HBM), rather than by the speed of the calculations themselves. Fusion combats this by 13:

  1. Reducing Kernel Launch Overhead: Launching one fused kernel is faster than launching multiple smaller ones sequentially.
  2. Minimizing HBM Traffic: By keeping intermediate results within the GPU's much faster on-chip memory (SRAM or registers) between fused operations, fusion avoids costly round trips to HBM. Data is loaded once, processed through multiple steps, and then written back.

Attention Fusion: The Killer App

The self-attention mechanism in Transformers is a prime candidate for fusion. Standard attention calculation involves several steps:

  1. Calculate (GEMM).
  2. Scale the result.
  3. Apply attention mask (element-wise).
  4. Apply softmax (element-wise).
  5. Apply dropout (element-wise, during training).
  6. Multiply by (GEMM).

Crucially, the intermediate attention score matrix ( after scaling and masking) can be very large, especially for long sequences (size ). In a naive implementation, this large matrix must be written to and read back from HBM, consuming significant memory () and bandwidth [User Material 4].

FlashAttention and xFormers:

Libraries and algorithms have been developed to optimize attention by fusing these steps and avoiding the materialization of the full attention matrix:

  • FlashAttention: An algorithm that computes the exact attention output without ever storing the full attention matrix in HBM.13 It works by loading blocks (tiles) of , , and into the fast on-chip SRAM, computing the attention output for that block (including the softmax), and writing the result for that block back to HBM. It uses clever numerical techniques (softmax rescaling) to ensure the final result is correct even though it's computed block by block.13 This reduces the memory complexity from to and significantly speeds up computation by dramatically reducing HBM access.13 FlashAttention-2 and FlashAttention-3 offer further performance improvements on newer hardware.13
  • xFormers: A library from Meta AI containing various optimized building blocks for Transformers, including highly efficient memory-optimized attention implementations.3 These implementations also aim to reduce memory usage and improve speed compared to naive attention.

PyTorch Integration: scaled_dot_product_attention (SDPA)

Recognizing the importance of fused attention, PyTorch 2.0 introduced torch.nn.functional.scaled_dot_product_attention (SDPA). This function provides a unified interface for calculating scaled dot product attention.19 When running on CUDA with appropriate hardware and inputs, SDPA automatically dispatches to highly optimized fused kernels, including 4:

  • FlashAttention (v1, v2, or v3 depending on PyTorch version and hardware).13
  • Memory-Efficient Attention kernels (inspired by or directly from xFormers).
  • A fallback PyTorch C++ implementation if fused kernels are unavailable.

Using F.scaled_dot_product_attention is now the standard and recommended way to implement attention in PyTorch to automatically benefit from these powerful fusion optimizations.

Reproducible Code (Using SDPA):

Python
import torch
import torch.nn.functional as F
import time
import math

# Attempt to import SDPA context manager, handle older PyTorch versions
sdpa_kernel, SDPBackend = None, None
try:
    from torch.backends.cuda import sdp_kernel, SDPBackend
except ImportError:
    print("Warning: Could not import sdp_kernel. Explicit backend control disabled.")
    print("         (Requires PyTorch 2.0+ with CUDA support)")

# Setup Parameters
batch_size = 8
num_heads = 12
seq_len = 2048 # Longer sequence length benefits more from fusion
head_dim = 64
dtype = torch.float16 # Fused kernels often optimized for FP16/BF16
device = 'cuda'

# Check prerequisites
if not torch.cuda.is_available():
    print("CUDA not available, skipping SDPA example.")
elif torch.__version__ < "2.0":
    print("Fused attention via SDPA requires PyTorch 2.0 or later. Skipping.")
else:
    print(f"PyTorch Version: {torch.__version__}")
    print(f"Using device: {device}, dtype: {dtype}")
    print(f"Batch={batch_size}, Heads={num_heads}, SeqLen={seq_len}, HeadDim={head_dim}")

    # Create realistic input tensors
    query = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
    key = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
    value = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
    # Example: Causal mask for decoder self-attention
    is_causal_flag = True

    # --- Using torch.nn.functional.scaled_dot_product_attention ---
    print("\nRunning with torch.nn.functional.scaled_dot_product_attention...")
    num_runs = 20
    warmup_runs = 5

    results = {}

    # --- Backend Option 1: Math Kernel (Reference, Non-Fused) ---
    if sdp_kernel: # Only run if context manager is available
        print("\nRunning with Math backend (Reference)...")
        try:
            # Force use of the basic math implementation
            with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
                # Warmup
                for _ in range(warmup_runs):
                    _ = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal_flag)
                torch.cuda.synchronize()
                start_time = time.time()
                for _ in range(num_runs):
                    output_math = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal_flag)
                torch.cuda.synchronize()
            math_time = (time.time() - start_time) / num_runs
            results['math'] = math_time
            print(f"Math backend average time: {math_time:.6f} seconds")
        except RuntimeError as e:
            print(f"Math backend failed or not applicable: {e}")
            results['math'] = float('inf')
    else:
        print("\nSkipping Math backend (sdp_kernel not available).")
        results['math'] = float('inf')


    # --- Backend Option 2: FlashAttention Kernel ---
    if sdp_kernel:
        print("\nRunning with FlashAttention backend (if available)...")
        try:
            # Force use of FlashAttention
            with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
                 # Warmup
                for _ in range(warmup_runs):
                    _ = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal_flag)
                torch.cuda.synchronize()
                start_time = time.time()
                for _ in range(num_runs):
                    output_flash = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal_flag)
                torch.cuda.synchronize()
            flash_time = (time.time() - start_time) / num_runs
            results['flash'] = flash_time
            print(f"FlashAttention backend average time: {flash_time:.6f} seconds")
            if not math.isinf(results['math']) and flash_time > 0:
                print(f"  Speedup vs Math: {results['math'] / flash_time:.2f}x")
        except RuntimeError as e:
            print(f"FlashAttention backend not supported or failed: {e}")
            results['flash'] = float('inf')
    else:
        print("\nSkipping FlashAttention backend (sdp_kernel not available).")
        results['flash'] = float('inf')


    # --- Backend Option 3: Memory Efficient Attention Kernel ---
    if sdp_kernel:
        print("\nRunning with Memory Efficient backend (if available)...")
        try:
            # Force use of Memory Efficient Attention
            with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
                 # Warmup
                for _ in range(warmup_runs):
                    _ = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal_flag)
                torch.cuda.synchronize()
                start_time = time.time()
                for _ in range(num_runs):
                    output_mem = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal_flag)
                torch.cuda.synchronize()
            mem_efficient_time = (time.time() - start_time) / num_runs
            results['mem_efficient'] = mem_efficient_time
            print(f"Memory Efficient backend average time: {mem_efficient_time:.6f} seconds")
            if not math.isinf(results['math']) and mem_efficient_time > 0:
                 print(f"  Speedup vs Math: {results['math'] / mem_efficient_time:.2f}x")
        except RuntimeError as e:
            print(f"Memory Efficient backend not supported or failed: {e}")
            results['mem_efficient'] = float('inf')
    else:
        print("\nSkipping Memory Efficient backend (sdp_kernel not available).")
        results['mem_efficient'] = float('inf')

    # --- Automatic Backend Selection (Recommended Usage) ---
    print("\nRunning SDPA with automatic backend selection...")
    # No context manager needed, PyTorch chooses the best available fused kernel
    try:
        # Warmup
        for _ in range(warmup_runs):
            _ = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal_flag)
        torch.cuda.synchronize()
        start_time = time.time()
        for _ in range(num_runs):
            output_auto = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal_flag)
        torch.cuda.synchronize()
        auto_time = (time.time() - start_time) / num_runs
        results['auto'] = auto_time
        print(f"Automatic backend average time: {auto_time:.6f} seconds")
        # Compare automatic time to the fastest explicit fused kernel time
        best_fused_time = min(results.get('flash', float('inf')), results.get('mem_efficient', float('inf')))
        if not math.isinf(best_fused_time) and best_fused_time > 0:
             print(f"  Ratio vs best explicit fused backend: {auto_time / best_fused_time:.2f}")
        if not math.isinf(results['math']) and auto_time > 0:
             print(f"  Speedup vs Math: {results['math'] / auto_time:.2f}x")

    except Exception as e:
        print(f"Automatic SDPA failed: {e}")
        results['auto'] = float('inf')

    # (Optional) Check if outputs match across successful backends
    # Note: Numerical differences might exist due to computation order/precision used in fused kernels
    print("\nChecking output consistency (if multiple backends ran successfully)...")
    outputs_to_compare =
    if not math.isinf(results['math']): outputs_to_compare.append(output_math)
    if not math.isinf(results['flash']): outputs_to_compare.append(output_flash)
    if not math.isinf(results['mem_efficient']): outputs_to_compare.append(output_mem)
    if not math.isinf(results['auto']): outputs_to_compare.append(output_auto)

    if len(outputs_to_compare) > 1:
        base_output = outputs_to_compare
        match = True
        for i in range(1, len(outputs_to_compare)):
            if not torch.allclose(base_output, outputs_to_compare[i], atol=1e-2, rtol=1e-2): # Tolerance for FP16
                print(f"❌ Mismatch detected between backend outputs!")
                match = False
                break
        if match:
            print("✅ Outputs appear consistent across successful backends (within tolerance).")
    else:
        print("Not enough successful backend runs to compare outputs.")

3

Fusion, especially in memory-intensive operations like attention, is a cornerstone of modern LLM optimization. By minimizing data movement, fused kernels like those automatically invoked by PyTorch's SDPA significantly boost performance and enable models to handle much longer input sequences.

7. Level 3: Unleashing Custom Power with Triton

While libraries like cuBLAS, bitsandbytes, and the fused kernels within PyTorch provide significant acceleration, there are scenarios where even more performance is desired, or where novel operations need to be implemented efficiently. This is where tools like OpenAI Triton come into play, allowing developers to write custom, high-performance GPU kernels directly within Python.

What is Triton?

Triton is an open-source programming language and compiler designed to make writing efficient GPU code easier and more productive than traditional CUDA C++.6 It provides a Pythonic syntax, allowing developers familiar with NumPy and PyTorch to write parallel GPU kernels.6 Triton's compiler takes this high-level Python code, performs numerous optimizations (like instruction scheduling, memory access coalescing, and automatic management of shared memory), and compiles it down to low-level machine code (PTX for NVIDIA GPUs or other backends) using frameworks like LLVM/MLIR.22

Why Use Triton?

  1. Custom Kernel Fusion: Triton excels at creating custom fused kernels. You can combine sequences of operations specific to your model (e.g., GEMM + bias + custom activation + residual add) into a single kernel, minimizing memory traffic beyond what standard libraries or even torch.compile might achieve automatically.22
  2. Rapid Prototyping: The Python-based syntax allows for much faster iteration and experimentation with different kernel implementations compared to CUDA C++.22
  3. Performance: Triton aims to generate kernels with performance comparable to those written by CUDA experts, often achieving near hardware limits, especially for operations like GEMM and attention.22 It includes an auto-tuner to find optimal configurations (like block sizes) for specific hardware.
  4. Accessibility: It significantly lowers the barrier to entry for GPU programming, enabling more researchers and engineers to write custom kernels without deep CUDA expertise.6
  5. Integration with PyTorch: Triton kernels can be seamlessly called from PyTorch. Furthermore, TorchInductor (the default backend for torch.compile in PyTorch 2.0+) heavily utilizes Triton to generate its optimized GPU kernels.5 Understanding Triton can provide insights into how torch.compile works.

Core Triton Concepts and Syntax:

Triton kernels operate on blocks of data (tiles) processed in parallel by GPU thread blocks. Key elements include:

  • @triton.jit: The JIT compiler decorator for kernel functions.6
  • tl.program_id(axis): Returns the unique ID of the current thread block executing the kernel instance.6 Used for partitioning work.
  • tl.arange(start, end): Creates a 1D tensor representing a range of indices, often used for calculating memory offsets within a block.6
  • tl.load(pointer, mask=..., other=...): Loads a tile of data from global GPU memory (HBM) into faster on-chip memory (SRAM/registers). The mask handles boundary conditions safely.6
  • tl.store(pointer, value, mask=...): Stores a tile of data from on-chip memory back to HBM, again using mask for safety.6
  • tl.dot(a, b): Performs block-level matrix multiplication, typically leveraging Tensor Cores for acceleration. Assumes a and b are tiles loaded into SRAM.6
  • tl.* element-wise ops: Functions like tl.exp, tl.maximum, tl.sum, etc., operate element-wise on Triton tensors.6
  • tl.constexpr: Marks variables (like BLOCK_SIZE) as compile-time constants, allowing the compiler to generate more specialized and optimized code.25

Simple Triton Kernel Example (Fused ReLU):

Before tackling GEMM, let's look at a simpler example: applying the ReLU activation function. While PyTorch's built-in ReLU is fast, this illustrates the basic structure of a Triton kernel.

Python
import torch
import triton
import triton.language as tl

@triton.jit
def relu_kernel(x_ptr, # Pointer to input tensor
                y_ptr, # Pointer to output tensor
                n_elements, # Total number of elements in the tensor
                BLOCK_SIZE: tl.constexpr # Size of blocks processed by each kernel instance
                ):
    """
    Triton kernel for element-wise ReLU activation: y = max(0, x).
    """
    # 1. Calculate the range of elements this kernel instance will process
    pid = tl.program_id(axis=0) # Get the unique ID (0, 1, 2,...) for this instance
    # Calculate the starting offset for this instance
    block_start = pid * BLOCK_SIZE
    # Create a range of offsets relative to the block start
    offsets = block_start + tl.arange(0, BLOCK_SIZE) # e.g., [1024, 1025,..., 2047] if pid=1, BLOCK_SIZE=1024

    # 2. Create a mask to handle elements potentially out of bounds
    # This is important if n_elements is not perfectly divisible by BLOCK_SIZE
    mask = offsets < n_elements

    # 3. Load the input data block safely using the mask
    # Elements outside the mask will be loaded with 0.0 (or another specified value)
    x = tl.load(x_ptr + offsets, mask=mask, other=0.0)

    # 4. Apply the ReLU activation element-wise
    # tl.maximum performs element-wise max operation
    relu_output = tl.maximum(x, 0.0)

    # 5. Store the result back to the output tensor, using the mask
    # Only elements within the valid range (mask=True) are written
    tl.store(y_ptr + offsets, relu_output, mask=mask)

def triton_relu(x: torch.Tensor):
    """
    Wrapper function to launch the Triton ReLU kernel.
    """
    # Allocate output tensor
    y = torch.empty_like(x)
    # Ensure tensors are contiguous in memory, as many kernels assume this
    assert x.is_contiguous(), "Input tensor must be contiguous"
    assert y.is_contiguous(), "Output tensor must be contiguous"

    n_elements = x.numel() # Get the total number of elements

    # Define kernel launch parameters
    # BLOCK_SIZE determines how many elements each kernel instance handles.
    # This is a tunable parameter affecting performance.
    BLOCK_SIZE = 1024 # Powers of 2 are common

    # Grid size determines how many kernel instances to launch.
    # We need enough instances to cover all elements.
    # triton.cdiv(a, b) computes ceil(a / b)
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),) # Launch a 1D grid of instances

    # Launch the kernel
    # Pass pointers, dimensions, and constexpr block size
    relu_kernel[grid](x, y, n_elements, BLOCK_SIZE=BLOCK_SIZE)

    return y

# Example Usage
if torch.cuda.is_available():
    # Create a sample tensor on GPU
    input_tensor = torch.randn(2048 * 2, device='cuda') * 10 - 5 # Include negative values

    # Run Triton ReLU
    output_triton = triton_relu(input_tensor)

    # Run standard PyTorch ReLU for comparison
    output_torch = torch.relu(input_tensor)

    print("--- Triton ReLU Example ---")
    print(f"Input shape: {input_tensor.shape}")
    # print(f"Input sample: {input_tensor[:8]}")
    # print(f"Triton output sample: {output_triton[:8]}")
    # print(f"PyTorch output sample: {output_torch[:8]}")

    # Verify correctness
    if torch.allclose(output_triton, output_torch):
        print("✅ Triton and Torch ReLU match")
    else:
        print("❌ Triton and Torch ReLU differ")
        print(f"   Max difference: {torch.abs(output_triton - output_torch).max().item()}")
else:
    print("CUDA not available, skipping Triton ReLU example.")

Triton GEMM Example:

Now, let's look at the structure of a basic GEMM kernel in Triton. This kernel implements C=A×B. It divides the output matrix C into blocks and assigns each block to a kernel instance (program ID). Each instance then iterates through blocks of the K dimension, loading tiles of A and B into fast memory, performing a block-level matrix multiplication using tl.dot, and accumulating the result.

Python
import torch
import triton
import triton.language as tl

@triton.jit
def matmul_kernel(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        # Matrix dimensions
        M, N, K,
        # Strides for row-major layout
        stride_am, stride_ak, # Stride to move to next row (M) or next column (K) in A
        stride_bk, stride_bn, # Stride to move to next row (K) or next column (N) in B
        stride_cm, stride_cn, # Stride to move to next row (M) or next column (N) in C
        # Tile dimensions (compile-time constants)
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr
        # Optional: Grouping for L2 cache locality (not shown in this basic version)
        # GROUP_SIZE_M: tl.constexpr
        ):
    """
    Kernel for computing the matrix multiplication C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    Each kernel instance computes one BLOCK_SIZE_M x BLOCK_SIZE_N block of C.
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # We use a 1D launch grid, so pid is a scalar.
    pid = tl.program_id(axis=0)
    # Calculate how many blocks are needed grid-wise
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) # Number of blocks in M dimension
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # Number of blocks in N dimension
    # Calculate the grid indices for the current block
    pid_m = pid // num_pid_n # Row index of the block in the grid
    pid_n = pid % num_pid_n  # Column index of the block in the grid

    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We need ranges of indices for the M, N, and K dimensions within a block.
    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) # Row indices for the C block
    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) # Col indices for the C block
    offs_k = tl.arange(0, BLOCK_SIZE_K) # Indices for the K dimension

    # Calculate pointers for the initial K-block of A and B
    # Use broadcasting to create pointers for the entire tile
    # a_ptrs shape: (BLOCK_SIZE_M, BLOCK_SIZE_K)
    # b_ptrs shape: (BLOCK_SIZE_K, BLOCK_SIZE_N)
    a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate results in high precision (float32)
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    # Loop over the K dimension in steps of BLOCK_SIZE_K
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # Load the current blocks of A and B
        # Masks handle potential out-of-bounds accesses in the K dimension
        # 'other=0.0' pads with zeros if out of bounds
        a = tl.load(a_ptrs + k * BLOCK_SIZE_K * stride_ak, # Advance pointer in K dim
                    mask=(offs_k[None, :] + k * BLOCK_SIZE_K < K), other=0.0)
        b = tl.load(b_ptrs + k * BLOCK_SIZE_K * stride_bk, # Advance pointer in K dim
                    mask=(offs_k[:, None] + k * BLOCK_SIZE_K < K), other=0.0)

        # Perform block-level matrix multiplication and accumulate
        accumulator += tl.dot(a, b)

        # Pointers `a_ptrs` and `b_ptrs` themselves are NOT advanced here
        # because the offset `k * BLOCK_SIZE_K * stride_ak` is added inside tl.load.
        # Alternatively, one could advance the base pointers like:
        # a = tl.load(a_ptrs, mask=...)
        # b = tl.load(b_ptrs, mask=...)
        # accumulator += tl.dot(a, b)
        # a_ptrs += BLOCK_SIZE_K * stride_ak
        # b_ptrs += BLOCK_SIZE_K * stride_bk

    # Convert accumulator to the desired output type (e.g., float16)
    c = accumulator.to(c_ptr.dtype.element_ty)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C with masks.
    # Calculate pointers for the C block
    c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :]
    # Create masks to prevent writing out of bounds for the M and N dimensions
    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


def triton_matmul(a, b):
    """
    Wrapper function to launch the Triton Matmul kernel.
    """
    # Check constraints.
    assert a.shape[1] == b.shape, "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    assert b.is_contiguous(), "Matrix B must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=a.dtype) # Match input dtype

    # Define block sizes (these are crucial for performance and can be tuned)
    # Common values, but optimal ones depend on GPU architecture and matrix shapes
    BLOCK_SIZE_M = 128
    BLOCK_SIZE_N = 128
    BLOCK_SIZE_K = 32
    # GROUP_SIZE_M = 8 # Example for potential L2 cache optimization

    # Launch grid: 1D grid where each program instance computes one block of C
    grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), )

    # Launch the kernel
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K
        # GROUP_SIZE_M=GROUP_SIZE_M # If using grouping
    )
    return c

# Example Usage
if torch.cuda.is_available():
    torch.manual_seed(0)
    # Use dimensions that are multiples of block sizes if possible
    M, N, K = 512, 1024, 256
    a = torch.randn((M, K), device='cuda', dtype=torch.float16)
    b = torch.randn((K, N), device='cuda', dtype=torch.float16)

    print("\n--- Triton GEMM Example ---")
    print(f"Input shapes: A=({M}, {K}), B=({K}, {N})")
    print(f"Using dtype: {a.dtype}")

    # Time Triton GEMM
    # Warmup
    _ = triton_matmul(a, b)
    torch.cuda.synchronize()
    start_time = time.time()
    triton_output = triton_matmul(a, b)
    torch.cuda.synchronize()
    triton_time = time.time() - start_time
    print(f"Triton GEMM time: {triton_time:.6f} seconds")

    # Time PyTorch GEMM
    # Warmup
    _ = torch.matmul(a, b)
    torch.cuda.synchronize()
    start_time = time.time()
    torch_output = torch.matmul(a, b)
    torch.cuda.synchronize()
    torch_time = time.time() - start_time
    print(f"Torch GEMM time: {torch_time:.6f} seconds")

    # Verify correctness
    if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0.01): # Adjust tolerance for FP16
        print("✅ Triton and Torch GEMM match")
    else:
        print("❌ Triton and Torch GEMM differ")
        print(f"   Max difference: {torch.abs(triton_output - torch_output).max().item()}")

    if torch_time > 0:
        print(f"Speedup factor (Torch / Triton): {torch_time / triton_time:.2f}x")
        # Note: Torch matmul often calls highly optimized cuBLAS kernels.
        # Beating it requires careful Triton tuning and potentially fusion.
        # This basic kernel might not outperform cuBLAS without further optimization.

else:
    print("CUDA not available, skipping Triton GEMM example.")

6

Triton represents a significant step towards making custom GPU kernel development more accessible. While torch.compile leverages it automatically, understanding Triton allows developers to manually craft highly optimized kernels for specific bottlenecks or novel operations, pushing performance beyond standard libraries when necessary.

8. Leveraging Pre-Optimized Kernels: Introduction to Liger

While Triton empowers developers to write custom kernels, another approach is to leverage pre-built, optimized Triton kernels packaged into libraries. Liger-Kernel is a prominent example specifically targeting LLM acceleration.14

What is Liger?

Liger-Kernel is an open-source library developed by LinkedIn containing a collection of efficient Triton kernels designed to accelerate common operations found in LLM training and inference pipelines.14 It aims to provide drop-in replacements for standard Hugging Face Transformer components, offering performance improvements with minimal code changes.

Key Features & Optimizations:

Liger achieves performance gains primarily through kernel fusion, similar in principle to FlashAttention but applied to different parts of the Transformer block:

  • Fused Kernels: It provides optimized Triton implementations for operations like:
    • RMSNorm: Fuses the normalization calculation with element-wise operations.
    • RoPE (Rotary Position Embeddings): Fuses the embedding calculation and application.
    • SwiGLU / GeGLU: Fuses the gated linear unit activations (often involving multiple element-wise operations and GEMMs).
    • CrossEntropy Loss: Can fuse the final linear projection with the cross-entropy loss calculation, avoiding materialization of the full logit tensor, which is beneficial for large vocabularies.31
    • Other potential fusions depending on the library version. By fusing these operations, Liger reduces memory I/O and kernel launch overhead.14
  • Performance Gains: Liger reports significant throughput increases (e.g., up to ~20-26% in multi-GPU training) and substantial memory reduction (e.g., up to ~60%) compared to baseline Hugging Face implementations. This can enable training with larger batch sizes or longer context lengths on the same hardware.14
  • Ease of Use: Liger is designed for easy integration. It offers:
    • AutoLigerKernelForCausalLM: An AutoModel class that automatically patches supported Hugging Face models upon loading.30
    • Model-Specific Patching APIs (e.g., apply_liger_kernel_to_llama): Functions to apply Liger kernels to specific model architectures before instantiation.30
    • Individual Kernel Modules: For advanced users composing custom models.32
  • Lightweight Dependency: Requires only PyTorch and Triton, avoiding complex dependencies.14
  • Compatibility: Supports multi-GPU training setups (FSDP, DeepSpeed) and both NVIDIA CUDA and AMD ROCm platforms.14

Reproducible Code (Applying Liger to HF Model):

Python
# NOTE: Requires Liger-Kernel: pip install liger-kernel
# NOTE: Requires transformers: pip install transformers
# NOTE: Requires accelerate for device_map='auto': pip install accelerate
# NOTE: Replace "meta-llama/Llama-2-7b-hf" with an actual model path or identifier
#       accessible to your environment (requires appropriate access/download).

import torch
import transformers
import os
import importlib

# Check if Liger is installed and CUDA is available
try:
    import liger_kernel
    # Check for specific components to be more certain
    from liger_kernel.transformers import AutoLigerKernelForCausalLM, apply_liger_kernel_to_llama
    liger_available = True
    print(f"Liger-Kernel version {liger_kernel.__version__} found.")
except ImportError:
    liger_available = False
    print("Liger-Kernel not installed. Skipping Liger example.")

if not torch.cuda.is_available():
     print("CUDA not available. Skipping Liger example.")
     liger_available = False # Ensure flag is false if no CUDA

if liger_available:
    # --- Option 1: Using AutoLigerKernelForCausalLM (Simplest) ---
    print("\n--- Method 1: Using AutoLigerKernelForCausalLM ---")
    # Use a known supported model identifier (replace if necessary)
    # Ensure you have access permission and enough disk space/memory.
    model_name_or_path = "meta-llama/Llama-2-7b-hf"
    print(f"Attempting to load '{model_name_or_path}' with AutoLigerKernelForCausalLM...")
    print("Note: This may download the model if not cached locally.")

    try:
        # This AutoModel wrapper automatically patches if the model type (e.g., Llama) is supported by Liger.
        model_auto = AutoLigerKernelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
            device_map='auto' # Requires 'accelerate' library
        )
        print("\nModel loaded successfully using AutoLigerKernelForCausalLM!")
        print("Liger kernels should be automatically applied where supported.")
        # Example: Check if a layer known to be patched exists and is of Liger type (conceptual)
        # This requires knowing Liger's internal class names, which might change.
        # Example check for RMSNorm in the first block of a Llama model:
        if hasattr(model_auto, 'model') and hasattr(model_auto.model, 'layers') and len(model_auto.model.layers) > 0:
             first_layer = model_auto.model.layers
             if hasattr(first_layer, 'input_layernorm'):
                 norm_layer = first_layer.input_layernorm
                 if 'liger_kernel' in str(norm_layer.__class__).lower():
                      print(f"Detected Liger RMSNorm in layer 0: {type(norm_layer)}")
                 else:
                      print(f"Layer 0 RMSNorm type: {type(norm_layer)} (Not Liger?)")

        # Clean up memory
        del model_auto
        torch.cuda.empty_cache()

    except Exception as e:
        print(f"\n[Error] Could not load model with AutoLigerKernelForCausalLM: {e}")
        print("  Possible reasons: Model not found, insufficient memory/disk, missing dependencies (transformers, accelerate), or access restrictions.")


    # --- Option 2: Using Model-Specific Patching API (for Llama) ---
    print("\n--- Method 2: Using apply_liger_kernel_to_llama ---")
    # Check if the model name suggests it's a Llama model for this specific patcher
    if "llama" in model_name_or_path.lower():
        print(f"Applying Liger kernel patches for Llama architecture...")
        try:
            # Apply patches *before* loading the model.
            # Defaults usually apply all relevant kernels. Can be specific:
            apply_liger_kernel_to_llama(
                rms_norm=True,
                rope=True,
                swiglu=True,
                # cross_entropy=True, # Only relevant for models with a value head, etc.
                # fused_linear_cross_entropy=False
            )
            print("Liger patches applied to transformers Llama classes.")

            print(f"\nAttempting to load '{model_name_or_path}' with standard AutoModelForCausalLM...")
            # Load the model using the standard Hugging Face class AFTER patching
            model_patched = transformers.AutoModelForCausalLM.from_pretrained(
                model_name_or_path,
                torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
                device_map='auto'
            )
            print("\nModel loaded successfully using standard AutoModel after patching!")
            print("The instantiated model should now use Liger kernels where applicable.")

            # Example check (similar to above)
            if hasattr(model_patched, 'model') and hasattr(model_patched.model, 'layers') and len(model_patched.model.layers) > 0:
                 first_layer = model_patched.model.layers
                 if hasattr(first_layer, 'input_layernorm'):
                     norm_layer = first_layer.input_layernorm
                     if 'liger_kernel' in str(norm_layer.__class__).lower():
                          print(f"Detected Liger RMSNorm in layer 0: {type(norm_layer)}")
                     else:
                          print(f"Layer 0 RMSNorm type: {type(norm_layer)} (Not Liger?)")

            # Clean up memory
            del model_patched
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"\n[Error] Could not load model after applying Liger patches: {e}")
            print("  Possible reasons: Model not found, insufficient memory/disk, missing dependencies, or access restrictions.")
    else:
        print(f"Model name '{model_name_or_path}' does not contain 'llama', skipping Llama-specific patching example.")

14

Libraries like Liger-Kernel represent a valuable layer in the optimization stack. By leveraging Triton internally, they provide the performance benefits of custom kernels without requiring end-users to write or tune Triton code themselves. This modular approach allows the community to build and share optimized components, accelerating progress in LLM efficiency.

9. When GEMM Isn't the Whole Story

While this tutorial has focused heavily on GEMM optimization due to its dominance in standard dense Transformer LLMs, it's important to acknowledge scenarios where other operations might become bottlenecks, or where the computational pattern shifts away from GEMM entirely [User Material 6].

  • Convolutions: In Convolutional Neural Networks (CNNs) or hybrid architectures incorporating convolutional layers, the convolution operation itself is key. Libraries like NVIDIA's cuDNN (used by torch.nn.Conv2d) employ various algorithms for convolutions, including:
    • Implicit GEMM: Techniques like im2col or im2row transform the convolution into a large GEMM operation, which can then be accelerated by cuBLAS or Tensor Cores.
    • Direct Methods: Algorithms like Winograd or FFT-based convolutions can be faster than implicit GEMM for certain kernel sizes and hardware. cuDNN selects the best algorithm based on heuristics.
  • Sparse Models: Techniques like Mixture-of-Experts (MoE) or model pruning introduce sparsity into the computations. Instead of dense GEMM, these models rely heavily on Sparse Matrix Multiplication (SpMM) kernels. Optimizing SpMM involves different strategies focused on handling irregular memory access and load balancing for the non-zero elements [User Material 6].
  • Alternative Architectures & Long Sequences: For handling extremely long sequences where the quadratic complexity of standard attention becomes prohibitive even with FlashAttention, alternative architectures are being explored. State-Space Models (SSMs) like Mamba, or certain recurrent architectures, may rely less on traditional GEMM and more on operations like parallel scans or structured matrix-vector multiplications [User Material 6]. Even within attention, techniques like FlashAttention are critical adaptations for long sequences, fundamentally changing the memory access pattern compared to naive GEMM-based attention [User Material 4].

While GEMM optimization is crucial for today's prevalent dense LLMs, developers should be aware that the critical bottlenecks can shift depending on the model architecture and specific application. Profiling the model remains essential to identify where computational time is actually spent.

10. Conclusion: Your GEMM Optimization Toolkit

Accelerating Large Language Models hinges significantly on optimizing the underlying General Matrix-Matrix Multiplication (GEMM) operations that dominate their computational cost. This tutorial has journeyed from the fundamentals of GEMM and its importance in LLMs to a multi-level approach for optimization within the PyTorch ecosystem.

We began by understanding why GPUs are so effective at GEMM, leveraging their massive parallelism, high arithmetic intensity, and specialized hardware like Tensor Cores. We then explored practical optimization techniques, starting with the foundational Level 1 strategies:

  • Using lower-precision data types (FP16, BF16, TF32) to boost throughput and reduce memory usage.
  • Aligning matrix dimensions with hardware tile sizes (multiples of 8/16) to maximize utilization.
  • Leveraging torch.compile for automatic kernel fusion and optimization, often utilizing Triton under the hood.

Moving to Level 2, we examined more advanced techniques requiring specific libraries or structural awareness:

  • Employing torch.bmm to batch many small matrix multiplications into fewer, larger kernel calls.
  • Utilizing quantization (INT8/INT4) via libraries like bitsandbytes to drastically cut memory footprint and potentially speed up inference.
  • Leveraging fused attention through PyTorch's scaled_dot_product_attention, which automatically uses backends like FlashAttention to overcome memory bandwidth limitations and enable longer sequence processing.

Finally, at Level 3, we introduced Triton as a powerful tool for writing custom, high-performance GPU kernels directly in Python, enabling bespoke fusions and fine-grained control. We also looked at Liger-Kernel as an example of a library providing pre-built, optimized Triton kernels for common LLM operations beyond attention, simplifying access to advanced fusion techniques.

The core principles underlying these optimizations are consistent: reduce data movement between slow and fast memory (HBM vs SRAM/registers), maximize the utilization of specialized hardware units (like Tensor Cores), and exploit the massive parallelism offered by GPUs.

Choosing the right technique depends on the specific context, performance requirements, and development resources available. Profiling your model is crucial to identify the true bottlenecks before applying optimizations. Start with the Level 1 techniques, as they often provide substantial gains with minimal effort. If further speed is needed, explore batching, quantization (especially for inference), and ensure fused attention is active via SDPA. For cutting-edge performance or unique model components, delving into Triton or leveraging libraries like Liger may be necessary.

The following table summarizes the key techniques discussed:

Table: LLM GEMM Optimization Techniques in PyTorch

TechniqueKey PyTorch/Library ToolImpact (Speed/Memory)Ease of UseWhen to ApplyKey Considerations/Trade-offs
Lower Precision (FP16/BF16)model.half(), .to(torch.bfloat16), set_default_dtype++ Speed, + Memory★★★★★Almost always (Training & Inference)FP16: Range issues (needs scaling). BF16: Ampere+. Accuracy.
Enable TF32torch.backends.cuda.matmul.allow_tf32 = True+ Speed (for FP32)★★★★★FP32 inputs on Ampere+ GPUsMinimal accuracy loss vs FP32. No memory saving.
Align DimensionsModel architecture design (multiples of 8/16/64)+ Speed★★★☆☆During model design/selectionRequires architectural changes or padding.
Automatic Compilationtorch.compile(model)++ Speed, +/- Memory★★★★☆PyTorch 2.0+, Training & InferenceCompile time overhead. Potential graph breaks.
Batched GEMMtorch.bmm()++ Speed (for many small GEMMs)★★★★☆Batches of small matricesRequires specific input shape (B, M, K).
Quantization (INT8/INT4)bitsandbytes (load_in_8bit, Linear8bitLt)+ Speed, ++ Memory★★★★☆Primarily Inference (or QAT)Accuracy loss possible. Careful setup.
Fused Attentiontorch.nn.functional.scaled_dot_product_attention (SDPA)++ Speed, ++ Memory (Seq Len)★★★★☆PyTorch 2.0+, Attention layersAuto-uses FlashAttention/xFormers if avail.
Custom Kernels (Triton)@triton.jit, tl.* functions+++ Speed, +++ Memory (Custom)★★☆☆☆Specific bottlenecks, novel fusionsKernel development effort & tuning required.
Pre-built Kernels (Liger)liger_kernel patching APIs / AutoModel++ Speed, ++ Memory★★★★☆Supported LLM layers (RMSNorm, etc.)Library dependency. Model architecture support needed.

By systematically applying these GEMM optimization techniques, developers can significantly enhance the performance of their LLMs in PyTorch, making them faster, more efficient, and capable of tackling larger and more complex tasks.