Skip to content

fix: resolve 8 bugs in core model and inference code#244

Open
sjhddh wants to merge 1 commit intoshiyu-coder:masterfrom
sjhddh:fix/model-bugs
Open

fix: resolve 8 bugs in core model and inference code#244
sjhddh wants to merge 1 commit intoshiyu-coder:masterfrom
sjhddh:fix/model-bugs

Conversation

@sjhddh
Copy link
Copy Markdown

@sjhddh sjhddh commented Apr 13, 2026

Summary

Systematic audit of model/module.py and model/kronos.py uncovered 8 bugs, several with crash potential:

  • requires_grad typo in FixedEmbeddingw.require_grad = False is a silent no-op (should be requires_grad)
  • Gradient count mismatch in DifferentiableEntropyFunction.backward — returns 5 values for 4 forward args
  • int() unpacking crash in get_codebook_entry / get_group_codebook_entryh, w = int(...) can't unpack a scalar
  • NameError when soft_entropy=Falseavg_prob referenced but never assigned in that branch
  • is_causal=True in cross-attention — semantically wrong (cross-attn should attend to all keys) and crashes when attn_mask is also provided
  • None comparison crash in sample_from_logitstop_k > 0 raises TypeError when top_k is None
  • In-place logit mutation in top_k_top_p_filtering — caller's tensor silently modified
  • Missing @torch.no_grad() on KronosPredictor.generate/predict/predict_batch

Also:

  • Replaced wildcard from model.module import * with explicit imports
  • Fixed self-referential docstrings on indexes_to_codes / group_indexes_to_codes
  • Removed commented-out code blocks
  • Extracted shared _validate_and_prepare helper to deduplicate DataFrame validation

Test plan

  • Existing regression tests still pass (pytest tests/test_kronos_regression.py)
  • FixedEmbedding weights are now truly frozen (verify requires_grad == False)
  • sample_from_logits(logits, top_k=None, top_p=None) no longer crashes
  • top_k_top_p_filtering does not mutate the input tensor

🤖 Generated with Claude Code

- Fix requires_grad typo in FixedEmbedding (was no-op)
- Fix backward() gradient count mismatch in DifferentiableEntropyFunction
- Fix int() unpacking crash in codebook entry methods
- Fix undefined avg_prob when soft_entropy disabled
- Fix is_causal misuse in cross-attention (semantic bug + crash risk)
- Fix sample_from_logits crash on None top_k/top_p
- Prevent in-place mutation of logits in top_k_top_p_filtering
- Add @torch.no_grad() to inference methods
- Replace wildcard import with explicit imports
- Extract shared DataFrame validation helper

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant