fix(search): keep partial results on search timeout (#36142)

Treat search command budget timeouts as soft truncation so partial results survive, while real search failures still return structured errors.
This commit is contained in:
Teknium 2026-06-13 14:35:21 -07:00 committed by GitHub
parent 069bfd6545
commit 1fa761f8de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 193 additions and 16 deletions

View file

@ -241,10 +241,11 @@ class SearchResult:
counts: Dict[str, int] = field(default_factory=dict)
total_count: int = 0
truncated: bool = False
limit_reason: Optional[str] = None
error: Optional[str] = None
def to_dict(self) -> dict:
result = {"total_count": self.total_count}
result: dict[str, object] = {"total_count": self.total_count}
if self.matches:
result["matches"] = [
{"path": m.path, "line": m.line_number, "content": m.content}
@ -256,6 +257,8 @@ class SearchResult:
result["counts"] = self.counts
if self.truncated:
result["truncated"] = True
if self.limit_reason:
result["limit_reason"] = self.limit_reason
if self.error:
result["error"] = self.error
return result
@ -285,6 +288,16 @@ class ExecuteResult:
exit_code: int = 0
_SEARCH_TIMEOUT_MARKER_RE = re.compile(r"\n?\[Command timed out after \d+s\]\s*$")
def _search_stdout_and_limit(result: ExecuteResult) -> tuple[str, Optional[str]]:
"""Return stdout cleaned for parsing and a limit reason for search timeouts."""
if result.exit_code == 124:
return _SEARCH_TIMEOUT_MARKER_RE.sub("", result.stdout), "search_timeout"
return result.stdout, None
def _split_tool_diagnostics(output: str) -> tuple[str, str]:
"""Separate rg/grep diagnostic lines from real match output.
@ -1967,15 +1980,17 @@ class ShellFileOperations(FileOperations):
f"-printf '%T@ %p\\n' 2>/dev/null | sort -rn{pagination_expr}"
result = self._exec(cmd, timeout=60)
stdout, limit_reason = _search_stdout_and_limit(result)
if not result.stdout.strip():
if not stdout.strip() and not limit_reason:
# Try without -printf (BSD find compatibility -- macOS)
cmd_simple = f"find {self._escape_shell_arg(path)}{hidden_filter_expr} -type f -name {self._escape_shell_arg(search_pattern)} " \
f"2>/dev/null | sort -rn{pagination_expr}"
result = self._exec(cmd_simple, timeout=60)
stdout, limit_reason = _search_stdout_and_limit(result)
files = []
for line in result.stdout.strip().split('\n'):
for line in stdout.strip().split('\n'):
if not line:
continue
parts = line.split(' ', 1)
@ -2003,7 +2018,9 @@ class ShellFileOperations(FileOperations):
return SearchResult(
files=files,
total_count=len(files)
total_count=len(files),
truncated=bool(limit_reason),
limit_reason=limit_reason,
)
def _search_files_rg(self, pattern: str, path: str, limit: int, offset: int) -> SearchResult:
@ -2029,9 +2046,10 @@ class ShellFileOperations(FileOperations):
f"| head -n {fetch_limit}"
)
result = self._exec(cmd_sorted, timeout=60)
all_files = [f for f in result.stdout.strip().split('\n') if f]
stdout, limit_reason = _search_stdout_and_limit(result)
all_files = [f for f in stdout.strip().split('\n') if f]
if not all_files:
if not all_files and not limit_reason:
# --sortr may have failed on older rg; retry without it.
cmd_plain = (
f"rg --files -g {self._escape_shell_arg(glob_pattern)} "
@ -2039,14 +2057,16 @@ class ShellFileOperations(FileOperations):
f"| head -n {fetch_limit}"
)
result = self._exec(cmd_plain, timeout=60)
all_files = [f for f in result.stdout.strip().split('\n') if f]
stdout, limit_reason = _search_stdout_and_limit(result)
all_files = [f for f in stdout.strip().split('\n') if f]
page = all_files[offset:offset + limit]
return SearchResult(
files=page,
total_count=len(all_files),
truncated=len(all_files) >= fetch_limit,
truncated=len(all_files) >= fetch_limit or bool(limit_reason),
limit_reason=limit_reason,
)
def _search_content(self, pattern: str, path: str, file_glob: Optional[str],
@ -2102,12 +2122,13 @@ class ShellFileOperations(FileOperations):
# introduce false errors on a successful-but-truncated search.
cmd = "set -o pipefail; " + " ".join(cmd_parts)
result = self._exec(cmd, timeout=60)
stdout, limit_reason = _search_stdout_and_limit(result)
# _exec merges stderr into stdout (stderr=subprocess.STDOUT), so rg's
# diagnostic lines ("rg: <file>: <error>", "rg: regex parse error:")
# are interleaved with match output. Split them out: diagnostics must
# not be parsed as matches, and on a hard error they ARE the message.
diagnostics, payload = _split_tool_diagnostics(result.stdout)
diagnostics, payload = _split_tool_diagnostics(stdout)
# rg exit codes: 0=matches found, 1=no matches, 2=error. rg returns 2
# even on partial errors (e.g. one unreadable file in a tree that
@ -2124,7 +2145,12 @@ class ShellFileOperations(FileOperations):
all_files = [f for f in stdout.strip().split('\n') if f]
total = len(all_files)
page = all_files[offset:offset + limit]
return SearchResult(files=page, total_count=total)
return SearchResult(
files=page,
total_count=total,
truncated=bool(limit_reason),
limit_reason=limit_reason,
)
elif output_mode == "count":
counts = {}
@ -2136,7 +2162,12 @@ class ShellFileOperations(FileOperations):
counts[parts[0]] = int(parts[1])
except ValueError:
pass
return SearchResult(counts=counts, total_count=sum(counts.values()))
return SearchResult(
counts=counts,
total_count=sum(counts.values()),
truncated=bool(limit_reason),
limit_reason=limit_reason,
)
else:
# Parse content matches and context lines.
@ -2177,7 +2208,8 @@ class ShellFileOperations(FileOperations):
return SearchResult(
matches=page,
total_count=total,
truncated=total > offset + limit
truncated=total > offset + limit or bool(limit_reason),
limit_reason=limit_reason,
)
def _search_with_grep(self, pattern: str, path: str, file_glob: Optional[str],
@ -2218,12 +2250,13 @@ class ShellFileOperations(FileOperations):
# pipefail does not turn truncated results into false errors.
cmd = "set -o pipefail; " + " ".join(cmd_parts)
result = self._exec(cmd, timeout=60)
stdout, limit_reason = _search_stdout_and_limit(result)
# _exec merges stderr into stdout, so grep's diagnostic lines
# ("grep: <file>: <error>") are interleaved with matches. Split them
# out so they're never parsed as matches and so a hard error has a
# clean message.
diagnostics, payload = _split_tool_diagnostics(result.stdout)
diagnostics, payload = _split_tool_diagnostics(stdout)
# grep exit codes: 0=matches found, 1=no matches, 2=error. grep
# returns 2 on partial errors (e.g. an unreadable file) even when
@ -2238,7 +2271,12 @@ class ShellFileOperations(FileOperations):
all_files = [f for f in stdout.strip().split('\n') if f]
total = len(all_files)
page = all_files[offset:offset + limit]
return SearchResult(files=page, total_count=total)
return SearchResult(
files=page,
total_count=total,
truncated=bool(limit_reason),
limit_reason=limit_reason,
)
elif output_mode == "count":
counts = {}
@ -2250,7 +2288,12 @@ class ShellFileOperations(FileOperations):
counts[parts[0]] = int(parts[1])
except ValueError:
pass
return SearchResult(counts=counts, total_count=sum(counts.values()))
return SearchResult(
counts=counts,
total_count=sum(counts.values()),
truncated=bool(limit_reason),
limit_reason=limit_reason,
)
else:
# grep match lines: "file:lineno:content" (colon)
@ -2288,5 +2331,6 @@ class ShellFileOperations(FileOperations):
return SearchResult(
matches=page,
total_count=total,
truncated=total > offset + limit
truncated=total > offset + limit or bool(limit_reason),
limit_reason=limit_reason,
)