Releases: chansigit/torchgw
Releases · chansigit/torchgw
v0.4.2 — Memory Optimization & Large-Scale Benchmarks
Performance & Scaling
- ~40% peak GPU memory reduction via in-place momentum update and sink_dtype Lambda construction
- Max scale on L40S 48GB: 35k x 40k (was 25k x 30k)
- Max scale on H100 80GB: 45k x 45k / 40k x 50k, 17-18s, Spearman >= 0.999
- Fix
torch.dotint32 overflow for transport plans with N*K > 2^31 - Graceful float64 output fallback when GPU memory is exhausted
Documentation
- Rewritten "How It Works" with Mermaid flowchart, algorithm intuition, and speedup table
- H100 + L40S benchmark tables in README and Sphinx docs
- Updated quickstart with
pip install torchgw,grad_modeexamples - Reproducible benchmark script:
python examples/benchmark_scale.py
Install
pip install torchgw==0.4.2Full changelog: CHANGELOG.md
v0.4.1 — Exact Differentiable Gradients
Exact Differentiable Gradients via Implicit Differentiation
This release fixes a correctness bug in the differentiable backward pass and replaces it with an exact gradient computation.
What changed
The previous differentiable=True backward used a "frozen-potentials" approximation (grad_C = -grad_T * T / ε) that ignored how Sinkhorn dual potentials depend on the cost matrix. This produced gradients with cosine similarity as low as 0.07 against the true gradient — essentially random at small ε.
The new default (grad_mode="implicit") solves the adjoint system derived from the implicit function theorem at the Sinkhorn fixed point, giving exact gradients with the same O(NK) memory cost.
Highlights
grad_mode="implicit"(default): exact gradient via Schur complement solve on the Sinkhorn Jacobian. Memory-efficient: O(NK + K²).grad_mode="unrolled": exact gradient via unrolled PyTorch autograd. Fallback for extreme ε.- Numerically stable: Schur complement formulated on J^T (eigenvalues in [0,2]) with rank-1 null-space correction. Stable across all ε values.
- 111 tests passing, including 8 new gradient correctness tests (finite differences, implicit vs unrolled comparison, descent direction).
Usage
# Default: exact gradients (no change needed)
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat, differentiable=True)
# Explicit mode selection
T = sampled_gw(..., differentiable=True, grad_mode="implicit") # default
T = sampled_gw(..., differentiable=True, grad_mode="unrolled") # alternativeInstall
pip install torchgw==0.4.1Full changelog: CHANGELOG.md
Algorithm docs: Differentiable Sinkhorn theory