programming
Transformer Model

Complete Guide to Advanced Transformer Components

Let me teach you all the advanced components used in modern Transformers step by step. We'll build from basic to cutting-edge optimizations!

1. Normalization Layers

Standard LayerNorm (Basic)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class StandardLayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5):
        super().__init__()
        self.ln = nn.LayerNorm(d_model, eps=eps)
 
    def forward(self, x):
        # Normalizes each sample across features
        return self.ln(x)

RMSNorm (Advanced)

class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Norm - used in LLaMA, PaLM
    Simpler and more stable than LayerNorm
    """
    def __init__(self, d_model, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))
 
    def forward(self, x):
        # Calculate RMS norm
        norm = x.norm(dim=-1, keepdim=True) * (x.size(-1) ** -0.5)
        return self.weight * x / (norm + self.eps)

Why RMSNorm is Better:

  • Simpler: No mean centering, just scaling by RMS
  • Faster: Fewer operations than LayerNorm
  • More stable: Better numerical properties at scale

2. Advanced Activation Functions

GELU (Standard)

class GELU(nn.Module):
    def forward(self, x):
        return F.gelu(x)  # Gaussian Error Linear Unit

SwiGLU (State-of-the-Art)

class SwiGLU(nn.Module):
    """
    Swish-Gated Linear Unit - used in PaLM, LLaMA
    Combines Swish activation with gating mechanism
    """
    def __init__(self, d_model, hidden_dim=None):
        super().__init__()
        hidden_dim = hidden_dim or int(d_model * 8/3)  # ~2.67x expansion
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, d_model, bias=False)
        self.w3 = nn.Linear(d_model, hidden_dim, bias=False)
 
    def forward(self, x):
        # SwiGLU: x -> gate(x) * swish(projection(x))
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

GLU (Gated Linear Unit)

class GLU(nn.Module):
    def __init__(self, d_model, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(d_model, hidden_dim)
        self.w2 = nn.Linear(d_model, hidden_dim)
        self.out = nn.Linear(hidden_dim, d_model)
 
    def forward(self, x):
        gate = torch.sigmoid(self.w1(x))
        projected = self.w2(x)
        return self.out(gate * projected)

3. Positional Encoding Variants

Absolute Sinusoidal (Original Transformer)

class AbsolutePositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * 
                           (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
 
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

Learned Positional Embeddings

class LearnedPositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.pos_emb = nn.Embedding(max_len, d_model)
 
    def forward(self, x):
        seq_len = x.size(1)
        positions = torch.arange(seq_len, device=x.device)
        return x + self.pos_emb(positions)

Basics Rotatory Position - Education Purpose and Clean Implementation

import torch
import math
 
def rotary_position_encoding(seq_len, d_model, theta=10000.0):
    """
    Generate rotary position encoding.
    
    Args:
        seq_len: Length of the sequence
        d_model: Model dimension (must be even)
        theta: Base for the geometric progression (default: 10000.0)
    
    Returns:
        cos_pos: Cosine position encoding [seq_len, d_model//2]
        sin_pos: Sine position encoding [seq_len, d_model//2]
    """
    if d_model % 2 != 0:
        raise ValueError("d_model must be even")
    
    # Generate position indices
    position = torch.arange(seq_len, dtype=torch.float).unsqueeze(1)  # [seq_len, 1]
    
    # Generate dimension indices
    dim_idx = torch.arange(0, d_model, 2, dtype=torch.float)  # [d_model//2]
    
    # Calculate frequencies
    freqs = 1.0 / (theta ** (dim_idx / d_model))  # [d_model//2]
    
    # Calculate angles
    angles = position * freqs  # [seq_len, d_model//2]
    
    # Generate cos and sin encodings
    cos_pos = torch.cos(angles)  # [seq_len, d_model//2]
    sin_pos = torch.sin(angles)  # [seq_len, d_model//2]
    
    return cos_pos, sin_pos
 
def apply_rotary_encoding(x, cos_pos, sin_pos):
    """
    Apply rotary position encoding to input tensor.
    
    Args:
        x: Input tensor [batch_size, seq_len, d_model]
        cos_pos: Cosine position encoding [seq_len, d_model//2]
        sin_pos: Sine position encoding [seq_len, d_model//2]
    
    Returns:
        Tensor with rotary position encoding applied
    """
    # Split x into even and odd dimensions
    x_even = x[..., 0::2]  # [batch_size, seq_len, d_model//2]
    x_odd = x[..., 1::2]   # [batch_size, seq_len, d_model//2]
    
    # Apply rotation
    x_rotated_even = x_even * cos_pos - x_odd * sin_pos
    x_rotated_odd = x_even * sin_pos + x_odd * cos_pos
    
    # Interleave back
    x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
    x_rotated = x_rotated.flatten(-2)
    
    return x_rotated

Rotary Position Embedding (RoPE) - SOTA

class RotaryPositionEmbedding(nn.Module):
    """
    Rotary Position Embedding - used in GPT-NeoX, LLaMA
    Encodes relative position by rotating Q and K vectors
    """
    def __init__(self, dim, max_seq_len=32768):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.max_seq_len = max_seq_len
 
    def forward(self, seq_len, device):
        t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb[None, :, :]
 
def apply_rotary_pos_emb(q, k, cos, sin):
    """Apply rotary position embedding to query and key tensors"""
    def rotate_half(x):
        x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
        return torch.cat((-x2, x1), dim=-1)
    
    cos, sin = cos[:, :q.shape[1], :], sin[:, :q.shape[1], :]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

ALiBi (Attention with Linear Biases)

class ALiBiPositionalBias(nn.Module):
    """
    Attention with Linear Biases - used in BLOOM
    Adds learned linear bias to attention scores
    """
    def __init__(self, num_heads):
        super().__init__()
        self.num_heads = num_heads
        # Create slopes for each head
        slopes = torch.tensor([2**(-8/num_heads * i) for i in range(num_heads)])
        self.register_buffer('slopes', slopes)
 
    def get_bias(self, seq_len):
        # Create distance matrix
        distances = torch.arange(seq_len)[:, None] - torch.arange(seq_len)[None, :]
        bias = distances * self.slopes.view(-1, 1, 1)
        return bias.unsqueeze(0)  # [1, num_heads, seq_len, seq_len]

4. Advanced Attention Mechanisms

Multi-Head Attention with RoPE

class MultiHeadAttentionRoPE(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        
        self.rope = RotaryPositionEmbedding(self.head_dim)
        self.scale = self.head_dim ** -0.5
 
    def forward(self, x):
        B, L, D = x.shape
        
        q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Apply RoPE
        cos_sin = self.rope(L, x.device)
        cos, sin = cos_sin.chunk(2, dim=-1)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        
        # Standard attention
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn_probs = torch.softmax(attn_scores, dim=-1)
        attn_out = torch.matmul(attn_probs, v)
        
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)
        return self.out_proj(attn_out)

Grouped Query Attention (GQA)

class GroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention - used in LLaMA 2
    Reduces memory usage by sharing K,V across query groups
    """
    def __init__(self, d_model, num_heads, num_kv_heads=None):
        super().__init__()
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads or num_heads
        self.head_dim = d_model // num_heads
        self.kv_head_dim = d_model // self.num_kv_heads
        self.group_size = num_heads // self.num_kv_heads
        
        self.q_proj = nn.Linear(d_model, num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(d_model, self.num_kv_heads * self.kv_head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, self.num_kv_heads * self.kv_head_dim, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        
        self.scale = self.head_dim ** -0.5
 
    def forward(self, x):
        B, L, D = x.shape
        
        q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, L, self.num_kv_heads, self.kv_head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, L, self.num_kv_heads, self.kv_head_dim).transpose(1, 2)
        
        # Repeat K,V for each group
        k = k.repeat_interleave(self.group_size, dim=1)
        v = v.repeat_interleave(self.group_size, dim=1)
        
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn_probs = torch.softmax(attn_scores, dim=-1)
        attn_out = torch.matmul(attn_probs, v)
        
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)
        return self.out_proj(attn_out)

Flash Attention Integration

class FlashMultiHeadAttention(nn.Module):
    """
    Flash Attention wrapper - memory efficient attention
    """
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        
        try:
            from flash_attn import flash_attn_func
            self.flash_attn_func = flash_attn_func
            self.has_flash = True
        except ImportError:
            self.has_flash = False
 
    def forward(self, x):
        B, L, D = x.shape
        qkv = self.qkv_proj(x).view(B, L, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(2)
        
        if self.has_flash and self.training:
            # Use Flash Attention
            attn_out = self.flash_attn_func(q, k, v)
        else:
            # Fallback to standard attention
            q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
            attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
            attn_probs = torch.softmax(attn_scores, dim=-1)
            attn_out = torch.matmul(attn_probs, v).transpose(1, 2)
        
        attn_out = attn_out.contiguous().view(B, L, D)
        return self.out_proj(attn_out)

5. Advanced Feed-Forward Networks

Standard FFN

class StandardFFN(nn.Module):
    def __init__(self, d_model, hidden_dim=None, dropout=0.1):
        super().__init__()
        hidden_dim = hidden_dim or 4 * d_model
        self.w1 = nn.Linear(d_model, hidden_dim)
        self.w2 = nn.Linear(hidden_dim, d_model)
        self.dropout = nn.Dropout(dropout)
 
    def forward(self, x):
        return self.w2(self.dropout(F.gelu(self.w1(x))))

SwiGLU FFN (SOTA)

class SwiGLUFFN(nn.Module):
    def __init__(self, d_model, hidden_dim=None, dropout=0.1):
        super().__init__()
        hidden_dim = hidden_dim or int(d_model * 8/3)
        self.swiglu = SwiGLU(d_model, hidden_dim)
        self.dropout = nn.Dropout(dropout)
 
    def forward(self, x):
        return self.dropout(self.swiglu(x))

6. Complete Advanced Transformer Layer

class AdvancedTransformerLayer(nn.Module):
    """
    State-of-the-art Transformer layer with all advanced components
    """
    def __init__(self, d_model, num_heads, num_kv_heads=None, dropout=0.1):
        super().__init__()
        
        # Advanced normalization
        self.attn_norm = RMSNorm(d_model)
        self.ffn_norm = RMSNorm(d_model)
        
        # Advanced attention with RoPE and GQA
        self.attn = GroupedQueryAttention(d_model, num_heads, num_kv_heads)
        
        # Advanced feed-forward with SwiGLU
        self.ffn = SwiGLUFFN(d_model, dropout=dropout)
        
        self.dropout = nn.Dropout(dropout)
 
    def forward(self, x):
        # Pre-norm attention block
        residual = x
        x = self.attn_norm(x)
        x = self.attn(x)
        x = self.dropout(x) + residual
        
        # Pre-norm FFN block
        residual = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + residual
        
        return x

7. Model Initialization and Optimization

Advanced Weight Initialization

def init_weights(module):
    """Advanced weight initialization for better training"""
    if isinstance(module, nn.Linear):
        # Xavier/Glorot initialization scaled for transformer depth
        std = (2.0 / (module.in_features + module.out_features)) ** 0.5
        torch.nn.init.normal_(module.weight, mean=0.0, std=std)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

Gradient Clipping and Scaling

class GradientClipper:
    def __init__(self, max_norm=1.0):
        self.max_norm = max_norm
    
    def clip_gradients(self, model):
        return torch.nn.utils.clip_grad_norm_(model.parameters(), self.max_norm)

Key Improvements Over Standard Transformers

  1. RMSNorm: 15-20% faster, more stable training
  2. SwiGLU: Better performance than GELU/ReLU
  3. RoPE: Better length extrapolation than absolute position
  4. GQA: 50% memory reduction with minimal quality loss
  5. Flash Attention: 2-4x memory reduction, faster training

These components are what separate research code from production models like LLaMA, PaLM, and GPT-4. Now you're ready to understand why these same optimizations matter for Hyena!

[1] https://arxiv.org/abs/2306.16524 (opens in a new tab)