Pure PyTorch | Triton GPU Kernels | Differentiable | Up to 175x faster than POT
TorchGW aligns two point clouds by matching their internal distance structures -- even when they live in different dimensions. Instead of the full O(NK(N+K)) GW cost, it samples M anchor pairs each iteration and approximates the cost in O(NKM), enabling GPU-accelerated alignment at scales where standard solvers are impractical.
Use cases: single-cell multi-omics integration, cross-domain graph matching, shape correspondence, manifold alignment.
|
Performance
|
Features
|
v0.4.1 (2026-04-09) -- Exact differentiable gradients via implicit differentiation. The previous "envelope theorem" backward was a frozen-potentials approximation with up to 30x gradient error; now replaced by an adjoint system solved via Schur complement on the Sinkhorn Jacobian. New
grad_modeparameter ("implicit"default,"unrolled"alternative). See algorithm docs for the math.v0.4.0 (2026-04-07) -- Triton fused Sinkhorn (2-5x GPU speedup), mixed precision, smart early stopping, Sinkhorn warm-start, Dijkstra caching, and 15 numerical stability fixes. See CHANGELOG.md.
pip install -e .Requirements: numpy, scipy, scikit-learn, torch>=2.0, joblib.
Triton ships with PyTorch and enables GPU kernel fusion automatically. No POT needed.
from torchgw import sampled_gw
# Basic usage
T = sampled_gw(X_source, X_target)
# Recommended for large-scale (fastest)
T = sampled_gw(X_source, X_target, distance_mode="landmark", mixed_precision=True)Minimal working example (click to expand)
import torch
from torchgw import sampled_gw
X = torch.randn(500, 3) # source: 500 points in 3D
Y = torch.randn(600, 5) # target: 600 points in 5D (dimensions may differ)
T = sampled_gw(X, Y, epsilon=0.005, M=80, max_iter=200)
# T is a (500, 600) transport plan: T[i,j] = coupling weight between X[i] and Y[j]
print(f"Transport plan: {T.shape}, total mass: {T.sum():.4f}")Spiral (2D) to Swiss roll (3D) alignment, mixed_precision=True, landmark distances:
NVIDIA H100 80GB HBM3:
| Scale | Time | Spearman rho | GPU Memory |
|---|---|---|---|
| 4,000 x 5,000 | 0.8 s | 0.999 | 0.7 GB |
| 10,000 x 12,000 | 4.1 s | 0.999 | 3.9 GB |
| 20,000 x 25,000 | 4.6 s | 0.999 | 16 GB |
| 30,000 x 35,000 | 9.3 s | 0.999 | 34 GB |
| 40,000 x 50,000 | 17 s | 0.999 | 64 GB |
| 45,000 x 45,000 | 18 s | 0.999 | 65 GB |
NVIDIA L40S 48GB
| Scale | Time | Spearman rho | GPU Memory |
|---|---|---|---|
| 4,000 x 5,000 | 2.4 s | 0.999 | 1.1 GB |
| 10,000 x 12,000 | 3.0 s | 0.999 | 6.7 GB |
| 20,000 x 25,000 | 12 s | 0.999 | 18 GB |
| 30,000 x 35,000 | 25 s | 0.999 | 34 GB |
| 35,000 x 40,000 | 34 s | 0.999 | 45 GB |
Alignment quality (Spearman >= 0.999) is maintained across all scales. At 4000x5000, TorchGW is ~175x faster than POT (1.0s vs 183s). Max scale is bounded by GPU memory for the N*K transport plan (~80% VRAM utilization).
Reproduce
python examples/benchmark_scale.pyChoose based on your data scale:
| Mode | Best for | Per-iteration | Memory | Notes |
|---|---|---|---|---|
"precomputed" |
N < 5k | O(NM) lookup | O(N^2) | All-pairs Dijkstra upfront |
"dijkstra" |
5k-50k | O(MN log N) | O(NM) | On-the-fly with caching |
"landmark" |
any scale | O(NMd) GPU | O(Nd) | Recommended default |
# Small scale: precompute all distances once
T = sampled_gw(X, Y, distance_mode="precomputed")
# Bring your own distance matrices
T = sampled_gw(dist_source=D_X, dist_target=D_Y, distance_mode="precomputed")
# Large scale (recommended)
T = sampled_gw(X, Y, distance_mode="landmark", n_landmarks=50)Best performance settings
T = sampled_gw(
X, Y,
distance_mode="landmark", # avoids expensive all-pairs Dijkstra
mixed_precision=True, # float32 Sinkhorn (2x faster on GPU)
M=80, # more samples = better cost estimate
epsilon=0.005, # moderate regularization
)Fused Gromov-Wasserstein
Blend structural (graph distance) and feature (linear) costs:
C_feat = torch.cdist(features_src, features_tgt)
T = sampled_gw(X, Y, fgw_alpha=0.5, C_linear=C_feat)
# Pure Wasserstein (no graph distances needed)
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat)Semi-relaxed transport
For unbalanced datasets (e.g., cell types present in one sample but not the other):
T = sampled_gw(X, Y, semi_relaxed=True, rho=1.0)
# Source marginal enforced, target marginal relaxed via KL penaltyMulti-scale warm start
Speeds up convergence by solving a coarse problem first:
T = sampled_gw(X, Y, multiscale=True, n_coarse=200)Note: GW has symmetric local optima. Works best on data without strong symmetries.
Differentiable mode
Use GW transport as a differentiable layer (exact gradients via implicit differentiation):
C_feat = torch.cdist(encoder(X), encoder(Y))
T = sampled_gw(fgw_alpha=1.0, C_linear=C_feat, differentiable=True)
loss = (C_feat.detach() * T).sum()
loss.backward() # exact gradients flow to encoder parameters
# For memory-constrained settings, unrolled autograd is also available:
T = sampled_gw(..., differentiable=True, grad_mode="unrolled")Low-rank Sinkhorn (N, K > 50k)
For very large problems where the N*K transport plan does not fit in memory:
from torchgw import sampled_lowrank_gw
T = sampled_lowrank_gw(X, Y, rank=30, distance_mode="landmark", n_landmarks=50)Memory: O((N+K)*rank) instead of O(NK).
sampled_gw(
X_source, X_target, # (N, D) and (K, D') feature matrices
*,
distance_mode="dijkstra", # "precomputed" | "dijkstra" | "landmark"
fgw_alpha=0.0, # 0 = pure GW, 1 = Wasserstein, (0,1) = Fused GW
C_linear=None, # (N, K) feature cost matrix for FGW
M=50, # anchor pairs per iteration
epsilon=0.001, # entropic regularization
max_iter=500, tol=1e-5, # convergence control
mixed_precision=False, # float32 Sinkhorn for GPU speed
semi_relaxed=False, # relax target marginal
differentiable=False, # keep autograd graph
multiscale=False, # coarse-to-fine warm start
log=False, # return (T, info_dict)
... # see docs for full parameter list
) -> Tensor # (N, K) transport planSame interface plus rank, lr_max_iter, lr_dykstra_max_iter.
Uses Scetbon, Cuturi & Peyre (2021) factorization.
When to use: only when N*K exceeds GPU memory. At smaller scales,
sampled_gwis faster.
Full API documentation: chansigit.github.io/torchgw
Gromov-Wasserstein finds a coupling between two point clouds by comparing distances within each space rather than distances across spaces. This means the two inputs can live in completely different dimensions -- a 2D spiral can be aligned to a 3D Swiss roll, or a gene expression matrix to a chromatin accessibility matrix.
Standard GW solvers compute the full NN and KK pairwise distance matrices and an O(NK(N+K)) cost tensor at each step, which is prohibitive beyond a few thousand points. TorchGW replaces this with a stochastic approximation: sample M anchor pairs from the current transport plan, compute distances only for those anchors, and build a low-variance cost estimate in O(NKM) -- making each iteration orders of magnitude cheaper.
flowchart TB
subgraph inputs [" "]
direction LR
X["Source X\n(N points, D dims)"]
Y["Target Y\n(K points, D' dims)"]
end
X --> G1["Build kNN graph"]
Y --> G2["Build kNN graph"]
G1 --> loop
G2 --> loop
subgraph loop ["GW Main Loop — repeat until converged"]
S["1. Sample M anchor pairs (i,j)\nfrom current T\n(GPU multinomial)"]
D["2. Compute graph distances\nfrom anchors\nD_left (N×M), D_tgt (K×M)"]
C["3. Assemble GW cost matrix (N×K)\nΛ = mean(D²_left) − 2/M · D_left·D_tgt' + mean(D²_tgt)"]
K["4. Sinkhorn projection → T_new\n(Triton fused kernels, log-domain)"]
M["5. Momentum blend\nT ← (1−α)T + α·T_new\n+ warm-start potentials"]
CV["6. Converged?\n(cost plateau detection)"]
S --> D --> C --> K --> M --> CV
CV -- "no" --> S
end
CV -- "yes" --> T["T* (N × K)\noptimal transport plan"]
style inputs fill:none,stroke:none
style loop fill:#f8f9fa,stroke:#dee2e6,stroke-width:2px
style T fill:#d4edda,stroke:#28a745,stroke-width:2px
| Technique | What it does | Speedup |
|---|---|---|
| Sampled cost | O(NKM) instead of O(NK(N+K)) per iteration | 10-100x |
| Triton Sinkhorn | Fused GPU kernels: single-pass logsumexp, no intermediate N*K allocations | 2-5x |
| Warm-start | Reuse Sinkhorn potentials (log u, log v) across GW iterations | 2-3x fewer Sinkhorn steps |
| Mixed precision | float32 Sinkhorn in log domain (numerically safe), float64 output | up to 2x on consumer GPUs |
| Dijkstra cache | Cache per-node shortest paths, FIFO eviction | avoids redundant graph traversals |
| Cost plateau detection | Stop when GW cost EMA plateaus, not when noisy ‖T-T_prev‖ < tol | saves 50-80% of max_iter |
See the algorithm documentation for the full mathematical formulation, including the semi-relaxed extension and differentiable gradient computation.
git clone https://github.com/chansigit/torchgw.git
cd torchgw
pip install -e ".[dev]"
pytest tests/ -v # 72 tests, ~18sIf you use TorchGW in your research, please cite:
@software{torchgw,
author = {Sijie Chen},
title = {TorchGW: Fast Sampled Gromov-Wasserstein Optimal Transport},
url = {https://github.com/chansigit/torchgw},
version = {0.4.1},
year = {2026},
}This project is source-available.
It is free for academic and other non-commercial research and educational use under the terms of the included LICENSE.
Any commercial use — including any use by or on behalf of a for-profit entity, internal commercial research, product development, consulting, paid services, or deployment in commercial settings — requires a separate paid commercial license.
Copyright (c) 2026 The Board of Trustees of the Leland Stanford Junior University.
For commercial licensing inquiries, contact Stanford Office of Technology Licensing: otl@stanford.edu
See COMMERCIAL_LICENSE.md for details.

