You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This repository contains the official implementation of the paper "TREAD: Token Routing for Efficient Architecture-agnostic Diffusion Training".
18
18
19
19
We propose TREAD, a new method to increase the efficiency of diffusion training by improving upon iteration speed and performance at the same time. For this, we use uni-directional token transportation to modulate the information flow in the network.
In order to train a diffusion model, we offer a minimalistic training script in `train.py`. In its simplest form it can be started using:
30
28
@@ -44,9 +42,20 @@ Under `model` one can decide between `dit` and `tread` which are the preconfigur
44
42
45
43
In our paper, we show that TREAD can also work on other architectures. In practice, one needs to be more careful with the routing process in order to adhere to the characteristics of the specific architecture as some have a spatial bias (RWKV, Mamba, etc.). For simplicity, we only provide code for the Transformer architecture as it is the most widely used while being robust and easy to work with.
46
44
47
-
### Sampling
45
+
## 🖼️ Sampling
46
+
47
+
For most experiments we use the [EDM](https://github.com/NVlabs/edm) training and sampling to stay consistent with prior art, and the FID calculation is done via the [ADM](https://github.com/openai/guided-diffusion) evaluation suite. We provide a `fid.py` to evaluate our models during training using the same reference batches as ADM.
48
+
49
+
## 💥 Guiding TREAD
50
+
51
+
TREAD works great during _training_! How about _inference_? \
52
+
It turns out TREAD can be applied during guided inference as well to gain additional performance and reduce FLOPS at the same time! \
53
+
Instead of dropping the class label (CFG), we can guide with a selection rate delta. Since TREAD's selection rate (0.5) generalizes to other rates, this can be tuned in inference-time only.
54
+
55
+
We demonstrate this in `rf.py` which contains minimal flow matching code for training and sampling:
48
56
49
-
For sampling, we use the [EDM](https://github.com/NVlabs/edm) sampling, and the FID calculation is done via the [ADM](https://github.com/openai/guided-diffusion) evaluation suite. We provide a `fid.py` to evaluate our models during training using the same reference batches as ADM.
0 commit comments