Skip to content

Commit 2d7ac18

Browse files
praxeocjpais
andauthored
adds cohere-transcribe INT4/INT8 via onnx runtime (#75)
* feat: add cohere onnx integration * fix: scope encoder_outputs to release mutable borrow before decoder loop * eliminate per-token KV cache clones in decoder loop * add common greedy decoder, cleanup lib * int4 quant in lib * minor * Update README.md --------- Co-authored-by: praxeo <praxeo@users.noreply.github.com> Co-authored-by: CJ Pais <cj@cjpais.com>
1 parent 2eea053 commit 2d7ac18

13 files changed

Lines changed: 689 additions & 63 deletions

File tree

Cargo.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ required-features = ["onnx"]
107107
name = "canary"
108108
required-features = ["onnx"]
109109

110+
[[example]]
111+
name = "cohere"
112+
required-features = ["onnx"]
113+
110114
[[example]]
111115
name = "whisperfile"
112116
required-features = ["whisperfile"]
@@ -143,6 +147,10 @@ required-features = ["onnx"]
143147
name = "canary"
144148
required-features = ["onnx"]
145149

150+
[[test]]
151+
name = "cohere"
152+
required-features = ["onnx"]
153+
146154
[[test]]
147155
name = "whisperfile"
148156
required-features = ["whisperfile"]

README.md

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# transcribe-rs
22

3-
Multi-engine speech-to-text library for Rust. Supports Parakeet, Canary, Moonshine, SenseVoice, GigaAM, Whisper, Whisperfile, and OpenAI.
3+
Multi-engine speech-to-text library for Rust. Supports Parakeet, Canary, Cohere, Moonshine, SenseVoice, GigaAM, Whisper, Whisperfile, and OpenAI.
44

55
## Breaking Changes in 0.3.0
66

@@ -24,7 +24,7 @@ No features are enabled by default. Pick the engines you need:
2424

2525
| Feature | Engines |
2626
|---------|---------|
27-
| `onnx` | Parakeet, Canary, Moonshine, SenseVoice, GigaAM (via ONNX Runtime) |
27+
| `onnx` | Parakeet, Canary, Cohere, Moonshine, SenseVoice, GigaAM (via ONNX Runtime) |
2828
| `whisper-cpp` | Whisper (local, GGML via whisper.cpp with Metal/Vulkan/CUDA) |
2929
| `whisperfile` | Whisperfile (local server wrapper) |
3030
| `openai` | OpenAI API (remote, async) |
@@ -143,6 +143,30 @@ Model variant (Flash vs V2) is auto-detected from vocabulary size. Flash models
143143
- **ITN** (inverse text normalization) — enabled by default. Converts spoken numbers to written form (e.g. "one hundred twenty three" becomes "123"). Set `use_itn: false` to disable. Only supported on V2 models; silently ignored on Flash.
144144
- **Translation** — set `target_language` to translate between supported languages.
145145

146+
### Cohere
147+
148+
```rust
149+
use transcribe_rs::onnx::cohere::{CohereModel, CohereParams};
150+
use transcribe_rs::onnx::Quantization;
151+
use std::path::PathBuf;
152+
153+
let mut model = CohereModel::load(
154+
&PathBuf::from("models/cohere-int4"),
155+
&Quantization::Int4,
156+
)?;
157+
158+
let samples = transcribe_rs::audio::read_wav_samples(&PathBuf::from("audio.wav"))?;
159+
let result = model.transcribe_with(
160+
&samples,
161+
&CohereParams {
162+
language: Some("en".to_string()),
163+
..Default::default()
164+
},
165+
)?;
166+
```
167+
168+
Available in int4 and int8 quantizations.
169+
146170
### SenseVoice
147171

148172
```rust
@@ -295,6 +319,8 @@ All audio input must be **16 kHz, mono, 16-bit PCM WAV**.
295319
| Canary 180M Flash | [HuggingFace](https://huggingface.co/istupakov/canary-180m-flash-onnx) |
296320
| Canary 1B Flash | [HuggingFace](https://huggingface.co/istupakov/canary-1b-flash-onnx) |
297321
| Canary 1B v2 | [HuggingFace](https://huggingface.co/istupakov/canary-1b-v2-onnx) |
322+
| Cohere (int4) | [HuggingFace](https://huggingface.co/cstr/cohere-transcribe-onnx-int4) |
323+
| Cohere (int8) | [HuggingFace](https://huggingface.co/tristanripke/cohere-transcribe-onnx-int8) |
298324
| SenseVoice (int8) | [blob.handy.computer](https://blob.handy.computer/sense-voice-int8.tar.gz) / [sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models) |
299325
| Moonshine | [blob.handy.computer (base)](https://blob.handy.computer/moonshine-base.tar.gz), [blob.handy.computer (tiny streaming en)](https://blob.handy.computer/moonshine-tiny-streaming-en.tar.gz), [blob.handy.computer (small streaming en)](https://blob.handy.computer/moonshine-small-streaming-en.tar.gz), [blob.handy.computer (medium streaming en)](https://blob.handy.computer/moonshine-medium-streaming-en.tar.gz) |
300326
| GigaAM | [HuggingFace](https://huggingface.co/istupakov/gigaam-v3-onnx/tree/main) |
@@ -321,6 +347,16 @@ models/canary-1b-v2/
321347
└── vocab.txt
322348
```
323349

350+
**Cohere** (directory):
351+
```
352+
models/cohere-int4/
353+
├── cohere-encoder.int4.onnx
354+
├── cohere-encoder.int4.onnx.data
355+
├── cohere-decoder.int4.onnx
356+
├── cohere-decoder.int4.onnx.data
357+
└── tokens.txt
358+
```
359+
324360
**SenseVoice** (directory):
325361
```
326362
models/sense-voice/
@@ -375,6 +411,7 @@ Each engine has an example in `examples/`. Run with the appropriate feature flag
375411
```bash
376412
cargo run --example parakeet --features onnx
377413
cargo run --example canary --features onnx
414+
cargo run --example cohere --features onnx
378415
cargo run --example sense_voice --features onnx
379416
cargo run --example moonshine --features onnx
380417
cargo run --example moonshine_streaming --features onnx

examples/canary.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
3434
positional
3535
.get(1)
3636
.map(|s| s.as_str())
37-
.unwrap_or("samples/jfk.wav"),
37+
.unwrap_or("samples/dots.wav"),
3838
);
3939

4040
let audio_duration = get_audio_duration(&wav_path)?;

examples/cohere.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
use std::path::PathBuf;
2+
use std::time::Instant;
3+
4+
use transcribe_rs::onnx::cohere::CohereModel;
5+
use transcribe_rs::onnx::Quantization;
6+
use transcribe_rs::SpeechModel;
7+
8+
fn get_audio_duration(path: &PathBuf) -> Result<f64, Box<dyn std::error::Error>> {
9+
let reader = hound::WavReader::open(path)?;
10+
let spec = reader.spec();
11+
let duration = reader.duration() as f64 / spec.sample_rate as f64;
12+
Ok(duration)
13+
}
14+
15+
fn main() -> Result<(), Box<dyn std::error::Error>> {
16+
env_logger::init();
17+
18+
let args: Vec<String> = std::env::args().collect();
19+
let quant = args.get(1).map(|s| s.as_str()).unwrap_or("int8");
20+
21+
let (model_path, quantization) = match quant {
22+
"int4" => ("models/cohere-int4", Quantization::Int4),
23+
"int8" => ("models/cohere-int8", Quantization::Int8),
24+
other => {
25+
eprintln!("Unknown quantization: {other}. Use 'int4' or 'int8'.");
26+
std::process::exit(1);
27+
}
28+
};
29+
30+
let model_path = PathBuf::from(model_path);
31+
let wav_path = PathBuf::from("samples/dots.wav");
32+
33+
let audio_duration = get_audio_duration(&wav_path)?;
34+
println!("Audio duration: {:.2}s", audio_duration);
35+
36+
println!("Using Cohere ONNX engine ({quant})");
37+
println!("Loading model: {:?}", model_path);
38+
39+
let load_start = Instant::now();
40+
let mut model = CohereModel::load(&model_path, &quantization)?;
41+
let load_duration = load_start.elapsed();
42+
println!("Model loaded in {:.2?}", load_duration);
43+
44+
println!("Transcribing file: {:?}", wav_path);
45+
let transcribe_start = Instant::now();
46+
let result = model.transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default())?;
47+
let transcribe_duration = transcribe_start.elapsed();
48+
println!("Transcription completed in {:.2?}", transcribe_duration);
49+
50+
let speedup_factor = audio_duration / transcribe_duration.as_secs_f64();
51+
println!(
52+
"Real-time speedup: {:.2}x faster than real-time",
53+
speedup_factor
54+
);
55+
println!("Transcription result:\n{}", result.text);
56+
57+
Ok(())
58+
}

src/decode/greedy.rs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/// Greedy autoregressive token selection with repetition detection.
2+
///
3+
/// Wraps the common argmax + EOS + repeat-guard pattern shared by all
4+
/// autoregressive decoder engines (Canary, Moonshine, Cohere).
5+
///
6+
/// Each engine still owns its KV cache and decoder session — this struct
7+
/// only handles token selection and stopping decisions.
8+
9+
const DEFAULT_MAX_CONSECUTIVE_REPEATS: usize = 8;
10+
11+
pub struct GreedyDecoder {
12+
eos_id: i64,
13+
max_consecutive_repeats: usize,
14+
last_token: i64,
15+
consecutive_count: usize,
16+
}
17+
18+
impl GreedyDecoder {
19+
pub fn new(eos_id: i64) -> Self {
20+
Self {
21+
eos_id,
22+
max_consecutive_repeats: DEFAULT_MAX_CONSECUTIVE_REPEATS,
23+
last_token: -1,
24+
consecutive_count: 0,
25+
}
26+
}
27+
28+
pub fn with_max_repeats(mut self, n: usize) -> Self {
29+
self.max_consecutive_repeats = n;
30+
self
31+
}
32+
33+
/// Given logits for the last decoder position, pick the next token.
34+
///
35+
/// Returns `Some(token_id)` to continue decoding, or `None` to stop
36+
/// (EOS reached or repetition limit hit).
37+
pub fn next_token(&mut self, logits: &[f32]) -> Option<i64> {
38+
let token = argmax(logits) as i64;
39+
40+
if token == self.eos_id {
41+
return None;
42+
}
43+
44+
if token == self.last_token {
45+
self.consecutive_count += 1;
46+
if self.consecutive_count > self.max_consecutive_repeats {
47+
log::warn!(
48+
"Greedy decode: token {} repeated {} consecutive times, stopping",
49+
token,
50+
self.consecutive_count
51+
);
52+
return None;
53+
}
54+
} else {
55+
self.consecutive_count = 1;
56+
}
57+
58+
self.last_token = token;
59+
Some(token)
60+
}
61+
}
62+
63+
fn argmax(logits: &[f32]) -> usize {
64+
let mut max_idx = 0;
65+
let mut max_val = f32::NEG_INFINITY;
66+
for (i, &v) in logits.iter().enumerate() {
67+
if v > max_val {
68+
max_val = v;
69+
max_idx = i;
70+
}
71+
}
72+
max_idx
73+
}
74+
75+
#[cfg(test)]
76+
mod tests {
77+
use super::*;
78+
79+
#[test]
80+
fn test_argmax() {
81+
assert_eq!(argmax(&[1.0, 3.0, 2.0]), 1);
82+
assert_eq!(argmax(&[-1.0, -3.0, -0.5]), 2);
83+
assert_eq!(argmax(&[5.0]), 0);
84+
}
85+
86+
#[test]
87+
fn test_eos_stops() {
88+
let mut dec = GreedyDecoder::new(2);
89+
// logits where token 2 (EOS) wins
90+
assert_eq!(dec.next_token(&[0.0, 0.0, 10.0, 0.0]), None);
91+
}
92+
93+
#[test]
94+
fn test_normal_token() {
95+
let mut dec = GreedyDecoder::new(2);
96+
assert_eq!(dec.next_token(&[0.0, 10.0, 0.0, 0.0]), Some(1));
97+
}
98+
99+
#[test]
100+
fn test_repeat_limit() {
101+
let mut dec = GreedyDecoder::new(99).with_max_repeats(3);
102+
let logits = [0.0, 10.0, 0.0]; // always picks token 1
103+
assert_eq!(dec.next_token(&logits), Some(1)); // count=1
104+
assert_eq!(dec.next_token(&logits), Some(1)); // count=2
105+
assert_eq!(dec.next_token(&logits), Some(1)); // count=3
106+
assert_eq!(dec.next_token(&logits), None); // count=4 > 3 → stop
107+
}
108+
109+
#[test]
110+
fn test_repeat_resets_on_different_token() {
111+
let mut dec = GreedyDecoder::new(99).with_max_repeats(3);
112+
assert_eq!(dec.next_token(&[0.0, 10.0, 0.0]), Some(1)); // count=1
113+
assert_eq!(dec.next_token(&[0.0, 10.0, 0.0]), Some(1)); // count=2
114+
assert_eq!(dec.next_token(&[10.0, 0.0, 0.0]), Some(0)); // different, count=1
115+
assert_eq!(dec.next_token(&[0.0, 10.0, 0.0]), Some(1)); // count=1
116+
assert_eq!(dec.next_token(&[0.0, 10.0, 0.0]), Some(1)); // count=2
117+
assert_eq!(dec.next_token(&[0.0, 10.0, 0.0]), Some(1)); // count=3
118+
assert_eq!(dec.next_token(&[0.0, 10.0, 0.0]), None); // count=4 > 3 → stop
119+
}
120+
121+
#[test]
122+
fn test_nan_handling() {
123+
let mut dec = GreedyDecoder::new(99);
124+
// NaN logits — argmax uses `>` which is false for NaN, so index 0 wins
125+
assert_eq!(dec.next_token(&[f32::NAN, f32::NAN, f32::NAN]), Some(0));
126+
}
127+
}

src/decode/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
mod ctc;
2+
mod greedy;
23
mod sentencepiece;
34
pub mod tokens;
45

56
pub use ctc::{ctc_greedy_decode, CtcDecoderResult};
7+
pub use greedy::GreedyDecoder;
68
pub use sentencepiece::sentencepiece_to_text;
79
pub use tokens::{load_vocab, SymbolTable};

src/onnx/canary/decoder.rs

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use ort::value::ValueType;
44
use ort::value::{DynValue, Tensor};
55

66
use super::vocab::Vocab;
7+
use crate::decode::GreedyDecoder;
78
use crate::TranscribeError;
89

910
pub fn decode_autoregressive(
@@ -26,6 +27,7 @@ pub fn decode_autoregressive(
2627
let mut decoder_mems: DynValue = Tensor::from_array(empty_cache)?.into_dyn();
2728

2829
let eos_id = vocab.eos_token_id();
30+
let mut greedy = GreedyDecoder::new(eos_id);
2931
let mut all_tokens = prompt_tokens;
3032

3133
// Limit decode steps so total tokens (prompt + generated) stays within
@@ -58,7 +60,7 @@ pub fn decode_autoregressive(
5860
])?;
5961

6062
// Extract logits in a scoped borrow, then release before remove()
61-
let next_token = {
63+
let last_logits = {
6264
let (logits_shape, logits_data) =
6365
outputs["logits"].try_extract_tensor::<f32>().map_err(|e| {
6466
TranscribeError::Inference(format!("Failed to extract logits: {e}"))
@@ -68,18 +70,19 @@ pub fn decode_autoregressive(
6870
let vocab_size = logits_shape[2] as usize;
6971

7072
let last_step_offset = (seq_len - 1) * vocab_size;
71-
let last_logits = &logits_data[last_step_offset..last_step_offset + vocab_size];
73+
logits_data[last_step_offset..last_step_offset + vocab_size].to_vec()
74+
};
7275

73-
argmax(last_logits) as i64
76+
let next_token = match greedy.next_token(&last_logits) {
77+
Some(t) => t,
78+
None => {
79+
log::debug!("Decode stopped at step {}", step);
80+
break;
81+
}
7482
};
7583

7684
log::debug!("Step {}: predicted token ID {}", step, next_token);
7785

78-
if next_token == eos_id {
79-
log::debug!("EOS token reached at step {}", step);
80-
break;
81-
}
82-
8386
all_tokens.push(next_token);
8487

8588
// Take the KV cache directly from outputs (Arc clone, no data copy)
@@ -129,27 +132,3 @@ fn extract_decoder_mems_shape(decoder: &Session) -> Result<(usize, usize), Trans
129132
))),
130133
}
131134
}
132-
133-
fn argmax(slice: &[f32]) -> usize {
134-
let mut max_idx = 0;
135-
let mut max_val = f32::NEG_INFINITY;
136-
for (i, &v) in slice.iter().enumerate() {
137-
if v > max_val {
138-
max_val = v;
139-
max_idx = i;
140-
}
141-
}
142-
max_idx
143-
}
144-
145-
#[cfg(test)]
146-
mod tests {
147-
use super::*;
148-
149-
#[test]
150-
fn test_argmax() {
151-
assert_eq!(argmax(&[1.0, 3.0, 2.0]), 1);
152-
assert_eq!(argmax(&[-1.0, -3.0, -0.5]), 2);
153-
assert_eq!(argmax(&[5.0]), 0);
154-
}
155-
}

0 commit comments

Comments
 (0)