Skip to content

Commit 7d7d89d

Browse files
committed
feat(ollama): add health check, timeout, and model verification
Harden Ollama provider with configurable timeout, temperature, and num_predict. Add health_check() via /api/tags and verify_model() to differentiate connection failures from missing models. Detect connection and timeout errors with dedicated error variants.
1 parent 4315561 commit 7d7d89d

8 files changed

Lines changed: 736 additions & 135 deletions

File tree

Cargo.lock

Lines changed: 591 additions & 131 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ license = "GPL-3.0-only"
1313
[dependencies]
1414
# CLI
1515
clap = { version = "4.5", features = ["derive", "env"] }
16+
clap_complete = "4.5"
1617

1718
# Async runtime
1819
tokio = { version = "1.43", features = ["rt-multi-thread", "macros", "signal", "sync"] }
@@ -43,12 +44,17 @@ tree-sitter-javascript = "0.25"
4344

4445
# Error handling
4546
thiserror = "2.0"
47+
miette = { version = "7.6", features = ["fancy"] }
4648

4749
# Terminal UI
4850
dialoguer = "0.12"
4951
console = "0.16"
5052
indicatif = "0.18"
5153

54+
# Logging
55+
tracing = "0.1"
56+
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
57+
5258
# Utilities
5359
regex = "1.12"
5460

src/app.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ impl App {
135135
));
136136

137137
let provider = llm::create_provider(&self.config)?;
138+
provider.verify().await?;
138139

139140
// Setup streaming output
140141
let (tx, mut rx) = mpsc::channel::<String>(64);

src/config.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,18 @@ pub struct Config {
8484
#[serde(default = "default_max_context_chars")]
8585
pub max_context_chars: usize,
8686

87+
/// Request timeout in seconds (default 300)
88+
#[serde(default = "default_timeout_secs")]
89+
pub timeout_secs: u64,
90+
91+
/// LLM temperature (0.0-2.0, default 0.3)
92+
#[serde(default = "default_temperature")]
93+
pub temperature: f32,
94+
95+
/// Maximum tokens to generate (default 256)
96+
#[serde(default = "default_num_predict")]
97+
pub num_predict: u32,
98+
8799
/// Commit message format options
88100
#[serde(default)]
89101
pub format: CommitFormat,
@@ -105,6 +117,15 @@ fn default_max_diff_lines() -> usize {
105117
fn default_max_file_lines() -> usize {
106118
100
107119
}
120+
fn default_timeout_secs() -> u64 {
121+
300
122+
}
123+
fn default_temperature() -> f32 {
124+
0.3
125+
}
126+
fn default_num_predict() -> u32 {
127+
256
128+
}
108129

109130
impl Default for Config {
110131
fn default() -> Self {
@@ -116,6 +137,9 @@ impl Default for Config {
116137
max_diff_lines: default_max_diff_lines(),
117138
max_file_lines: default_max_file_lines(),
118139
max_context_chars: default_max_context_chars(),
140+
timeout_secs: default_timeout_secs(),
141+
temperature: default_temperature(),
142+
num_predict: default_num_predict(),
119143
format: CommitFormat::default(),
120144
}
121145
}

src/error.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ pub enum Error {
2424
#[error("Potential secrets detected: {patterns:?}. Use --allow-secrets to proceed.")]
2525
SecretsDetected { patterns: Vec<String> },
2626

27+
#[error("Cannot connect to Ollama at {host}. Is it running?")]
28+
OllamaNotRunning { host: String },
29+
30+
#[error("Model '{model}' not found. Available: {}", available.join(", "))]
31+
ModelNotFound {
32+
model: String,
33+
available: Vec<String>,
34+
},
35+
2736
#[error("Provider '{provider}' error: {message}")]
2837
Provider { provider: String, message: String },
2938

src/services/llm/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ impl LlmBackend {
3333
Self::Ollama(p) => p.name(),
3434
}
3535
}
36+
37+
/// Verify provider connectivity and model availability
38+
pub async fn verify(&self) -> Result<()> {
39+
match self {
40+
Self::Ollama(p) => p.verify_model().await,
41+
}
42+
}
3643
}
3744

3845
pub fn create_provider(config: &Config) -> Result<LlmBackend> {

src/services/llm/ollama.rs

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
//
33
// SPDX-License-Identifier: GPL-3.0-only
44

5+
use std::time::Duration;
6+
57
use reqwest::Client;
68
use serde::{Deserialize, Serialize};
79
use tokio::sync::mpsc;
@@ -15,6 +17,8 @@ pub struct OllamaProvider {
1517
client: Client,
1618
host: String,
1719
model: String,
20+
temperature: f32,
21+
num_predict: u32,
1822
}
1923

2024
#[derive(Serialize)]
@@ -23,6 +27,13 @@ struct GenerateRequest {
2327
prompt: String,
2428
system: String,
2529
stream: bool,
30+
options: OllamaOptions,
31+
}
32+
33+
#[derive(Serialize)]
34+
struct OllamaOptions {
35+
temperature: f32,
36+
num_predict: u32,
2637
}
2738

2839
const SYSTEM_PROMPT: &str = r#"You are a commit message generator. Analyze git diffs and output JSON commit messages.
@@ -42,16 +53,79 @@ struct GenerateResponse {
4253
done: bool,
4354
}
4455

56+
#[derive(Deserialize)]
57+
struct TagsResponse {
58+
models: Vec<ModelInfo>,
59+
}
60+
61+
#[derive(Deserialize)]
62+
struct ModelInfo {
63+
name: String,
64+
}
65+
4566
impl OllamaProvider {
4667
pub fn new(config: &Config) -> Self {
68+
let client = Client::builder()
69+
.timeout(Duration::from_secs(config.timeout_secs))
70+
.build()
71+
.unwrap_or_default();
72+
4773
Self {
48-
client: Client::new(),
74+
client,
4975
// Sanitize: remove trailing slashes to avoid //api/generate
5076
host: config.ollama_host.trim_end_matches('/').to_string(),
5177
model: config.model.clone(),
78+
temperature: config.temperature,
79+
num_predict: config.num_predict,
5280
}
5381
}
5482

83+
/// Check Ollama connectivity and return available model names
84+
pub async fn health_check(&self) -> Result<Vec<String>> {
85+
let url = format!("{}/api/tags", self.host);
86+
87+
let response = self.client.get(&url).send().await.map_err(|e| {
88+
if e.is_connect() {
89+
Error::OllamaNotRunning {
90+
host: self.host.clone(),
91+
}
92+
} else {
93+
Error::Provider {
94+
provider: "ollama".into(),
95+
message: e.to_string(),
96+
}
97+
}
98+
})?;
99+
100+
let tags: TagsResponse = response.json().await.map_err(|e| Error::Provider {
101+
provider: "ollama".into(),
102+
message: format!("failed to parse /api/tags response: {e}"),
103+
})?;
104+
105+
Ok(tags.models.into_iter().map(|m| m.name).collect())
106+
}
107+
108+
/// Verify that the configured model is available
109+
pub async fn verify_model(&self) -> Result<()> {
110+
let available = self.health_check().await?;
111+
112+
// Ollama model names may include `:latest` tag
113+
let model_matches = available.iter().any(|name| {
114+
name == &self.model
115+
|| name == &format!("{}:latest", self.model)
116+
|| name.strip_suffix(":latest") == Some(&self.model)
117+
});
118+
119+
if !model_matches {
120+
return Err(Error::ModelNotFound {
121+
model: self.model.clone(),
122+
available,
123+
});
124+
}
125+
126+
Ok(())
127+
}
128+
55129
pub async fn generate(
56130
&self,
57131
prompt: &str,
@@ -68,12 +142,29 @@ impl OllamaProvider {
68142
prompt: prompt.to_string(),
69143
system: SYSTEM_PROMPT.to_string(),
70144
stream: true,
145+
options: OllamaOptions {
146+
temperature: self.temperature,
147+
num_predict: self.num_predict,
148+
},
71149
})
72150
.send()
73151
.await
74-
.map_err(|e| Error::Provider {
75-
provider: "ollama".into(),
76-
message: e.to_string(),
152+
.map_err(|e| {
153+
if e.is_connect() {
154+
Error::OllamaNotRunning {
155+
host: self.host.clone(),
156+
}
157+
} else if e.is_timeout() {
158+
Error::Provider {
159+
provider: "ollama".into(),
160+
message: "request timed out".into(),
161+
}
162+
} else {
163+
Error::Provider {
164+
provider: "ollama".into(),
165+
message: e.to_string(),
166+
}
167+
}
77168
})?;
78169

79170
if !response.status().is_success() {

tests/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ fn default_config_values() {
1616
assert_eq!(config.max_diff_lines, 500);
1717
assert_eq!(config.max_file_lines, 100);
1818
assert_eq!(config.max_context_chars, 24000);
19+
assert_eq!(config.timeout_secs, 300);
20+
assert!((config.temperature - 0.3).abs() < f32::EPSILON);
21+
assert_eq!(config.num_predict, 256);
1922
assert!(config.format.include_body);
2023
assert!(config.format.include_scope);
2124
assert!(config.format.lowercase_subject);

0 commit comments

Comments
 (0)