Skip to content

Commit 92fbb82

Browse files
committed
add tread sampling
1 parent 4b59467 commit 92fbb82

4 files changed

Lines changed: 184 additions & 9 deletions

File tree

README.md

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,18 @@
1111
<b>CompVis Group @ LMU Munich</b> <br/>
1212
</p>
1313

14-
[![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/abs/2501.04765)
14+
[![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/pdf/2501.04765)
1515
[![Project Page](https://img.shields.io/badge/Project-Page-blue)](https://compvis.github.io/tread/)
1616

1717
This repository contains the official implementation of the paper "TREAD: Token Routing for Efficient Architecture-agnostic Diffusion Training".
1818

1919
We propose TREAD, a new method to increase the efficiency of diffusion training by improving upon iteration speed and performance at the same time. For this, we use uni-directional token transportation to modulate the information flow in the network.
2020

2121
<div align="center">
22-
<img src="./docs/images/teaser.png" alt="teaser" style="width:50%;">
22+
<img src="./docs/static/images/teaser.png" alt="teaser" style="width:50%;">
2323
</div>
2424

25-
## 🚀 Usage
26-
27-
### Training
25+
## 🚀 Training
2826

2927
In order to train a diffusion model, we offer a minimalistic training script in `train.py`. In its simplest form it can be started using:
3028

@@ -44,9 +42,20 @@ Under `model` one can decide between `dit` and `tread` which are the preconfigur
4442

4543
In our paper, we show that TREAD can also work on other architectures. In practice, one needs to be more careful with the routing process in order to adhere to the characteristics of the specific architecture as some have a spatial bias (RWKV, Mamba, etc.). For simplicity, we only provide code for the Transformer architecture as it is the most widely used while being robust and easy to work with.
4644

47-
### Sampling
45+
## 🖼️ Sampling
46+
47+
For most experiments we use the [EDM](https://github.com/NVlabs/edm) training and sampling to stay consistent with prior art, and the FID calculation is done via the [ADM](https://github.com/openai/guided-diffusion) evaluation suite. We provide a `fid.py` to evaluate our models during training using the same reference batches as ADM.
48+
49+
## 💥 Guiding TREAD
50+
51+
TREAD works great during _training_! How about _inference_? \
52+
It turns out TREAD can be applied during guided inference as well to gain additional performance and reduce FLOPS at the same time! \
53+
Instead of dropping the class label (CFG), we can guide with a selection rate delta. Since TREAD's selection rate (0.5) generalizes to other rates, this can be tuned in inference-time only.
54+
55+
We demonstrate this in `rf.py` which contains minimal flow matching code for training and sampling:
4856

49-
For sampling, we use the [EDM](https://github.com/NVlabs/edm) sampling, and the FID calculation is done via the [ADM](https://github.com/openai/guided-diffusion) evaluation suite. We provide a `fid.py` to evaluate our models during training using the same reference batches as ADM.
57+
`sample`: normal sampling\
58+
`sample_tread`: TREAD sampling 🔥
5059

5160
## 🎓 Citation
5261

dit.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,20 +228,22 @@ def unpatchify(self, x):
228228

229229
def forward(self, x: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
230230
class_drop_prob = kwargs.get('class_drop_prob', 0.0)
231+
force_routing = kwargs.get('force_routing', False)
232+
overwrite_selection_ratio = kwargs.get('overwrite_selection_ratio', None)
231233

232234
x = self.x_embedder(x) + self.pos_embed
233235
t = self.t_embedder(t)
234236
y = self.y_embedder(y, class_drop_prob)
235237
c = t + y
236238

237-
use_routing = self.training and self.enable_routing and self.routes
239+
use_routing = (self.training and self.enable_routing and self.routes) or force_routing
238240
route_count = 0 if use_routing else None
239241
fp32_next = False
240242

241243
for idx, block in enumerate(self.blocks):
242244
if use_routing and idx == self.routes[route_count]['start_layer_idx']:
243245
x_D_last = x.clone()
244-
mask_info = self.router.get_mask(x, mask_ratio=self.routes[route_count]['selection_ratio'])
246+
mask_info = self.router.get_mask(x, mask_ratio=self.routes[route_count]['selection_ratio'] if overwrite_selection_ratio is None else overwrite_selection_ratio)
245247
x = self.router.start_route(x, mask_info)
246248

247249
if fp32_next:

inference.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
import hydra
3+
from omegaconf import DictConfig
4+
5+
@hydra.main(config_path="configs", config_name="config")
6+
def main(cfg: DictConfig):
7+
8+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9+
diffuser = hydra.utils.instantiate(cfg.diffuser).to(device)
10+
11+
model = hydra.utils.instantiate(cfg.model).to(device)
12+
# model = load_checkpoint(model, cfg.ckpt_path, device)
13+
model.eval()
14+
latents = torch.randn(1, 4, 32, 32, device=device) # for ImageNet-256 + SD-VAE
15+
label = torch.randint(0, 1000, (1,), device=device)
16+
17+
# normal sampling
18+
with torch.no_grad():
19+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
20+
diffuser.sample(
21+
unet=model,
22+
z=model,
23+
sample_steps=40,
24+
cfg_scale=1.5,
25+
)
26+
27+
# TREAD sampling
28+
with torch.no_grad():
29+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
30+
# Models trained with TREAD extrapolate to unseen selection ratios (here gamma).
31+
# We can use this for GUIDED sampling.
32+
# For example, gamma1=0.3 and gamma2=0.7 are not seen during training.
33+
# The model can still sample with these values.
34+
# This enables models trained with TREAD to achieve a on-the-fly trade-off between
35+
# quality and FLOPS.
36+
# One can use the delta between gamma1 and gamma2 to guide the sampling,
37+
# but it can also be combined with CFG (class dropout) or AutoGuidance (unet1 != unet2).
38+
diffuser.sample_tread(
39+
unet1=model,
40+
gamma1=0.3,
41+
unet2=model,
42+
gamma2=0.7,
43+
z=latents,
44+
y=label,
45+
sample_steps=40,
46+
cfg_scale=1.5,
47+
)
48+
49+
if __name__ == "__main__":
50+
torch.backends.cudnn.allow_tf32 = True
51+
torch.backends.cuda.matmul.allow_tf32 = True
52+
main()

rf.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import re
2+
import torch
3+
from torch import nn
4+
5+
"""
6+
The class structure is inspired by: https://github.com/cloneofsimo/minRF/blob/main/rf.py
7+
"""
8+
9+
def parse_int_list(s):
10+
if isinstance(s, list): return s
11+
ranges = []
12+
range_re = re.compile(r'^(\d+)-(\d+)$')
13+
for p in s.split(','):
14+
m = range_re.match(p)
15+
if m:
16+
ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
17+
else:
18+
ranges.append(int(p))
19+
return ranges
20+
21+
class RF(nn.Module):
22+
def __init__(self, train_timestep_sampling: str = "logit_sigmoid", immiscible: bool = False, dts_lambda: float = 0.0):
23+
super().__init__()
24+
self.train_timestep_sampling = train_timestep_sampling
25+
26+
def forward(self, unet: nn.Module, x: torch.Tensor, **kwargs) -> torch.Tensor:
27+
B = x.size(0)
28+
if self.train_timestep_sampling == "logit_sigmoid":
29+
t = torch.sigmoid(torch.randn(B, device=x.device))
30+
elif self.train_timestep_sampling == "uniform":
31+
t = torch.rand(B, device=x.device)
32+
else:
33+
raise ValueError(f'Unknown train timestep sampling method "{self.train_timestep_sampling}".')
34+
t_exp = t.view([B] + [1] * (x.ndim - 1))
35+
z1 = torch.randn_like(x)
36+
zt = (1 - t_exp) * x + t_exp * z1
37+
vtheta = unet(zt, t, **kwargs)
38+
loss = ((z1 - x - vtheta) ** 2).mean(dim=list(range(1, x.ndim)))
39+
return loss
40+
41+
def sample(
42+
self,
43+
unet: nn.Module,
44+
z: torch.Tensor,
45+
sample_steps: int = 50,
46+
cfg_scale: float = 1.0,
47+
**kwargs,
48+
):
49+
B = z.size(0)
50+
dt_value = 1.0 / sample_steps
51+
dt = torch.full((B,), dt_value, device=z.device, dtype=z.dtype) \
52+
.view([B] + [1] * (z.ndim - 1))
53+
54+
for i in range(sample_steps, 0, -1):
55+
t = torch.full((B,), i / sample_steps, device=z.device, dtype=z.dtype)
56+
57+
do_guidance = cfg_scale > 1.0
58+
do_unconditional_only = cfg_scale == 0.0
59+
60+
if do_unconditional_only:
61+
vtheta = unet(z, t, class_drop_prob=1.0, **kwargs)
62+
else:
63+
vtheta = unet(z, t, class_drop_prob=0.0, force_dfs=do_guidance, **kwargs)
64+
65+
if do_guidance:
66+
vtheta_uncond = unet(z, t, class_drop_prob=1.0, force_dfs=do_guidance, **kwargs)
67+
vtheta = vtheta_uncond + cfg_scale * (vtheta - vtheta_uncond)
68+
69+
z = z - dt * vtheta
70+
return z
71+
72+
def sample_tread(
73+
self,
74+
unet1: nn.Module,
75+
gamma1: float,
76+
unet2: nn.Module,
77+
gamma2: float,
78+
z: torch.Tensor,
79+
sample_steps: int = 50,
80+
cfg_scale: float = 1.5,
81+
**kwargs
82+
):
83+
B = z.size(0)
84+
dt_value = 1.0 / sample_steps
85+
dt = torch.full((B,), dt_value, device=z.device, dtype=z.dtype) \
86+
.view([B] + [1] * (z.ndim - 1))
87+
88+
for i in range(sample_steps, 0, -1):
89+
t = torch.full((B,), i / sample_steps, device=z.device, dtype=z.dtype)
90+
91+
assert cfg_scale > 1.0, "We assume cfg_scale > 1.0 for this TREAD sampling."
92+
vtheta1 = unet1(z, t, class_drop_prob=1.0, force_routing=True, overwrite_selection_ratio=gamma1, **kwargs)
93+
vtheta2 = unet2(z, t, class_drop_prob=1.0, force_routing=True, overwrite_selection_ratio=gamma2, **kwargs)
94+
vtheta = vtheta2 + cfg_scale * (vtheta1 - vtheta2)
95+
z = z - dt * vtheta
96+
97+
return z
98+
99+
class LatentRF(RF):
100+
def __init__(self, ae: nn.Module, **kwargs):
101+
super().__init__(**kwargs)
102+
self.ae = ae
103+
104+
def forward(self, unet: nn.Module, x: torch.Tensor, precomputed: bool = False, **kwargs) -> torch.Tensor:
105+
if not precomputed:
106+
with torch.no_grad():
107+
x = self.ae.encode(x)
108+
return super().forward(unet, x, **kwargs)
109+
110+
def sample(self, unet: nn.Module, z: torch.Tensor, sample_steps: int = 50, return_list: bool = False, **kwargs):
111+
latent = super().sample(unet, z, sample_steps=sample_steps, return_list=return_list, **kwargs)
112+
return self.ae.decode(latent)

0 commit comments

Comments
 (0)