Skip to content

Commit ed1ec4a

Browse files
authored
add xnnpack support (#78)
1 parent 2d2a5c6 commit ed1ec4a

3 files changed

Lines changed: 51 additions & 2 deletions

File tree

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ ort-directml = ["onnx", "ort/directml"]
4242
ort-rocm = ["onnx", "ort/rocm"]
4343
ort-coreml = ["onnx", "ort/coreml"]
4444
ort-webgpu = ["onnx", "ort/webgpu"]
45+
ort-xnnpack = ["onnx", "ort/xnnpack"]
4546
ort-tracing = ["onnx", "ort/tracing"]
46-
ort-accel = ["ort-cuda", "ort-tensorrt", "ort-directml", "ort-rocm", "ort-coreml", "ort-webgpu"]
47+
ort-accel = ["ort-cuda", "ort-tensorrt", "ort-directml", "ort-rocm", "ort-coreml", "ort-webgpu", "ort-xnnpack"]
4748

4849
# Convenience
4950
all = ["onnx", "whisper-cpp", "whisperfile", "openai"]

src/accel.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ pub enum OrtAccelerator {
5252
/// WebGPU via Dawn (Windows, Linux, WebAssembly).
5353
#[serde(rename = "webgpu")]
5454
WebGpu = 6,
55+
/// XNNPACK CPU acceleration (ARM, x86_64). Optimised for Conv/Gemm/MatMul
56+
/// kernels; uses its own threadpool independent of the session intra-op pool.
57+
#[serde(rename = "xnnpack")]
58+
Xnnpack = 8,
5559
}
5660

5761
static ORT_ACCELERATOR: AtomicU8 = AtomicU8::new(OrtAccelerator::Auto as u8);
@@ -95,6 +99,9 @@ impl OrtAccelerator {
9599
#[cfg(feature = "ort-webgpu")]
96100
v.push(OrtAccelerator::WebGpu);
97101

102+
#[cfg(feature = "ort-xnnpack")]
103+
v.push(OrtAccelerator::Xnnpack);
104+
98105
v
99106
}
100107

@@ -108,6 +115,7 @@ impl OrtAccelerator {
108115
5 => Self::CoreMl,
109116
6 => Self::WebGpu,
110117
7 => Self::TensorRt,
118+
8 => Self::Xnnpack,
111119
_ => Self::Auto,
112120
}
113121
}
@@ -130,6 +138,7 @@ impl fmt::Display for OrtAccelerator {
130138
Self::Rocm => "rocm",
131139
Self::CoreMl => "coreml",
132140
Self::WebGpu => "webgpu",
141+
Self::Xnnpack => "xnnpack",
133142
};
134143
f.write_str(s)
135144
}
@@ -148,6 +157,7 @@ impl FromStr for OrtAccelerator {
148157
"rocm" => Ok(Self::Rocm),
149158
"coreml" | "core_ml" => Ok(Self::CoreMl),
150159
"webgpu" | "web_gpu" => Ok(Self::WebGpu),
160+
"xnnpack" => Ok(Self::Xnnpack),
151161
other => Err(format!("unknown ORT accelerator: {other}")),
152162
}
153163
}
@@ -336,6 +346,7 @@ mod tests {
336346
OrtAccelerator::Rocm,
337347
OrtAccelerator::CoreMl,
338348
OrtAccelerator::WebGpu,
349+
OrtAccelerator::Xnnpack,
339350
] {
340351
let s = pref.to_string();
341352
let parsed: OrtAccelerator = s.parse().unwrap();
@@ -379,6 +390,7 @@ mod tests {
379390
(OrtAccelerator::Rocm, "\"rocm\""),
380391
(OrtAccelerator::CoreMl, "\"coreml\""),
381392
(OrtAccelerator::WebGpu, "\"webgpu\""),
393+
(OrtAccelerator::Xnnpack, "\"xnnpack\""),
382394
] {
383395
let json = serde_json::to_string(&pref).unwrap();
384396
assert_eq!(json, expected, "serialize {:?}", pref);

src/onnx/session.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use ort::ep::ROCm;
88
use ort::ep::TensorRT;
99
#[cfg(feature = "ort-webgpu")]
1010
use ort::ep::WebGPU;
11+
#[cfg(feature = "ort-xnnpack")]
12+
use ort::ep::XNNPACK;
1113
use ort::ep::CPU;
1214
#[cfg(feature = "ort-cuda")]
1315
use ort::ep::CUDA;
@@ -77,6 +79,27 @@ fn execution_providers() -> Vec<ort::ep::ExecutionProviderDispatch> {
7779
"Accelerator set to WebGPU but ort-webgpu feature is not enabled; falling back to CPU"
7880
);
7981
}
82+
OrtAccelerator::Xnnpack => {
83+
#[cfg(feature = "ort-xnnpack")]
84+
{
85+
// XNNPACK manages its own threadpool. Configure it with the
86+
// available logical core count; the session-level intra-op
87+
// pool is forced to 1 in build_session() when XNNPACK is
88+
// active to avoid contention.
89+
let n = std::thread::available_parallelism()
90+
.map(|n| n.get())
91+
.unwrap_or(1);
92+
if let Some(nz) = core::num::NonZeroUsize::new(n) {
93+
eps.push(XNNPACK::default().with_intra_op_num_threads(nz).build());
94+
} else {
95+
eps.push(XNNPACK::default().build());
96+
}
97+
}
98+
#[cfg(not(feature = "ort-xnnpack"))]
99+
log::warn!(
100+
"Accelerator set to XNNPACK but ort-xnnpack feature is not enabled; falling back to CPU"
101+
);
102+
}
80103
OrtAccelerator::Auto => {
81104
// Add compiled-in GPU EPs in priority order.
82105
// DirectML and WebGPU are excluded from Auto because they require
@@ -113,6 +136,14 @@ fn requires_sequential_session() -> bool {
113136
|| (pref == OrtAccelerator::WebGpu && cfg!(feature = "ort-webgpu"))
114137
}
115138

139+
/// Returns true if the XNNPACK EP is selected and compiled in. XNNPACK runs
140+
/// its own threadpool, so the session intra-op pool should be reduced to a
141+
/// single non-spinning thread to avoid contention.
142+
fn is_xnnpack_active() -> bool {
143+
let pref = get_ort_accelerator();
144+
pref == OrtAccelerator::Xnnpack && cfg!(feature = "ort-xnnpack")
145+
}
146+
116147
/// Internal session builder with full control over threading and EP selection.
117148
fn build_session(
118149
path: &Path,
@@ -122,7 +153,12 @@ fn build_session(
122153
let mut builder =
123154
Session::builder()?.with_optimization_level(GraphOptimizationLevel::Level3)?;
124155

125-
if let Some(n) = intra_threads {
156+
if is_xnnpack_active() {
157+
// See ort::ep::XNNPACK docs: disable session intra-op spinning and
158+
// force a single intra-op thread when XNNPACK is the active EP.
159+
builder = builder.with_intra_op_spinning(false)?;
160+
builder = builder.with_intra_threads(1)?;
161+
} else if let Some(n) = intra_threads {
126162
if n > 0 {
127163
builder = builder.with_intra_threads(n)?;
128164
}

0 commit comments

Comments
 (0)