VAE Data Diagnostics Guide for ShapeVAE
I was working on ShapeVAE from Hunyuan-3D for a Rutgers project and my model kept crashing showing posterior collapse and bad KL-divergence. Before feeding the data, make sure this stuff sounds good.
Great question! Yes, you should absolutely perform data diagnostics before training a VAE. Here are the essential tests to ensure your data will work well with VAE training:
Essential Pre-Training Data Tests
1. Data Normalization Check (CRITICAL)
Why: VAEs assume data is roughly zero-centered with unit variance. Unnormalized data causes training instability.
import numpy as np
import torch
# Load your mesh dataset
sdfs = [] # Your SDF samples
for mesh in dataset:
sdf = sample_sdf(mesh, resolution=128)
sdfs.append(sdf)
sdfs = np.array(sdfs) # Shape: [N, num_points]
# Check statistics
mean = sdfs.mean()
std = sdfs.std()
min_val = sdfs.min()
max_val = sdfs.max()
print(f"Mean: {mean:.4f}")
print(f"Std: {std:.4f}")
print(f"Min: {min_val:.4f}")
print(f"Max: {max_val:.4f}")
# ✅ GOOD: mean ≈ 0, std ≈ 1
# ❌ BAD: mean = 50, std = 200 (needs normalization)What you want:
- Mean close to 0 (between -0.5 and 0.5)
- Std close to 1 (between 0.5 and 2.0)
- No extreme outliers (min/max within ±5 std)
If not normalized, apply this:
# Global normalization
mean = sdfs.mean()
std = sdfs.std()
sdfs_normalized = (sdfs - mean) / (std + 1e-8)
# Save normalization params for inference
np.save('norm_params.npy', {'mean': mean, 'std': std})2. Distribution Shape Test
Why: VAE's Gaussian prior (N(0,1)) works best when data has roughly Gaussian-ish distribution.
import matplotlib.pyplot as plt
# Plot histogram of SDF values
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.hist(sdfs.flatten(), bins=100, alpha=0.7)
plt.xlabel('SDF Value')
plt.ylabel('Frequency')
plt.title('SDF Distribution')
# Plot Q-Q plot (compare to normal distribution)
from scipy import stats
plt.subplot(1, 2, 2)
stats.probplot(sdfs.flatten()[::100], dist="norm", plot=plt) # Sample for speed
plt.title('Q-Q Plot (vs Normal)')
plt.tight_layout()
plt.savefig('data_distribution.png')
# Compute skewness and kurtosis
skewness = stats.skew(sdfs.flatten())
kurtosis = stats.kurtosis(sdfs.flatten())
print(f"Skewness: {skewness:.4f}") # Should be close to 0
print(f"Kurtosis: {kurtosis:.4f}") # Should be close to 0What you want:
- Skewness between -1 and 1 (not heavily skewed)
- Kurtosis between -1 and 3 (not too heavy/light tails)
- Histogram looks roughly bell-shaped (doesn't need to be perfect)
If heavily skewed (e.g., SDF has many zeros near surface):
# Apply log transform or Box-Cox
from scipy.stats import boxcox
sdfs_transformed, lambda_param = boxcox(sdfs.flatten() + 1) # +1 to handle negatives3. Dataset Diversity Check
Why: Low diversity → VAE will collapse to dataset mean.
# Compute pairwise distances between meshes
def chamfer_distance(mesh1, mesh2):
# Your chamfer distance implementation
pass
# Sample 100 random pairs
diversities = []
for _ in range(100):
i, j = np.random.choice(len(dataset), 2, replace=False)
dist = chamfer_distance(dataset[i], dataset[j])
diversities.append(dist)
avg_diversity = np.mean(diversities)
std_diversity = np.std(diversities)
print(f"Avg pairwise distance: {avg_diversity:.4f}")
print(f"Std pairwise distance: {std_diversity:.4f}")
# ✅ GOOD: High avg diversity (>0.1 for normalized shapes)
# ❌ BAD: Low diversity (<0.01) → all shapes look similarWhat you want:
- Coefficient of variation > 0.3:
std_diversity / avg_diversity > 0.3 - If too low → you need more diverse data, or VAE will just memorize the mean
4. Outlier Detection
Why: Extreme outliers break VAE training (gradients explode).
# Detect outliers using IQR method
Q1 = np.percentile(sdfs, 25)
Q3 = np.percentile(sdfs, 75)
IQR = Q3 - Q1
lower_bound = Q1 - 3 * IQR
upper_bound = Q3 + 3 * IQR
outlier_mask = (sdfs < lower_bound) | (sdfs > upper_bound)
outlier_percentage = outlier_mask.sum() / sdfs.size * 100
print(f"Outliers: {outlier_percentage:.2f}%")
if outlier_percentage > 5:
print("⚠️ WARNING: Too many outliers! Consider:")
print(" 1. Clipping: sdfs = np.clip(sdfs, lower_bound, upper_bound)")
print(" 2. Remove bad samples")
print(" 3. Check data preprocessing pipeline")What you want:
- Less than 5% outliers
- If greater than 10% outliers → something wrong with data collection
5. Variance Per Dimension
Why: If some dimensions have zero variance, VAE wastes capacity encoding nothing.
# Compute variance per dimension (e.g., per point in point cloud)
per_dim_variance = sdfs.var(axis=0) # Variance across dataset
# Count low-variance dimensions
low_var_dims = (per_dim_variance < 0.01).sum()
total_dims = len(per_dim_variance)
print(f"Low variance dims: {low_var_dims}/{total_dims} ({low_var_dims/total_dims*100:.1f}%)")
# ✅ GOOD: <10% low-variance dimensions
# ❌ BAD: >30% low-variance → consider dimensionality reduction6. Class/Category Balance (If Multi-Class)
Why: Imbalanced classes → VAE learns dominant class only.
# If your dataset has categories (e.g., chairs, tables, cars)
from collections import Counter
categories = [sample['category'] for sample in dataset]
category_counts = Counter(categories)
print("Category distribution:")
for cat, count in category_counts.most_common():
print(f" {cat}: {count} ({count/len(dataset)*100:.1f}%)")
# Check imbalance ratio
max_count = max(category_counts.values())
min_count = min(category_counts.values())
imbalance_ratio = max_count / min_count
print(f"Imbalance ratio: {imbalance_ratio:.2f}")
# ✅ GOOD: ratio < 10
# ⚠️ MEDIUM: ratio 10-50 (consider oversampling)
# ❌ BAD: ratio > 50 (need rebalancing)If imbalanced:
# Oversample minority classes
from torch.utils.data import WeightedRandomSampler
# Create sample weights (inverse frequency)
class_weights = {cat: 1.0/count for cat, count in category_counts.items()}
sample_weights = [class_weights[sample['category']] for sample in dataset]
sampler = WeightedRandomSampler(sample_weights, len(dataset))
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)7. Reconstruction Baseline (Pre-VAE Sanity Check)
Why: If simple autoencoder fails, VAE will definitely fail.
# Train simple autoencoder (no KL term) for 10 epochs
class SimpleAE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = YourEncoder()
self.decoder = YourDecoder()
def forward(self, x):
z = self.encoder(x)
recon = self.decoder(z)
return recon
ae = SimpleAE()
optimizer = torch.optim.Adam(ae.parameters(), lr=1e-4)
for epoch in range(10):
for batch in dataloader:
recon = ae(batch)
loss = F.mse_loss(recon, batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
# ✅ GOOD: Loss decreases smoothly to <0.01
# ❌ BAD: Loss stuck above 0.1 → encoder/decoder architecture problemComprehensive Data Diagnostic Script
import numpy as np
import torch
from scipy import stats
from collections import Counter
import matplotlib.pyplot as plt
def diagnose_vae_data(dataset, sample_fn=None):
"""
Comprehensive VAE data diagnostic
Args:
dataset: Your dataset object
sample_fn: Function to convert mesh → SDF array
"""
print("=" * 50)
print("VAE DATA DIAGNOSTIC REPORT")
print("=" * 50)
# 1. Load and flatten data
all_samples = []
for i, sample in enumerate(dataset):
if sample_fn:
data = sample_fn(sample)
else:
data = sample # Assume already processed
all_samples.append(data.flatten())
if i >= 1000: # Sample first 1000 for speed
break
all_samples = np.array(all_samples)
flat_data = all_samples.flatten()
# 2. Basic statistics
print("\n📊 BASIC STATISTICS")
print(f" Dataset size: {len(all_samples)}")
print(f" Sample shape: {all_samples[0].shape}")
print(f" Mean: {flat_data.mean():.4f}")
print(f" Std: {flat_data.std():.4f}")
print(f" Min: {flat_data.min():.4f}")
print(f" Max: {flat_data.max():.4f}")
# Check normalization
if abs(flat_data.mean()) > 0.5 or abs(flat_data.std() - 1.0) > 0.5:
print(" ⚠️ WARNING: Data not normalized! Apply z-score normalization.")
else:
print(" ✅ Data normalization looks good")
# 3. Distribution check
print("\n📈 DISTRIBUTION ANALYSIS")
skewness = stats.skew(flat_data)
kurtosis = stats.kurtosis(flat_data)
print(f" Skewness: {skewness:.4f}")
print(f" Kurtosis: {kurtosis:.4f}")
if abs(skewness) > 1.0:
print(" ⚠️ WARNING: Data is skewed. Consider transformation.")
if abs(kurtosis) > 3.0:
print(" ⚠️ WARNING: Heavy-tailed distribution.")
if abs(skewness) < 1.0 and abs(kurtosis) < 3.0:
print(" ✅ Distribution shape looks reasonable")
# 4. Outlier detection
print("\n🔍 OUTLIER DETECTION")
Q1, Q3 = np.percentile(flat_data, [25, 75])
IQR = Q3 - Q1
outliers = ((flat_data < Q1 - 3*IQR) | (flat_data > Q3 + 3*IQR)).sum()
outlier_pct = outliers / len(flat_data) * 100
print(f" Outliers: {outlier_pct:.2f}%")
if outlier_pct > 5:
print(" ⚠️ WARNING: High outlier percentage!")
else:
print(" ✅ Outlier percentage acceptable")
# 5. Variance check
print("\n📉 VARIANCE ANALYSIS")
per_sample_var = all_samples.var(axis=0)
low_var_pct = (per_sample_var < 0.01).sum() / len(per_sample_var) * 100
print(f" Low variance dimensions: {low_var_pct:.1f}%")
if low_var_pct > 30:
print(" ⚠️ WARNING: Many low-variance dimensions!")
else:
print(" ✅ Variance distribution looks good")
# 6. Diversity estimate
print("\n🎨 DIVERSITY CHECK")
sample_indices = np.random.choice(len(all_samples), min(100, len(all_samples)), replace=False)
pairwise_dists = []
for i in range(len(sample_indices)-1):
dist = np.linalg.norm(all_samples[sample_indices[i]] - all_samples[sample_indices[i+1]])
pairwise_dists.append(dist)
avg_dist = np.mean(pairwise_dists)
std_dist = np.std(pairwise_dists)
cv = std_dist / avg_dist if avg_dist > 0 else 0
print(f" Avg pairwise distance: {avg_dist:.4f}")
print(f" Coefficient of variation: {cv:.4f}")
if cv < 0.2:
print(" ⚠️ WARNING: Low diversity! Dataset may be too similar.")
else:
print(" ✅ Dataset diversity looks good")
# 7. Summary
print("\n" + "=" * 50)
print("RECOMMENDATION:")
issues = []
if abs(flat_data.mean()) > 0.5 or abs(flat_data.std() - 1.0) > 0.5:
issues.append("Normalize data")
if abs(skewness) > 1.0:
issues.append("Apply transformation (log/Box-Cox)")
if outlier_pct > 5:
issues.append("Remove/clip outliers")
if low_var_pct > 30:
issues.append("Reduce dimensionality")
if cv < 0.2:
issues.append("Collect more diverse data")
if not issues:
print("✅ Data looks ready for VAE training!")
else:
print("⚠️ Fix these issues before training:")
for issue in issues:
print(f" - {issue}")
print("=" * 50)
return {
'mean': flat_data.mean(),
'std': flat_data.std(),
'skewness': skewness,
'kurtosis': kurtosis,
'outlier_pct': outlier_pct,
'diversity_cv': cv
}
# Usage
stats = diagnose_vae_data(your_dataset)Quick Answer to Your Question
Do you NEED to do these tests?
Strictly required (will break training if not done):
- ✅ Normalization check - CRITICAL
- ✅ Outlier detection - CRITICAL
Highly recommended (training will be painful without):
- ✅ Distribution shape - Important for tuning β
- ✅ Simple AE baseline - Saves debugging time
Optional but useful:
- Variance per dimension (helps architecture design)
- Diversity check (helps prevent collapse)
- Class balance (only if multi-class)
TL;DR Minimum Viable Check
If you only do ONE thing:
# Normalize your data
data_mean = dataset.mean()
data_std = dataset.std()
dataset_normalized = (dataset - data_mean) / data_std
print(f"Mean: {dataset_normalized.mean():.4f} (should be ~0)")
print(f"Std: {dataset_normalized.std():.4f} (should be ~1)")
# Save for inference
torch.save({'mean': data_mean, 'std': data_std}, 'norm_params.pt')If mean ≈ 0 and std ≈ 1, you're 80% good to go!