research
Vae Pretraining

VAE Data Diagnostics Guide for ShapeVAE

I was working on ShapeVAE from Hunyuan-3D for a Rutgers project and my model kept crashing showing posterior collapse and bad KL-divergence. Before feeding the data, make sure this stuff sounds good.

Great question! Yes, you should absolutely perform data diagnostics before training a VAE. Here are the essential tests to ensure your data will work well with VAE training:

Essential Pre-Training Data Tests

1. Data Normalization Check (CRITICAL)

Why: VAEs assume data is roughly zero-centered with unit variance. Unnormalized data causes training instability.

import numpy as np
import torch
 
# Load your mesh dataset
sdfs = []  # Your SDF samples
for mesh in dataset:
    sdf = sample_sdf(mesh, resolution=128)
    sdfs.append(sdf)
 
sdfs = np.array(sdfs)  # Shape: [N, num_points]
 
# Check statistics
mean = sdfs.mean()
std = sdfs.std()
min_val = sdfs.min()
max_val = sdfs.max()
 
print(f"Mean: {mean:.4f}")
print(f"Std: {std:.4f}")
print(f"Min: {min_val:.4f}")
print(f"Max: {max_val:.4f}")
 
# ✅ GOOD: mean ≈ 0, std ≈ 1
# ❌ BAD: mean = 50, std = 200 (needs normalization)

What you want:

  • Mean close to 0 (between -0.5 and 0.5)
  • Std close to 1 (between 0.5 and 2.0)
  • No extreme outliers (min/max within ±5 std)

If not normalized, apply this:

# Global normalization
mean = sdfs.mean()
std = sdfs.std()
sdfs_normalized = (sdfs - mean) / (std + 1e-8)
 
# Save normalization params for inference
np.save('norm_params.npy', {'mean': mean, 'std': std})

2. Distribution Shape Test

Why: VAE's Gaussian prior (N(0,1)) works best when data has roughly Gaussian-ish distribution.

import matplotlib.pyplot as plt
 
# Plot histogram of SDF values
plt.figure(figsize=(10, 4))
 
plt.subplot(1, 2, 1)
plt.hist(sdfs.flatten(), bins=100, alpha=0.7)
plt.xlabel('SDF Value')
plt.ylabel('Frequency')
plt.title('SDF Distribution')
 
# Plot Q-Q plot (compare to normal distribution)
from scipy import stats
plt.subplot(1, 2, 2)
stats.probplot(sdfs.flatten()[::100], dist="norm", plot=plt)  # Sample for speed
plt.title('Q-Q Plot (vs Normal)')
 
plt.tight_layout()
plt.savefig('data_distribution.png')
 
# Compute skewness and kurtosis
skewness = stats.skew(sdfs.flatten())
kurtosis = stats.kurtosis(sdfs.flatten())
 
print(f"Skewness: {skewness:.4f}")  # Should be close to 0
print(f"Kurtosis: {kurtosis:.4f}")  # Should be close to 0

What you want:

  • Skewness between -1 and 1 (not heavily skewed)
  • Kurtosis between -1 and 3 (not too heavy/light tails)
  • Histogram looks roughly bell-shaped (doesn't need to be perfect)

If heavily skewed (e.g., SDF has many zeros near surface):

# Apply log transform or Box-Cox
from scipy.stats import boxcox
sdfs_transformed, lambda_param = boxcox(sdfs.flatten() + 1)  # +1 to handle negatives

3. Dataset Diversity Check

Why: Low diversity → VAE will collapse to dataset mean.

# Compute pairwise distances between meshes
def chamfer_distance(mesh1, mesh2):
    # Your chamfer distance implementation
    pass
 
# Sample 100 random pairs
diversities = []
for _ in range(100):
    i, j = np.random.choice(len(dataset), 2, replace=False)
    dist = chamfer_distance(dataset[i], dataset[j])
    diversities.append(dist)
 
avg_diversity = np.mean(diversities)
std_diversity = np.std(diversities)
 
print(f"Avg pairwise distance: {avg_diversity:.4f}")
print(f"Std pairwise distance: {std_diversity:.4f}")
 
# ✅ GOOD: High avg diversity (>0.1 for normalized shapes)
# ❌ BAD: Low diversity (<0.01) → all shapes look similar

What you want:

  • Coefficient of variation > 0.3: std_diversity / avg_diversity > 0.3
  • If too low → you need more diverse data, or VAE will just memorize the mean

4. Outlier Detection

Why: Extreme outliers break VAE training (gradients explode).

# Detect outliers using IQR method
Q1 = np.percentile(sdfs, 25)
Q3 = np.percentile(sdfs, 75)
IQR = Q3 - Q1
 
lower_bound = Q1 - 3 * IQR
upper_bound = Q3 + 3 * IQR
 
outlier_mask = (sdfs < lower_bound) | (sdfs > upper_bound)
outlier_percentage = outlier_mask.sum() / sdfs.size * 100
 
print(f"Outliers: {outlier_percentage:.2f}%")
 
if outlier_percentage > 5:
    print("⚠️ WARNING: Too many outliers! Consider:")
    print("  1. Clipping: sdfs = np.clip(sdfs, lower_bound, upper_bound)")
    print("  2. Remove bad samples")
    print("  3. Check data preprocessing pipeline")

What you want:

  • Less than 5% outliers
  • If greater than 10% outliers → something wrong with data collection

5. Variance Per Dimension

Why: If some dimensions have zero variance, VAE wastes capacity encoding nothing.

# Compute variance per dimension (e.g., per point in point cloud)
per_dim_variance = sdfs.var(axis=0)  # Variance across dataset
 
# Count low-variance dimensions
low_var_dims = (per_dim_variance < 0.01).sum()
total_dims = len(per_dim_variance)
 
print(f"Low variance dims: {low_var_dims}/{total_dims} ({low_var_dims/total_dims*100:.1f}%)")
 
# ✅ GOOD: <10% low-variance dimensions
# ❌ BAD: >30% low-variance → consider dimensionality reduction

6. Class/Category Balance (If Multi-Class)

Why: Imbalanced classes → VAE learns dominant class only.

# If your dataset has categories (e.g., chairs, tables, cars)
from collections import Counter
 
categories = [sample['category'] for sample in dataset]
category_counts = Counter(categories)
 
print("Category distribution:")
for cat, count in category_counts.most_common():
    print(f"  {cat}: {count} ({count/len(dataset)*100:.1f}%)")
 
# Check imbalance ratio
max_count = max(category_counts.values())
min_count = min(category_counts.values())
imbalance_ratio = max_count / min_count
 
print(f"Imbalance ratio: {imbalance_ratio:.2f}")
 
# ✅ GOOD: ratio < 10
# ⚠️ MEDIUM: ratio 10-50 (consider oversampling)
# ❌ BAD: ratio > 50 (need rebalancing)

If imbalanced:

# Oversample minority classes
from torch.utils.data import WeightedRandomSampler
 
# Create sample weights (inverse frequency)
class_weights = {cat: 1.0/count for cat, count in category_counts.items()}
sample_weights = [class_weights[sample['category']] for sample in dataset]
 
sampler = WeightedRandomSampler(sample_weights, len(dataset))
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

7. Reconstruction Baseline (Pre-VAE Sanity Check)

Why: If simple autoencoder fails, VAE will definitely fail.

# Train simple autoencoder (no KL term) for 10 epochs
class SimpleAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = YourEncoder()
        self.decoder = YourDecoder()
 
    def forward(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        return recon
 
ae = SimpleAE()
optimizer = torch.optim.Adam(ae.parameters(), lr=1e-4)
 
for epoch in range(10):
    for batch in dataloader:
        recon = ae(batch)
        loss = F.mse_loss(recon, batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
 
# ✅ GOOD: Loss decreases smoothly to <0.01
# ❌ BAD: Loss stuck above 0.1 → encoder/decoder architecture problem

Comprehensive Data Diagnostic Script

import numpy as np
import torch
from scipy import stats
from collections import Counter
import matplotlib.pyplot as plt
 
def diagnose_vae_data(dataset, sample_fn=None):
    """
    Comprehensive VAE data diagnostic
 
    Args:
        dataset: Your dataset object
        sample_fn: Function to convert mesh → SDF array
    """
    print("=" * 50)
    print("VAE DATA DIAGNOSTIC REPORT")
    print("=" * 50)
 
    # 1. Load and flatten data
    all_samples = []
    for i, sample in enumerate(dataset):
        if sample_fn:
            data = sample_fn(sample)
        else:
            data = sample  # Assume already processed
        all_samples.append(data.flatten())
 
        if i >= 1000:  # Sample first 1000 for speed
            break
 
    all_samples = np.array(all_samples)
    flat_data = all_samples.flatten()
 
    # 2. Basic statistics
    print("\n📊 BASIC STATISTICS")
    print(f"  Dataset size: {len(all_samples)}")
    print(f"  Sample shape: {all_samples[0].shape}")
    print(f"  Mean: {flat_data.mean():.4f}")
    print(f"  Std: {flat_data.std():.4f}")
    print(f"  Min: {flat_data.min():.4f}")
    print(f"  Max: {flat_data.max():.4f}")
 
    # Check normalization
    if abs(flat_data.mean()) > 0.5 or abs(flat_data.std() - 1.0) > 0.5:
        print("  ⚠️ WARNING: Data not normalized! Apply z-score normalization.")
    else:
        print("  ✅ Data normalization looks good")
 
    # 3. Distribution check
    print("\n📈 DISTRIBUTION ANALYSIS")
    skewness = stats.skew(flat_data)
    kurtosis = stats.kurtosis(flat_data)
    print(f"  Skewness: {skewness:.4f}")
    print(f"  Kurtosis: {kurtosis:.4f}")
 
    if abs(skewness) > 1.0:
        print("  ⚠️ WARNING: Data is skewed. Consider transformation.")
    if abs(kurtosis) > 3.0:
        print("  ⚠️ WARNING: Heavy-tailed distribution.")
    if abs(skewness) < 1.0 and abs(kurtosis) < 3.0:
        print("  ✅ Distribution shape looks reasonable")
 
    # 4. Outlier detection
    print("\n🔍 OUTLIER DETECTION")
    Q1, Q3 = np.percentile(flat_data, [25, 75])
    IQR = Q3 - Q1
    outliers = ((flat_data < Q1 - 3*IQR) | (flat_data > Q3 + 3*IQR)).sum()
    outlier_pct = outliers / len(flat_data) * 100
    print(f"  Outliers: {outlier_pct:.2f}%")
 
    if outlier_pct > 5:
        print("  ⚠️ WARNING: High outlier percentage!")
    else:
        print("  ✅ Outlier percentage acceptable")
 
    # 5. Variance check
    print("\n📉 VARIANCE ANALYSIS")
    per_sample_var = all_samples.var(axis=0)
    low_var_pct = (per_sample_var < 0.01).sum() / len(per_sample_var) * 100
    print(f"  Low variance dimensions: {low_var_pct:.1f}%")
 
    if low_var_pct > 30:
        print("  ⚠️ WARNING: Many low-variance dimensions!")
    else:
        print("  ✅ Variance distribution looks good")
 
    # 6. Diversity estimate
    print("\n🎨 DIVERSITY CHECK")
    sample_indices = np.random.choice(len(all_samples), min(100, len(all_samples)), replace=False)
    pairwise_dists = []
    for i in range(len(sample_indices)-1):
        dist = np.linalg.norm(all_samples[sample_indices[i]] - all_samples[sample_indices[i+1]])
        pairwise_dists.append(dist)
 
    avg_dist = np.mean(pairwise_dists)
    std_dist = np.std(pairwise_dists)
    cv = std_dist / avg_dist if avg_dist > 0 else 0
 
    print(f"  Avg pairwise distance: {avg_dist:.4f}")
    print(f"  Coefficient of variation: {cv:.4f}")
 
    if cv < 0.2:
        print("  ⚠️ WARNING: Low diversity! Dataset may be too similar.")
    else:
        print("  ✅ Dataset diversity looks good")
 
    # 7. Summary
    print("\n" + "=" * 50)
    print("RECOMMENDATION:")
 
    issues = []
    if abs(flat_data.mean()) > 0.5 or abs(flat_data.std() - 1.0) > 0.5:
        issues.append("Normalize data")
    if abs(skewness) > 1.0:
        issues.append("Apply transformation (log/Box-Cox)")
    if outlier_pct > 5:
        issues.append("Remove/clip outliers")
    if low_var_pct > 30:
        issues.append("Reduce dimensionality")
    if cv < 0.2:
        issues.append("Collect more diverse data")
 
    if not issues:
        print("✅ Data looks ready for VAE training!")
    else:
        print("⚠️ Fix these issues before training:")
        for issue in issues:
            print(f"  - {issue}")
 
    print("=" * 50)
 
    return {
        'mean': flat_data.mean(),
        'std': flat_data.std(),
        'skewness': skewness,
        'kurtosis': kurtosis,
        'outlier_pct': outlier_pct,
        'diversity_cv': cv
    }
 
# Usage
stats = diagnose_vae_data(your_dataset)

Quick Answer to Your Question

Do you NEED to do these tests?

Strictly required (will break training if not done):

  • Normalization check - CRITICAL
  • Outlier detection - CRITICAL

Highly recommended (training will be painful without):

  • Distribution shape - Important for tuning β
  • Simple AE baseline - Saves debugging time

Optional but useful:

  • Variance per dimension (helps architecture design)
  • Diversity check (helps prevent collapse)
  • Class balance (only if multi-class)

TL;DR Minimum Viable Check

If you only do ONE thing:

# Normalize your data
data_mean = dataset.mean()
data_std = dataset.std()
dataset_normalized = (dataset - data_mean) / data_std
 
print(f"Mean: {dataset_normalized.mean():.4f} (should be ~0)")
print(f"Std: {dataset_normalized.std():.4f} (should be ~1)")
 
# Save for inference
torch.save({'mean': data_mean, 'std': data_std}, 'norm_params.pt')

If mean ≈ 0 and std ≈ 1, you're 80% good to go!