Skip to content

Commit 242e97d

Browse files
samcconeLUCI
authored andcommitted
Implement command forgiveness with autocorrect
Similar to `git`, when a user types an unknown command like `repo tart`, we now use `difflib.get_close_matches` to suggest similar commands. If `help.autocorrect` is set in the git config, it will optionally prompt the user to automatically run the assumed command, or wait for a configured delay before executing it. Verification Steps: 1. Created a dummy repo project locally. 2. Verified `help.autocorrect=0|false|off|no|show` suggests command and exits. 3. Verified `help.autocorrect=1|true|on|yes|immediate` automatically runs suggestion. 4. Verified `help.autocorrect=<number>` runs after `<number>*0.1` seconds. 5. Verified `help.autocorrect=never` exits immediately without suggestions. 6. Verified `help.autocorrect=prompt` asks user to accept [y/n] and handles correctly. BUG: b/489753302 Change-Id: I6dcd63229cbd7badf5404459b48690c68f5b4857 Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/558021 Tested-by: Sam Saccone <samccone@google.com> Commit-Queue: Sam Saccone <samccone@google.com> Reviewed-by: Mike Frysinger <vapier@google.com>
1 parent ade45de commit 242e97d

2 files changed

Lines changed: 279 additions & 13 deletions

File tree

main.py

Lines changed: 113 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
which takes care of execing this entry point.
2020
"""
2121

22+
import difflib
2223
import getpass
2324
import json
2425
import netrc
@@ -29,6 +30,7 @@
2930
import sys
3031
import textwrap
3132
import time
33+
from typing import Optional
3234
import urllib.request
3335

3436
from repo_logging import RepoLogger
@@ -292,6 +294,102 @@ def _Run(self, name, gopts, argv):
292294
result = run()
293295
return result
294296

297+
def _autocorrect_command_name(
298+
self, name: str, config: RepoConfig
299+
) -> Optional[str]:
300+
"""Autocorrect command name based on user's git config."""
301+
302+
close_commands = difflib.get_close_matches(
303+
name, self.commands.keys(), n=5, cutoff=0.7
304+
)
305+
306+
if not close_commands:
307+
logger.error(
308+
"repo: '%s' is not a repo command. See 'repo help'.", name
309+
)
310+
return None
311+
312+
assumed = close_commands[0]
313+
autocorrect = config.GetString("help.autocorrect")
314+
315+
# If there are multiple close matches, git won't automatically run one.
316+
# We'll always prompt instead of guessing.
317+
if len(close_commands) > 1:
318+
autocorrect = "prompt"
319+
320+
# Handle git configuration boolean values:
321+
# 0, "false", "off", "no", "show": show suggestion (default)
322+
# 1, "true", "on", "yes", "immediate": run suggestion immediately
323+
# "never": don't run or show any suggested command
324+
# "prompt": show the suggestion and prompt for confirmation
325+
# positive number > 1: run suggestion after specified deciseconds
326+
if autocorrect is None:
327+
autocorrect = "0"
328+
329+
autocorrect = autocorrect.lower()
330+
331+
if autocorrect in ("0", "false", "off", "no", "show"):
332+
autocorrect = 0
333+
elif autocorrect in ("true", "on", "yes", "immediate"):
334+
autocorrect = -1 # immediate
335+
elif autocorrect == "never":
336+
return None
337+
elif autocorrect == "prompt":
338+
logger.warning(
339+
"You called a repo command named "
340+
"'%s', which does not exist.",
341+
name,
342+
)
343+
try:
344+
resp = input(f"Run '{assumed}' instead [y/N]? ")
345+
if resp.lower().startswith("y"):
346+
return assumed
347+
except (KeyboardInterrupt, EOFError):
348+
pass
349+
return None
350+
else:
351+
try:
352+
autocorrect = int(autocorrect)
353+
except ValueError:
354+
autocorrect = 0
355+
356+
if autocorrect != 0:
357+
if autocorrect < 0:
358+
logger.warning(
359+
"You called a repo command named "
360+
"'%s', which does not exist.\n"
361+
"Continuing assuming that "
362+
"you meant '%s'.",
363+
name,
364+
assumed,
365+
)
366+
else:
367+
delay = autocorrect * 0.1
368+
logger.warning(
369+
"You called a repo command named "
370+
"'%s', which does not exist.\n"
371+
"Continuing in %.1f seconds, assuming "
372+
"that you meant '%s'.",
373+
name,
374+
delay,
375+
assumed,
376+
)
377+
try:
378+
time.sleep(delay)
379+
except KeyboardInterrupt:
380+
return None
381+
return assumed
382+
383+
logger.error(
384+
"repo: '%s' is not a repo command. See 'repo help'.", name
385+
)
386+
logger.warning(
387+
"The most similar command%s\n\t%s",
388+
"s are" if len(close_commands) > 1 else " is",
389+
"\n\t".join(close_commands),
390+
)
391+
return None
392+
295393
def _RunLong(self, name, gopts, argv, git_trace2_event_log):
296394
"""Execute the (longer running) requested subcommand."""
297395
result = 0
@@ -306,20 +404,22 @@ def _RunLong(self, name, gopts, argv, git_trace2_event_log):
306404
outer_client=outer_client,
307405
)
308406

309-
try:
310-
cmd = self.commands[name](
311-
repodir=self.repodir,
312-
client=repo_client,
313-
manifest=repo_client.manifest,
314-
outer_client=outer_client,
315-
outer_manifest=outer_client.manifest,
316-
git_event_log=git_trace2_event_log,
317-
)
318-
except KeyError:
319-
logger.error(
320-
"repo: '%s' is not a repo command. See 'repo help'.", name
407+
if name not in self.commands:
408+
corrected_name = self._autocorrect_command_name(
409+
name, outer_client.globalConfig
321410
)
322-
return 1
411+
if not corrected_name:
412+
return 1
413+
name = corrected_name
414+
415+
cmd = self.commands[name](
416+
repodir=self.repodir,
417+
client=repo_client,
418+
manifest=repo_client.manifest,
419+
outer_client=outer_client,
420+
outer_manifest=outer_client.manifest,
421+
git_event_log=git_trace2_event_log,
422+
)
323423

324424
Editor.globalConfig = cmd.client.globalConfig
325425

tests/test_main.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright (C) 2026 The Android Open Source Project
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for the main repo script and subcommand routing."""
16+
17+
from unittest import mock
18+
19+
import pytest
20+
21+
from main import _Repo
22+
23+
24+
@pytest.fixture(name="repo")
25+
def fixture_repo():
26+
repo = _Repo("repodir")
27+
# Overriding the command list here ensures that we are only testing
28+
# against a fixed set of commands, reducing fragility to new
29+
# subcommands being added to the main repo tool.
30+
repo.commands = {"start": None, "sync": None, "smart": None}
31+
return repo
32+
33+
34+
@pytest.fixture(name="mock_config")
35+
def fixture_mock_config():
36+
return mock.MagicMock()
37+
38+
39+
@mock.patch("time.sleep")
40+
def test_autocorrect_delay(mock_sleep, repo, mock_config):
41+
"""Test autocorrect with positive delay."""
42+
mock_config.GetString.return_value = "10"
43+
44+
res = repo._autocorrect_command_name("tart", mock_config)
45+
46+
mock_config.GetString.assert_called_with("help.autocorrect")
47+
mock_sleep.assert_called_with(1.0)
48+
assert res == "start"
49+
50+
51+
@mock.patch("time.sleep")
52+
def test_autocorrect_delay_one(mock_sleep, repo, mock_config):
53+
"""Test autocorrect with '1' (0.1s delay, not immediate)."""
54+
mock_config.GetString.return_value = "1"
55+
56+
res = repo._autocorrect_command_name("tart", mock_config)
57+
58+
mock_sleep.assert_called_with(0.1)
59+
assert res == "start"
60+
61+
62+
@mock.patch("time.sleep", side_effect=KeyboardInterrupt())
63+
def test_autocorrect_delay_interrupt(mock_sleep, repo, mock_config):
64+
"""Test autocorrect handles KeyboardInterrupt during delay."""
65+
mock_config.GetString.return_value = "10"
66+
67+
res = repo._autocorrect_command_name("tart", mock_config)
68+
69+
mock_sleep.assert_called_with(1.0)
70+
assert res is None
71+
72+
73+
@mock.patch("time.sleep")
74+
def test_autocorrect_immediate(mock_sleep, repo, mock_config):
75+
"""Test autocorrect with immediate/negative delay."""
76+
# Test numeric negative.
77+
mock_config.GetString.return_value = "-1"
78+
res = repo._autocorrect_command_name("tart", mock_config)
79+
mock_sleep.assert_not_called()
80+
assert res == "start"
81+
82+
# Test string boolean "true".
83+
mock_config.GetString.return_value = "true"
84+
res = repo._autocorrect_command_name("tart", mock_config)
85+
mock_sleep.assert_not_called()
86+
assert res == "start"
87+
88+
# Test string boolean "yes".
89+
mock_config.GetString.return_value = "YES"
90+
res = repo._autocorrect_command_name("tart", mock_config)
91+
mock_sleep.assert_not_called()
92+
assert res == "start"
93+
94+
# Test string boolean "immediate".
95+
mock_config.GetString.return_value = "Immediate"
96+
res = repo._autocorrect_command_name("tart", mock_config)
97+
mock_sleep.assert_not_called()
98+
assert res == "start"
99+
100+
101+
def test_autocorrect_zero_or_show(repo, mock_config):
102+
"""Test autocorrect with zero delay (suggestions only)."""
103+
# Test numeric zero.
104+
mock_config.GetString.return_value = "0"
105+
res = repo._autocorrect_command_name("tart", mock_config)
106+
assert res is None
107+
108+
# Test string boolean "false".
109+
mock_config.GetString.return_value = "False"
110+
res = repo._autocorrect_command_name("tart", mock_config)
111+
assert res is None
112+
113+
# Test string boolean "show".
114+
mock_config.GetString.return_value = "show"
115+
res = repo._autocorrect_command_name("tart", mock_config)
116+
assert res is None
117+
118+
119+
def test_autocorrect_never(repo, mock_config):
120+
"""Test autocorrect with 'never'."""
121+
mock_config.GetString.return_value = "never"
122+
res = repo._autocorrect_command_name("tart", mock_config)
123+
assert res is None
124+
125+
126+
@mock.patch("builtins.input", return_value="y")
127+
def test_autocorrect_prompt_yes(mock_input, repo, mock_config):
128+
"""Test autocorrect with prompt and user answers yes."""
129+
mock_config.GetString.return_value = "prompt"
130+
131+
res = repo._autocorrect_command_name("tart", mock_config)
132+
133+
assert res == "start"
134+
135+
136+
@mock.patch("builtins.input", return_value="n")
137+
def test_autocorrect_prompt_no(mock_input, repo, mock_config):
138+
"""Test autocorrect with prompt and user answers no."""
139+
mock_config.GetString.return_value = "prompt"
140+
141+
res = repo._autocorrect_command_name("tart", mock_config)
142+
143+
assert res is None
144+
145+
146+
@mock.patch("builtins.input", return_value="y")
147+
def test_autocorrect_multiple_candidates(mock_input, repo, mock_config):
148+
"""Test autocorrect with multiple matches forces a prompt."""
149+
mock_config.GetString.return_value = "10" # Normally just delay
150+
151+
# 'snart' matches both 'start' and 'smart' with > 0.7 ratio
152+
res = repo._autocorrect_command_name("snart", mock_config)
153+
154+
# Because there are multiple candidates, it should prompt
155+
mock_input.assert_called_once()
156+
assert res == "start"
157+
158+
159+
@mock.patch("builtins.input", side_effect=KeyboardInterrupt())
160+
def test_autocorrect_prompt_interrupt(mock_input, repo, mock_config):
161+
"""Test autocorrect with prompt and user interrupts."""
162+
mock_config.GetString.return_value = "prompt"
163+
164+
res = repo._autocorrect_command_name("tart", mock_config)
165+
166+
assert res is None

0 commit comments

Comments
 (0)