mirror of
https://github.com/eigent-ai/eigent.git
synced 2026-05-25 14:46:19 +00:00
188 lines
6.6 KiB
Python
188 lines
6.6 KiB
Python
# ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ========= Copyright 2023-2026 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Optional, Set
|
|
|
|
from colorama import Fore
|
|
|
|
from camel.societies.workforce.base import BaseNode
|
|
from camel.societies.workforce.task_channel import TaskChannel
|
|
from camel.societies.workforce.utils import check_if_running
|
|
from camel.tasks.task import Task, TaskState
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Worker(BaseNode, ABC):
|
|
r"""A worker node that works on tasks. It is the basic unit of task
|
|
processing in the workforce system.
|
|
|
|
Args:
|
|
description (str): Description of the node.
|
|
node_id (Optional[str]): ID of the node. If not provided, it will
|
|
be generated automatically. (default: :obj:`None`)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
description: str,
|
|
node_id: Optional[str] = None,
|
|
) -> None:
|
|
super().__init__(description, node_id=node_id)
|
|
self._active_task_ids: Set[str] = set()
|
|
self._running_tasks: Set[asyncio.Task] = set()
|
|
|
|
def __repr__(self):
|
|
return f"Worker node {self.node_id} ({self.description})"
|
|
|
|
@abstractmethod
|
|
async def _process_task(
|
|
self, task: Task, dependencies: List[Task]
|
|
) -> TaskState:
|
|
r"""Processes a task based on its dependencies.
|
|
|
|
Returns:
|
|
'DONE' if the task is successfully processed,
|
|
'FAILED' if the processing fails.
|
|
"""
|
|
pass
|
|
|
|
async def _get_assigned_task(self) -> Task:
|
|
r"""Get a task assigned to this node from the channel."""
|
|
return await self._channel.get_assigned_task_by_assignee(self.node_id)
|
|
|
|
@staticmethod
|
|
def _get_dep_tasks_info(dependencies: List[Task]) -> str:
|
|
result_lines = [
|
|
f"id: {dep_task.id}, content: {dep_task.content}. "
|
|
f"result: {dep_task.result}."
|
|
for dep_task in dependencies
|
|
]
|
|
result_str = "\n".join(result_lines)
|
|
return result_str
|
|
|
|
@check_if_running(False)
|
|
def set_channel(self, channel: TaskChannel):
|
|
self._channel = channel
|
|
|
|
async def _process_single_task(self, task: Task) -> None:
|
|
r"""Process a single task and handle its completion/failure."""
|
|
try:
|
|
self._active_task_ids.add(task.id)
|
|
print(
|
|
f"{Fore.YELLOW}{self} get task {task.id}: {task.content}"
|
|
f"{Fore.RESET}"
|
|
)
|
|
|
|
# Process the task
|
|
task_state = await self._process_task(task, task.dependencies)
|
|
|
|
# Update the result and status of the task
|
|
task.set_state(task_state)
|
|
|
|
await self._channel.return_task(task.id)
|
|
except Exception as e:
|
|
logger.error(f"Error processing task {task.id}: {e}")
|
|
# Store error information in task result
|
|
task.result = f"{type(e).__name__}: {e!s}"
|
|
task.set_state(TaskState.FAILED)
|
|
await self._channel.return_task(task.id)
|
|
finally:
|
|
self._active_task_ids.discard(task.id)
|
|
|
|
@check_if_running(False)
|
|
async def _listen_to_channel(self):
|
|
r"""Continuously listen to the channel and process assigned tasks.
|
|
|
|
This method supports parallel task execution without artificial limits.
|
|
"""
|
|
self._running = True
|
|
logger.info(f"{self} started.")
|
|
|
|
while self._running:
|
|
try:
|
|
# Clean up completed tasks
|
|
completed_tasks = [t for t in self._running_tasks if t.done()]
|
|
for completed_task in completed_tasks:
|
|
self._running_tasks.discard(completed_task)
|
|
# Check for exceptions in completed tasks
|
|
try:
|
|
await completed_task
|
|
except Exception as e:
|
|
logger.error(f"Task processing failed: {e}")
|
|
|
|
# Try to get a new task (with short timeout to avoid blocking)
|
|
try:
|
|
task = await asyncio.wait_for(
|
|
self._get_assigned_task(), timeout=1.0
|
|
)
|
|
|
|
# Create and start processing task
|
|
task_coroutine = asyncio.create_task(
|
|
self._process_single_task(task)
|
|
)
|
|
self._running_tasks.add(task_coroutine)
|
|
|
|
except asyncio.TimeoutError:
|
|
# No tasks available, continue loop
|
|
if not self._running_tasks:
|
|
# No tasks running and none available, short sleep
|
|
await asyncio.sleep(0.1)
|
|
continue
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error in worker {self.node_id} listen loop: {e}"
|
|
)
|
|
await asyncio.sleep(0.1)
|
|
continue
|
|
|
|
# Wait for all remaining tasks to complete when stopping
|
|
if self._running_tasks:
|
|
logger.info(
|
|
f"{self} stopping, waiting for {len(self._running_tasks)} "
|
|
f"tasks to complete..."
|
|
)
|
|
await asyncio.gather(*self._running_tasks, return_exceptions=True)
|
|
|
|
logger.info(f"{self} stopped.")
|
|
|
|
@check_if_running(False)
|
|
async def start(self):
|
|
r"""Start the worker."""
|
|
await self._listen_to_channel()
|
|
|
|
@check_if_running(True)
|
|
def stop(self):
|
|
r"""Forcefully stop the worker.
|
|
|
|
Cancels all running tasks immediately and sets the stop flag.
|
|
The worker will exit after completing cancellation.
|
|
"""
|
|
# First cancel all running tasks to interrupt ongoing work
|
|
tasks_to_cancel = list(self._running_tasks)
|
|
for task in tasks_to_cancel:
|
|
if not task.done():
|
|
task.cancel()
|
|
|
|
# Clear the running tasks set since they're all cancelled
|
|
self._running_tasks.clear()
|
|
|
|
# Set stop flag to exit the listen loop
|
|
self._running = False
|
|
return
|