Skip to content

Commit 057884a

Browse files
committed
fix: align Cohere Transcribe API with repo conventions
- Add Quantization parameter to load() for API consistency with other ONNX models (ignored since no quantized variants exist) - Change max_new_tokens default from 448 to 256 to match Python reference - Use #[derive(Default)] for CohereTranscribeParams per PORTING.md - Add HuggingFace download links in README - Accept optional CLI arg in example for testing custom wav files - Update doc comments with accurate audio length behavior
1 parent b60be4d commit 057884a

4 files changed

Lines changed: 31 additions & 19 deletions

File tree

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,13 @@ let result = model.transcribe_file(&PathBuf::from("audio.wav"), &transcribe_rs::
216216

217217
```rust
218218
use transcribe_rs::onnx::cohere_transcribe::{CohereTranscribeModel, CohereTranscribeParams};
219+
use transcribe_rs::onnx::Quantization;
219220
use std::path::PathBuf;
220221

221-
let mut model = CohereTranscribeModel::load(&PathBuf::from("models/cohere-transcribe"))?;
222+
let mut model = CohereTranscribeModel::load(
223+
&PathBuf::from("models/cohere-transcribe"),
224+
&Quantization::default(),
225+
)?;
222226

223227
let samples = transcribe_rs::audio::read_wav_samples(&PathBuf::from("audio.wav"))?;
224228
let result = model.transcribe_with(
@@ -317,7 +321,7 @@ All audio input must be **16 kHz, mono, 16-bit PCM WAV**.
317321
| SenseVoice (int8) | [blob.handy.computer](https://blob.handy.computer/sense-voice-int8.tar.gz) / [sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models) |
318322
| Moonshine | [blob.handy.computer (base)](https://blob.handy.computer/moonshine-base.tar.gz), [blob.handy.computer (tiny streaming en)](https://blob.handy.computer/moonshine-tiny-streaming-en.tar.gz), [blob.handy.computer (small streaming en)](https://blob.handy.computer/moonshine-small-streaming-en.tar.gz), [blob.handy.computer (medium streaming en)](https://blob.handy.computer/moonshine-medium-streaming-en.tar.gz) |
319323
| GigaAM | [HuggingFace](https://huggingface.co/istupakov/gigaam-v3-onnx/tree/main) |
320-
| Cohere Transcribe | Model-specific ONNX export plus `vocab.txt` |
324+
| Cohere Transcribe | [HuggingFace](https://huggingface.co/eschmidbauer/cohere-transcribe-03-2026-onnx) (ONNX export of [CohereLabs/cohere-transcribe-03-2026](https://huggingface.co/CohereLabs/cohere-transcribe-03-2026)) |
321325
| Whisper (GGML) | [HuggingFace](https://huggingface.co/ggerganov/whisper.cpp/tree/main) |
322326
| Whisperfile binary | [GitHub](https://github.com/mozilla-ai/llamafile/releases/download/0.9.3/whisperfile-0.9.3) |
323327

examples/cohere_transcribe.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::path::PathBuf;
22
use std::time::Instant;
33

44
use transcribe_rs::onnx::cohere_transcribe::{CohereTranscribeModel, CohereTranscribeParams};
5+
use transcribe_rs::onnx::Quantization;
56

67
fn get_audio_duration(path: &PathBuf) -> Result<f64, Box<dyn std::error::Error>> {
78
let reader = hound::WavReader::open(path)?;
@@ -14,14 +15,17 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
1415
env_logger::init();
1516

1617
let model_path = PathBuf::from("models/cohere-transcribe");
17-
let wav_path = PathBuf::from("samples/jfk.wav");
18+
let wav_path = std::env::args()
19+
.nth(1)
20+
.map(PathBuf::from)
21+
.unwrap_or_else(|| PathBuf::from("samples/jfk.wav"));
1822

1923
let audio_duration = get_audio_duration(&wav_path)?;
2024
println!("Audio duration: {:.2}s", audio_duration);
2125

2226
// Load
2327
let load_start = Instant::now();
24-
let mut model = CohereTranscribeModel::load(&model_path)?;
28+
let mut model = CohereTranscribeModel::load(&model_path, &Quantization::default())?;
2529
println!("Model loaded in {:.2?}", load_start.elapsed());
2630

2731
// Transcribe

src/onnx/cohere_transcribe/mod.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use ort::value::{DynValue, Tensor};
1010

1111
use self::decoder::decode_autoregressive;
1212
use self::vocab::Vocab;
13+
use super::Quantization;
1314
use crate::features::{compute_mel, MelConfig, WindowType};
1415
use crate::{
1516
ModelCapabilities, SpeechModel, TranscribeError, TranscribeOptions, TranscriptionResult,
@@ -51,21 +52,12 @@ fn mel_config() -> MelConfig {
5152
}
5253

5354
/// Per-model inference parameters for Cohere Transcribe.
54-
#[derive(Debug, Clone)]
55+
#[derive(Debug, Clone, Default)]
5556
pub struct CohereTranscribeParams {
5657
/// Source language (ISO-639-1, e.g. "en"). Defaults to "en".
5758
pub language: Option<String>,
58-
/// Maximum number of new tokens to generate. Defaults to 448.
59-
pub max_new_tokens: usize,
60-
}
61-
62-
impl Default for CohereTranscribeParams {
63-
fn default() -> Self {
64-
Self {
65-
language: None,
66-
max_new_tokens: 448,
67-
}
68-
}
59+
/// Maximum number of new tokens to generate. Defaults to 256.
60+
pub max_new_tokens: Option<usize>,
6961
}
7062

7163
/// Cohere Transcribe speech model backed by ONNX sessions.
@@ -87,7 +79,10 @@ pub struct CohereTranscribeModel {
8779

8880
impl CohereTranscribeModel {
8981
/// Load a Cohere Transcribe model from `model_dir`.
90-
pub fn load(model_dir: &Path) -> Result<Self, TranscribeError> {
82+
///
83+
/// The `quantization` parameter is accepted for API consistency but currently
84+
/// ignored since no quantized variants of this model are available.
85+
pub fn load(model_dir: &Path, _quantization: &Quantization) -> Result<Self, TranscribeError> {
9186
if !model_dir.exists() {
9287
return Err(TranscribeError::ModelNotFound(model_dir.to_path_buf()));
9388
}
@@ -139,12 +134,19 @@ impl CohereTranscribeModel {
139134
}
140135

141136
/// Transcribe with model-specific parameters.
137+
///
138+
/// The upstream config specifies `max_audio_clip_s = 35`, but the ONNX encoder
139+
/// accepts longer audio. The autoregressive decoder is the practical limit:
140+
/// it stops at EOS or `max_new_tokens` (default 256), so very long audio may
141+
/// be truncated. For long-form transcription, use a chunked transcriber
142+
/// (e.g. `VadChunkedTranscriber` or `EnergyAdaptiveTranscriber`).
142143
pub fn transcribe_with(
143144
&mut self,
144145
samples: &[f32],
145146
params: &CohereTranscribeParams,
146147
) -> Result<TranscriptionResult, TranscribeError> {
147148
let language = params.language.as_deref().unwrap_or("en");
149+
let max_new_tokens = params.max_new_tokens.unwrap_or(256);
148150
let total_start = Instant::now();
149151

150152
// Step 1: Compute mel features
@@ -212,7 +214,7 @@ impl CohereTranscribeModel {
212214
src_len,
213215
prompt_tokens,
214216
&self.vocab,
215-
params.max_new_tokens,
217+
max_new_tokens,
216218
)?;
217219

218220
log::debug!("Decoding completed in {:.2?}", decode_start.elapsed());

tests/cohere_transcribe.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ mod common;
22

33
use std::path::PathBuf;
44
use transcribe_rs::onnx::cohere_transcribe::CohereTranscribeModel;
5+
use transcribe_rs::onnx::Quantization;
56
use transcribe_rs::SpeechModel;
67

78
#[test]
@@ -15,7 +16,8 @@ fn test_cohere_transcribe() {
1516
return;
1617
}
1718

18-
let mut model = CohereTranscribeModel::load(&model_dir).expect("Failed to load model");
19+
let mut model =
20+
CohereTranscribeModel::load(&model_dir, &Quantization::default()).expect("Failed to load model");
1921

2022
let result = model
2123
.transcribe_file(&wav_path, &transcribe_rs::TranscribeOptions::default())

0 commit comments

Comments
 (0)