programming
Advance Hyena

The Hyena Operator: Deep Dive into Advanced Components

Now that you understand advanced Transformer components, let me teach you the complete Hyena operator with all its sophisticated details. The Hyena operator is what makes StripedHyena competitive with attention while being much more efficient.

1. The Core Hyena Philosophy

Three Key Properties from Attention

The Hyena operator was designed to replicate attention's three crucial properties:[6]

  1. Data Control: Linear operator controlled by input data
  2. Sublinear Parameter Scaling: Parameters independent of sequence length
  3. Unrestricted Context: Can model dependencies between any positions

2. Mathematical Foundation

The Hyena Recurrence

# Mathematical definition:
# yā‚€ = v (value projection)
# For i = 1 to N:
#   y_i = (y_{i-1} * h_i) āŠ™ x_i
# Where * = convolution, āŠ™ = element-wise multiplication

Matrix Form Equivalence

Hyena can be expressed as: y = H(u) * v where H(u) is a data-controlled matrix[6]

3. Complete Hyena Operator Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.fft import fft, ifft
import math
 
class PositionalEncoding(nn.Module):
    """Advanced positional encoding for Hyena filters"""
    def __init__(self, d_model, max_len=32768):
        super().__init__()
        self.d_model = d_model
        
        # Exponential decay encoding
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        # Multiple frequency components
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add exponential decay component
        decay_term = torch.exp(-position / max_len * 4)  # Learnable decay
        pe = pe * decay_term
        
        self.register_buffer('pe', pe)
 
    def forward(self, seq_len):
        return self.pe[:seq_len]
 
class ImplicitFilterNetwork(nn.Module):
    """
    Advanced implicit filter generation with multiple techniques:
    - Positional encoding input
    - Windowing function
    - Multiple filter types
    """
    def __init__(self, d_model, seq_len, order, hidden_dim=64):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.order = order
        
        # Enhanced positional encoding
        self.pos_encoder = PositionalEncoding(hidden_dim//2, seq_len)
        
        # Separate MLPs for each filter in the recurrence
        self.filter_networks = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim//2, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, d_model),
                nn.Tanh()  # Bound the filter values
            ) for _ in range(order)
        ])
        
        # Learnable window function parameters
        self.window_params = nn.Parameter(torch.ones(order, 2))  # (alpha, beta) per filter
 
    def apply_window(self, filter_vals, filter_idx, seq_len):
        """Apply learnable window function to filter"""
        alpha, beta = self.window_params[filter_idx]
        
        # Create window positions
        positions = torch.linspace(0, 1, seq_len, device=filter_vals.device)
        
        # Exponential window: alpha * exp(-beta * position)
        window = alpha * torch.exp(-beta * positions).unsqueeze(-1)
        
        return filter_vals * window
 
    def forward(self, seq_len):
        # Get positional encoding
        pos_emb = self.pos_encoder(seq_len)  # [seq_len, hidden_dim//2]
        
        filters = []
        for i, network in enumerate(self.filter_networks):
            # Generate raw filter
            filter_vals = network(pos_emb)  # [seq_len, d_model]
            
            # Apply window function
            filter_vals = self.apply_window(filter_vals, i, seq_len)
            
            filters.append(filter_vals.transpose(0, 1))  # [d_model, seq_len]
        
        return filters
 
class AdvancedFFTConvolution(nn.Module):
    """
    Optimized FFT convolution with stability improvements
    """
    def __init__(self):
        super().__init__()
        
    def forward(self, x, filter_weights):
        batch, seq_len, d_model = x.shape
        
        # Optimal FFT length (power of 2 for efficiency)
        fft_len = 1  [batch, d_model, seq_len]
        x = x.transpose(1, 2)
        
        if self.conv_type == 'separable':
            x = self.depthwise(x)
            x = self.pointwise(x)
        else:
            x = self.conv(x)
        
        x = self.activation(x)
        return x.transpose(1, 2)  # Back to [batch, seq_len, d_model]
 
class HyenaOperator(nn.Module):
    """
    Complete advanced Hyena operator with all optimizations
    """
    def __init__(self, d_model, seq_len, order=2, hidden_dim=64, 
                 conv_type='depthwise', dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.order = order
        
        # Input projections (N+1 branches)
        self.input_projection = nn.Linear(d_model, (order + 1) * d_model, bias=False)
        
        # Short convolutions for each branch
        self.short_convs = nn.ModuleList([
            ShortConvolution(d_model, conv_type=conv_type)
            for _ in range(order + 1)
        ])
        
        # Implicit filter network
        self.filter_network = ImplicitFilterNetwork(d_model, seq_len, order, hidden_dim)
        
        # FFT convolution
        self.fft_conv = AdvancedFFTConvolution()
        
        # Normalization and dropout
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
 
    def forward(self, x):
        batch, seq_len, d_model = x.shape
        
        # Step 1: Input projections
        projections = self.input_projection(x)  # [batch, seq_len, (order+1)*d_model]
        projections = projections.view(batch, seq_len, self.order + 1, d_model)
        
        # Step 2: Short convolutions on each branch
        branches = []
        for i in range(self.order + 1):
            branch = projections[:, :, i, :]  # [batch, seq_len, d_model]
            branch = self.short_convs[i](branch)
            branch = self.layer_norm(branch)  # Normalize each branch
            branches.append(branch)
        
        # Step 3: Generate implicit filters
        filters = self.filter_network(seq_len)  # List of [d_model, seq_len] tensors
        
        # Step 4: Hyena recurrence
        y = branches[-1]  # Start with "value" branch
        
        for i in range(self.order):
            # Long convolution via FFT
            y = self.fft_conv(y, filters[i])
            
            # Element-wise gating (data-dependent control)
            gate = branches[i]
            y = y * gate
            
            # Residual connection and dropout
            if i > 0:  # Skip connection after first iteration
                y = y + branches[-1] * 0.1  # Scaled residual
            
            y = self.dropout(y)
        
        return y
 
class HyenaLayer(nn.Module):
    """
    Complete Hyena layer with pre-norm and feed-forward
    """
    def __init__(self, d_model, seq_len, order=2, dropout=0.1):
        super().__init__()
        
        # Use RMSNorm for better performance
        self.pre_norm = RMSNorm(d_model)
        self.post_norm = RMSNorm(d_model)
        
        # Hyena operator
        self.hyena = HyenaOperator(d_model, seq_len, order, dropout=dropout)
        
        # SwiGLU feed-forward
        self.ffn = SwiGLU(d_model)
        
        self.dropout = nn.Dropout(dropout)
 
    def forward(self, x):
        # Hyena block with pre-norm
        residual = x
        x = self.pre_norm(x)
        x = self.hyena(x)
        x = self.dropout(x) + residual
        
        # Feed-forward block with pre-norm
        residual = x
        x = self.post_norm(x)
        x = self.ffn(x)
        x = x + residual
        
        return x
 
# From our previous advanced Transformer components
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))
 
    def forward(self, x):
        norm = x.norm(dim=-1, keepdim=True) * (x.size(-1) ** -0.5)
        return self.weight * x / (norm + self.eps)
 
class SwiGLU(nn.Module):
    def __init__(self, d_model, hidden_dim=None):
        super().__init__()
        hidden_dim = hidden_dim or int(d_model * 8/3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, d_model, bias=False)
        self.w3 = nn.Linear(d_model, hidden_dim, bias=False)
 
    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

4. Advanced Features Explained

A. Enhanced Positional Encoding

  • Multiple frequencies: Like sinusoidal but with decay
  • Learnable components: Window parameters adapt during training
  • Exponential decay: Emphasizes local vs. global patterns

B. Windowing Functions

  • Purpose: Shape filter response (like audio DSP)
  • Learnable: α and β parameters trained end-to-end
  • Effect: Controls filter locality vs. globality

C. Numerical Stability

  • Complex FFT: Better precision than real-only
  • Power-of-2 FFT lengths: Hardware optimization
  • Proper padding: Avoids circular convolution artifacts

D. Advanced Gating

  • Residual connections: Gradient flow improvement
  • Dropout: Regularization within recurrence
  • Scaled connections: Balanced information flow

5. Training Optimizations

Special Parameter Treatment

def configure_optimizers(model):
    """Special treatment for Hyena parameters"""
    
    # Standard parameters
    standard_params = []
    # Special Hyena parameters (no weight decay)
    hyena_params = []
    
    for name, param in model.named_parameters():
        if 'window_params' in name or 'filter_network' in name:
            hyena_params.append(param)
        else:
            standard_params.append(param)
    
    optimizer = torch.optim.AdamW([
        {'params': standard_params, 'weight_decay': 0.1},
        {'params': hyena_params, 'weight_decay': 0.0}  # No decay for filters!
    ], lr=1e-4)
    
    return optimizer

6. Key Insights

What Makes This Advanced:

  1. Implicit Parameterization: Filters generated by MLPs, not stored directly
  2. Hierarchical Processing: Each recurrence step adds complexity
  3. Data-Controlled Gating: Input determines information flow
  4. Efficient Implementation: FFT makes long convolutions tractable
  5. Hybrid Design: Combines best of CNNs and attention

Performance Characteristics:

  • Time Complexity: O(N log N) vs O(N²) for attention
  • Memory: O(N) vs O(N²) for attention
  • Quality: Competitive with attention on language tasks[2][6]
  • Speed: 2x faster at 8K, 100x faster at 64K sequences[10]

This advanced Hyena operator represents the state-of-the-art in efficient sequence modeling - it's what powers both StripedHyena language models and Evo biological foundation models!

[1] https://github.com/togethercomputer/stripedhyena (opens in a new tab) [2] https://www.together.ai/blog/stripedhyena-7b (opens in a new tab) [3] https://github.com/NousResearch/StripedHyenaTrainer (opens in a new tab) [4] https://patmcguinness.substack.com/p/beyond-transformers-with-mamba (opens in a new tab) [5] https://www.marktechpost.com/2023/12/13/together-ai-introduces-stripedhyena-7b-an-alternative-artificial-intelligence-model-competitive-with-the-best-open-source-transformers-in-short-and-long-context-evaluations/ (opens in a new tab) [6] https://proceedings.mlr.press/v202/poli23a/poli23a.pdf (opens in a new tab) [7] https://www.biorxiv.org/content/10.1101/2024.02.27.582234v1.full (opens in a new tab) [8] https://arxiv.org/pdf/2403.17844.pdf (opens in a new tab) [9] https://ermongroup.github.io/blog/hyena/ (opens in a new tab) [10] https://arxiv.org/abs/2302.10866 (opens in a new tab)