mirror of
https://github.com/eigent-ai/eigent.git
synced 2026-05-27 08:36:24 +00:00
Update toolkit_listen.py
This commit is contained in:
parent
dc1c352d0a
commit
30fa9fc4cc
1 changed files with 20 additions and 3 deletions
|
|
@ -33,6 +33,20 @@ import logging
|
|||
logger = logging.getLogger("toolkit_listen")
|
||||
|
||||
|
||||
def _filter_kwargs_for_callable(func: Callable[..., Any], kwargs: dict) -> dict:
|
||||
"""Drop unexpected kwargs unless the callable accepts **kwargs."""
|
||||
if not kwargs:
|
||||
return kwargs
|
||||
try:
|
||||
sig = signature(func)
|
||||
except (TypeError, ValueError):
|
||||
return kwargs
|
||||
if any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values()):
|
||||
return kwargs
|
||||
allowed = set(sig.parameters.keys())
|
||||
return {k: v for k, v in kwargs.items() if k in allowed}
|
||||
|
||||
|
||||
def _safe_put_queue(task_lock, data):
|
||||
"""Safely put data to the queue, handling both sync and async contexts"""
|
||||
try:
|
||||
|
|
@ -154,7 +168,8 @@ def listen_toolkit(
|
|||
error = None
|
||||
res = None
|
||||
try:
|
||||
res = await func(*args, **kwargs)
|
||||
safe_kwargs = _filter_kwargs_for_callable(func, kwargs)
|
||||
res = await func(*args, **safe_kwargs)
|
||||
except Exception as e:
|
||||
error = e
|
||||
|
||||
|
|
@ -212,7 +227,8 @@ def listen_toolkit(
|
|||
# Check if api_task_id exists
|
||||
if not hasattr(toolkit, 'api_task_id'):
|
||||
logger.warning(f"[listen_toolkit] {toolkit.__class__.__name__} missing api_task_id, calling method directly")
|
||||
return func(*args, **kwargs)
|
||||
safe_kwargs = _filter_kwargs_for_callable(func, kwargs)
|
||||
return func(*args, **safe_kwargs)
|
||||
|
||||
task_lock = get_task_lock(toolkit.api_task_id)
|
||||
|
||||
|
|
@ -260,7 +276,8 @@ def listen_toolkit(
|
|||
error = None
|
||||
res = None
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
safe_kwargs = _filter_kwargs_for_callable(func, kwargs)
|
||||
res = func(*args, **safe_kwargs)
|
||||
# Safety check: if the result is a coroutine, this is a programming error
|
||||
if asyncio.iscoroutine(res):
|
||||
error_msg = f"Async function {func.__name__} was incorrectly called in sync context. This is a bug - the function should be marked as async or should not return a coroutine."
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue