Complete Guide to Advanced Transformer Components
Let me teach you all the advanced components used in modern Transformers step by step. We'll build from basic to cutting-edge optimizations!
1. Normalization Layers
Standard LayerNorm (Basic)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class StandardLayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-5):
super().__init__()
self.ln = nn.LayerNorm(d_model, eps=eps)
def forward(self, x):
# Normalizes each sample across features
return self.ln(x)
RMSNorm (Advanced)
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Norm - used in LLaMA, PaLM
Simpler and more stable than LayerNorm
"""
def __init__(self, d_model, eps=1e-8):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
# Calculate RMS norm
norm = x.norm(dim=-1, keepdim=True) * (x.size(-1) ** -0.5)
return self.weight * x / (norm + self.eps)
Why RMSNorm is Better:
- Simpler: No mean centering, just scaling by RMS
- Faster: Fewer operations than LayerNorm
- More stable: Better numerical properties at scale
2. Advanced Activation Functions
GELU (Standard)
class GELU(nn.Module):
def forward(self, x):
return F.gelu(x) # Gaussian Error Linear Unit
SwiGLU (State-of-the-Art)
class SwiGLU(nn.Module):
"""
Swish-Gated Linear Unit - used in PaLM, LLaMA
Combines Swish activation with gating mechanism
"""
def __init__(self, d_model, hidden_dim=None):
super().__init__()
hidden_dim = hidden_dim or int(d_model * 8/3) # ~2.67x expansion
self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, d_model, bias=False)
self.w3 = nn.Linear(d_model, hidden_dim, bias=False)
def forward(self, x):
# SwiGLU: x -> gate(x) * swish(projection(x))
return self.w2(F.silu(self.w1(x)) * self.w3(x))
GLU (Gated Linear Unit)
class GLU(nn.Module):
def __init__(self, d_model, hidden_dim):
super().__init__()
self.w1 = nn.Linear(d_model, hidden_dim)
self.w2 = nn.Linear(d_model, hidden_dim)
self.out = nn.Linear(hidden_dim, d_model)
def forward(self, x):
gate = torch.sigmoid(self.w1(x))
projected = self.w2(x)
return self.out(gate * projected)
3. Positional Encoding Variants
Absolute Sinusoidal (Original Transformer)
class AbsolutePositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
Learned Positional Embeddings
class LearnedPositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
self.pos_emb = nn.Embedding(max_len, d_model)
def forward(self, x):
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device)
return x + self.pos_emb(positions)
Basics Rotatory Position - Education Purpose and Clean Implementation
import torch
import math
def rotary_position_encoding(seq_len, d_model, theta=10000.0):
"""
Generate rotary position encoding.
Args:
seq_len: Length of the sequence
d_model: Model dimension (must be even)
theta: Base for the geometric progression (default: 10000.0)
Returns:
cos_pos: Cosine position encoding [seq_len, d_model//2]
sin_pos: Sine position encoding [seq_len, d_model//2]
"""
if d_model % 2 != 0:
raise ValueError("d_model must be even")
# Generate position indices
position = torch.arange(seq_len, dtype=torch.float).unsqueeze(1) # [seq_len, 1]
# Generate dimension indices
dim_idx = torch.arange(0, d_model, 2, dtype=torch.float) # [d_model//2]
# Calculate frequencies
freqs = 1.0 / (theta ** (dim_idx / d_model)) # [d_model//2]
# Calculate angles
angles = position * freqs # [seq_len, d_model//2]
# Generate cos and sin encodings
cos_pos = torch.cos(angles) # [seq_len, d_model//2]
sin_pos = torch.sin(angles) # [seq_len, d_model//2]
return cos_pos, sin_pos
def apply_rotary_encoding(x, cos_pos, sin_pos):
"""
Apply rotary position encoding to input tensor.
Args:
x: Input tensor [batch_size, seq_len, d_model]
cos_pos: Cosine position encoding [seq_len, d_model//2]
sin_pos: Sine position encoding [seq_len, d_model//2]
Returns:
Tensor with rotary position encoding applied
"""
# Split x into even and odd dimensions
x_even = x[..., 0::2] # [batch_size, seq_len, d_model//2]
x_odd = x[..., 1::2] # [batch_size, seq_len, d_model//2]
# Apply rotation
x_rotated_even = x_even * cos_pos - x_odd * sin_pos
x_rotated_odd = x_even * sin_pos + x_odd * cos_pos
# Interleave back
x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
x_rotated = x_rotated.flatten(-2)
return x_rotated
Rotary Position Embedding (RoPE) - SOTA
class RotaryPositionEmbedding(nn.Module):
"""
Rotary Position Embedding - used in GPT-NeoX, LLaMA
Encodes relative position by rotating Q and K vectors
"""
def __init__(self, dim, max_seq_len=32768):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.max_seq_len = max_seq_len
def forward(self, seq_len, device):
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb[None, :, :]
def apply_rotary_pos_emb(q, k, cos, sin):
"""Apply rotary position embedding to query and key tensors"""
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
return torch.cat((-x2, x1), dim=-1)
cos, sin = cos[:, :q.shape[1], :], sin[:, :q.shape[1], :]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
ALiBi (Attention with Linear Biases)
class ALiBiPositionalBias(nn.Module):
"""
Attention with Linear Biases - used in BLOOM
Adds learned linear bias to attention scores
"""
def __init__(self, num_heads):
super().__init__()
self.num_heads = num_heads
# Create slopes for each head
slopes = torch.tensor([2**(-8/num_heads * i) for i in range(num_heads)])
self.register_buffer('slopes', slopes)
def get_bias(self, seq_len):
# Create distance matrix
distances = torch.arange(seq_len)[:, None] - torch.arange(seq_len)[None, :]
bias = distances * self.slopes.view(-1, 1, 1)
return bias.unsqueeze(0) # [1, num_heads, seq_len, seq_len]
4. Advanced Attention Mechanisms
Multi-Head Attention with RoPE
class MultiHeadAttentionRoPE(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.rope = RotaryPositionEmbedding(self.head_dim)
self.scale = self.head_dim ** -0.5
def forward(self, x):
B, L, D = x.shape
q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
# Apply RoPE
cos_sin = self.rope(L, x.device)
cos, sin = cos_sin.chunk(2, dim=-1)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# Standard attention
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn_probs = torch.softmax(attn_scores, dim=-1)
attn_out = torch.matmul(attn_probs, v)
attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)
return self.out_proj(attn_out)
Grouped Query Attention (GQA)
class GroupedQueryAttention(nn.Module):
"""
Grouped Query Attention - used in LLaMA 2
Reduces memory usage by sharing K,V across query groups
"""
def __init__(self, d_model, num_heads, num_kv_heads=None):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads or num_heads
self.head_dim = d_model // num_heads
self.kv_head_dim = d_model // self.num_kv_heads
self.group_size = num_heads // self.num_kv_heads
self.q_proj = nn.Linear(d_model, num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(d_model, self.num_kv_heads * self.kv_head_dim, bias=False)
self.v_proj = nn.Linear(d_model, self.num_kv_heads * self.kv_head_dim, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.scale = self.head_dim ** -0.5
def forward(self, x):
B, L, D = x.shape
q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, L, self.num_kv_heads, self.kv_head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, L, self.num_kv_heads, self.kv_head_dim).transpose(1, 2)
# Repeat K,V for each group
k = k.repeat_interleave(self.group_size, dim=1)
v = v.repeat_interleave(self.group_size, dim=1)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn_probs = torch.softmax(attn_scores, dim=-1)
attn_out = torch.matmul(attn_probs, v)
attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)
return self.out_proj(attn_out)
Flash Attention Integration
class FlashMultiHeadAttention(nn.Module):
"""
Flash Attention wrapper - memory efficient attention
"""
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
try:
from flash_attn import flash_attn_func
self.flash_attn_func = flash_attn_func
self.has_flash = True
except ImportError:
self.has_flash = False
def forward(self, x):
B, L, D = x.shape
qkv = self.qkv_proj(x).view(B, L, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(2)
if self.has_flash and self.training:
# Use Flash Attention
attn_out = self.flash_attn_func(q, k, v)
else:
# Fallback to standard attention
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn_probs = torch.softmax(attn_scores, dim=-1)
attn_out = torch.matmul(attn_probs, v).transpose(1, 2)
attn_out = attn_out.contiguous().view(B, L, D)
return self.out_proj(attn_out)
5. Advanced Feed-Forward Networks
Standard FFN
class StandardFFN(nn.Module):
def __init__(self, d_model, hidden_dim=None, dropout=0.1):
super().__init__()
hidden_dim = hidden_dim or 4 * d_model
self.w1 = nn.Linear(d_model, hidden_dim)
self.w2 = nn.Linear(hidden_dim, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w2(self.dropout(F.gelu(self.w1(x))))
SwiGLU FFN (SOTA)
class SwiGLUFFN(nn.Module):
def __init__(self, d_model, hidden_dim=None, dropout=0.1):
super().__init__()
hidden_dim = hidden_dim or int(d_model * 8/3)
self.swiglu = SwiGLU(d_model, hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.swiglu(x))
6. Complete Advanced Transformer Layer
class AdvancedTransformerLayer(nn.Module):
"""
State-of-the-art Transformer layer with all advanced components
"""
def __init__(self, d_model, num_heads, num_kv_heads=None, dropout=0.1):
super().__init__()
# Advanced normalization
self.attn_norm = RMSNorm(d_model)
self.ffn_norm = RMSNorm(d_model)
# Advanced attention with RoPE and GQA
self.attn = GroupedQueryAttention(d_model, num_heads, num_kv_heads)
# Advanced feed-forward with SwiGLU
self.ffn = SwiGLUFFN(d_model, dropout=dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Pre-norm attention block
residual = x
x = self.attn_norm(x)
x = self.attn(x)
x = self.dropout(x) + residual
# Pre-norm FFN block
residual = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + residual
return x
7. Model Initialization and Optimization
Advanced Weight Initialization
def init_weights(module):
"""Advanced weight initialization for better training"""
if isinstance(module, nn.Linear):
# Xavier/Glorot initialization scaled for transformer depth
std = (2.0 / (module.in_features + module.out_features)) ** 0.5
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
Gradient Clipping and Scaling
class GradientClipper:
def __init__(self, max_norm=1.0):
self.max_norm = max_norm
def clip_gradients(self, model):
return torch.nn.utils.clip_grad_norm_(model.parameters(), self.max_norm)
Key Improvements Over Standard Transformers
- RMSNorm: 15-20% faster, more stable training
- SwiGLU: Better performance than GELU/ReLU
- RoPE: Better length extrapolation than absolute position
- GQA: 50% memory reduction with minimal quality loss
- Flash Attention: 2-4x memory reduction, faster training
These components are what separate research code from production models like LLaMA, PaLM, and GPT-4. Now you're ready to understand why these same optimizations matter for Hyena!