Skip to content

Commit 513b890

Browse files
committed
feat(app): add group rationale and post-generation validation retry
- infer_group_rationale() generates GROUP_REASON: for per-group prompts based on file categories and commit type - validate_and_retry() checks LLM output against evidence flags via CommitValidator, re-prompts once with CORRECTIONS on violation - Wire rayon-based analyzer and concurrent git content fetching
1 parent 29c8282 commit 513b890

1 file changed

Lines changed: 129 additions & 24 deletions

File tree

src/app.rs

Lines changed: 129 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
//
33
// SPDX-License-Identifier: PolyForm-Noncommercial-1.0.0
44

5-
use std::collections::HashMap;
65
use std::io::IsTerminal;
76
use std::path::PathBuf;
87

@@ -15,14 +14,15 @@ use tracing::{debug, warn};
1514

1615
use crate::cli::{Cli, Commands, HookAction};
1716
use crate::config::Config;
18-
use crate::domain::{ChangeStatus, CodeSymbol, StagedChanges};
17+
use crate::domain::PromptContext;
18+
use crate::domain::{ChangeStatus, CodeSymbol, CommitType, FileCategory, StagedChanges};
1919
use crate::error::{Error, Result};
2020
use crate::services::{
2121
analyzer::AnalyzerService,
2222
context::ContextBuilder,
2323
git::GitService,
2424
llm, safety,
25-
sanitizer::CommitSanitizer,
25+
sanitizer::{CommitSanitizer, CommitValidator},
2626
splitter::{CommitSplitter, SplitSuggestion},
2727
};
2828

@@ -119,27 +119,14 @@ impl App {
119119
// Step 3: Pre-fetch file content and analyze with tree-sitter
120120
self.print_status("Extracting code symbols...");
121121

122-
let mut analyzer = AnalyzerService::new()?;
122+
let analyzer = AnalyzerService::new()?;
123123

124-
// Pre-fetch all file content asynchronously, then pass as sync maps
124+
// Fetch all file content concurrently (async I/O via tokio JoinSet)
125125
let file_paths: Vec<PathBuf> = changes.files.iter().map(|f| f.path.clone()).collect();
126-
let mut staged_map: HashMap<PathBuf, String> = HashMap::new();
127-
let mut head_map: HashMap<PathBuf, String> = HashMap::new();
126+
let (staged_map, head_map) = git.fetch_file_contents(&file_paths).await;
128127

129-
for path in &file_paths {
130-
if let Some(content) = git.get_staged_content(path).await {
131-
staged_map.insert(path.clone(), content);
132-
}
133-
if let Some(content) = git.get_head_content(path).await {
134-
head_map.insert(path.clone(), content);
135-
}
136-
}
137-
138-
let symbols = analyzer.extract_symbols(
139-
&changes.files,
140-
&|path| staged_map.get(path).cloned(),
141-
&|path| head_map.get(path).cloned(),
142-
);
128+
// Parse symbols in parallel across CPU cores (rayon)
129+
let symbols = analyzer.extract_symbols(&changes.files, &staged_map, &head_map);
143130

144131
debug!(count = symbols.len(), "symbols extracted");
145132

@@ -252,7 +239,14 @@ impl App {
252239
candidate = i + 1,
253240
"sanitizing LLM response"
254241
);
255-
match CommitSanitizer::sanitize(&raw_message, &self.config.format) {
242+
243+
// Validate against evidence and retry once if violations found
244+
let raw_to_sanitize = self
245+
.validate_and_retry(&raw_message, &context, &provider, &prompt)
246+
.await
247+
.unwrap_or(raw_message);
248+
249+
match CommitSanitizer::sanitize(&raw_to_sanitize, &self.config.format) {
256250
Ok(msg) => candidates.push(msg),
257251
Err(e) => {
258252
warn!(candidate = i + 1, error = %e, "failed to sanitize candidate");
@@ -483,7 +477,11 @@ impl App {
483477
.cloned()
484478
.collect();
485479

486-
let context = ContextBuilder::build(&sub_changes, &sub_symbols, &self.config);
480+
let mut context = ContextBuilder::build(&sub_changes, &sub_symbols, &self.config);
481+
context.group_rationale = Some(Self::infer_group_rationale(
482+
&sub_changes,
483+
&group.commit_type,
484+
));
487485
let prompt = context.to_prompt();
488486

489487
if self.cli.show_prompt {
@@ -529,7 +527,13 @@ impl App {
529527
group = i + 1,
530528
"sanitizing split group response"
531529
);
532-
let message = CommitSanitizer::sanitize(&raw_message, &self.config.format)?;
530+
531+
let raw_to_sanitize = self
532+
.validate_and_retry(&raw_message, &context, &provider, &prompt)
533+
.await
534+
.unwrap_or(raw_message);
535+
536+
let message = CommitSanitizer::sanitize(&raw_to_sanitize, &self.config.format)?;
533537
commit_messages.push((message, group.files.clone()));
534538
}
535539

@@ -578,6 +582,55 @@ impl App {
578582
Ok(())
579583
}
580584

585+
/// Generate a short rationale describing why files were grouped together.
586+
fn infer_group_rationale(changes: &StagedChanges, commit_type: &CommitType) -> String {
587+
let file_count = changes.files.len();
588+
let categories: Vec<_> = changes.files.iter().map(|f| f.category).collect();
589+
590+
// All same category?
591+
if categories.iter().all(|c| *c == categories[0]) {
592+
let cat = match categories[0] {
593+
FileCategory::Docs => "documentation",
594+
FileCategory::Test => "test",
595+
FileCategory::Config => "configuration",
596+
FileCategory::Build => "build/CI",
597+
FileCategory::Source => "source",
598+
FileCategory::Other => "miscellaneous",
599+
};
600+
return format!(
601+
"{} {} changes across {} files",
602+
commit_type.as_str(),
603+
cat,
604+
file_count
605+
);
606+
}
607+
608+
// Mixed categories
609+
let source_count = categories
610+
.iter()
611+
.filter(|c| **c == FileCategory::Source)
612+
.count();
613+
let test_count = categories
614+
.iter()
615+
.filter(|c| **c == FileCategory::Test)
616+
.count();
617+
618+
if source_count > 0 && test_count > 0 {
619+
format!(
620+
"{} changes in {} source + {} test files",
621+
commit_type.as_str(),
622+
source_count,
623+
test_count
624+
)
625+
} else {
626+
format!(
627+
"{} changes across {} files",
628+
commit_type.as_str(),
629+
file_count
630+
)
631+
}
632+
}
633+
581634
fn display_split_suggestion(
582635
groups: &[crate::services::splitter::CommitGroup],
583636
changes: &StagedChanges,
@@ -953,6 +1006,58 @@ fi
9531006
Ok(())
9541007
}
9551008

1009+
// ─── Post-Generation Validation ───
1010+
1011+
/// Validate raw LLM output against evidence flags. If violations exist,
1012+
/// re-prompt once with corrections appended. Returns the corrected raw
1013+
/// output, or None if no retry was needed or the retry also failed.
1014+
async fn validate_and_retry(
1015+
&self,
1016+
raw: &str,
1017+
context: &PromptContext,
1018+
provider: &llm::LlmBackend,
1019+
original_prompt: &str,
1020+
) -> Option<String> {
1021+
let structured = CommitSanitizer::parse_structured(raw)?;
1022+
1023+
let violations = CommitValidator::validate(
1024+
&structured,
1025+
context.has_bug_evidence,
1026+
context.is_mechanical,
1027+
context.public_api_removed_count,
1028+
context.is_dependency_only,
1029+
);
1030+
1031+
if violations.is_empty() {
1032+
return None; // No violations, use original
1033+
}
1034+
1035+
debug!(
1036+
violations = violations.len(),
1037+
"evidence violations detected, retrying with corrections"
1038+
);
1039+
1040+
let corrections = CommitValidator::format_corrections(&violations);
1041+
let retry_prompt = format!("{}\n{}", original_prompt, corrections);
1042+
1043+
let (tx, mut rx) = mpsc::channel::<String>(64);
1044+
let cancel = self.cancel_token.clone();
1045+
let drain_handle = tokio::spawn(async move { while rx.recv().await.is_some() {} });
1046+
1047+
match provider.generate(&retry_prompt, tx, cancel).await {
1048+
Ok(retry_raw) if !retry_raw.trim().is_empty() => {
1049+
let _ = drain_handle.await;
1050+
debug!("retry succeeded, using corrected output");
1051+
Some(retry_raw)
1052+
}
1053+
_ => {
1054+
let _ = drain_handle.await;
1055+
warn!("retry failed or empty, using original output");
1056+
None
1057+
}
1058+
}
1059+
}
1060+
9561061
// ─── Output Helpers ───
9571062

9581063
fn print_status(&self, msg: &str) {

0 commit comments

Comments
 (0)