Skip to content

Releases: chansigit/torchgw

v0.4.2 — Memory Optimization & Large-Scale Benchmarks

10 Apr 00:50

Choose a tag to compare

Performance & Scaling

  • ~40% peak GPU memory reduction via in-place momentum update and sink_dtype Lambda construction
  • Max scale on L40S 48GB: 35k x 40k (was 25k x 30k)
  • Max scale on H100 80GB: 45k x 45k / 40k x 50k, 17-18s, Spearman >= 0.999
  • Fix torch.dot int32 overflow for transport plans with N*K > 2^31
  • Graceful float64 output fallback when GPU memory is exhausted

Documentation

  • Rewritten "How It Works" with Mermaid flowchart, algorithm intuition, and speedup table
  • H100 + L40S benchmark tables in README and Sphinx docs
  • Updated quickstart with pip install torchgw, grad_mode examples
  • Reproducible benchmark script: python examples/benchmark_scale.py

Install

pip install torchgw==0.4.2

Full changelog: CHANGELOG.md

v0.4.1 — Exact Differentiable Gradients

09 Apr 21:55

Choose a tag to compare

Exact Differentiable Gradients via Implicit Differentiation

This release fixes a correctness bug in the differentiable backward pass and replaces it with an exact gradient computation.

What changed

The previous differentiable=True backward used a "frozen-potentials" approximation (grad_C = -grad_T * T / ε) that ignored how Sinkhorn dual potentials depend on the cost matrix. This produced gradients with cosine similarity as low as 0.07 against the true gradient — essentially random at small ε.

The new default (grad_mode="implicit") solves the adjoint system derived from the implicit function theorem at the Sinkhorn fixed point, giving exact gradients with the same O(NK) memory cost.

Highlights

  • grad_mode="implicit" (default): exact gradient via Schur complement solve on the Sinkhorn Jacobian. Memory-efficient: O(NK + K²).
  • grad_mode="unrolled": exact gradient via unrolled PyTorch autograd. Fallback for extreme ε.
  • Numerically stable: Schur complement formulated on J^T (eigenvalues in [0,2]) with rank-1 null-space correction. Stable across all ε values.
  • 111 tests passing, including 8 new gradient correctness tests (finite differences, implicit vs unrolled comparison, descent direction).

Usage

# Default: exact gradients (no change needed)
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat, differentiable=True)

# Explicit mode selection
T = sampled_gw(..., differentiable=True, grad_mode="implicit")   # default
T = sampled_gw(..., differentiable=True, grad_mode="unrolled")   # alternative

Install

pip install torchgw==0.4.1

Full changelog: CHANGELOG.md
Algorithm docs: Differentiable Sinkhorn theory