22//
33// SPDX-License-Identifier: GPL-3.0-only
44
5+ use std:: time:: Duration ;
6+
57use reqwest:: Client ;
68use serde:: { Deserialize , Serialize } ;
79use tokio:: sync:: mpsc;
@@ -15,6 +17,8 @@ pub struct OllamaProvider {
1517 client : Client ,
1618 host : String ,
1719 model : String ,
20+ temperature : f32 ,
21+ num_predict : u32 ,
1822}
1923
2024#[ derive( Serialize ) ]
@@ -23,6 +27,13 @@ struct GenerateRequest {
2327 prompt : String ,
2428 system : String ,
2529 stream : bool ,
30+ options : OllamaOptions ,
31+ }
32+
33+ #[ derive( Serialize ) ]
34+ struct OllamaOptions {
35+ temperature : f32 ,
36+ num_predict : u32 ,
2637}
2738
2839const SYSTEM_PROMPT : & str = r#"You are a commit message generator. Analyze git diffs and output JSON commit messages.
@@ -42,16 +53,79 @@ struct GenerateResponse {
4253 done : bool ,
4354}
4455
56+ #[ derive( Deserialize ) ]
57+ struct TagsResponse {
58+ models : Vec < ModelInfo > ,
59+ }
60+
61+ #[ derive( Deserialize ) ]
62+ struct ModelInfo {
63+ name : String ,
64+ }
65+
4566impl OllamaProvider {
4667 pub fn new ( config : & Config ) -> Self {
68+ let client = Client :: builder ( )
69+ . timeout ( Duration :: from_secs ( config. timeout_secs ) )
70+ . build ( )
71+ . unwrap_or_default ( ) ;
72+
4773 Self {
48- client : Client :: new ( ) ,
74+ client,
4975 // Sanitize: remove trailing slashes to avoid //api/generate
5076 host : config. ollama_host . trim_end_matches ( '/' ) . to_string ( ) ,
5177 model : config. model . clone ( ) ,
78+ temperature : config. temperature ,
79+ num_predict : config. num_predict ,
5280 }
5381 }
5482
83+ /// Check Ollama connectivity and return available model names
84+ pub async fn health_check ( & self ) -> Result < Vec < String > > {
85+ let url = format ! ( "{}/api/tags" , self . host) ;
86+
87+ let response = self . client . get ( & url) . send ( ) . await . map_err ( |e| {
88+ if e. is_connect ( ) {
89+ Error :: OllamaNotRunning {
90+ host : self . host . clone ( ) ,
91+ }
92+ } else {
93+ Error :: Provider {
94+ provider : "ollama" . into ( ) ,
95+ message : e. to_string ( ) ,
96+ }
97+ }
98+ } ) ?;
99+
100+ let tags: TagsResponse = response. json ( ) . await . map_err ( |e| Error :: Provider {
101+ provider : "ollama" . into ( ) ,
102+ message : format ! ( "failed to parse /api/tags response: {e}" ) ,
103+ } ) ?;
104+
105+ Ok ( tags. models . into_iter ( ) . map ( |m| m. name ) . collect ( ) )
106+ }
107+
108+ /// Verify that the configured model is available
109+ pub async fn verify_model ( & self ) -> Result < ( ) > {
110+ let available = self . health_check ( ) . await ?;
111+
112+ // Ollama model names may include `:latest` tag
113+ let model_matches = available. iter ( ) . any ( |name| {
114+ name == & self . model
115+ || name == & format ! ( "{}:latest" , self . model)
116+ || name. strip_suffix ( ":latest" ) == Some ( & self . model )
117+ } ) ;
118+
119+ if !model_matches {
120+ return Err ( Error :: ModelNotFound {
121+ model : self . model . clone ( ) ,
122+ available,
123+ } ) ;
124+ }
125+
126+ Ok ( ( ) )
127+ }
128+
55129 pub async fn generate (
56130 & self ,
57131 prompt : & str ,
@@ -68,12 +142,29 @@ impl OllamaProvider {
68142 prompt : prompt. to_string ( ) ,
69143 system : SYSTEM_PROMPT . to_string ( ) ,
70144 stream : true ,
145+ options : OllamaOptions {
146+ temperature : self . temperature ,
147+ num_predict : self . num_predict ,
148+ } ,
71149 } )
72150 . send ( )
73151 . await
74- . map_err ( |e| Error :: Provider {
75- provider : "ollama" . into ( ) ,
76- message : e. to_string ( ) ,
152+ . map_err ( |e| {
153+ if e. is_connect ( ) {
154+ Error :: OllamaNotRunning {
155+ host : self . host . clone ( ) ,
156+ }
157+ } else if e. is_timeout ( ) {
158+ Error :: Provider {
159+ provider : "ollama" . into ( ) ,
160+ message : "request timed out" . into ( ) ,
161+ }
162+ } else {
163+ Error :: Provider {
164+ provider : "ollama" . into ( ) ,
165+ message : e. to_string ( ) ,
166+ }
167+ }
77168 } ) ?;
78169
79170 if !response. status ( ) . is_success ( ) {
0 commit comments