diff --git a/README.md b/README.md index b0a6a5a3e..d38f9fa34 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,9 @@ # `Agent Zero` -[![Thanks to Sponsors](https://img.shields.io/badge/GitHub%20Sponsors-Thanks%20to%20Sponsors-FF69B4?style=for-the-badge&logo=githubsponsors&logoColor=white)](https://github.com/sponsors/frdel) [![Join our Skool Community](https://img.shields.io/badge/Skool-Join%20our%20Community-4A90E2?style=for-the-badge&logo=skool&logoColor=white)](https://www.skool.com/agent-zero) [![Join our Discord](https://img.shields.io/badge/Discord-Join%20our%20server-5865F2?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/B8KZKNsPpj) [![Subscribe on YouTube](https://img.shields.io/badge/YouTube-Subscribe-red?style=for-the-badge&logo=youtube&logoColor=white)](https://www.youtube.com/@AgentZeroFW) [![Connect on LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/jan-tomasek/) [![Follow on X.com](https://img.shields.io/badge/X.com-Follow-1DA1F2?style=for-the-badge&logo=x&logoColor=white)](https://x.com/JanTomasekDev) +[![Thanks to Sponsors](https://img.shields.io/badge/GitHub%20Sponsors-Thanks%20to%20Sponsors-FF69B4?style=for-the-badge&logo=githubsponsors&logoColor=white)](https://github.com/sponsors/frdel) [![Join our Skool Community](https://img.shields.io/badge/Skool-Join%20our%20Community-4A90E2?style=for-the-badge&logo=skool&logoColor=white)](https://www.skool.com/agent-zero) [![Join our Discord](https://img.shields.io/badge/Discord-Join%20our%20server-5865F2?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/B8KZKNsPpj) [![Subscribe on YouTube](https://img.shields.io/badge/YouTube-Subscribe-red?style=for-the-badge&logo=youtube&logoColor=white)](https://www.youtube.com/@AgentZeroFW) [![Connect on LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/jan-tomasek/) [![Follow on Warpcast](https://img.shields.io/badge/Warpcast-Follow-5A32F3?style=for-the-badge)](https://warpcast.com/agent-zero) + +> **Note:** Agent Zero does not use Twitter/X. Any Twitter/X accounts claiming to represent this project are fake. [Installation](./docs/installation.md) • [How to update](./docs/installation.md#how-to-update-agent-zero) • @@ -14,12 +16,15 @@ +[![Showcase](/docs/res/showcase-thumb.png)](https://youtu.be/lazLNcEYsiQ) + + + See [www.agent-zero.ai](https://agent-zero.ai) for more info [![Browser Agent](/docs/res/web_screenshot.jpg)](https://agent-zero.ai) -[![Browser Agent](/docs/res/081_vid.png)](https://youtu.be/quv145buW74) > [!NOTE] > **🎉 v0.8.1 Release**: Now featuring a browser agent capable of using Chromium for web interactions! This enables Agent Zero to browse the web, gather information, and interact with web content autonomously. diff --git a/agent.py b/agent.py index c5775d131..d6bc73973 100644 --- a/agent.py +++ b/agent.py @@ -10,8 +10,14 @@ import models from langchain_core.prompt_values import ChatPromptValue from python.helpers import extract_tools, rate_limiter, files, errors, history, tokens from python.helpers.print_style import PrintStyle -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.messages import HumanMessage, SystemMessage, AIMessage +from langchain_core.prompts import ( + ChatPromptTemplate, + MessagesPlaceholder, + HumanMessagePromptTemplate, + StringPromptTemplate, +) +from langchain_core.prompts.image import ImagePromptTemplate +from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.llms import BaseLLM from langchain_core.embeddings import Embeddings @@ -19,6 +25,7 @@ import python.helpers.log as Log from python.helpers.dirty_json import DirtyJson from python.helpers.defer import DeferredTask from typing import Callable +from python.helpers.history import OutputMessage class AgentContext: @@ -89,7 +96,7 @@ class AgentContext: else: current_agent = self.agent0 - self.task =self.run_task(current_agent.monologue) + self.task = self.run_task(current_agent.monologue) return self.task def communicate(self, msg: "UserMessage", broadcast_level: int = 1): @@ -128,9 +135,9 @@ class AgentContext: async def _process_chain(self, agent: "Agent", msg: "UserMessage|str", user=True): try: msg_template = ( - await agent.hist_add_user_message(msg) # type: ignore + agent.hist_add_user_message(msg) # type: ignore if user - else await agent.hist_add_tool_result( + else agent.hist_add_tool_result( tool_name="call_subordinate", tool_result=msg # type: ignore ) ) @@ -187,7 +194,7 @@ class AgentConfig: @dataclass class UserMessage: message: str - attachments: list[str] + attachments: list[str] = field(default_factory=list[str]) class LoopData: @@ -281,9 +288,6 @@ class Agent: printer.stream(chunk) self.log_from_stream(full, log) - # store as last context window content - self.set_data(Agent.DATA_NAME_CTX_WINDOW, prompt.format()) - agent_response = await self.call_chat_model( prompt, callback=stream_callback ) @@ -294,10 +298,10 @@ class Agent: self.loop_data.last_response == agent_response ): # if assistant_response is the same as last message in history, let him know # Append the assistant's response to the history - await self.hist_add_ai_response(agent_response) + self.hist_add_ai_response(agent_response) # Append warning message to the history warning_msg = self.read_prompt("fw.msg_repeat.md") - await self.hist_add_warning(message=warning_msg) + self.hist_add_warning(message=warning_msg) PrintStyle(font_color="orange", padding=True).print( warning_msg ) @@ -305,7 +309,7 @@ class Agent: else: # otherwise proceed with tool # Append the assistant's response to the history - await self.hist_add_ai_response(agent_response) + self.hist_add_ai_response(agent_response) # process tools requested in agent message tools_result = await self.process_tools(agent_response) if tools_result: # final response of message loop available @@ -317,7 +321,7 @@ class Agent: except RepairableException as e: # Forward repairable errors to the LLM, maybe it can fix them error_message = errors.format_error(e) - await self.hist_add_warning(error_message) + self.hist_add_warning(error_message) PrintStyle(font_color="red", padding=True).print(error_message) self.context.log.log(type="error", content=error_message) except Exception as e: @@ -356,19 +360,31 @@ class Agent: extras += history.Message(False, content=extra).output() loop_data.extras_temporary.clear() - # combine history and extras - history_combined = history.group_outputs_abab(loop_data.history_output + extras) - - # convert history to LLM format - history_langchain = history.output_langchain(history_combined) + # convert history + extras to LLM format + history_langchain: list[BaseMessage] = history.output_langchain( + loop_data.history_output + extras + ) # build chain from system prompt, message history and model + system_text = "\n\n".join(loop_data.system) prompt = ChatPromptTemplate.from_messages( [ - SystemMessage(content="\n\n".join(loop_data.system)), + SystemMessage(content=system_text), *history_langchain, ] ) + + # store as last context window content + self.set_data( + Agent.DATA_NAME_CTX_WINDOW, + { + "text": prompt.format(), + "tokens": self.history.get_tokens() + + tokens.approximate_tokens(system_text) + + tokens.approximate_tokens(history.output_text(extras)), + }, + ) + return prompt def handle_critical_exception(self, exception: Exception): @@ -435,12 +451,12 @@ class Agent: def set_data(self, field: str, value): self.data[field] = value - def hist_add_message(self, ai: bool, content: history.MessageContent): - return self.history.add_message(ai=ai, content=content) - - async def hist_add_user_message( - self, message: UserMessage, intervention: bool = False + def hist_add_message( + self, ai: bool, content: history.MessageContent, tokens: int = 0 ): + return self.history.add_message(ai=ai, content=content, tokens=tokens) + + def hist_add_user_message(self, message: UserMessage, intervention: bool = False): self.history.new_topic() # user message starts a new topic in history # load message template based on intervention @@ -470,16 +486,16 @@ class Agent: self.last_user_message = msg return msg - async def hist_add_ai_response(self, message: str): + def hist_add_ai_response(self, message: str): self.loop_data.last_response = message content = self.parse_prompt("fw.ai_response.md", message=message) return self.hist_add_message(True, content=content) - async def hist_add_warning(self, message: history.MessageContent): + def hist_add_warning(self, message: history.MessageContent): content = self.parse_prompt("fw.warning.md", message=message) return self.hist_add_message(False, content=content) - async def hist_add_tool_result(self, tool_name: str, tool_result: str): + def hist_add_tool_result(self, tool_name: str, tool_result: str): content = self.parse_prompt( "fw.tool_result.md", tool_name=tool_name, tool_result=tool_result ) @@ -613,9 +629,9 @@ class Agent: msg = self.intervention self.intervention = None # reset the intervention message if progress.strip(): - await self.hist_add_ai_response(progress) + self.hist_add_ai_response(progress) # append the intervention message - await self.hist_add_user_message(msg, intervention=True) + self.hist_add_user_message(msg, intervention=True) raise InterventionException(msg) async def wait_if_paused(self): @@ -642,7 +658,7 @@ class Agent: return response.message else: msg = self.read_prompt("fw.msg_misformat.md") - await self.hist_add_warning(msg) + self.hist_add_warning(msg) PrintStyle(font_color="red", padding=True).print(msg) self.context.log.log( type="error", content=f"{self.agent_name}: Message misformat" diff --git a/docker/run/DockerfileKali b/docker/run/DockerfileKali new file mode 100644 index 000000000..f5808cdce --- /dev/null +++ b/docker/run/DockerfileKali @@ -0,0 +1,33 @@ +# Use the latest slim version of Kali Linux +FROM kalilinux/kali-rolling + +# Check if the argument is provided, else throw an error +ARG BRANCH +RUN if [ -z "$BRANCH" ]; then echo "ERROR: BRANCH is not set!" >&2; exit 1; fi +ENV BRANCH=$BRANCH + +# Copy contents of the project to /a0 +COPY ./fs/ / + +# pre installation steps +RUN bash /ins/pre_install.sh $BRANCH +RUN bash /ins/pre_install_kali.sh $BRANCH + +# install additional software +RUN bash /ins/install_additional.sh $BRANCH + +# install A0 +RUN bash /ins/install_A0.sh $BRANCH + +# cleanup repo and install A0 without caching, this speeds up builds +ARG CACHE_DATE=none +RUN echo "cache buster $CACHE_DATE" && bash /ins/install_A02.sh $BRANCH + +# post installation steps +RUN bash /ins/post_install.sh $BRANCH + +# Expose ports +EXPOSE 22 80 + +# initialize runtime +CMD ["/bin/bash", "-c", "/bin/bash /exe/initialize.sh $BRANCH"] \ No newline at end of file diff --git a/docker/run/build.txt b/docker/run/build.txt index 5dfde71e7..94b6ed213 100644 --- a/docker/run/build.txt +++ b/docker/run/build.txt @@ -4,6 +4,9 @@ docker build -t agent-zero-run:local --build-arg BRANCH=development --build-arg # local image without cache docker build -t agent-zero-run:local --build-arg BRANCH=development --no-cache . +# local image from Kali +docker build -f ./DockerfileKali -t agent-zero-run:hacking --build-arg BRANCH=main --build-arg CACHE_DATE=$(date +%Y-%m-%d:%H:%M:%S) . + # dockerhub push: docker login diff --git a/docker/run/fs/ins/install_additional.sh b/docker/run/fs/ins/install_additional.sh index e857ef471..477ef6e85 100644 --- a/docker/run/fs/ins/install_additional.sh +++ b/docker/run/fs/ins/install_additional.sh @@ -1,4 +1,7 @@ #!/bin/bash +# install playwright +bash /ins/install_playwright.sh "$@" + # searxng bash /ins/install_searxng.sh "$@" \ No newline at end of file diff --git a/docker/run/fs/ins/install_playwright.sh b/docker/run/fs/ins/install_playwright.sh index cd5755424..eca6ab98a 100644 --- a/docker/run/fs/ins/install_playwright.sh +++ b/docker/run/fs/ins/install_playwright.sh @@ -7,4 +7,12 @@ pip install playwright # install chromium with dependencies -playwright install --with-deps chromium-headless-shell +# for kali-based +if [ "$@" = "hacking" ]; then + apt-get install -y fonts-unifont libnss3 libnspr4 + playwright install chromium-headless-shell +else + # for debian based + playwright install --with-deps chromium-headless-shell +fi + diff --git a/docker/run/fs/ins/install_searxng.sh b/docker/run/fs/ins/install_searxng.sh index fdc7775ad..ee9ad3d1c 100644 --- a/docker/run/fs/ins/install_searxng.sh +++ b/docker/run/fs/ins/install_searxng.sh @@ -2,7 +2,7 @@ # Install necessary packages apt-get install -y \ - python3-dev python3-babel python3-venv \ + python3.12-dev python3-babel python3.12-venv \ uwsgi uwsgi-plugin-python3 \ git build-essential libxslt-dev zlib1g-dev libffi-dev libssl-dev diff --git a/docker/run/fs/ins/post_install.sh b/docker/run/fs/ins/post_install.sh index c1a7e1d18..410ff19a7 100644 --- a/docker/run/fs/ins/post_install.sh +++ b/docker/run/fs/ins/post_install.sh @@ -1,8 +1,5 @@ #!/bin/bash -# install playwright -bash /ins/install_playwright.sh "$@" - # Cleanup package list rm -rf /var/lib/apt/lists/* apt-get clean \ No newline at end of file diff --git a/docker/run/fs/ins/pre_install.sh b/docker/run/fs/ins/pre_install.sh index 74a78844f..48d6ed5ae 100644 --- a/docker/run/fs/ins/pre_install.sh +++ b/docker/run/fs/ins/pre_install.sh @@ -2,9 +2,8 @@ # Update and install necessary packages apt-get update && apt-get install -y \ - python3 \ - python3-pip \ - python3-venv \ + python3.12 \ + python3.12-venv \ nodejs \ npm \ openssh-server \ @@ -14,5 +13,18 @@ apt-get update && apt-get install -y \ git \ ffmpeg -# prepare SSH daemon -bash /ins/setup_ssh.sh "$@" \ No newline at end of file +# Configure system alternatives so that /usr/bin/python3 points to Python 3.12 +sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 +sudo update-alternatives --set python3 /usr/bin/python3.12 + +# Update pip3 symlink: if pip3.12 exists, point pip3 to it; +# otherwise, install pip using Python 3.12's ensurepip. +if [ -f /usr/bin/pip3.12 ]; then + sudo ln -sf /usr/bin/pip3.12 /usr/bin/pip3 +else + python3 -m ensurepip --upgrade + python3 -m pip install --upgrade pip +fi + +# Prepare SSH daemon +bash /ins/setup_ssh.sh "$@" diff --git a/docker/run/fs/ins/pre_install_kali.sh b/docker/run/fs/ins/pre_install_kali.sh new file mode 100644 index 000000000..dffcbdff3 --- /dev/null +++ b/docker/run/fs/ins/pre_install_kali.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +# ubuntu based dependencies for playwright +# moved to install_playwright.sh +# apt-get update && apt-get install -y fonts-unifont fonts-ubuntu + diff --git a/docker/run/fs/ins/setup_ssh.sh b/docker/run/fs/ins/setup_ssh.sh index 958323623..331e905fc 100644 --- a/docker/run/fs/ins/setup_ssh.sh +++ b/docker/run/fs/ins/setup_ssh.sh @@ -1,6 +1,6 @@ #!/bin/bash # Set up SSH -mkdir /var/run/sshd && \ +mkdir -p /var/run/sshd && \ # echo 'root:toor' | chpasswd && \ sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config \ No newline at end of file diff --git a/docs/res/showcase-thumb.png b/docs/res/showcase-thumb.png new file mode 100644 index 000000000..c5e6d9d15 Binary files /dev/null and b/docs/res/showcase-thumb.png differ diff --git a/initialize.py b/initialize.py index ead07d2f2..0003b0223 100644 --- a/initialize.py +++ b/initialize.py @@ -13,6 +13,7 @@ def initialize(): provider=models.ModelProvider[current_settings["chat_model_provider"]], name=current_settings["chat_model_name"], ctx_length=current_settings["chat_model_ctx_length"], + vision=current_settings["chat_model_vision"], limit_requests=current_settings["chat_model_rl_requests"], limit_input=current_settings["chat_model_rl_input"], limit_output=current_settings["chat_model_rl_output"], diff --git a/models.py b/models.py index 4b517f1e2..09d0705a6 100644 --- a/models.py +++ b/models.py @@ -20,7 +20,7 @@ from langchain_huggingface import ( HuggingFaceEndpoint, ) from langchain_google_genai import ( - GoogleGenerativeAI, + ChatGoogleGenerativeAI, HarmBlockThreshold, HarmCategory, embeddings as google_embeddings, @@ -267,7 +267,7 @@ def get_google_chat( ): if not api_key: api_key = get_api_key("google") - return GoogleGenerativeAI(model=model_name, google_api_key=api_key, safety_settings={HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE}, **kwargs) # type: ignore + return ChatGoogleGenerativeAI(model=model_name, google_api_key=api_key, safety_settings={HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE}, **kwargs) # type: ignore def get_google_embedding( @@ -277,7 +277,7 @@ def get_google_embedding( ): if not api_key: api_key = get_api_key("google") - return google_embeddings.GoogleGenerativeAIEmbeddings(model=model_name, api_key=api_key, **kwargs) # type: ignore + return google_embeddings.GoogleGenerativeAIEmbeddings(model=model_name, google_api_key=api_key, **kwargs) # type: ignore # Mistral models diff --git a/prompts/default/agent.system.instruments.md b/prompts/default/agent.system.instruments.md index 677907270..ece095724 100644 --- a/prompts/default/agent.system.instruments.md +++ b/prompts/default/agent.system.instruments.md @@ -1,4 +1,5 @@ # Instruments -- following are instruments at disposal: +- following are instruments at disposal +- do not overly rely on them they might not be relevant -{{instruments}} \ No newline at end of file +{{instruments}} diff --git a/prompts/default/agent.system.memories.md b/prompts/default/agent.system.memories.md index f7ab13fa5..ae32bf361 100644 --- a/prompts/default/agent.system.memories.md +++ b/prompts/default/agent.system.memories.md @@ -1,4 +1,5 @@ # Memories on the topic -- following are your memories about current topic: +- following are memories about current topic +- do not overly rely on them they might not be relevant {{memories}} \ No newline at end of file diff --git a/prompts/default/agent.system.solutions.md b/prompts/default/agent.system.solutions.md index 0d8743d15..926f35b51 100644 --- a/prompts/default/agent.system.solutions.md +++ b/prompts/default/agent.system.solutions.md @@ -1,4 +1,5 @@ # Solutions from the past -- following are your memories about successful solutions of related problems: +- following are memories about successful solutions of related problems +- do not overly rely on them they might not be relevant {{solutions}} \ No newline at end of file diff --git a/prompts/default/agent.system.tool.code_exe.md b/prompts/default/agent.system.tool.code_exe.md index 9be20508f..cd7f2a8e1 100644 --- a/prompts/default/agent.system.tool.code_exe.md +++ b/prompts/default/agent.system.tool.code_exe.md @@ -3,10 +3,9 @@ execute terminal commands python nodejs code for computation or software tasks place code in "code" arg; escape carefully and indent properly select "runtime" arg: "terminal" "python" "nodejs" "output" "reset" -for dialogues (Y/N etc.), use "terminal" runtime next step, send answer +select "session" number, 0 default, others for multitasking if code runs long, use "output" to wait, "reset" to kill process use "pip" "npm" "apt-get" in "terminal" to install packages -important: never use implicit print/output—it doesn't work! to output, use print() or console.log() if tool outputs error, adjust code before retrying; knowledge_tool can help important: check code for placeholders or demo data; replace with real variables; don't reuse snippets @@ -26,6 +25,7 @@ usage: "tool_name": "code_execution_tool", "tool_args": { "runtime": "python", + "session": 0, "code": "import os\nprint(os.getcwd())", } } @@ -41,6 +41,7 @@ usage: "tool_name": "code_execution_tool", "tool_args": { "runtime": "terminal", + "session": 0, "code": "apt-get install zip", } } @@ -55,6 +56,7 @@ usage: "tool_name": "code_execution_tool", "tool_args": { "runtime": "output", + "session": 0, } } ~~~ @@ -68,6 +70,7 @@ usage: "tool_name": "code_execution_tool", "tool_args": { "runtime": "reset", + "session": 0, } } ~~~ \ No newline at end of file diff --git a/prompts/default/agent.system.tools_vision.md b/prompts/default/agent.system.tools_vision.md new file mode 100644 index 000000000..dd04e8b26 --- /dev/null +++ b/prompts/default/agent.system.tools_vision.md @@ -0,0 +1,18 @@ +## "Multimodal (Vision) Agent Tools" available: + +### vision_load: +load image data to LLM +use paths arg for attachments + +**Example usage**: +```json +{ + "thoughts": [ + "I need to see the image...", + ], + "tool_name": "vision_load", + "tool_args": { + "paths": ["/path/to/image.png"], + } +} +``` \ No newline at end of file diff --git a/prompts/default/fw.tool_result.md b/prompts/default/fw.tool_result.md index 1943cc022..ef41f23bf 100644 --- a/prompts/default/fw.tool_result.md +++ b/prompts/default/fw.tool_result.md @@ -1,6 +1,6 @@ -~~~json +```json { "tool_name": {{tool_name}}, "tool_result": {{tool_result}} } -~~~ \ No newline at end of file +``` diff --git a/python/api/ctx_window_get.py b/python/api/ctx_window_get.py index a8e20db94..16a4438b7 100644 --- a/python/api/ctx_window_get.py +++ b/python/api/ctx_window_get.py @@ -9,6 +9,10 @@ class GetCtxWindow(ApiHandler): context = self.get_context(ctxid) agent = context.streaming_agent or context.agent0 window = agent.get_data(agent.DATA_NAME_CTX_WINDOW) - size = tokens.approximate_tokens(window) + if not window or not isinstance(window, dict): + return {"content": "", "tokens": 0} - return {"content": window, "tokens": size} + text = window["text"] + tokens = window["tokens"] + + return {"content": text, "tokens": tokens} diff --git a/python/api/history_get.py b/python/api/history_get.py index 579a32890..9b1359c49 100644 --- a/python/api/history_get.py +++ b/python/api/history_get.py @@ -8,8 +8,8 @@ class GetHistory(ApiHandler): ctxid = input.get("context", []) context = self.get_context(ctxid) agent = context.streaming_agent or context.agent0 - history = agent.history.output() - size = tokens.approximate_tokens(agent.history.output_text()) + history = agent.history.output_text() + size = agent.history.get_tokens() return { "history": history, diff --git a/python/extensions/system_prompt/_10_system_prompt.py b/python/extensions/system_prompt/_10_system_prompt.py index e7544b2ec..68f7d6dd8 100644 --- a/python/extensions/system_prompt/_10_system_prompt.py +++ b/python/extensions/system_prompt/_10_system_prompt.py @@ -12,11 +12,17 @@ class SystemPrompt(Extension): system_prompt.append(main) system_prompt.append(tools) + def get_main_prompt(agent: Agent): return get_prompt("agent.system.main.md", agent) + def get_tools_prompt(agent: Agent): - return get_prompt("agent.system.tools.md", agent) + prompt = get_prompt("agent.system.tools.md", agent) + if agent.config.chat_model.vision: + prompt += '\n' + get_prompt("agent.system.tools_vision.md", agent) + return prompt + def get_prompt(file: str, agent: Agent): # variables for system prompts @@ -26,4 +32,4 @@ def get_prompt(file: str, agent: Agent): "date_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "agent_name": agent.agent_name, } - return agent.read_prompt(file, **vars) \ No newline at end of file + return agent.read_prompt(file, **vars) diff --git a/python/helpers/dirty_json.py b/python/helpers/dirty_json.py index 28054ff10..b65c120db 100644 --- a/python/helpers/dirty_json.py +++ b/python/helpers/dirty_json.py @@ -103,9 +103,14 @@ class DirtyJson: return None def _match(self, text: str) -> bool: - cnt = len(text) - if self._peek(cnt).lower() == text.lower(): - self._advance(cnt) + # first char should match current char + if not self.current_char or self.current_char.lower() != text[0].lower(): + return False + + # peek remaining chars + remaining = len(text) - 1 + if self._peek(remaining).lower() == text[1:].lower(): + self._advance(len(text)) return True return False @@ -187,6 +192,13 @@ class DirtyJson: self._skip_whitespace() if self.current_char == ',': self._advance() + # handle trailing commas, end of array + self._skip_whitespace() + if self.current_char is None or self.current_char == ']': + if self.current_char == ']': + self._advance() + self.stack.pop() + return elif self.current_char != ']': self.stack.pop() return @@ -245,30 +257,6 @@ class DirtyJson: except ValueError: return float(number_str) - def _parse_true(self): - self._advance() - for char in 'rue': - if self.current_char != char: - return None - self._advance() - return True - - def _parse_false(self): - self._advance() - for char in 'alse': - if self.current_char != char: - return None - self._advance() - return False - - def _parse_null(self): - self._advance() - for char in 'ull': - if self.current_char != char: - return None - self._advance() - return None - def _parse_unquoted_string(self): result = "" while self.current_char is not None and self.current_char not in [':', ',', '}', ']']: diff --git a/python/helpers/files.py b/python/helpers/files.py index 91bdd0566..bc0e97eb9 100644 --- a/python/helpers/files.py +++ b/python/helpers/files.py @@ -1,6 +1,7 @@ from fnmatch import fnmatch import json import os, re +import base64 import re import shutil @@ -45,6 +46,32 @@ def read_file(_relative_path, _backup_dirs=None, _encoding="utf-8", **kwargs): return content +def read_file_bin(_relative_path, _backup_dirs=None): + # init backup dirs + if _backup_dirs is None: + _backup_dirs = [] + + # get absolute path + absolute_path = find_file_in_dirs(_relative_path, _backup_dirs) + + # read binary content + with open(absolute_path, "rb") as f: + return f.read() + + +def read_file_base64(_relative_path, _backup_dirs=None): + # init backup dirs + if _backup_dirs is None: + _backup_dirs = [] + + # get absolute path + absolute_path = find_file_in_dirs(_relative_path, _backup_dirs) + + # read binary content and encode to base64 + with open(absolute_path, "rb") as f: + return base64.b64encode(f.read()).decode('utf-8') + + def replace_placeholders_text(_content: str, **kwargs): # Replace placeholders with values from kwargs for key, value in kwargs.items(): @@ -175,6 +202,15 @@ def write_file_bin(relative_path: str, content: bytes): f.write(content) +def write_file_base64(relative_path: str, content: str): + # decode base64 string to bytes + data = base64.b64decode(content) + abs_path = get_abs_path(relative_path) + os.makedirs(os.path.dirname(abs_path), exist_ok=True) + with open(abs_path, "wb") as f: + f.write(data) + + def delete_file(relative_path: str): abs_path = get_abs_path(relative_path) if os.path.exists(abs_path): diff --git a/python/helpers/history.py b/python/helpers/history.py index 766e01849..c09ef61a5 100644 --- a/python/helpers/history.py +++ b/python/helpers/history.py @@ -1,12 +1,13 @@ from abc import abstractmethod import asyncio from collections import OrderedDict +from collections.abc import Mapping import json import math -from typing import Coroutine, Literal, TypedDict, cast +from typing import Coroutine, Literal, TypedDict, cast, Union, Dict, List, Any, override from python.helpers import messages, tokens, settings, call_llm from enum import Enum -from langchain_core.messages import HumanMessage, SystemMessage, AIMessage +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage BULK_MERGE_COUNT = 3 TOPICS_KEEP_COUNT = 3 @@ -15,14 +16,22 @@ HISTORY_TOPIC_RATIO = 0.3 HISTORY_BULK_RATIO = 0.2 TOPIC_COMPRESS_RATIO = 0.65 LARGE_MESSAGE_TO_TOPIC_RATIO = 0.25 +RAW_MESSAGE_OUTPUT_TEXT_TRIM = 100 -MessageContent = ( - list["MessageContent"] - | OrderedDict[str, "MessageContent"] - | list[OrderedDict[str, "MessageContent"]] - | str - | list[str] -) + +class RawMessage(TypedDict): + raw_content: "MessageContent" + preview: str | None + + +MessageContent = Union[ + List["MessageContent"], + Dict[str, "MessageContent"], + List[Dict[str, "MessageContent"]], + str, + List[str], + RawMessage, +] class OutputMessage(TypedDict): @@ -34,9 +43,9 @@ class Record: def __init__(self): pass + @abstractmethod def get_tokens(self) -> int: - out = self.output_text() - return tokens.approximate_tokens(out) + pass @abstractmethod async def compress(self) -> bool: @@ -67,10 +76,25 @@ class Record: class Message(Record): - def __init__(self, ai: bool, content: MessageContent): + def __init__(self, ai: bool, content: MessageContent, tokens: int = 0): self.ai = ai self.content = content - self.summary: MessageContent = "" + self.summary: str = "" + self.tokens: int = tokens or self.calculate_tokens() + + @override + def get_tokens(self) -> int: + if not self.tokens: + self.tokens = self.calculate_tokens() + return self.tokens + + def calculate_tokens(self): + text = self.output_text() + return tokens.approximate_tokens(text) + + def set_summary(self, summary: str): + self.summary = summary + self.tokens = self.calculate_tokens() async def compress(self): return False @@ -90,12 +114,15 @@ class Message(Record): "ai": self.ai, "content": self.content, "summary": self.summary, + "tokens": self.tokens, } @staticmethod def from_dict(data: dict, history: "History"): - msg = Message(ai=data["ai"], content=data.get("content", "Content lost")) + content = data.get("content", "Content lost") + msg = Message(ai=data["ai"], content=content) msg.summary = data.get("summary", "") + msg.tokens = data.get("tokens", 0) return msg @@ -105,8 +132,16 @@ class Topic(Record): self.summary: str = "" self.messages: list[Message] = [] - def add_message(self, ai: bool, content: MessageContent): - msg = Message(ai=ai, content=content) + def get_tokens(self): + if self.summary: + return tokens.approximate_tokens(self.summary) + else: + return sum(msg.get_tokens() for msg in self.messages) + + def add_message( + self, ai: bool, content: MessageContent, tokens: int = 0 + ) -> Message: + msg = Message(ai=ai, content=content, tokens=tokens) self.messages.append(msg) return msg @@ -115,7 +150,7 @@ class Topic(Record): return [OutputMessage(ai=False, content=self.summary)] else: msgs = [m for r in self.messages for m in r.output()] - return group_outputs_abab(msgs) + return msgs async def summarize(self): self.summary = await self.summarize_messages(self.messages) @@ -126,27 +161,36 @@ class Topic(Record): msg_max_size = ( set["chat_model_ctx_length"] * set["chat_model_ctx_history"] - * HISTORY_TOPIC_RATIO + * CURRENT_TOPIC_RATIO * LARGE_MESSAGE_TO_TOPIC_RATIO ) large_msgs = [] for m in (m for m in self.messages if not m.summary): + # TODO refactor this out = m.output() text = output_text(out) - tok = tokens.approximate_tokens(text) + tok = m.get_tokens() leng = len(text) if tok > msg_max_size: large_msgs.append((m, tok, leng, out)) large_msgs.sort(key=lambda x: x[1], reverse=True) for msg, tok, leng, out in large_msgs: trim_to_chars = leng * (msg_max_size / tok) - trunc = messages.truncate_dict_by_ratio( - self.history.agent, - out[0]["content"], - trim_to_chars * 1.15, - trim_to_chars * 0.85, - ) - msg.summary = trunc + # raw messages will be replaced as a whole, they would become invalid when truncated + if _is_raw_message(out[0]["content"]): + msg.set_summary( + "Message content replaced to save space in context window" + ) + + # regular messages will be truncated + else: + trunc = messages.truncate_dict_by_ratio( + self.history.agent, + out[0]["content"], + trim_to_chars * 1.15, + trim_to_chars * 0.85, + ) + msg.set_summary(_json_dumps(trunc)) return True return False @@ -172,6 +216,7 @@ class Topic(Record): return False async def summarize_messages(self, messages: list[Message]): + # FIXME: vision bytes are sent to utility LLM, send summary instead msg_txt = [m.output_text() for m in messages] summary = await self.history.agent.call_utility_model( system=self.history.agent.read_prompt("fw.topic_summary.sys.md"), @@ -191,9 +236,9 @@ class Topic(Record): @staticmethod def from_dict(data: dict, history: "History"): topic = Topic(history=history) - topic.summary = data["summary"] + topic.summary = data.get("summary", "") topic.messages = [ - Message.from_dict(m, history=history) for m in data["messages"] + Message.from_dict(m, history=history) for m in data.get("messages", []) ] return topic @@ -211,7 +256,7 @@ class Bulk(Record): return [OutputMessage(ai=False, content=self.summary)] else: msgs = [m for r in self.records for m in r.output()] - return group_outputs_abab(msgs) + return msgs async def compress(self): return False @@ -250,8 +295,15 @@ class History(Record): self.current = Topic(history=self) self.agent: Agent = agent + def get_tokens(self) -> int: + return ( + self.get_bulks_tokens() + + self.get_topics_tokens() + + self.get_current_topic_tokens() + ) + def is_over_limit(self): - limit = get_ctx_size_for_history() + limit = _get_ctx_size_for_history() total = self.get_tokens() return total > limit @@ -264,15 +316,10 @@ class History(Record): def get_current_topic_tokens(self) -> int: return self.current.get_tokens() - def get_tokens(self) -> int: - return ( - self.get_bulks_tokens() - + self.get_topics_tokens() - + self.get_current_topic_tokens() - ) - - def add_message(self, ai: bool, content: MessageContent): - return self.current.add_message(ai, content=content) + def add_message( + self, ai: bool, content: MessageContent, tokens: int = 0 + ) -> Message: + return self.current.add_message(ai, content=content, tokens=tokens) def new_topic(self): if self.current.messages: @@ -284,7 +331,6 @@ class History(Record): result += [m for b in self.bulks for m in b.output()] result += [m for t in self.topics for m in t.output()] result += self.current.output() - result = group_outputs_abab(result) return result @staticmethod @@ -304,7 +350,7 @@ class History(Record): def serialize(self): data = self.to_dict() - return json.dumps(data) + return _json_dumps(data) async def compress(self): compressed = False @@ -314,7 +360,7 @@ class History(Record): self.get_topics_tokens(), self.get_bulks_tokens(), ) - total = get_ctx_size_for_history() + total = _get_ctx_size_for_history() ratios = [ (curr, CURRENT_TOPIC_RATIO, "current_topic"), (hist, HISTORY_TOPIC_RATIO, "history_topic"), @@ -389,25 +435,46 @@ class History(Record): def deserialize_history(json_data: str, agent) -> History: history = History(agent=agent) if json_data: - data = json.loads(json_data) + data = _json_loads(json_data) history = History.from_dict(data, history=history) return history -def get_ctx_size_for_history() -> int: +def _get_ctx_size_for_history() -> int: set = settings.get_settings() return int(set["chat_model_ctx_length"] * set["chat_model_ctx_history"]) -def serialize_output(output: OutputMessage, ai_label="ai", human_label="human"): - return f'{ai_label if output["ai"] else human_label}: {serialize_content(output["content"])}' +def _stringify_output(output: OutputMessage, ai_label="ai", human_label="human"): + return f'{ai_label if output["ai"] else human_label}: {_stringify_content(output["content"])}' -def serialize_content(content: MessageContent) -> str: +def _stringify_content(content: MessageContent) -> str: + # already a string if isinstance(content, str): return content + + # raw messages return preview or trimmed json + if _is_raw_message(content): + preview: str = content.get("preview", "") # type: ignore + if preview: + return preview + text = _json_dumps(content) + if len(text) > RAW_MESSAGE_OUTPUT_TEXT_TRIM: + return text[:RAW_MESSAGE_OUTPUT_TEXT_TRIM] + "... TRIMMED" + return text + + # regular messages of non-string are dumped as json + return _json_dumps(content) + + +def _output_content_langchain(content: MessageContent): + if isinstance(content, str): + return content + if _is_raw_message(content): + return content["raw_content"] # type: ignore try: - return json.dumps(content) + return _json_dumps(content) except Exception as e: raise e @@ -418,51 +485,73 @@ def group_outputs_abab(outputs: list[OutputMessage]) -> list[OutputMessage]: if result and result[-1]["ai"] == out["ai"]: result[-1] = OutputMessage( ai=result[-1]["ai"], - content=merge_outputs(result[-1]["content"], out["content"]), + content=_merge_outputs(result[-1]["content"], out["content"]), ) else: result.append(out) return result +def group_messages_abab(messages: list[BaseMessage]) -> list[BaseMessage]: + result = [] + for msg in messages: + if result and isinstance(result[-1], type(msg)): + # create new instance of the same type with merged content + result[-1] = type(result[-1])(content=_merge_outputs(result[-1].content, msg.content)) # type: ignore + else: + result.append(msg) + return result + + def output_langchain(messages: list[OutputMessage]): result = [] for m in messages: if m["ai"]: - result.append(AIMessage(content=serialize_content(m["content"]))) + # result.append(AIMessage(content=serialize_content(m["content"]))) + result.append(AIMessage(_output_content_langchain(content=m["content"]))) # type: ignore else: - result.append(HumanMessage(content=serialize_content(m["content"]))) + # result.append(HumanMessage(content=serialize_content(m["content"]))) + result.append(HumanMessage(_output_content_langchain(content=m["content"]))) # type: ignore + # ensure message type alternation + result = group_messages_abab(result) return result def output_text(messages: list[OutputMessage], ai_label="ai", human_label="human"): - return "\n".join(serialize_output(o, ai_label, human_label) for o in messages) + return "\n".join(_stringify_output(o, ai_label, human_label) for o in messages) -def merge_outputs(a: MessageContent, b: MessageContent) -> MessageContent: +def _merge_outputs(a: MessageContent, b: MessageContent) -> MessageContent: + if isinstance(a, str) and isinstance(b, str): + return a + b + if not isinstance(a, list): a = [a] if not isinstance(b, list): b = [b] - return a + b # type: ignore - # return merge_properties(a, b) + + return cast(MessageContent, a + b) -def merge_properties(a: MessageContent, b: MessageContent) -> MessageContent: - if isinstance(a, list): - if isinstance(b, list): - return a + b # type: ignore +def _merge_properties( + a: Dict[str, MessageContent], b: Dict[str, MessageContent] +) -> Dict[str, MessageContent]: + result = a.copy() + for k, v in b.items(): + if k in result: + result[k] = _merge_outputs(result[k], v) else: - return a + [b] - elif isinstance(b, list): - return [a] + b # type: ignore - elif isinstance(a, dict) and isinstance(b, dict): - for key, value in b.items(): - if key in a: - a[key] = merge_properties(a[key], value) - else: - a[key] = value - return a - elif isinstance(a, str) and isinstance(b, str): - return a + b - raise ValueError(f"Cannot merge {a} and {b}") + result[k] = v + return result + + +def _is_raw_message(obj: object) -> bool: + return isinstance(obj, Mapping) and "raw_content" in obj + + +def _json_dumps(obj): + return json.dumps(obj, ensure_ascii=False) + + +def _json_loads(obj): + return json.loads(obj) diff --git a/python/helpers/images.py b/python/helpers/images.py new file mode 100644 index 000000000..9377f1e9c --- /dev/null +++ b/python/helpers/images.py @@ -0,0 +1,35 @@ +from PIL import Image +import io +import math + + +def compress_image(image_data: bytes, *, max_pixels: int = 256_000, quality: int = 50) -> bytes: + """Compress an image by scaling it down and converting to JPEG with quality settings. + + Args: + image_data: Raw image bytes + max_pixels: Maximum number of pixels in the output image (width * height) + quality: JPEG quality setting (1-100) + + Returns: + Compressed image as bytes + """ + # load image from bytes + img = Image.open(io.BytesIO(image_data)) + + # calculate scaling factor to get to max_pixels + current_pixels = img.width * img.height + if current_pixels > max_pixels: + scale = math.sqrt(max_pixels / current_pixels) + new_width = int(img.width * scale) + new_height = int(img.height * scale) + img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # convert to RGB if needed (for JPEG) + if img.mode in ('RGBA', 'P'): + img = img.convert('RGB') + + # save as JPEG with compression + output = io.BytesIO() + img.save(output, format='JPEG', quality=quality, optimize=True) + return output.getvalue() diff --git a/python/helpers/runtime.py b/python/helpers/runtime.py index 6709c6fdf..189d45288 100644 --- a/python/helpers/runtime.py +++ b/python/helpers/runtime.py @@ -2,6 +2,9 @@ import argparse import inspect from typing import TypeVar, Callable, Awaitable, Union, overload, cast from python.helpers import dotenv, rfc, settings +import asyncio +import threading +import queue T = TypeVar('T') R = TypeVar('R') @@ -102,3 +105,22 @@ def _get_rfc_url() -> str: url = url+":"+str(set["rfc_port_http"]) url += "/rfc" return url + + +def call_development_function_sync(func: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args, **kwargs) -> T: + # run async function in sync manner + result_queue = queue.Queue() + + def run_in_thread(): + result = asyncio.run(call_development_function(func, *args, **kwargs)) + result_queue.put(result) + + thread = threading.Thread(target=run_in_thread) + thread.start() + thread.join(timeout=30) # wait for thread with timeout + + if thread.is_alive(): + raise TimeoutError("Function call timed out after 30 seconds") + + result = result_queue.get_nowait() + return cast(T, result) diff --git a/python/helpers/settings.py b/python/helpers/settings.py index 228330753..fc7a41ce2 100644 --- a/python/helpers/settings.py +++ b/python/helpers/settings.py @@ -16,6 +16,7 @@ class Settings(TypedDict): chat_model_kwargs: dict[str, str] chat_model_ctx_length: int chat_model_ctx_history: float + chat_model_vision: bool chat_model_rl_requests: int chat_model_rl_input: int chat_model_rl_output: int @@ -149,6 +150,17 @@ def convert_out(settings: Settings) -> SettingsOutput: } ) + + chat_model_fields.append( + { + "id": "chat_model_vision", + "title": "Supports Vision", + "description": "Models capable of Vision can for example natively see the content of image attachments.", + "type": "switch", + "value": settings["chat_model_vision"], + } + ) + chat_model_fields.append( { "id": "chat_model_rl_requests", @@ -777,6 +789,7 @@ def get_default_settings() -> Settings: chat_model_kwargs={ "temperature": "0" }, chat_model_ctx_length=120000, chat_model_ctx_history=0.7, + chat_model_vision=False, chat_model_rl_requests=0, chat_model_rl_input=0, chat_model_rl_output=0, diff --git a/python/helpers/tool.py b/python/helpers/tool.py index a50ec8295..fef3058ac 100644 --- a/python/helpers/tool.py +++ b/python/helpers/tool.py @@ -1,14 +1,15 @@ from abc import abstractmethod from dataclasses import dataclass + from agent import Agent from python.helpers.print_style import PrintStyle -from python.helpers import messages + @dataclass class Response: message:str - break_loop:bool - + break_loop: bool + class Tool: def __init__(self, agent: Agent, name: str, args: dict[str,str], message: str, **kwargs) -> None: @@ -29,10 +30,10 @@ class Tool: PrintStyle(font_color="#85C1E9", bold=True).stream(self.nice_key(key)+": ") PrintStyle(font_color="#85C1E9", padding=isinstance(value,str) and "\n" in value).stream(value) PrintStyle().print() - + async def after_execution(self, response: Response, **kwargs): text = response.message.strip() - await self.agent.hist_add_tool_result(self.name, text) + self.agent.hist_add_tool_result(self.name, text) PrintStyle(font_color="#1B4F72", background_color="white", padding=True, bold=True).print(f"{self.agent.agent_name}: Response from tool '{self.name}'") PrintStyle(font_color="#85C1E9").print(response.message) self.log.update(content=response.message) @@ -44,4 +45,4 @@ class Tool: words = key.split('_') words = [words[0].capitalize()] + [word.lower() for word in words[1:]] result = ' '.join(words) - return result \ No newline at end of file + return result diff --git a/python/helpers/whisper.py b/python/helpers/whisper.py index f92e4c8b5..0f644487d 100644 --- a/python/helpers/whisper.py +++ b/python/helpers/whisper.py @@ -30,7 +30,7 @@ async def _preload(model_name:str): is_updating_model = True if not _model or _model_name != model_name: PrintStyle.standard(f"Loading Whisper model: {model_name}") - _model = whisper.load_model(model_name) + _model = whisper.load_model(name=model_name) # type: ignore _model_name = model_name finally: is_updating_model = False diff --git a/python/tools/call_subordinate.py b/python/tools/call_subordinate.py index 73eb9b574..3a50ffdc4 100644 --- a/python/tools/call_subordinate.py +++ b/python/tools/call_subordinate.py @@ -18,7 +18,7 @@ class Delegation(Tool): # add user message to subordinate agent subordinate: Agent = self.agent.get_data(Agent.DATA_NAME_SUBORDINATE) - await subordinate.hist_add_user_message(UserMessage(message=message, attachments=[])) + subordinate.hist_add_user_message(UserMessage(message=message, attachments=[])) # run subordinate monologue result = await subordinate.monologue() # result diff --git a/python/tools/code_execution_tool.py b/python/tools/code_execution_tool.py index 6a0686875..593e734fa 100644 --- a/python/tools/code_execution_tool.py +++ b/python/tools/code_execution_tool.py @@ -12,7 +12,7 @@ from python.helpers.docker import DockerContainerManager @dataclass class State: - shell: LocalInteractiveSession | SSHInteractiveSession + shells: dict[int, LocalInteractiveSession | SSHInteractiveSession] docker: DockerContainerManager | None @@ -27,19 +27,26 @@ class CodeExecution(Tool): # os.chdir(files.get_abs_path("./work_dir")) #change CWD to work_dir runtime = self.args.get("runtime", "").lower().strip() + session = int(self.args.get("session", 0)) if runtime == "python": - response = await self.execute_python_code(self.args["code"]) + response = await self.execute_python_code( + code=self.args["code"], session=session + ) elif runtime == "nodejs": - response = await self.execute_nodejs_code(self.args["code"]) + response = await self.execute_nodejs_code( + code=self.args["code"], session=session + ) elif runtime == "terminal": - response = await self.execute_terminal_command(self.args["code"]) + response = await self.execute_terminal_command( + command=self.args["code"], session=session + ) elif runtime == "output": response = await self.get_terminal_output( - wait_with_output=5, wait_without_output=60 + session=session, wait_with_output=5, wait_without_output=60 ) elif runtime == "reset": - response = await self.reset_terminal() + response = await self.reset_terminal(session=session) else: response = self.agent.read_prompt( "fw.code_runtime_wrong.md", runtime=runtime @@ -72,11 +79,15 @@ class CodeExecution(Tool): # PrintStyle().print() def get_log_object(self): - return self.agent.context.log.log(type="code_exe", heading=f"{self.agent.agent_name}: Using tool '{self.name}'", content="", kvps=self.args) - + return self.agent.context.log.log( + type="code_exe", + heading=f"{self.agent.agent_name}: Using tool '{self.name}'", + content="", + kvps=self.args, + ) async def after_execution(self, response, **kwargs): - await self.agent.hist_add_tool_result(self.name, response.message) + self.agent.hist_add_tool_result(self.name, response.message) async def prepare_state(self, reset=False): self.state = self.agent.get_data("_cot_state") @@ -97,7 +108,11 @@ class CodeExecution(Tool): # initialize local or remote interactive shell insterface if self.agent.config.code_exec_ssh_enabled: - pswd = self.agent.config.code_exec_ssh_pass if self.agent.config.code_exec_ssh_pass else await rfc_exchange.get_root_password() + pswd = ( + self.agent.config.code_exec_ssh_pass + if self.agent.config.code_exec_ssh_pass + else await rfc_exchange.get_root_password() + ) shell = SSHInteractiveSession( self.agent.context.log, self.agent.config.code_exec_ssh_addr, @@ -108,42 +123,63 @@ class CodeExecution(Tool): else: shell = LocalInteractiveSession() - self.state = State(shell=shell, docker=docker) + self.state = State(shells={0: shell}, docker=docker) await shell.connect() self.agent.set_data("_cot_state", self.state) - async def execute_python_code(self, code: str, reset: bool = False): + async def execute_python_code(self, session: int, code: str, reset: bool = False): escaped_code = shlex.quote(code) command = f"ipython -c {escaped_code}" - return await self.terminal_session(command, reset) + return await self.terminal_session(session, command, reset) - async def execute_nodejs_code(self, code: str, reset: bool = False): + async def execute_nodejs_code(self, session: int, code: str, reset: bool = False): escaped_code = shlex.quote(code) command = f"node /exe/node_eval.js {escaped_code}" - return await self.terminal_session(command, reset) + return await self.terminal_session(session, command, reset) - async def execute_terminal_command(self, command: str, reset: bool = False): - return await self.terminal_session(command, reset) + async def execute_terminal_command( + self, session: int, command: str, reset: bool = False + ): + return await self.terminal_session(session, command, reset) - async def terminal_session(self, command: str, reset: bool = False): + async def terminal_session(self, session: int, command: str, reset: bool = False): await self.agent.handle_intervention() # wait for intervention and handle it, if paused # try again on lost connection for i in range(2): try: - + if reset: await self.reset_terminal() - self.state.shell.send_command(command) + if session not in self.state.shells: + if self.agent.config.code_exec_ssh_enabled: + pswd = ( + self.agent.config.code_exec_ssh_pass + if self.agent.config.code_exec_ssh_pass + else await rfc_exchange.get_root_password() + ) + shell = SSHInteractiveSession( + self.agent.context.log, + self.agent.config.code_exec_ssh_addr, + self.agent.config.code_exec_ssh_port, + self.agent.config.code_exec_ssh_user, + pswd, + ) + else: + shell = LocalInteractiveSession() + self.state.shells[session] = shell + await shell.connect() - PrintStyle(background_color="white", font_color="#1B4F72", bold=True).print( - f"{self.agent.agent_name} code execution output" - ) - return await self.get_terminal_output() + self.state.shells[session].send_command(command) + + PrintStyle( + background_color="white", font_color="#1B4F72", bold=True + ).print(f"{self.agent.agent_name} code execution output") + return await self.get_terminal_output(session) except Exception as e: - if i==1: + if i == 1: # try again on lost connection PrintStyle.error(str(e)) await self.prepare_state(reset=True) @@ -153,6 +189,7 @@ class CodeExecution(Tool): async def get_terminal_output( self, + session=0, reset_full_output=True, wait_with_output=3, wait_without_output=10, @@ -165,10 +202,10 @@ class CodeExecution(Tool): while max_exec_time <= 0 or time.time() - start_time < max_exec_time: await asyncio.sleep(SLEEP_TIME) # Wait for some output to be generated - full_output, partial_output = await self.state.shell.read_output( + full_output, partial_output = await self.state.shells[session].read_output( timeout=max_exec_time, reset_full_output=reset_full_output ) - reset_full_output = False # only reset once + reset_full_output = False # only reset once await self.agent.handle_intervention() # wait for intervention and handle it, if paused @@ -184,8 +221,10 @@ class CodeExecution(Tool): break return full_output - async def reset_terminal(self): - self.state.shell.close() + async def reset_terminal(self, session=0): + if session in self.state.shells: + self.state.shells[session].close() + del self.state.shells[session] await self.prepare_state(reset=True) response = self.agent.read_prompt("fw.code_reset.md") self.log.update(content=response) diff --git a/python/tools/input.py b/python/tools/input.py index 6e5a812f8..b23c9adee 100644 --- a/python/tools/input.py +++ b/python/tools/input.py @@ -20,4 +20,4 @@ class Input(Tool): return self.agent.context.log.log(type="code_exe", heading=f"{self.agent.agent_name}: Using tool '{self.name}'", content="", kvps=self.args) async def after_execution(self, response, **kwargs): - await self.agent.hist_add_tool_result(self.name, response.message) \ No newline at end of file + self.agent.hist_add_tool_result(self.name, response.message) \ No newline at end of file diff --git a/python/tools/vision_load.py b/python/tools/vision_load.py new file mode 100644 index 000000000..e1c455fa5 --- /dev/null +++ b/python/tools/vision_load.py @@ -0,0 +1,79 @@ +import base64 +from python.helpers.print_style import PrintStyle +from python.helpers.tool import Tool, Response +from python.helpers import runtime, files, images +from mimetypes import guess_type +from python.helpers import history + +# image optimization and token estimation for context window +MAX_PIXELS = 768_000 +QUALITY = 75 +TOKENS_ESTIMATE = 1500 + + +class VisionLoad(Tool): + async def execute(self, paths: list[str] = [], **kwargs) -> Response: + + self.images_dict = {} + template: list[dict[str, str]] = [] # type: ignore + + for path in paths: + if not await runtime.call_development_function(files.exists, str(path)): + continue + + if path not in self.images_dict: + mime_type, _ = guess_type(str(path)) + if mime_type and mime_type.startswith("image/"): + # Read binary file + file_content = await runtime.call_development_function( + files.read_file_base64, str(path) + ) + file_content = base64.b64decode(file_content) + # Compress and convert to JPEG + compressed = images.compress_image( + file_content, max_pixels=MAX_PIXELS, quality=QUALITY + ) + # Encode as base64 + file_content_b64 = base64.b64encode(compressed).decode("utf-8") + + # DEBUG: Save compressed image + # await runtime.call_development_function( + # files.write_file_base64, str(path), file_content_b64 + # ) + + # Construct the data URL (always JPEG after compression) + self.images_dict[path] = file_content_b64 + + return Response(message="dummy", break_loop=False) + + async def after_execution(self, response: Response, **kwargs): + + # build image data messages for LLMs, or error message + content = [] + if self.images_dict: + for _, image in self.images_dict.items(): + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{image}"}, + } + ) + # append as raw message content for LLMs with vision tokens estimate + msg = history.RawMessage(raw_content=content, preview="") + self.agent.hist_add_message( + False, content=msg, tokens=TOKENS_ESTIMATE * len(content) + ) + else: + self.agent.hist_add_tool_result(self.name, "No images processed") + + # print and log short version + message = ( + "No images processed" + if not self.images_dict + else f"{len(self.images_dict)} images processed" + ) + PrintStyle( + font_color="#1B4F72", background_color="white", padding=True, bold=True + ).print(f"{self.agent.agent_name}: Response from tool '{self.name}'") + PrintStyle(font_color="#85C1E9").print(message) + self.log.update(result=message) diff --git a/webui/js/history.js b/webui/js/history.js index e62fc615e..040de72fe 100644 --- a/webui/js/history.js +++ b/webui/js/history.js @@ -3,9 +3,10 @@ import { getContext } from "../index.js"; export async function openHistoryModal() { try { const hist = await window.sendJsonData("/history_get", { context: getContext() }); - const data = JSON.stringify(hist.history, null, 4); + // const data = JSON.stringify(hist.history, null, 4); + const data = hist.history const size = hist.tokens - await showEditorModal(data, "json", `History ~${size} tokens`, "Conversation history visible to the LLM. History is compressed to fit into the context window over time."); + await showEditorModal(data, "markdown", `History ~${size} tokens`, "Conversation history visible to the LLM. History is compressed to fit into the context window over time."); } catch (e) { window.toastFetchError("Error fetching history", e) return