diff --git a/agent.py b/agent.py index f2140efea..cb767a5a0 100644 --- a/agent.py +++ b/agent.py @@ -335,6 +335,7 @@ class Agent: await self.call_extensions("before_main_llm_call", loop_data=self.loop_data) async def reasoning_callback(chunk: str, full: str): + await self.handle_intervention() if chunk == full: printer.print("Reasoning: ") # start of reasoning # Pass chunk and full data to extensions for processing @@ -349,6 +350,7 @@ class Agent: await self.handle_reasoning_stream(stream_data["full"]) async def stream_callback(chunk: str, full: str): + await self.handle_intervention() # output the agent response stream if chunk == full: printer.print("Response: ") # start of response @@ -804,6 +806,7 @@ class Agent: ) async def handle_reasoning_stream(self, stream: str): + await self.handle_intervention() await self.call_extensions( "reasoning_stream", loop_data=self.loop_data, @@ -811,6 +814,7 @@ class Agent: ) async def handle_response_stream(self, stream: str): + await self.handle_intervention() try: if len(stream) < 25: return # no reason to try diff --git a/models.py b/models.py index 02997e1c5..9fe322eae 100644 --- a/models.py +++ b/models.py @@ -23,6 +23,7 @@ from python.helpers.dotenv import load_dotenv from python.helpers.providers import get_provider_config from python.helpers.rate_limiter import RateLimiter from python.helpers.tokens import approximate_tokens +from python.helpers import dirty_json from langchain_core.language_models.chat_models import SimpleChatModel from langchain_core.outputs.chat_generation import ChatGenerationChunk @@ -90,6 +91,7 @@ class ChatChunk(TypedDict): rate_limiters: dict[str, RateLimiter] = {} api_keys_round_robin: dict[str, int] = {} + def get_api_key(service: str) -> str: # get api key for the service key = ( @@ -116,7 +118,14 @@ def get_rate_limiter( limiter.limits["output"] = output or 0 return limiter -async def apply_rate_limiter(model_config: ModelConfig|None, input_text: str, rate_limiter_callback: Callable[[str, str, int, int], Awaitable[bool]] | None = None): + +async def apply_rate_limiter( + model_config: ModelConfig | None, + input_text: str, + rate_limiter_callback: ( + Callable[[str, str, int, int], Awaitable[bool]] | None + ) = None, +): if not model_config: return limiter = get_rate_limiter( @@ -131,25 +140,41 @@ async def apply_rate_limiter(model_config: ModelConfig|None, input_text: str, ra await limiter.wait(rate_limiter_callback) return limiter -def apply_rate_limiter_sync(model_config: ModelConfig|None, input_text: str, rate_limiter_callback: Callable[[str, str, int, int], Awaitable[bool]] | None = None): + +def apply_rate_limiter_sync( + model_config: ModelConfig | None, + input_text: str, + rate_limiter_callback: ( + Callable[[str, str, int, int], Awaitable[bool]] | None + ) = None, +): if not model_config: return import asyncio, nest_asyncio + nest_asyncio.apply() - return asyncio.run(apply_rate_limiter(model_config, input_text, rate_limiter_callback)) + return asyncio.run( + apply_rate_limiter(model_config, input_text, rate_limiter_callback) + ) class LiteLLMChatWrapper(SimpleChatModel): model_name: str provider: str kwargs: dict = {} - + class Config: arbitrary_types_allowed = True extra = "allow" # Allow extra attributes validate_assignment = False # Don't validate on assignment - def __init__(self, model: str, provider: str, model_config: Optional[ModelConfig] = None, **kwargs: Any): + def __init__( + self, + model: str, + provider: str, + model_config: Optional[ModelConfig] = None, + **kwargs: Any, + ): model_value = f"{provider}/{model}" super().__init__(model_name=model_value, provider=provider, kwargs=kwargs) # type: ignore # Set A0 model config as instance attribute after parent init @@ -158,7 +183,7 @@ class LiteLLMChatWrapper(SimpleChatModel): @property def _llm_type(self) -> str: return "litellm-chat" - + def _convert_messages(self, messages: List[BaseMessage]) -> List[dict]: result = [] # Map LangChain message types to LiteLLM roles @@ -169,7 +194,9 @@ class LiteLLMChatWrapper(SimpleChatModel): "tool": "tool", } for m in messages: - role = role_mapping.get(m.type, m.type) + m_type = getattr(m, "type", getattr(m, "role", "")) + role = role_mapping.get(m_type, m_type) + content = getattr(m, "content", getattr(m, "text", "")) message_dict = {"role": role, "content": m.content} # Handle tool calls for AI messages @@ -215,12 +242,12 @@ class LiteLLMChatWrapper(SimpleChatModel): **kwargs: Any, ) -> str: import asyncio - + msgs = self._convert_messages(messages) - + # Apply rate limiting if configured apply_rate_limiter_sync(self.a0_model_conf, str(msgs)) - + # Call the model resp = completion( model=self.model_name, messages=msgs, stop=stop, **{**self.kwargs, **kwargs} @@ -238,12 +265,12 @@ class LiteLLMChatWrapper(SimpleChatModel): **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: import asyncio - + msgs = self._convert_messages(messages) - + # Apply rate limiting if configured apply_rate_limiter_sync(self.a0_model_conf, str(msgs)) - + for chunk in completion( model=self.model_name, messages=msgs, @@ -266,11 +293,10 @@ class LiteLLMChatWrapper(SimpleChatModel): **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: msgs = self._convert_messages(messages) - + # Apply rate limiting if configured await apply_rate_limiter(self.a0_model_conf, str(msgs)) - - + response = await acompletion( model=self.model_name, messages=msgs, @@ -294,7 +320,9 @@ class LiteLLMChatWrapper(SimpleChatModel): response_callback: Callable[[str, str], Awaitable[None]] | None = None, reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None, tokens_callback: Callable[[str, int], Awaitable[None]] | None = None, - rate_limiter_callback: Callable[[str, str, int, int], Awaitable[bool]] | None = None, + rate_limiter_callback: ( + Callable[[str, str, int, int], Awaitable[bool]] | None + ) = None, **kwargs: Any, ) -> Tuple[str, str]: @@ -312,7 +340,9 @@ class LiteLLMChatWrapper(SimpleChatModel): msgs_conv = self._convert_messages(messages) # Apply rate limiting if configured - limiter = await apply_rate_limiter(self.a0_model_conf, str(msgs_conv), rate_limiter_callback) + limiter = await apply_rate_limiter( + self.a0_model_conf, str(msgs_conv), rate_limiter_callback + ) # call model _completion = await acompletion( @@ -360,7 +390,27 @@ class LiteLLMChatWrapper(SimpleChatModel): return response, reasoning -class BrowserCompatibleChatWrapper(LiteLLMChatWrapper): +class AsyncAIChatReplacement: + class _Completions: + def __init__(self, wrapper): + self._wrapper = wrapper + + async def create(self, *args, **kwargs): + # call the async _acall method on the wrapper + return await self._wrapper._acall(*args, **kwargs) + + class _Chat: + def __init__(self, wrapper): + self.completions = AsyncAIChatReplacement._Completions(wrapper) + + def __init__(self, wrapper, *args, **kwargs): + self._wrapper = wrapper + self.chat = AsyncAIChatReplacement._Chat(wrapper) + + +from browser_use.llm import ChatOllama, ChatOpenRouter, ChatGoogle, ChatAnthropic, ChatGroq, ChatOpenAI + +class BrowserCompatibleChatWrapper(ChatOpenRouter): """ A wrapper for browser agent that can filter/sanitize messages before sending them to the LLM. @@ -368,31 +418,61 @@ class BrowserCompatibleChatWrapper(LiteLLMChatWrapper): def __init__(self, *args, **kwargs): turn_off_logging() - super().__init__(*args, **kwargs) + # Create the underlying LiteLLM wrapper + self._wrapper = LiteLLMChatWrapper(*args, **kwargs) # Browser-use may expect a 'model' attribute - self.model = self.model_name + self.model = self._wrapper.model_name + self.kwargs = self._wrapper.kwargs - def _call( + @property + def model_name(self) -> str: + return self._wrapper.model_name + + @property + def provider(self) -> str: + return self._wrapper.provider + + def get_client(self, *args, **kwargs): # type: ignore + return AsyncAIChatReplacement(self, *args, **kwargs) + + async def _acall( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> str: - turn_off_logging() - result = super()._call(messages, stop, run_manager, **kwargs) - return result + ): + # Apply rate limiting if configured + apply_rate_limiter_sync(self._wrapper.a0_model_conf, str(messages)) - async def _astream( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - turn_off_logging() - async for chunk in super()._astream(messages, stop, run_manager, **kwargs): - yield chunk + # Call the model + try: + model = kwargs.pop("model", None) + kwrgs = {**self._wrapper.kwargs, **kwargs} + + # hack from browser-use to fix json schema for gemini + if "response_format" in kwrgs and "json_schema" in kwrgs["response_format"] and model.startswith("gemini/"): + kwrgs["response_format"]["json_schema"] = ChatGoogle("")._fix_gemini_schema(self._wrapper.kwargs) + + resp = await acompletion( + model=self._wrapper.model_name, + messages=messages, + stop=stop, + **kwrgs, + ) + except Exception as e: + raise e + + # another hack for browser-use post process invalid jsons + try: + if "response_format" in kwrgs and "json_schema" in kwrgs["response_format"] or "json_object" in kwrgs["response_format"]: + if resp.choices[0].message.content is not None and not resp.choices[0].message.content.startswith("{"): # type: ignore + js = dirty_json.parse(resp.choices[0].message.content) # type: ignore + resp.choices[0].message.content = dirty_json.stringify(js) # type: ignore + except Exception as e: + pass + + return resp class LiteLLMEmbeddingWrapper(Embeddings): @@ -400,15 +480,21 @@ class LiteLLMEmbeddingWrapper(Embeddings): kwargs: dict = {} a0_model_conf: Optional[ModelConfig] = None - def __init__(self, model: str, provider: str, model_config: Optional[ModelConfig] = None, **kwargs: Any): + def __init__( + self, + model: str, + provider: str, + model_config: Optional[ModelConfig] = None, + **kwargs: Any, + ): self.model_name = f"{provider}/{model}" if provider != "openai" else model self.kwargs = kwargs self.a0_model_conf = model_config - + def embed_documents(self, texts: List[str]) -> List[List[float]]: # Apply rate limiting if configured apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts)) - + resp = embedding(model=self.model_name, input=texts, **self.kwargs) return [ item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore @@ -418,7 +504,7 @@ class LiteLLMEmbeddingWrapper(Embeddings): def embed_query(self, text: str) -> List[float]: # Apply rate limiting if configured apply_rate_limiter_sync(self.a0_model_conf, text) - + resp = embedding(model=self.model_name, input=[text], **self.kwargs) item = resp.data[0] # type: ignore return item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore @@ -427,7 +513,13 @@ class LiteLLMEmbeddingWrapper(Embeddings): class LocalSentenceTransformerWrapper(Embeddings): """Local wrapper for sentence-transformers models to avoid HuggingFace API calls""" - def __init__(self, provider: str, model: str, model_config: Optional[ModelConfig] = None, **kwargs: Any): + def __init__( + self, + provider: str, + model: str, + model_config: Optional[ModelConfig] = None, + **kwargs: Any, + ): # Clean common user-input mistakes model = model.strip().strip('"').strip("'") @@ -449,18 +541,18 @@ class LocalSentenceTransformerWrapper(Embeddings): self.model = SentenceTransformer(model, **st_kwargs) self.model_name = model self.a0_model_conf = model_config - + def embed_documents(self, texts: List[str]) -> List[List[float]]: # Apply rate limiting if configured apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts)) - + embeddings = self.model.encode(texts, convert_to_tensor=False) # type: ignore return embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings # type: ignore def embed_query(self, text: str) -> List[float]: # Apply rate limiting if configured apply_rate_limiter_sync(self.a0_model_conf, text) - + embedding = self.model.encode([text], convert_to_tensor=False) # type: ignore result = ( embedding[0].tolist() if hasattr(embedding[0], "tolist") else embedding[0] @@ -485,10 +577,17 @@ def _get_litellm_chat( provider_name, model_name, kwargs = _adjust_call_args( provider_name, model_name, kwargs ) - return cls(provider=provider_name, model=model_name, model_config=model_config, **kwargs) + return cls( + provider=provider_name, model=model_name, model_config=model_config, **kwargs + ) -def _get_litellm_embedding(model_name: str, provider_name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any): +def _get_litellm_embedding( + model_name: str, + provider_name: str, + model_config: Optional[ModelConfig] = None, + **kwargs: Any, +): # Check if this is a local sentence-transformers model if provider_name == "huggingface" and model_name.startswith( "sentence-transformers/" @@ -498,7 +597,10 @@ def _get_litellm_embedding(model_name: str, provider_name: str, model_config: Op provider_name, model_name, kwargs ) return LocalSentenceTransformerWrapper( - provider=provider_name, model=model_name, model_config=model_config, **kwargs + provider=provider_name, + model=model_name, + model_config=model_config, + **kwargs, ) # use api key from kwargs or env @@ -511,7 +613,9 @@ def _get_litellm_embedding(model_name: str, provider_name: str, model_config: Op provider_name, model_name, kwargs = _adjust_call_args( provider_name, model_name, kwargs ) - return LiteLLMEmbeddingWrapper(model=model_name, provider=provider_name, model_config=model_config, **kwargs) + return LiteLLMEmbeddingWrapper( + model=model_name, provider=provider_name, model_config=model_config, **kwargs + ) def _parse_chunk(chunk: Any) -> ChatChunk: @@ -599,10 +703,14 @@ def _merge_provider_defaults( return provider_name, kwargs -def get_chat_model(provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any) -> LiteLLMChatWrapper: +def get_chat_model( + provider: str, name: str, model_config: Optional[ModelConfig] = None, **kwargs: Any +) -> LiteLLMChatWrapper: orig = provider.lower() provider_name, kwargs = _merge_provider_defaults("chat", orig, kwargs) - return _get_litellm_chat(LiteLLMChatWrapper, name, provider_name, model_config, **kwargs) + return _get_litellm_chat( + LiteLLMChatWrapper, name, provider_name, model_config, **kwargs + ) def get_browser_model( diff --git a/python/tools/browser_agent.py b/python/tools/browser_agent.py index a72ddc197..d040c9aa1 100644 --- a/python/tools/browser_agent.py +++ b/python/tools/browser_agent.py @@ -48,6 +48,7 @@ class State: accept_downloads=True, downloads_dir=files.get_abs_path("tmp/downloads"), downloads_path=files.get_abs_path("tmp/downloads"), + allowed_domains=["*"], executable_path=pw_binary, keep_alive=True, minimum_wait_page_load_time=1.0, @@ -143,6 +144,7 @@ class State: ), controller=controller, enable_memory=False, # Disable memory to avoid state conflicts + llm_timeout=3000, # TODO rem sensitive_data=cast(dict[str, str | dict[str, str]] | None, secrets_dict or {}), # Pass secrets ) except Exception as e: @@ -382,7 +384,7 @@ class BrowserAgent(Tool): def get_use_agent_log(use_agent: browser_use.Agent | None): result = ["🚦 Starting task"] if use_agent: - action_results = use_agent.state.history.action_results() + action_results = use_agent.history.action_results() or [] short_log = [] for item in action_results: # final results diff --git a/requirements.txt b/requirements.txt index 1a367004b..b2b8e2fd2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ a2wsgi==1.10.8 ansio==0.0.1 -browser-use==0.2.5 +browser-use==0.5.11 docker==7.1.0 duckduckgo-search==6.1.12 faiss-cpu==1.11.0 @@ -19,11 +19,11 @@ langchain-unstructured[all-docs]==0.1.6 openai-whisper==20240930 lxml_html_clean==0.3.1 markdown==3.7 -mcp==1.9.0 +mcp==1.13.1 newspaper3k==0.2.8 paramiko==3.5.0 playwright==1.52.0 -pypdf==4.3.1 +pypdf==6.0.0 python-dotenv==1.1.0 pytz==2024.2 sentence-transformers==3.0.1 @@ -33,7 +33,7 @@ unstructured-client==0.31.0 webcolors==24.6.0 nest-asyncio==1.6.0 crontab==1.0.1 -litellm==1.76 +litellm==1.75.3 markdownify==1.1.0 pymupdf==1.25.3 pytesseract==0.3.13