mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-05-19 07:59:34 +00:00
Merge pull request #1596 from ruizanthony/fix/code-execution-pty-fd-leak
fix(code_execution): close PTY file descriptors
This commit is contained in:
commit
48977bffc5
3 changed files with 151 additions and 20 deletions
|
|
@ -21,8 +21,15 @@ class LocalInteractiveSession:
|
|||
|
||||
async def close(self):
|
||||
if self.session:
|
||||
self.session.kill()
|
||||
# self.session.wait()
|
||||
session = self.session
|
||||
self.session = None
|
||||
try:
|
||||
await session.close()
|
||||
except Exception:
|
||||
try:
|
||||
session.kill()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def send_command(self, command: str):
|
||||
if not self.session:
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ class TTYSession:
|
|||
self.echo = echo # ← store preference
|
||||
self._proc = None
|
||||
self._buf: asyncio.Queue = None # type: ignore
|
||||
self._pump_task = None
|
||||
self._pty_master = None
|
||||
self._pty_master_ref = None
|
||||
|
||||
def __del__(self):
|
||||
# Simple cleanup on object destruction
|
||||
|
|
@ -46,30 +49,92 @@ class TTYSession:
|
|||
self._proc = await _spawn_posix_pty(
|
||||
self.cmd, self.cwd, self.env, self.echo
|
||||
) # ← pass echo
|
||||
self._pty_master_ref = getattr(self._proc, "_pty_master_ref", None)
|
||||
self._pty_master = (
|
||||
self._pty_master_ref.get("fd")
|
||||
if self._pty_master_ref is not None
|
||||
else getattr(self._proc, "_pty_master", None)
|
||||
)
|
||||
self._pump_task = asyncio.create_task(self._pump_stdout())
|
||||
|
||||
async def close(self):
|
||||
# Cancel the pump task if it exists
|
||||
if hasattr(self, "_pump_task") and self._pump_task:
|
||||
if self._pump_task:
|
||||
self._pump_task.cancel()
|
||||
try:
|
||||
await self._pump_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Terminate the process if it exists
|
||||
if self._proc:
|
||||
self._proc.terminate()
|
||||
await self._proc.wait()
|
||||
try:
|
||||
if getattr(self._proc, "returncode", None) is None:
|
||||
self._proc.terminate()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await self._proc.wait()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._release_pty_master()
|
||||
self._proc = None
|
||||
self._pump_task = None
|
||||
|
||||
def _release_pty_master(self):
|
||||
"""Release the POSIX PTY master exactly once.
|
||||
|
||||
The fd number is invalidated before os.close() so that a concurrent or
|
||||
later cleanup path cannot close the same integer after the OS has reused
|
||||
it for another file/socket.
|
||||
"""
|
||||
ref = self._pty_master_ref
|
||||
master = ref.get("fd") if ref is not None else self._pty_master
|
||||
if master is None:
|
||||
self._pty_master = None
|
||||
return
|
||||
if ref is not None:
|
||||
ref["fd"] = None
|
||||
self._pty_master = None
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.remove_reader(master)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
os.close(master)
|
||||
except OSError:
|
||||
pass
|
||||
self._pty_master_ref = None
|
||||
|
||||
async def send(self, data: str | bytes):
|
||||
if self._proc is None:
|
||||
raise RuntimeError("TTYSpawn is not started")
|
||||
if not _IS_WIN:
|
||||
master = (
|
||||
self._pty_master_ref.get("fd")
|
||||
if self._pty_master_ref is not None
|
||||
else self._pty_master
|
||||
)
|
||||
if master is None:
|
||||
raise RuntimeError("TTYSpawn PTY is closed")
|
||||
if getattr(self._proc, "returncode", None) is not None:
|
||||
raise RuntimeError("TTYSpawn process has exited")
|
||||
if isinstance(data, str):
|
||||
data = data.encode(self.encoding)
|
||||
self._proc.stdin.write(data) # type: ignore
|
||||
await self._proc.stdin.drain() # type: ignore
|
||||
try:
|
||||
self._proc.stdin.write(data) # type: ignore
|
||||
await self._proc.stdin.drain() # type: ignore
|
||||
except OSError as e:
|
||||
if e.errno in (errno.EBADF, errno.EIO, errno.EINVAL):
|
||||
self._release_pty_master()
|
||||
raise RuntimeError("TTYSpawn PTY is closed") from e
|
||||
raise
|
||||
|
||||
async def sendline(self, line: str):
|
||||
await self.send(line + "\n")
|
||||
|
|
@ -99,6 +164,7 @@ class TTYSession:
|
|||
except ProcessLookupError:
|
||||
# Child already gone – treat as successfully killed
|
||||
pass
|
||||
self._release_pty_master()
|
||||
|
||||
async def read(self, timeout=None):
|
||||
# Return any decoded text the child produced, or None on timeout
|
||||
|
|
@ -173,10 +239,34 @@ async def _spawn_posix_pty(cmd, cwd, env, echo):
|
|||
|
||||
loop = asyncio.get_running_loop()
|
||||
reader = asyncio.StreamReader()
|
||||
master_ref = {"fd": master}
|
||||
|
||||
def _release_master_fd():
|
||||
cur = master_ref.get("fd")
|
||||
if cur is None:
|
||||
return
|
||||
# Invalidate before close so later cleanup cannot close a reused fd.
|
||||
master_ref["fd"] = None
|
||||
try:
|
||||
proc._pty_master = None # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
loop.remove_reader(cur)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
os.close(cur)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _on_data():
|
||||
cur = master_ref.get("fd")
|
||||
if cur is None:
|
||||
reader.feed_eof()
|
||||
return
|
||||
try:
|
||||
data = os.read(master, 1 << 16)
|
||||
data = os.read(cur, 1 << 16)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EIO: # EIO == EOF on some systems
|
||||
raise
|
||||
|
|
@ -185,19 +275,24 @@ async def _spawn_posix_pty(cmd, cwd, env, echo):
|
|||
reader.feed_data(data)
|
||||
else:
|
||||
reader.feed_eof()
|
||||
loop.remove_reader(master)
|
||||
_release_master_fd()
|
||||
|
||||
loop.add_reader(master, _on_data)
|
||||
|
||||
class _Stdin:
|
||||
def write(self, d):
|
||||
os.write(master, d)
|
||||
cur = master_ref.get("fd")
|
||||
if cur is None:
|
||||
raise OSError(errno.EBADF, "PTY master closed")
|
||||
os.write(cur, d)
|
||||
|
||||
async def drain(self):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
proc.stdin = _Stdin() # type: ignore
|
||||
proc.stdout = reader
|
||||
proc._pty_master = master # type: ignore[attr-defined]
|
||||
proc._pty_master_ref = master_ref # type: ignore[attr-defined]
|
||||
return proc
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import errno
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
import shlex
|
||||
|
|
@ -15,6 +16,17 @@ from plugins._code_execution.helpers.shell_local import LocalInteractiveSession
|
|||
from plugins._code_execution.helpers.shell_ssh import SSHInteractiveSession
|
||||
|
||||
|
||||
def _is_closed_pty_error(exc: BaseException) -> bool:
|
||||
if isinstance(exc, RuntimeError) and "TTYSpawn PTY is closed" in str(exc):
|
||||
return True
|
||||
if isinstance(exc, OSError) and exc.errno in (errno.EBADF, errno.EIO, errno.EINVAL):
|
||||
return True
|
||||
cause = getattr(exc, "__cause__", None)
|
||||
if cause and cause is not exc:
|
||||
return _is_closed_pty_error(cause)
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShellWrap:
|
||||
id: int
|
||||
|
|
@ -194,12 +206,12 @@ class CodeExecution(Tool):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
if i == 1:
|
||||
PrintStyle.error(str(e))
|
||||
if _is_closed_pty_error(e) and i == 0:
|
||||
PrintStyle.warning(f"Terminal session {session} was closed; resetting and retrying once.")
|
||||
await self.prepare_state(cfg, reset=True, session=session)
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
PrintStyle.error(str(e))
|
||||
raise
|
||||
|
||||
def format_command_for_output(self, command: str):
|
||||
short_cmd = command[:250]
|
||||
|
|
@ -245,9 +257,19 @@ class CodeExecution(Tool):
|
|||
|
||||
while True:
|
||||
await asyncio.sleep(sleep_time)
|
||||
full_output, partial_output = await self.state.shells[session].session.read_output(
|
||||
timeout=1, reset_full_output=reset_full_output
|
||||
)
|
||||
try:
|
||||
full_output, partial_output = await self.state.shells[session].session.read_output(
|
||||
timeout=1, reset_full_output=reset_full_output
|
||||
)
|
||||
except Exception as e:
|
||||
if _is_closed_pty_error(e):
|
||||
await self.prepare_state(cfg, reset=True, session=session)
|
||||
self.mark_session_idle(session)
|
||||
sysinfo = "Terminal session was closed and has been reset. Please run the command again."
|
||||
response = self.agent.read_prompt("fw.code.info.md", info=sysinfo)
|
||||
self.log.update(content=prefix + response)
|
||||
return response
|
||||
raise
|
||||
reset_full_output = False # only reset once
|
||||
|
||||
await self.agent.handle_intervention()
|
||||
|
|
@ -364,9 +386,16 @@ class CodeExecution(Tool):
|
|||
prompt_patterns = cfg["prompt_patterns"]
|
||||
dialog_patterns = cfg["dialog_patterns"]
|
||||
|
||||
full_output, _ = await self.state.shells[session].session.read_output(
|
||||
timeout=1, reset_full_output=reset_full_output
|
||||
)
|
||||
try:
|
||||
full_output, _ = await self.state.shells[session].session.read_output(
|
||||
timeout=1, reset_full_output=reset_full_output
|
||||
)
|
||||
except Exception as e:
|
||||
if _is_closed_pty_error(e):
|
||||
await self.prepare_state(cfg, reset=True, session=session)
|
||||
self.mark_session_idle(session)
|
||||
return None
|
||||
raise
|
||||
truncated_output = self.fix_full_output(full_output)
|
||||
self.set_progress(truncated_output)
|
||||
heading = self.get_heading_from_output(truncated_output, 0)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue