Skip to content

Commit f6e23eb

Browse files
authored
fix bug in sentence piece (#77)
* fix bug in sentence piece * fmt
1 parent 7eadc74 commit f6e23eb

5 files changed

Lines changed: 139 additions & 16 deletions

File tree

src/decode/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ pub mod tokens;
55

66
pub use ctc::{ctc_greedy_decode, CtcDecoderResult};
77
pub use greedy::GreedyDecoder;
8-
pub use sentencepiece::sentencepiece_to_text;
8+
pub use sentencepiece::{parse_byte_token, sentencepiece_to_text};
99
pub use tokens::{load_vocab, SymbolTable};

src/decode/sentencepiece.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,84 @@ pub fn sentencepiece_to_text(tokens: &[&str]) -> String {
1111
// Clean up contraction spacing (e.g. "can 't" → "can't")
1212
text.replace(" '", "'")
1313
}
14+
15+
/// Parse a byte-level BPE token like `<0xE5>` into its byte value.
16+
///
17+
/// SentencePiece tokenizers emit these for characters outside the base vocabulary
18+
/// (e.g. CJK characters are split into individual UTF-8 bytes).
19+
pub fn parse_byte_token(token: &str) -> Option<u8> {
20+
if token.starts_with("<0x") && token.ends_with('>') && token.len() == 6 {
21+
let hex = &token[3..5];
22+
u8::from_str_radix(hex, 16).ok()
23+
} else {
24+
None
25+
}
26+
}
27+
28+
#[cfg(test)]
29+
mod tests {
30+
use super::*;
31+
32+
#[test]
33+
fn test_parse_byte_token_valid() {
34+
assert_eq!(parse_byte_token("<0xE5>"), Some(0xE5));
35+
assert_eq!(parse_byte_token("<0xB0>"), Some(0xB0));
36+
assert_eq!(parse_byte_token("<0xBC>"), Some(0xBC));
37+
assert_eq!(parse_byte_token("<0x00>"), Some(0x00));
38+
assert_eq!(parse_byte_token("<0xFF>"), Some(0xFF));
39+
}
40+
41+
#[test]
42+
fn test_parse_byte_token_invalid() {
43+
assert_eq!(parse_byte_token("hello"), None);
44+
assert_eq!(parse_byte_token("<|en|>"), None);
45+
assert_eq!(parse_byte_token("<0x>"), None);
46+
assert_eq!(parse_byte_token("<0xEE"), None); // missing >
47+
assert_eq!(parse_byte_token("<0xGG>"), None); // invalid hex
48+
}
49+
50+
#[test]
51+
fn test_byte_tokens_reassemble_chinese() {
52+
// 尼 = E5 B0 BC, 豪 = E8 B1 AA
53+
// Simulates what Cohere's decode_ids does with byte tokens
54+
let tokens = vec!["<0xE5>", "<0xB0>", "<0xBC>", "豪", "。"];
55+
let mut bytes: Vec<u8> = Vec::new();
56+
for token in &tokens {
57+
if let Some(byte_val) = parse_byte_token(token) {
58+
bytes.push(byte_val);
59+
} else {
60+
bytes.extend(token.as_bytes());
61+
}
62+
}
63+
let text = String::from_utf8_lossy(&bytes);
64+
assert_eq!(text, "尼豪。");
65+
}
66+
67+
#[test]
68+
fn test_byte_tokens_full_cjk_sequence() {
69+
// 你好 = E4 BD A0 E5 A5 BD
70+
let tokens = vec!["<0xE4>", "<0xBD>", "<0xA0>", "<0xE5>", "<0xA5>", "<0xBD>"];
71+
let mut bytes: Vec<u8> = Vec::new();
72+
for token in &tokens {
73+
if let Some(byte_val) = parse_byte_token(token) {
74+
bytes.push(byte_val);
75+
} else {
76+
bytes.extend(token.as_bytes());
77+
}
78+
}
79+
let text = String::from_utf8_lossy(&bytes);
80+
assert_eq!(text, "你好");
81+
}
82+
83+
#[test]
84+
fn test_sentencepiece_to_text_basic() {
85+
let tokens = vec![" Hello", " world"];
86+
assert_eq!(sentencepiece_to_text(&tokens), "Hello world");
87+
}
88+
89+
#[test]
90+
fn test_sentencepiece_to_text_contractions() {
91+
let tokens = vec![" can", " 't"];
92+
assert_eq!(sentencepiece_to_text(&tokens), "can't");
93+
}
94+
}

src/onnx/cohere/mod.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use ort::session::SessionInputValue;
1010
use ort::value::DynValue;
1111

1212
use super::{session, Quantization};
13-
use crate::decode::{load_vocab, sentencepiece_to_text, GreedyDecoder};
13+
use crate::decode::{load_vocab, parse_byte_token, GreedyDecoder};
1414
use crate::{
1515
ModelCapabilities, SpeechModel, TranscribeError, TranscribeOptions, TranscriptionResult,
1616
};
@@ -294,7 +294,7 @@ impl CohereModel {
294294
}
295295

296296
fn decode_ids(&self, token_ids: &[i64]) -> String {
297-
let pieces = token_ids
297+
let tokens: Vec<&str> = token_ids
298298
.iter()
299299
.filter_map(|&id| self.vocab.get(id as usize))
300300
.filter(|token| {
@@ -304,9 +304,24 @@ impl CohereModel {
304304
&& token.as_str() != "<pad>"
305305
})
306306
.map(|token| token.as_str())
307-
.collect::<Vec<_>>();
307+
.collect();
308+
309+
// Handle byte-level BPE tokens (<0xNN>) by collecting into a byte buffer.
310+
// SentencePiece tokenizers emit these for characters outside the base vocabulary
311+
// (e.g. CJK characters are split into individual UTF-8 bytes).
312+
let mut bytes: Vec<u8> = Vec::new();
313+
for token in &tokens {
314+
if let Some(byte_val) = parse_byte_token(token) {
315+
bytes.push(byte_val);
316+
} else {
317+
bytes.extend(token.as_bytes());
318+
}
319+
}
308320

309-
sentencepiece_to_text(&pieces)
321+
let text = String::from_utf8_lossy(&bytes);
322+
let text = text.trim();
323+
// Clean up contraction spacing (e.g. "can 't" → "can't")
324+
text.replace(" '", "'")
310325
}
311326

312327
fn decoder_input_name(&self, preferred: &str, fallbacks: &[&str]) -> String {

src/onnx/moonshine/model.rs

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::fs::File;
77
use std::io::BufReader;
88
use std::path::Path;
99

10-
use crate::decode::GreedyDecoder;
10+
use crate::decode::{parse_byte_token, GreedyDecoder};
1111
use crate::onnx::session;
1212
use crate::onnx::Quantization;
1313
use crate::{
@@ -417,7 +417,7 @@ impl MoonshineTokenizer {
417417
let mut bytes: Vec<u8> = Vec::new();
418418

419419
for token in &tokens {
420-
if let Some(byte_val) = Self::parse_byte_token(token) {
420+
if let Some(byte_val) = parse_byte_token(token) {
421421
bytes.push(byte_val);
422422
} else {
423423
let decoded = token.replace('\u{2581}', " ");
@@ -430,13 +430,4 @@ impl MoonshineTokenizer {
430430

431431
Ok(text.to_string())
432432
}
433-
434-
fn parse_byte_token(token: &str) -> Option<u8> {
435-
if token.starts_with("<0x") && token.ends_with('>') && token.len() == 6 {
436-
let hex = &token[3..5];
437-
u8::from_str_radix(hex, 16).ok()
438-
} else {
439-
None
440-
}
441-
}
442433
}

tests/cohere.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,39 @@ fn test_cohere_german() {
6565
result.text
6666
);
6767
}
68+
69+
#[test]
70+
fn test_cohere_chinese() {
71+
let model_path = PathBuf::from("models/cohere-int4");
72+
let audio_path = PathBuf::from("samples/chinese.wav");
73+
74+
if !common::require_paths(&[&model_path, &audio_path]) {
75+
return;
76+
}
77+
78+
let mut model =
79+
CohereModel::load(&model_path, &Quantization::Int4).expect("Failed to load Cohere model");
80+
81+
let result = model
82+
.transcribe_file(
83+
&audio_path,
84+
&transcribe_rs::TranscribeOptions {
85+
language: Some("zh".into()),
86+
..Default::default()
87+
},
88+
)
89+
.expect("Failed to transcribe Chinese audio with Cohere model");
90+
91+
println!("Chinese transcription: {}", result.text);
92+
assert!(
93+
!result.text.trim().is_empty(),
94+
"Cohere Chinese transcription should not be empty"
95+
);
96+
97+
// Output should contain actual Chinese characters, not byte tokens like <0xE5>
98+
assert!(
99+
!result.text.contains("<0x"),
100+
"Chinese transcription should not contain raw byte tokens, got: '{}'",
101+
result.text
102+
);
103+
}

0 commit comments

Comments
 (0)