The Hyena Operator: Deep Dive into Advanced Components
Now that you understand advanced Transformer components, let me teach you the complete Hyena operator with all its sophisticated details. The Hyena operator is what makes StripedHyena competitive with attention while being much more efficient.
1. The Core Hyena Philosophy
Three Key Properties from Attention
The Hyena operator was designed to replicate attention's three crucial properties:[6]
- Data Control: Linear operator controlled by input data
- Sublinear Parameter Scaling: Parameters independent of sequence length
- Unrestricted Context: Can model dependencies between any positions
2. Mathematical Foundation
The Hyena Recurrence
# Mathematical definition:
# yā = v (value projection)
# For i = 1 to N:
# y_i = (y_{i-1} * h_i) ā x_i
# Where * = convolution, ā = element-wise multiplication
Matrix Form Equivalence
Hyena can be expressed as: y = H(u) * v
where H(u) is a data-controlled matrix[6]
3. Complete Hyena Operator Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.fft import fft, ifft
import math
class PositionalEncoding(nn.Module):
"""Advanced positional encoding for Hyena filters"""
def __init__(self, d_model, max_len=32768):
super().__init__()
self.d_model = d_model
# Exponential decay encoding
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
# Multiple frequency components
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# Add exponential decay component
decay_term = torch.exp(-position / max_len * 4) # Learnable decay
pe = pe * decay_term
self.register_buffer('pe', pe)
def forward(self, seq_len):
return self.pe[:seq_len]
class ImplicitFilterNetwork(nn.Module):
"""
Advanced implicit filter generation with multiple techniques:
- Positional encoding input
- Windowing function
- Multiple filter types
"""
def __init__(self, d_model, seq_len, order, hidden_dim=64):
super().__init__()
self.d_model = d_model
self.seq_len = seq_len
self.order = order
# Enhanced positional encoding
self.pos_encoder = PositionalEncoding(hidden_dim//2, seq_len)
# Separate MLPs for each filter in the recurrence
self.filter_networks = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim//2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, d_model),
nn.Tanh() # Bound the filter values
) for _ in range(order)
])
# Learnable window function parameters
self.window_params = nn.Parameter(torch.ones(order, 2)) # (alpha, beta) per filter
def apply_window(self, filter_vals, filter_idx, seq_len):
"""Apply learnable window function to filter"""
alpha, beta = self.window_params[filter_idx]
# Create window positions
positions = torch.linspace(0, 1, seq_len, device=filter_vals.device)
# Exponential window: alpha * exp(-beta * position)
window = alpha * torch.exp(-beta * positions).unsqueeze(-1)
return filter_vals * window
def forward(self, seq_len):
# Get positional encoding
pos_emb = self.pos_encoder(seq_len) # [seq_len, hidden_dim//2]
filters = []
for i, network in enumerate(self.filter_networks):
# Generate raw filter
filter_vals = network(pos_emb) # [seq_len, d_model]
# Apply window function
filter_vals = self.apply_window(filter_vals, i, seq_len)
filters.append(filter_vals.transpose(0, 1)) # [d_model, seq_len]
return filters
class AdvancedFFTConvolution(nn.Module):
"""
Optimized FFT convolution with stability improvements
"""
def __init__(self):
super().__init__()
def forward(self, x, filter_weights):
batch, seq_len, d_model = x.shape
# Optimal FFT length (power of 2 for efficiency)
fft_len = 1 [batch, d_model, seq_len]
x = x.transpose(1, 2)
if self.conv_type == 'separable':
x = self.depthwise(x)
x = self.pointwise(x)
else:
x = self.conv(x)
x = self.activation(x)
return x.transpose(1, 2) # Back to [batch, seq_len, d_model]
class HyenaOperator(nn.Module):
"""
Complete advanced Hyena operator with all optimizations
"""
def __init__(self, d_model, seq_len, order=2, hidden_dim=64,
conv_type='depthwise', dropout=0.1):
super().__init__()
self.d_model = d_model
self.seq_len = seq_len
self.order = order
# Input projections (N+1 branches)
self.input_projection = nn.Linear(d_model, (order + 1) * d_model, bias=False)
# Short convolutions for each branch
self.short_convs = nn.ModuleList([
ShortConvolution(d_model, conv_type=conv_type)
for _ in range(order + 1)
])
# Implicit filter network
self.filter_network = ImplicitFilterNetwork(d_model, seq_len, order, hidden_dim)
# FFT convolution
self.fft_conv = AdvancedFFTConvolution()
# Normalization and dropout
self.layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
batch, seq_len, d_model = x.shape
# Step 1: Input projections
projections = self.input_projection(x) # [batch, seq_len, (order+1)*d_model]
projections = projections.view(batch, seq_len, self.order + 1, d_model)
# Step 2: Short convolutions on each branch
branches = []
for i in range(self.order + 1):
branch = projections[:, :, i, :] # [batch, seq_len, d_model]
branch = self.short_convs[i](branch)
branch = self.layer_norm(branch) # Normalize each branch
branches.append(branch)
# Step 3: Generate implicit filters
filters = self.filter_network(seq_len) # List of [d_model, seq_len] tensors
# Step 4: Hyena recurrence
y = branches[-1] # Start with "value" branch
for i in range(self.order):
# Long convolution via FFT
y = self.fft_conv(y, filters[i])
# Element-wise gating (data-dependent control)
gate = branches[i]
y = y * gate
# Residual connection and dropout
if i > 0: # Skip connection after first iteration
y = y + branches[-1] * 0.1 # Scaled residual
y = self.dropout(y)
return y
class HyenaLayer(nn.Module):
"""
Complete Hyena layer with pre-norm and feed-forward
"""
def __init__(self, d_model, seq_len, order=2, dropout=0.1):
super().__init__()
# Use RMSNorm for better performance
self.pre_norm = RMSNorm(d_model)
self.post_norm = RMSNorm(d_model)
# Hyena operator
self.hyena = HyenaOperator(d_model, seq_len, order, dropout=dropout)
# SwiGLU feed-forward
self.ffn = SwiGLU(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Hyena block with pre-norm
residual = x
x = self.pre_norm(x)
x = self.hyena(x)
x = self.dropout(x) + residual
# Feed-forward block with pre-norm
residual = x
x = self.post_norm(x)
x = self.ffn(x)
x = x + residual
return x
# From our previous advanced Transformer components
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-8):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
norm = x.norm(dim=-1, keepdim=True) * (x.size(-1) ** -0.5)
return self.weight * x / (norm + self.eps)
class SwiGLU(nn.Module):
def __init__(self, d_model, hidden_dim=None):
super().__init__()
hidden_dim = hidden_dim or int(d_model * 8/3)
self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, d_model, bias=False)
self.w3 = nn.Linear(d_model, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
4. Advanced Features Explained
A. Enhanced Positional Encoding
- Multiple frequencies: Like sinusoidal but with decay
- Learnable components: Window parameters adapt during training
- Exponential decay: Emphasizes local vs. global patterns
B. Windowing Functions
- Purpose: Shape filter response (like audio DSP)
- Learnable: α and β parameters trained end-to-end
- Effect: Controls filter locality vs. globality
C. Numerical Stability
- Complex FFT: Better precision than real-only
- Power-of-2 FFT lengths: Hardware optimization
- Proper padding: Avoids circular convolution artifacts
D. Advanced Gating
- Residual connections: Gradient flow improvement
- Dropout: Regularization within recurrence
- Scaled connections: Balanced information flow
5. Training Optimizations
Special Parameter Treatment
def configure_optimizers(model):
"""Special treatment for Hyena parameters"""
# Standard parameters
standard_params = []
# Special Hyena parameters (no weight decay)
hyena_params = []
for name, param in model.named_parameters():
if 'window_params' in name or 'filter_network' in name:
hyena_params.append(param)
else:
standard_params.append(param)
optimizer = torch.optim.AdamW([
{'params': standard_params, 'weight_decay': 0.1},
{'params': hyena_params, 'weight_decay': 0.0} # No decay for filters!
], lr=1e-4)
return optimizer
6. Key Insights
What Makes This Advanced:
- Implicit Parameterization: Filters generated by MLPs, not stored directly
- Hierarchical Processing: Each recurrence step adds complexity
- Data-Controlled Gating: Input determines information flow
- Efficient Implementation: FFT makes long convolutions tractable
- Hybrid Design: Combines best of CNNs and attention
Performance Characteristics:
- Time Complexity: O(N log N) vs O(N²) for attention
- Memory: O(N) vs O(N²) for attention
- Quality: Competitive with attention on language tasks[2][6]
- Speed: 2x faster at 8K, 100x faster at 64K sequences[10]
This advanced Hyena operator represents the state-of-the-art in efficient sequence modeling - it's what powers both StripedHyena language models and Evo biological foundation models!
[1] https://github.com/togethercomputer/stripedhyena (opens in a new tab) [2] https://www.together.ai/blog/stripedhyena-7b (opens in a new tab) [3] https://github.com/NousResearch/StripedHyenaTrainer (opens in a new tab) [4] https://patmcguinness.substack.com/p/beyond-transformers-with-mamba (opens in a new tab) [5] https://www.marktechpost.com/2023/12/13/together-ai-introduces-stripedhyena-7b-an-alternative-artificial-intelligence-model-competitive-with-the-best-open-source-transformers-in-short-and-long-context-evaluations/ (opens in a new tab) [6] https://proceedings.mlr.press/v202/poli23a/poli23a.pdf (opens in a new tab) [7] https://www.biorxiv.org/content/10.1101/2024.02.27.582234v1.full (opens in a new tab) [8] https://arxiv.org/pdf/2403.17844.pdf (opens in a new tab) [9] https://ermongroup.github.io/blog/hyena/ (opens in a new tab) [10] https://arxiv.org/abs/2302.10866 (opens in a new tab)