fix: three silent bugs in sampling and quantization + unit tests#238
Open
korbonits wants to merge 1 commit intoshiyu-coder:masterfrom
Open
fix: three silent bugs in sampling and quantization + unit tests#238korbonits wants to merge 1 commit intoshiyu-coder:masterfrom
korbonits wants to merge 1 commit intoshiyu-coder:masterfrom
Conversation
- sample_from_logits: guard top_k/top_p None comparisons before checking `> 0` / `< 1.0` (TypeError when one arg is None) - sample_from_logits: replace `top_k(probs, k=1)` with `torch.topk(probs, k=1)` — the parameter shadowed the function, making greedy decoding (sample_logits=False) uncallable - BinarySphericalQuantizer.get_codebook_entry / get_group_codebook_entry: `h, w = int(...)` raises TypeError because a scalar int is not iterable; fix to `h = w = int(...)` - FixedEmbedding: `w.require_grad = False` silently set a spurious attribute instead of disabling gradient tracking; fix to `w.requires_grad = False` Add 65 offline unit tests (tests/test_sampling.py, tests/test_bsq.py, tests/test_model_shapes.py) that cover all three bugs as regression guards, plus shape/behaviour tests for every major building block and both top-level models. Add pyproject.toml with uv project definition, ruff (lint + format), ty (type checker), and pytest config. Add uv.lock for reproducible installs. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
aidan46
added a commit
to aidan46/kronos
that referenced
this pull request
Apr 14, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Three bugs found via static type-checking (
ty) during a tooling audit:Bugs fixed
model/kronos.py—sample_from_logitsTypeErrorwhen one oftop_k/top_pisNoneThe guard
if top_k is not None or top_p is not Noneallows entering the block with e.g.top_k=None, top_p=0.9, then immediately hittingtop_k > 0which raisesTypeError: '>' not supported between instances of 'NoneType' and 'int'.Fix: use
(top_k is not None and top_k > 0) or (top_p is not None and top_p < 1.0).Greedy decoding (
sample_logits=False) was uncallable_, x = top_k(probs, k=1, dim=-1)called the parametertop_k(which isNoneby default) as a function instead oftorch.topk.Fix:
_, x = torch.topk(probs, k=1, dim=-1).model/module.py—BinarySphericalQuantizerTypeErroringet_codebook_entry/get_group_codebook_entryh, w = int(z_q.shape[1] ** 0.5)attempts to unpack a scalarintinto two variables, raisingTypeError: cannot unpack non-sequence int.Fix:
h = w = int(z_q.shape[1] ** 0.5).w.require_grad = Falsetypo inFixedEmbeddingThe misspelling silently created a new attribute on the tensor without actually disabling gradient tracking.
Fix:
w.requires_grad = False.Tests added
65 offline unit tests (no model downloads, ~0.8 s total) covering all four bugs as regression guards, plus shape/behaviour tests for every major building block:
tests/test_sampling.pytop_k_top_p_filtering,sample_from_logits(bugs 1 & 2)tests/test_bsq.pyBinarySphericalQuantizer,BSQuantizer(bug 3)tests/test_model_shapes.pyFixedEmbedding,TemporalEmbedding,TransformerBlock,DualHead,HierarchicalEmbedding,KronosTokenizer,Kronos(bug 4)Tooling
pyproject.toml— project metadata,uvdependency management,ruff(lint + format),ty(type checker),pytestconfiguv.lock— reproducible installs (uv synccreates a venv with all deps including dev tools)Test plan
uv sync --dev uv run pytest tests/test_sampling.py tests/test_bsq.py tests/test_model_shapes.py -v # 65 passed in < 1 sThe existing
tests/test_kronos_regression.pyis unchanged and continues to pass as before.A separate PR (korbonits:chore/ruff-format) with ruff lint/formatting cleanup across all source files is available if wanted — kept separate to avoid polluting this diff.
🤖 Generated with Claude Code