@@ -38,6 +38,9 @@ pub enum OrtAccelerator {
3838 CpuOnly = 1 ,
3939 /// NVIDIA CUDA (requires `ort-cuda` feature; adds ~800 MB to binary size).
4040 Cuda = 2 ,
41+ /// NVIDIA TensorRT (requires `ort-tensorrt` feature; builds on CUDA with optimised graph compilation).
42+ #[ serde( rename = "tensorrt" , alias = "tensor_rt" ) ]
43+ TensorRt = 7 ,
4144 /// Microsoft DirectML (Windows).
4245 #[ serde( rename = "directml" , alias = "direct_ml" ) ]
4346 DirectMl = 3 ,
@@ -77,6 +80,9 @@ impl OrtAccelerator {
7780 #[ cfg( feature = "ort-cuda" ) ]
7881 v. push ( OrtAccelerator :: Cuda ) ;
7982
83+ #[ cfg( feature = "ort-tensorrt" ) ]
84+ v. push ( OrtAccelerator :: TensorRt ) ;
85+
8086 #[ cfg( feature = "ort-directml" ) ]
8187 v. push ( OrtAccelerator :: DirectMl ) ;
8288
@@ -101,6 +107,7 @@ impl OrtAccelerator {
101107 4 => Self :: Rocm ,
102108 5 => Self :: CoreMl ,
103109 6 => Self :: WebGpu ,
110+ 7 => Self :: TensorRt ,
104111 _ => Self :: Auto ,
105112 }
106113 }
@@ -118,6 +125,7 @@ impl fmt::Display for OrtAccelerator {
118125 Self :: Auto => "auto" ,
119126 Self :: CpuOnly => "cpu" ,
120127 Self :: Cuda => "cuda" ,
128+ Self :: TensorRt => "tensorrt" ,
121129 Self :: DirectMl => "directml" ,
122130 Self :: Rocm => "rocm" ,
123131 Self :: CoreMl => "coreml" ,
@@ -135,6 +143,7 @@ impl FromStr for OrtAccelerator {
135143 "auto" => Ok ( Self :: Auto ) ,
136144 "cpu" | "cpu_only" | "cpuonly" => Ok ( Self :: CpuOnly ) ,
137145 "cuda" => Ok ( Self :: Cuda ) ,
146+ "tensorrt" | "trt" | "tensor_rt" => Ok ( Self :: TensorRt ) ,
138147 "directml" | "dml" => Ok ( Self :: DirectMl ) ,
139148 "rocm" => Ok ( Self :: Rocm ) ,
140149 "coreml" | "core_ml" => Ok ( Self :: CoreMl ) ,
@@ -322,6 +331,7 @@ mod tests {
322331 OrtAccelerator :: Auto ,
323332 OrtAccelerator :: CpuOnly ,
324333 OrtAccelerator :: Cuda ,
334+ OrtAccelerator :: TensorRt ,
325335 OrtAccelerator :: DirectMl ,
326336 OrtAccelerator :: Rocm ,
327337 OrtAccelerator :: CoreMl ,
@@ -347,6 +357,10 @@ mod tests {
347357 "cpu_only" . parse:: <OrtAccelerator >( ) . unwrap( ) ,
348358 OrtAccelerator :: CpuOnly
349359 ) ;
360+ assert_eq ! (
361+ "trt" . parse:: <OrtAccelerator >( ) . unwrap( ) ,
362+ OrtAccelerator :: TensorRt
363+ ) ;
350364 }
351365
352366 #[ test]
@@ -360,6 +374,7 @@ mod tests {
360374 ( OrtAccelerator :: Auto , "\" auto\" " ) ,
361375 ( OrtAccelerator :: CpuOnly , "\" cpu\" " ) ,
362376 ( OrtAccelerator :: Cuda , "\" cuda\" " ) ,
377+ ( OrtAccelerator :: TensorRt , "\" tensorrt\" " ) ,
363378 ( OrtAccelerator :: DirectMl , "\" directml\" " ) ,
364379 ( OrtAccelerator :: Rocm , "\" rocm\" " ) ,
365380 ( OrtAccelerator :: CoreMl , "\" coreml\" " ) ,
@@ -380,6 +395,8 @@ mod tests {
380395 assert_eq ! ( old_cpu, OrtAccelerator :: CpuOnly ) ;
381396 let old_dml: OrtAccelerator = serde_json:: from_str ( "\" direct_ml\" " ) . unwrap ( ) ;
382397 assert_eq ! ( old_dml, OrtAccelerator :: DirectMl ) ;
398+ let old_trt: OrtAccelerator = serde_json:: from_str ( "\" tensor_rt\" " ) . unwrap ( ) ;
399+ assert_eq ! ( old_trt, OrtAccelerator :: TensorRt ) ;
383400 }
384401
385402 #[ test]
0 commit comments