Skip to content

Commit 43a15dc

Browse files
author
Jerry Xiong
committed
fix mask_ratio->selection_rate in dit.py
1 parent 019d147 commit 43a15dc

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

dit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, y: torch.Tensor, **kwargs) -
243243
for idx, block in enumerate(self.blocks):
244244
if use_routing and idx == self.routes[route_count]['start_layer_idx']:
245245
x_D_last = x.clone()
246-
ids_keep = self.router.get_mask(x, mask_ratio=self.routes[route_count]['selection_ratio'] if overwrite_selection_ratio is None else overwrite_selection_ratio)
246+
ids_keep = self.router.get_mask(x, selection_rate=self.routes[route_count]['selection_ratio'] if overwrite_selection_ratio is None else overwrite_selection_ratio)
247247
x = self.router.start_route(x, ids_keep)
248248

249249
if fp32_next:

0 commit comments

Comments
 (0)