Skip to content

Commit 47b6c9f

Browse files
committed
fix: resolve top_k shadowing and None-guard bug in sample_from_logits (PR shiyu-coder#238)
1 parent 67b630e commit 47b6c9f

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

model/kronos.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,8 @@ def top_k_top_p_filtering(
372372

373373
def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True):
374374
logits = logits / temperature
375-
if top_k is not None or top_p is not None:
376-
if top_k > 0 or top_p < 1.0:
377-
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
375+
if (top_k is not None and top_k > 0) or (top_p is not None and top_p < 1.0):
376+
logits = top_k_top_p_filtering(logits, top_k=top_k or 0, top_p=top_p or 1.0)
378377

379378
probs = F.softmax(logits, dim=-1)
380379

0 commit comments

Comments
 (0)