agent-zero/plugins/_text_editor/tools/text_editor.py
Alessandro 4c2bc3d783 Add context-based patch_text support to text_editor
Introduces patch_text editing for the Docker-local text_editor, sharing request validation and freshness-state logic with text_editor_remote while preserving legacy line-number edits. Adds anchored context patching, safer state handling after context edits, updated model guidance, live remote wrapper reuse, and focused regression coverage for chained patches and Python replacement cases.
2026-04-21 18:18:59 +02:00

381 lines
12 KiB
Python

from helpers.tool import Tool, Response
from helpers.extension import call_extensions_async
from helpers import plugins, runtime
from plugins._text_editor.helpers.file_ops import (
FileInfo,
read_file,
write_file,
validate_edits,
apply_patch,
apply_context_patch_file,
file_info,
)
from plugins._text_editor.helpers.patch_request import parse_patch_request
from plugins._text_editor.helpers.patch_state import (
LOCAL_FRESHNESS_KEY,
apply_patch_post_state,
check_patch_freshness,
mark_file_state_stale,
record_file_state,
)
# Key used in agent.data to store file state for patch validation
# Value: {path: {"mtime": float, "total_lines": int}}
_MTIME_KEY = LOCAL_FRESHNESS_KEY
class TextEditor(Tool):
async def execute(self, **kwargs):
if self.method == "read":
return await self._read(**kwargs)
elif self.method == "write":
return await self._write(**kwargs)
elif self.method == "patch":
return await self._patch(**kwargs)
return Response(
message=f"unknown method '{self.name}:{self.method}'",
break_loop=False,
)
# ------------------------------------------------------------------
# READ
# ------------------------------------------------------------------
async def _read(self, path: str = "", **kwargs) -> Response:
if not path:
return self._error("read", path, "path is required")
cfg = _get_config(self.agent)
line_from = int(kwargs.get("line_from", 1))
raw_to = kwargs.get("line_to")
line_to = int(raw_to) if raw_to is not None else None
result = await runtime.call_development_function(
read_file,
path,
line_from=line_from,
line_to=line_to,
max_line_tokens=cfg["max_line_tokens"],
default_line_count=cfg["default_line_count"],
max_total_read_tokens=cfg["max_total_read_tokens"],
)
if result["error"]:
return self._error("read", path, result["error"])
info = await runtime.call_development_function(file_info, path)
record_file_state(
self.agent,
info,
key=_MTIME_KEY,
total_lines=result["total_lines"],
)
# Extension point
ext_data = {
"content": result["content"],
"warnings": result["warnings"],
}
await call_extensions_async(
"text_editor_read_after", agent=self.agent, data=ext_data
)
msg = self.agent.read_prompt(
"fw.text_editor.read_ok.md",
path=info["expanded"],
total_lines=str(result["total_lines"]),
warnings=ext_data["warnings"],
content=ext_data["content"],
)
return Response(message=msg, break_loop=False)
# ------------------------------------------------------------------
# WRITE
# ------------------------------------------------------------------
async def _write(
self, path: str = "", content: str | None = "", **kwargs
) -> Response:
if not path:
return self._error("write", path, "path is required")
# Extension point
ext_data = {"path": path, "content": content}
await call_extensions_async(
"text_editor_write_before", agent=self.agent, data=ext_data
)
result = await runtime.call_development_function(
write_file, ext_data["path"], ext_data["content"]
)
if result["error"]:
return self._error("write", path, result["error"])
# Extension point
await call_extensions_async(
"text_editor_write_after", agent=self.agent,
data={"path": path, "total_lines": result["total_lines"]},
)
info = await runtime.call_development_function(file_info, path)
record_file_state(
self.agent,
info,
key=_MTIME_KEY,
total_lines=result["total_lines"],
)
cfg = _get_config(self.agent)
read_result = await runtime.call_development_function(
read_file,
info["expanded"],
line_from=1,
line_to=result["total_lines"],
max_line_tokens=cfg["max_line_tokens"],
max_total_read_tokens=cfg["max_total_read_tokens"],
)
msg = self.agent.read_prompt(
"fw.text_editor.write_ok.md",
path=info["expanded"],
total_lines=str(result["total_lines"]),
content=read_result["content"],
)
return Response(message=msg, break_loop=False)
# ------------------------------------------------------------------
# PATCH
# ------------------------------------------------------------------
async def _patch(
self, path: str = "", edits=None, patch_text=None, **kwargs
) -> Response:
if not path:
return self._error("patch", path, "path is required")
patch_request, err = parse_patch_request(
edits,
patch_text,
missing_error="",
)
if err:
return self._error("patch", path, err)
info = await runtime.call_development_function(file_info, path)
if not info["is_file"]:
return self._error("patch", path, "file not found")
expanded = info["expanded"]
if patch_request and patch_request.mode == "patch_text":
return await self._patch_context(
path, expanded, patch_request.patch_text
)
return await self._patch_edits(
path,
expanded,
info,
patch_request.edits if patch_request else edits,
)
async def _patch_edits(
self, path: str, expanded: str, info: FileInfo, edits
) -> Response:
freshness_code = check_patch_freshness(self.agent, info, key=_MTIME_KEY)
if freshness_code:
return self._error(
"patch",
path,
_freshness_error_message(self.agent, info, freshness_code),
)
parsed, err = validate_edits(edits)
if err:
return self._error("patch", path, err)
# Extension point
ext_data = {"path": expanded, "edits": parsed}
await call_extensions_async(
"text_editor_patch_before", agent=self.agent, data=ext_data
)
try:
total_lines = await runtime.call_development_function(
apply_patch, ext_data["path"], ext_data["edits"]
)
except Exception as exc:
return self._error("patch", path, str(exc))
# Extension point
await call_extensions_async(
"text_editor_patch_after", agent=self.agent,
data={"path": expanded, "total_lines": total_lines},
)
# Refresh file info after patch for updated mtime
post_info = await runtime.call_development_function(
file_info, expanded
)
apply_patch_post_state(
self.agent,
post_info,
ext_data["edits"],
key=_MTIME_KEY,
total_lines=total_lines,
)
patch_content = await _read_patch_region(
expanded, ext_data["edits"], total_lines, _get_config(self.agent)
)
msg = self.agent.read_prompt(
"fw.text_editor.patch_ok.md",
path=expanded,
edit_count=str(len(edits or [])),
total_lines=str(total_lines),
content=patch_content,
)
return Response(message=msg, break_loop=False)
async def _patch_context(
self, path: str, expanded: str, patch_text
) -> Response:
patch_text = str(patch_text)
if not patch_text.strip():
return self._error("patch", path, "patch_text must not be empty")
# Extension point
ext_data = {
"path": expanded,
"patch_text": patch_text,
"edits": [],
"mode": "patch_text",
}
await call_extensions_async(
"text_editor_patch_before", agent=self.agent, data=ext_data
)
try:
result = await runtime.call_development_function(
apply_context_patch_file,
ext_data["path"],
ext_data["patch_text"],
)
except Exception as exc:
return self._error("patch", path, str(exc))
total_lines = result["total_lines"]
# Extension point
await call_extensions_async(
"text_editor_patch_after", agent=self.agent,
data={
"path": ext_data["path"],
"total_lines": total_lines,
"hunk_count": result["hunk_count"],
"mode": "patch_text",
},
)
post_info = await runtime.call_development_function(
file_info, ext_data["path"]
)
mark_file_state_stale(self.agent, post_info, key=_MTIME_KEY)
patch_content = await _read_context_patch_region(
ext_data["path"], result, _get_config(self.agent)
)
msg = self.agent.read_prompt(
"fw.text_editor.patch_ok.md",
path=ext_data["path"],
edit_count=str(result["hunk_count"]),
total_lines=str(total_lines),
content=patch_content,
)
return Response(message=msg, break_loop=False)
# ------------------------------------------------------------------
# Shared error helper
# ------------------------------------------------------------------
def _error(self, action: str, path: str, error: str) -> Response:
msg = self.agent.read_prompt(
f"fw.text_editor.{action}_error.md", path=path, error=error
)
return Response(message=msg, break_loop=False)
# ------------------------------------------------------------------
# Standalone helpers
# ------------------------------------------------------------------
async def _read_patch_region(
path: str, edits: list[dict], total_lines: int, cfg: dict
) -> str:
if not edits:
return ""
min_from = min(e["from"] for e in edits)
added = sum(
e["content"].count("\n")
+ (1 if e["content"] and not e["content"].endswith("\n") else 0)
for e in edits if e.get("content")
)
removed = sum(
max(e["to"] - e["from"] + 1, 0)
for e in edits if not e.get("insert")
)
max_to = max(e["to"] for e in edits)
end_line = max_to + added - removed + 3
result = await runtime.call_development_function(
read_file,
path,
line_from=max(min_from - 1, 1),
line_to=min(end_line, total_lines),
max_line_tokens=cfg["max_line_tokens"],
max_total_read_tokens=cfg["max_total_read_tokens"],
)
return result["content"]
async def _read_context_patch_region(
path: str, result: dict, cfg: dict
) -> str:
total_lines = int(result["total_lines"])
if total_lines <= 0:
return ""
line_from = min(max(int(result["line_from"]), 1), total_lines)
line_to = min(max(int(result["line_to"]), line_from) + 3, total_lines)
read_result = await runtime.call_development_function(
read_file,
path,
line_from=max(line_from - 1, 1),
line_to=line_to,
max_line_tokens=cfg["max_line_tokens"],
max_total_read_tokens=cfg["max_total_read_tokens"],
)
return read_result["content"]
def _freshness_error_message(agent, info: FileInfo, code: str) -> str:
prompt = (
"fw.text_editor.patch_stale_read.md"
if code == "patch_stale_read"
else "fw.text_editor.patch_need_read.md"
)
return agent.read_prompt(prompt, path=info["expanded"])
# ------------------------------------------------------------------
# Config
# ------------------------------------------------------------------
def _get_config(agent) -> dict:
config = plugins.get_plugin_config("_text_editor", agent=agent) or {}
return {
"max_line_tokens": int(config.get("max_line_tokens", 500)),
"default_line_count": int(config.get("default_line_count", 100)),
"max_total_read_tokens": int(config.get("max_total_read_tokens", 4000)),
}