Performance & Scaling
- ~40% peak GPU memory reduction via in-place momentum update and sink_dtype Lambda construction
- Max scale on L40S 48GB: 35k x 40k (was 25k x 30k)
- Max scale on H100 80GB: 45k x 45k / 40k x 50k, 17-18s, Spearman >= 0.999
- Fix
torch.dotint32 overflow for transport plans with N*K > 2^31 - Graceful float64 output fallback when GPU memory is exhausted
Documentation
- Rewritten "How It Works" with Mermaid flowchart, algorithm intuition, and speedup table
- H100 + L40S benchmark tables in README and Sphinx docs
- Updated quickstart with
pip install torchgw,grad_modeexamples - Reproducible benchmark script:
python examples/benchmark_scale.py
Install
pip install torchgw==0.4.2Full changelog: CHANGELOG.md