Skip to content

fix: three silent bugs in sampling and quantization + unit tests#238

Open
korbonits wants to merge 1 commit intoshiyu-coder:masterfrom
korbonits:fix/bugs-tests-tooling
Open

fix: three silent bugs in sampling and quantization + unit tests#238
korbonits wants to merge 1 commit intoshiyu-coder:masterfrom
korbonits:fix/bugs-tests-tooling

Conversation

@korbonits
Copy link
Copy Markdown

Summary

Three bugs found via static type-checking (ty) during a tooling audit:

Bugs fixed

model/kronos.pysample_from_logits

  1. TypeError when one of top_k/top_p is None
    The guard if top_k is not None or top_p is not None allows entering the block with e.g. top_k=None, top_p=0.9, then immediately hitting top_k > 0 which raises TypeError: '>' not supported between instances of 'NoneType' and 'int'.
    Fix: use (top_k is not None and top_k > 0) or (top_p is not None and top_p < 1.0).

  2. Greedy decoding (sample_logits=False) was uncallable
    _, x = top_k(probs, k=1, dim=-1) called the parameter top_k (which is None by default) as a function instead of torch.topk.
    Fix: _, x = torch.topk(probs, k=1, dim=-1).

model/module.pyBinarySphericalQuantizer

  1. TypeError in get_codebook_entry / get_group_codebook_entry
    h, w = int(z_q.shape[1] ** 0.5) attempts to unpack a scalar int into two variables, raising TypeError: cannot unpack non-sequence int.
    Fix: h = w = int(z_q.shape[1] ** 0.5).

  2. w.require_grad = False typo in FixedEmbedding
    The misspelling silently created a new attribute on the tensor without actually disabling gradient tracking.
    Fix: w.requires_grad = False.

Tests added

65 offline unit tests (no model downloads, ~0.8 s total) covering all four bugs as regression guards, plus shape/behaviour tests for every major building block:

File Tests Covers
tests/test_sampling.py 18 top_k_top_p_filtering, sample_from_logits (bugs 1 & 2)
tests/test_bsq.py 17 BinarySphericalQuantizer, BSQuantizer (bug 3)
tests/test_model_shapes.py 30 FixedEmbedding, TemporalEmbedding, TransformerBlock, DualHead, HierarchicalEmbedding, KronosTokenizer, Kronos (bug 4)

Tooling

  • pyproject.toml — project metadata, uv dependency management, ruff (lint + format), ty (type checker), pytest config
  • uv.lock — reproducible installs (uv sync creates a venv with all deps including dev tools)

Test plan

uv sync --dev
uv run pytest tests/test_sampling.py tests/test_bsq.py tests/test_model_shapes.py -v
# 65 passed in < 1 s

The existing tests/test_kronos_regression.py is unchanged and continues to pass as before.


A separate PR (korbonits:chore/ruff-format) with ruff lint/formatting cleanup across all source files is available if wanted — kept separate to avoid polluting this diff.

🤖 Generated with Claude Code

- sample_from_logits: guard top_k/top_p None comparisons before
  checking `> 0` / `< 1.0` (TypeError when one arg is None)
- sample_from_logits: replace `top_k(probs, k=1)` with
  `torch.topk(probs, k=1)` — the parameter shadowed the function,
  making greedy decoding (sample_logits=False) uncallable
- BinarySphericalQuantizer.get_codebook_entry /
  get_group_codebook_entry: `h, w = int(...)` raises TypeError
  because a scalar int is not iterable; fix to `h = w = int(...)`
- FixedEmbedding: `w.require_grad = False` silently set a spurious
  attribute instead of disabling gradient tracking; fix to
  `w.requires_grad = False`

Add 65 offline unit tests (tests/test_sampling.py,
tests/test_bsq.py, tests/test_model_shapes.py) that cover all
three bugs as regression guards, plus shape/behaviour tests for
every major building block and both top-level models.

Add pyproject.toml with uv project definition, ruff (lint +
format), ty (type checker), and pytest config. Add uv.lock for
reproducible installs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant