mirror of
https://github.com/eigent-ai/eigent.git
synced 2026-05-27 17:05:53 +00:00
Merge branch 'main' into chatbox-ux
This commit is contained in:
commit
bc139b71e8
2 changed files with 162 additions and 10 deletions
|
|
@ -1,5 +1,9 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
from camel.toolkits.terminal_toolkit import TerminalToolkit as BaseTerminalToolkit
|
||||
from camel.toolkits.terminal_toolkit.terminal_toolkit import _to_plain
|
||||
from app.component.environment import env
|
||||
|
|
@ -12,6 +16,8 @@ from app.service.task import process_task
|
|||
@auto_listen_toolkit(BaseTerminalToolkit)
|
||||
class TerminalToolkit(BaseTerminalToolkit, AbstractToolkit):
|
||||
agent_name: str = Agents.developer_agent
|
||||
_thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
_thread_local = threading.local()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -31,6 +37,11 @@ class TerminalToolkit(BaseTerminalToolkit, AbstractToolkit):
|
|||
self.agent_name = agent_name
|
||||
if working_directory is None:
|
||||
working_directory = env("file_save_path", os.path.expanduser("~/.eigent/terminal/"))
|
||||
if TerminalToolkit._thread_pool is None:
|
||||
TerminalToolkit._thread_pool = ThreadPoolExecutor(
|
||||
max_workers=1,
|
||||
thread_name_prefix="terminal_toolkit"
|
||||
)
|
||||
super().__init__(
|
||||
timeout=timeout,
|
||||
working_directory=working_directory,
|
||||
|
|
@ -55,16 +66,57 @@ class TerminalToolkit(BaseTerminalToolkit, AbstractToolkit):
|
|||
|
||||
def _update_terminal_output(self, output: str):
|
||||
task_lock = get_task_lock(self.api_task_id)
|
||||
# This method will be called during init. At that time, the process_task_id parameter does not exist, so it is set to be empty default
|
||||
process_task_id = process_task.get("")
|
||||
task = asyncio.create_task(
|
||||
task_lock.put_queue(
|
||||
ActionTerminalData(
|
||||
action=Action.terminal,
|
||||
process_task_id=process_task_id,
|
||||
data=output,
|
||||
)
|
||||
|
||||
# Create the coroutine
|
||||
coro = task_lock.put_queue(
|
||||
ActionTerminalData(
|
||||
action=Action.terminal,
|
||||
process_task_id=process_task_id,
|
||||
data=output,
|
||||
)
|
||||
)
|
||||
if hasattr(task_lock, "add_background_task"):
|
||||
task_lock.add_background_task(task)
|
||||
|
||||
# Try to get the current event loop, if none exists, create a new one in a thread
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in an async context, schedule the coroutine
|
||||
task = loop.create_task(coro)
|
||||
if hasattr(task_lock, "add_background_task"):
|
||||
task_lock.add_background_task(task)
|
||||
except RuntimeError:
|
||||
self._thread_pool.submit(self._run_coro_in_thread, coro,task_lock)
|
||||
|
||||
@staticmethod
|
||||
def _run_coro_in_thread(coro,task_lock):
|
||||
"""
|
||||
Execute coro in the thread pool, with each thread bound to a long-term event loop
|
||||
"""
|
||||
if not hasattr(TerminalToolkit._thread_local, "loop"):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
TerminalToolkit._thread_local.loop = loop
|
||||
else:
|
||||
loop = TerminalToolkit._thread_local.loop
|
||||
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
TerminalToolkit._thread_local.loop = loop
|
||||
|
||||
try:
|
||||
task = loop.create_task(coro)
|
||||
if hasattr(task_lock, "add_background_task"):
|
||||
task_lock.add_background_task(task)
|
||||
loop.run_until_complete(task)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Failed to execute coroutine in thread pool: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def shutdown(cls):
|
||||
if cls._thread_pool:
|
||||
cls._thread_pool.shutdown(wait=True)
|
||||
cls._thread_pool = None
|
||||
|
|
|
|||
100
backend/tests/unit/utils/test_terminal_toolkit.py
Normal file
100
backend/tests/unit/utils/test_terminal_toolkit.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import pytest
|
||||
from app.service.task import task_locks, TaskLock
|
||||
from app.utils.toolkit.terminal_toolkit import TerminalToolkit
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTerminalToolkit:
|
||||
"""Test to verify the RuntimeError: no running event loop."""
|
||||
|
||||
def test_no_runtime_error_in_sync_context(self):
|
||||
"""Test no running event loop."""
|
||||
test_api_task_id = "test_api_task_123"
|
||||
|
||||
if test_api_task_id not in task_locks:
|
||||
task_locks[test_api_task_id] = TaskLock(id=test_api_task_id, queue=asyncio.Queue(), human_input={})
|
||||
toolkit = TerminalToolkit("test_api_task_123")
|
||||
|
||||
# This should NOT raise RuntimeError: no running event loop
|
||||
# This simulates the exact scenario from the error traceback
|
||||
try:
|
||||
toolkit._write_to_log("/tmp/test.log", "Test output")
|
||||
time.sleep(0.1) # Give thread time to complete
|
||||
|
||||
except RuntimeError as e:
|
||||
if "no running event loop" in str(e):
|
||||
pytest.fail("RuntimeError: no running event loop should not be raised - the fix is not working!")
|
||||
else:
|
||||
raise # Re-raise if it's a different RuntimeError
|
||||
|
||||
def test_multiple_calls_no_runtime_error(self):
|
||||
"""Test that multiple calls don't raise RuntimeError."""
|
||||
test_api_task_id = "test_api_task_123"
|
||||
|
||||
if test_api_task_id not in task_locks:
|
||||
task_locks[test_api_task_id] = TaskLock(id=test_api_task_id, queue=asyncio.Queue(), human_input={})
|
||||
toolkit = TerminalToolkit("test_api_task_123")
|
||||
|
||||
# Make multiple calls - none should raise RuntimeError
|
||||
try:
|
||||
for i in range(5):
|
||||
toolkit._write_to_log(f"/tmp/test_{i}.log", f"Output {i}")
|
||||
time.sleep(0.2) # Give threads time to complete
|
||||
except RuntimeError as e:
|
||||
if "no running event loop" in str(e):
|
||||
pytest.fail("RuntimeError: no running event loop should not be raised!")
|
||||
else:
|
||||
raise
|
||||
|
||||
def test_thread_safety_no_runtime_error(self):
|
||||
"""Test thread safety without RuntimeError."""
|
||||
test_api_task_id = "test_api_task_123"
|
||||
|
||||
if test_api_task_id not in task_locks:
|
||||
task_locks[test_api_task_id] = TaskLock(id=test_api_task_id, queue=asyncio.Queue(), human_input={})
|
||||
toolkit = TerminalToolkit("test_api_task_123")
|
||||
|
||||
# Create multiple threads that call _write_to_log
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(
|
||||
target=toolkit._write_to_log,
|
||||
args=(f"/tmp/test_{i}.log", f"Thread {i} output")
|
||||
)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
time.sleep(0.2) # Give async operations time to complete
|
||||
|
||||
# Should not have raised any RuntimeError
|
||||
|
||||
def test_async_context_still_works(self):
|
||||
"""Test that async context still works without RuntimeError."""
|
||||
test_api_task_id = "test_api_task_123"
|
||||
|
||||
if test_api_task_id not in task_locks:
|
||||
task_locks[test_api_task_id] = TaskLock(id=test_api_task_id, queue=asyncio.Queue(), human_input={})
|
||||
toolkit = TerminalToolkit("test_api_task_123")
|
||||
|
||||
async def test_async_context():
|
||||
toolkit._write_to_log("/tmp/async_test.log", "Async context test")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Should work in async context without RuntimeError
|
||||
try:
|
||||
asyncio.run(test_async_context())
|
||||
except RuntimeError as e:
|
||||
if "no running event loop" in str(e):
|
||||
pytest.fail("RuntimeError: no running event loop should not be raised in async context!")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue