All notable changes to TorchGW are documented in this file.
Exact differentiable gradients via implicit differentiation. Fixes a correctness bug where the previous "envelope theorem" backward produced gradients with up to 30x error (cosine similarity as low as 0.07).
- Default gradient computation for
differentiable=Trueis now implicit differentiation (exact) instead of the old frozen-potentials approximation. No API change needed — the new default is strictly better.
grad_modeparameter forsampled_gw— controls how gradients are computed whendifferentiable=True:"implicit"(default): exact gradient via adjoint system at the Sinkhorn fixed point. Solved via Schur complement on the Sinkhorn Jacobian. Memory: O(NK + K^2). Same speed as the old approximate mode."unrolled": exact gradient via unrolled PyTorch autograd. Memory: O(NK * sinkhorn_iters). Useful as fallback at extremely small epsilon.
-
Gradient correctness — The old backward formula
grad_C = -grad_T * T / regtreated Sinkhorn potentials as constants ("frozen-potentials"). This is a first-order approximation that ignores how the potentials depend on C through the Sinkhorn iterations. The new implicit differentiation backward solves the adjoint system derived from the implicit function theorem at the Sinkhorn fixed point, giving exact gradients. -
Adjoint solver stability — Initial implementation used fixed-point iteration for the adjoint system, which diverges when the spectral radius >= 1 (common at small epsilon). Replaced with Schur complement direct solve on the Sinkhorn Jacobian J^T (eigenvalues in [0,2], well-conditioned). Null space from potential constant ambiguity removed via rank-1 correction (11^T/K).
-
Warning for non-differentiable pure GW —
differentiable=Truewithfgw_alpha=0now emits a warning, since gradients cannot flow through precomputed graph distances.
_SinkhornAutogradrenamed to_SinkhornApproximate(frozen-potentials, used internally for semi-relaxed only)- New
_SinkhornImplicitclass (exact VJP via adjoint) - New
_sinkhorn_unrolledfunction (exact VJP via autograd) _sinkhorn_differentiablerewritten as dispatcher for the three backends
- 8 new gradient correctness tests in
tests/test_sinkhorn_grad.py: implicit vs unrolled (rel_err < 2%), implicit vs finite differences, descent direction, non-uniform marginals, approximate formula verification, grad_mode validation - Integration tests for
grad_modethroughsampled_gw - 111 total tests, all passing
docs/algorithm.rst: full derivation of implicit differentiation (Jacobian structure, adjoint equation, Schur complement, null-space handling)- README: updated news, differentiable mode example with
grad_mode
Major performance and robustness release. 3-6x faster on typical workloads, with Triton GPU kernel acceleration, mixed precision support, and comprehensive numerical stability fixes.
- Triton fused Sinkhorn kernels — Custom GPU kernels for the Sinkhorn row/column
logsumexp updates, reducing kernel launches from ~6 to 1 per half-step. 2-5x speedup
on the Sinkhorn portion (5001×6001 fp32: 261ms → 50ms). Includes fused transport plan
materialization and fused marginal error check. Falls back to PyTorch automatically
when Triton is unavailable. (
_triton_sinkhorn.py) - Sinkhorn warm-start — Reuse log-domain potentials (log_u, log_v) from the previous GW iteration as initial values. Reduces Sinkhorn convergence from ~10 to ~3-5 steps.
- GPU sampling — Replace CPU numpy sampling with
torch.multinomialon GPU. Transfers 2×M integers instead of the full N×K transport plan per iteration. - Mixed precision — New
mixed_precision=Trueparameter runs Sinkhorn in float32 (safe in log domain) while keeping marginals and output in float64. Up to 1.7x faster on A100/L40S; larger gains expected on consumer GPUs. - Dijkstra caching —
DijkstraProvidercaches per-node SSSP results across iterations with FIFO eviction (max 2000 rows per side). Avoids redundant computation when the same anchor nodes are re-sampled. - Cost plateau early stopping — GW cost EMA + patience-based convergence detection.
Stops when the smoothed cost stops improving, rather than waiting for the noisy
||T - T_prev||to drop belowtol(which may never happen due to sampling noise). Example: dijkstra 1000×1200 stops at 97 iters instead of running all 500. - Parallel all-pairs Dijkstra —
PrecomputedProviderruns source and target graph Dijkstra in parallel via process-based parallelism (scipy holds the GIL). 1.2-1.5x speedup on large graphs (≥2000 total nodes). - Reduced CUDA sync points — Convergence checks batched every 5 iterations;
augmented penalty computed on GPU without
.item()sync. - Sinkhorn convergence check via logsumexp — Avoids materializing full N×K matrix for the marginal error computation.
- Pre-allocated augmented cost matrix — Reused across iterations instead of re-allocated each step.
mixed_precisionparameter forsampled_gwandsampled_lowrank_gwlambda_ema_betaparameter for cost matrix EMA smoothing (variance reduction)- Verbose Sinkhorn output (
verbose=Trueprints per-iteration marginal errors) sample_pairs_gpu()— GPU-native weighted sampling functionsample_pairs_from_plan()now accepts optionalrngparameter for reproducibility
-
Numerical stability
torch.log(a + 1e-300)replaced with.clamp(min=1e-30)— the 1e-300 constant vanishes in float32, providing no protection against log(0)- Regularization decay capped at 10x to prevent instability with large epsilon values
- Low-rank mirror descent: enforce
gamma * reg >= 1to prevent exponential overflow - Handle all-inf distance matrices (fully disconnected subgraphs) without crashing
sample_pairs_gpucasts to float32 beforetorch.multinomial(required on some PyTorch versions/devices)
-
Correctness
- kNN graph symmetrized via
.maximum(.T)—kneighbors_graphreturns directed edges - Semi-relaxed Sinkhorn: correct KL proximal blend
tau * new + (1-tau) * oldinstead oftau * newwhich discarded history differentiable=True+semi_relaxed=Truenow raisesNotImplementedError(envelope theorem gradient is invalid for unbalanced Sinkhorn)- Detach
T_previn differentiable mode to prevent computation graph accumulation across GW iterations (OOM after many iterations) joint_embedding: prevent index out-of-bounds whenout_dim > k_svdslambda_ema_beta=0.0now disables EMA (previously locked to first iteration's cost)- Dijkstra cache eviction safety: never evict keys needed by the current request
- kNN graph symmetrized via
-
Compatibility
scipy.sparse.linalg.cg: auto-detecttolvsrtolparameter name for SciPy 1.10-1.17+ compatibilitysampled_lowrank_gw:semi_relaxedvalidation moved to function start (fail-fast)
-
API consistency
sampled_lowrank_gwnow acceptsmixed_precisionparameter- Removed unused
semi_relaxed/rho/**kwargsfromsinkhorn_lowranksignature sample_pairs_from_planreturns(rows, cols)arrays instead oflist[tuple]
- 72 tests covering all solver modes, mixed precision, early stopping, Dijkstra cache, differentiable gradients, boundary values, and semi-relaxed mode
- Test suite runs in ~18s (down from ~68s before optimizations)
docs/optimization-log.md— Detailed optimization history with benchmarksdocs/improvements.md— Updated future directions (torch.compile, cuGraph, Triton extensions)
Initial public release.
- Sampled Gromov-Wasserstein solver (
sampled_gw) with log-domain Sinkhorn - Low-rank solver (
sampled_lowrank_gw) via mirror descent + Dykstra - Three distance modes:
dijkstra,precomputed,landmark - Fused Gromov-Wasserstein (
fgw_alphablending) - Multiscale warm-start via farthest-point sampling
- Differentiable transport plans (
differentiable=True) - Semi-relaxed mode for unbalanced transport
- Joint manifold embedding (
joint_embedding) - kNN graph construction with component stitching (
build_knn_graph)