Skip to content

Commit 2d2a5c6

Browse files
committed
add tensor-rt
1 parent dcfa730 commit 2d2a5c6

4 files changed

Lines changed: 38 additions & 3 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "transcribe-rs"
3-
version = "0.3.9"
3+
version = "0.3.10"
44
build = "build.rs"
55
edition = "2021"
66
description = "A simple library to help you transcribe audio"
@@ -37,12 +37,13 @@ vad-silero = ["dep:ort", "dep:ndarray"]
3737
# Note: ort-cuda pulls in the CUDA execution provider, which adds ~800 MB+
3838
# to the ORT binary and requires a CUDA toolkit / cuDNN installation at runtime.
3939
ort-cuda = ["onnx", "ort/cuda"]
40+
ort-tensorrt = ["ort-cuda", "ort/tensorrt"]
4041
ort-directml = ["onnx", "ort/directml"]
4142
ort-rocm = ["onnx", "ort/rocm"]
4243
ort-coreml = ["onnx", "ort/coreml"]
4344
ort-webgpu = ["onnx", "ort/webgpu"]
4445
ort-tracing = ["onnx", "ort/tracing"]
45-
ort-accel = ["ort-cuda", "ort-directml", "ort-rocm", "ort-coreml", "ort-webgpu"]
46+
ort-accel = ["ort-cuda", "ort-tensorrt", "ort-directml", "ort-rocm", "ort-coreml", "ort-webgpu"]
4647

4748
# Convenience
4849
all = ["onnx", "whisper-cpp", "whisperfile", "openai"]

src/accel.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ pub enum OrtAccelerator {
3838
CpuOnly = 1,
3939
/// NVIDIA CUDA (requires `ort-cuda` feature; adds ~800 MB to binary size).
4040
Cuda = 2,
41+
/// NVIDIA TensorRT (requires `ort-tensorrt` feature; builds on CUDA with optimised graph compilation).
42+
#[serde(rename = "tensorrt", alias = "tensor_rt")]
43+
TensorRt = 7,
4144
/// Microsoft DirectML (Windows).
4245
#[serde(rename = "directml", alias = "direct_ml")]
4346
DirectMl = 3,
@@ -77,6 +80,9 @@ impl OrtAccelerator {
7780
#[cfg(feature = "ort-cuda")]
7881
v.push(OrtAccelerator::Cuda);
7982

83+
#[cfg(feature = "ort-tensorrt")]
84+
v.push(OrtAccelerator::TensorRt);
85+
8086
#[cfg(feature = "ort-directml")]
8187
v.push(OrtAccelerator::DirectMl);
8288

@@ -101,6 +107,7 @@ impl OrtAccelerator {
101107
4 => Self::Rocm,
102108
5 => Self::CoreMl,
103109
6 => Self::WebGpu,
110+
7 => Self::TensorRt,
104111
_ => Self::Auto,
105112
}
106113
}
@@ -118,6 +125,7 @@ impl fmt::Display for OrtAccelerator {
118125
Self::Auto => "auto",
119126
Self::CpuOnly => "cpu",
120127
Self::Cuda => "cuda",
128+
Self::TensorRt => "tensorrt",
121129
Self::DirectMl => "directml",
122130
Self::Rocm => "rocm",
123131
Self::CoreMl => "coreml",
@@ -135,6 +143,7 @@ impl FromStr for OrtAccelerator {
135143
"auto" => Ok(Self::Auto),
136144
"cpu" | "cpu_only" | "cpuonly" => Ok(Self::CpuOnly),
137145
"cuda" => Ok(Self::Cuda),
146+
"tensorrt" | "trt" | "tensor_rt" => Ok(Self::TensorRt),
138147
"directml" | "dml" => Ok(Self::DirectMl),
139148
"rocm" => Ok(Self::Rocm),
140149
"coreml" | "core_ml" => Ok(Self::CoreMl),
@@ -322,6 +331,7 @@ mod tests {
322331
OrtAccelerator::Auto,
323332
OrtAccelerator::CpuOnly,
324333
OrtAccelerator::Cuda,
334+
OrtAccelerator::TensorRt,
325335
OrtAccelerator::DirectMl,
326336
OrtAccelerator::Rocm,
327337
OrtAccelerator::CoreMl,
@@ -347,6 +357,10 @@ mod tests {
347357
"cpu_only".parse::<OrtAccelerator>().unwrap(),
348358
OrtAccelerator::CpuOnly
349359
);
360+
assert_eq!(
361+
"trt".parse::<OrtAccelerator>().unwrap(),
362+
OrtAccelerator::TensorRt
363+
);
350364
}
351365

352366
#[test]
@@ -360,6 +374,7 @@ mod tests {
360374
(OrtAccelerator::Auto, "\"auto\""),
361375
(OrtAccelerator::CpuOnly, "\"cpu\""),
362376
(OrtAccelerator::Cuda, "\"cuda\""),
377+
(OrtAccelerator::TensorRt, "\"tensorrt\""),
363378
(OrtAccelerator::DirectMl, "\"directml\""),
364379
(OrtAccelerator::Rocm, "\"rocm\""),
365380
(OrtAccelerator::CoreMl, "\"coreml\""),
@@ -380,6 +395,8 @@ mod tests {
380395
assert_eq!(old_cpu, OrtAccelerator::CpuOnly);
381396
let old_dml: OrtAccelerator = serde_json::from_str("\"direct_ml\"").unwrap();
382397
assert_eq!(old_dml, OrtAccelerator::DirectMl);
398+
let old_trt: OrtAccelerator = serde_json::from_str("\"tensor_rt\"").unwrap();
399+
assert_eq!(old_trt, OrtAccelerator::TensorRt);
383400
}
384401

385402
#[test]

src/onnx/session.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use ort::ep::CoreML;
44
use ort::ep::DirectML;
55
#[cfg(feature = "ort-rocm")]
66
use ort::ep::ROCm;
7+
#[cfg(feature = "ort-tensorrt")]
8+
use ort::ep::TensorRT;
79
#[cfg(feature = "ort-webgpu")]
810
use ort::ep::WebGPU;
911
use ort::ep::CPU;
@@ -33,6 +35,18 @@ fn execution_providers() -> Vec<ort::ep::ExecutionProviderDispatch> {
3335
"Accelerator set to CUDA but ort-cuda feature is not enabled; falling back to CPU"
3436
);
3537
}
38+
OrtAccelerator::TensorRt => {
39+
#[cfg(feature = "ort-tensorrt")]
40+
{
41+
eps.push(TensorRT::default().build());
42+
// CUDA as fallback for ops TensorRT doesn't support
43+
eps.push(CUDA::default().build());
44+
}
45+
#[cfg(not(feature = "ort-tensorrt"))]
46+
log::warn!(
47+
"Accelerator set to TensorRT but ort-tensorrt feature is not enabled; falling back to CPU"
48+
);
49+
}
3650
OrtAccelerator::DirectMl => {
3751
#[cfg(feature = "ort-directml")]
3852
eps.push(DirectML::default().build());
@@ -71,6 +85,9 @@ fn execution_providers() -> Vec<ort::ep::ExecutionProviderDispatch> {
7185
// to opt in.
7286
// Ref: https://onnxruntime.ai/docs/execution-providers/DirectML-ExecutionProvider.html
7387
// https://onnxruntime.ai/docs/execution-providers/WebGPU-ExecutionProvider.html
88+
// TensorRT before CUDA so it gets first crack; CUDA handles unsupported ops.
89+
#[cfg(feature = "ort-tensorrt")]
90+
eps.push(TensorRT::default().build());
7491
#[cfg(feature = "ort-cuda")]
7592
eps.push(CUDA::default().build());
7693
#[cfg(feature = "ort-rocm")]

0 commit comments

Comments
 (0)