mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-05-23 21:06:39 +00:00
554 lines
20 KiB
Python
554 lines
20 KiB
Python
import base64
|
|
import json
|
|
import os
|
|
import re
|
|
import socketserver
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from python.helpers import files, tokens
|
|
|
|
|
|
BASE54_ALPHABET = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstyz"
|
|
BASE54_INDEX = {char: idx for idx, char in enumerate(BASE54_ALPHABET)}
|
|
BASE54_BASE = len(BASE54_ALPHABET)
|
|
|
|
CONTROL_TAG_SHOW_SAVINGS = "**show savings**"
|
|
CONTROL_TAG_SHOW_TOTAL = "**show total**"
|
|
|
|
|
|
def _safe_count_tokens(text: str) -> int:
|
|
if not text:
|
|
return 0
|
|
try:
|
|
return tokens.count_tokens(text)
|
|
except Exception:
|
|
return max(1, len(text.split()))
|
|
|
|
|
|
def _token_stats(raw_text: str, encoded_text: str) -> Dict[str, int]:
|
|
raw_tokens = _safe_count_tokens(raw_text)
|
|
encoded_tokens = _safe_count_tokens(encoded_text)
|
|
saved = raw_tokens - encoded_tokens
|
|
if saved < 0:
|
|
saved = 0
|
|
return {
|
|
"raw": raw_tokens,
|
|
"encoded": encoded_tokens,
|
|
"saved": saved,
|
|
}
|
|
|
|
|
|
def _strip_control_tags(text: str) -> Tuple[str, bool, bool]:
|
|
show_savings = False
|
|
show_total = False
|
|
if not text:
|
|
return text, show_savings, show_total
|
|
|
|
if re.search(re.escape(CONTROL_TAG_SHOW_SAVINGS), text, flags=re.IGNORECASE):
|
|
show_savings = True
|
|
text = re.sub(
|
|
re.escape(CONTROL_TAG_SHOW_SAVINGS), "", text, flags=re.IGNORECASE
|
|
)
|
|
if re.search(re.escape(CONTROL_TAG_SHOW_TOTAL), text, flags=re.IGNORECASE):
|
|
show_total = True
|
|
show_savings = True
|
|
text = re.sub(
|
|
re.escape(CONTROL_TAG_SHOW_TOTAL), "", text, flags=re.IGNORECASE
|
|
)
|
|
|
|
return text.strip(), show_savings, show_total
|
|
|
|
|
|
def b54encode(payload: bytes) -> str:
|
|
if not payload:
|
|
return ""
|
|
num = int.from_bytes(payload, "big")
|
|
encoded: List[str] = []
|
|
while num > 0:
|
|
num, rem = divmod(num, BASE54_BASE)
|
|
encoded.append(BASE54_ALPHABET[rem])
|
|
pad = 0
|
|
for byte in payload:
|
|
if byte == 0:
|
|
pad += 1
|
|
else:
|
|
break
|
|
encoded_str = "".join(reversed(encoded)) if encoded else ""
|
|
return (BASE54_ALPHABET[0] * pad) + encoded_str
|
|
|
|
|
|
def b54decode(payload: str) -> bytes:
|
|
if payload == "":
|
|
return b""
|
|
num = 0
|
|
for char in payload:
|
|
if char not in BASE54_INDEX:
|
|
raise ValueError(f"Invalid base54 character: {char!r}")
|
|
num = num * BASE54_BASE + BASE54_INDEX[char]
|
|
pad = 0
|
|
for char in payload:
|
|
if char == BASE54_ALPHABET[0]:
|
|
pad += 1
|
|
else:
|
|
break
|
|
decoded = b""
|
|
if num > 0:
|
|
byte_len = (num.bit_length() + 7) // 8
|
|
decoded = num.to_bytes(byte_len, "big")
|
|
return (b"\x00" * pad) + decoded
|
|
|
|
|
|
def _b64encode_text(text: str, encoding: str) -> str:
|
|
return base64.b64encode(text.encode(encoding, errors="replace")).decode("ascii")
|
|
|
|
|
|
def _context_text(messages: List[Dict[str, str]]) -> str:
|
|
return "\n".join(f"{entry['role']}: {entry['text']}" for entry in messages).strip()
|
|
|
|
|
|
@dataclass
|
|
class ConversationState:
|
|
conversation_id: str
|
|
encoding: str = "utf-8"
|
|
language: str = "unknown"
|
|
messages: List[Dict[str, str]] = field(default_factory=list)
|
|
context_b64: str = ""
|
|
context_tokens: Dict[str, int] = field(default_factory=dict)
|
|
prompt_tokens_raw: int = 0
|
|
prompt_tokens_encoded: int = 0
|
|
response_tokens_raw: int = 0
|
|
response_tokens_encoded: int = 0
|
|
last_prompt_stats: Dict[str, int] = field(default_factory=dict)
|
|
last_response_stats: Dict[str, int] = field(default_factory=dict)
|
|
pending_show_savings: bool = False
|
|
pending_show_total: bool = False
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"conversation_id": self.conversation_id,
|
|
"encoding": self.encoding,
|
|
"language": self.language,
|
|
"messages": self.messages,
|
|
"context_b64": self.context_b64,
|
|
"context_tokens": self.context_tokens,
|
|
"prompt_tokens_raw": self.prompt_tokens_raw,
|
|
"prompt_tokens_encoded": self.prompt_tokens_encoded,
|
|
"response_tokens_raw": self.response_tokens_raw,
|
|
"response_tokens_encoded": self.response_tokens_encoded,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, payload: Dict[str, Any]) -> "ConversationState":
|
|
return cls(
|
|
conversation_id=payload.get("conversation_id", ""),
|
|
encoding=payload.get("encoding", "utf-8"),
|
|
language=payload.get("language", "unknown"),
|
|
messages=payload.get("messages", []),
|
|
context_b64=payload.get("context_b64", ""),
|
|
context_tokens=payload.get("context_tokens", {}),
|
|
prompt_tokens_raw=payload.get("prompt_tokens_raw", 0),
|
|
prompt_tokens_encoded=payload.get("prompt_tokens_encoded", 0),
|
|
response_tokens_raw=payload.get("response_tokens_raw", 0),
|
|
response_tokens_encoded=payload.get("response_tokens_encoded", 0),
|
|
)
|
|
|
|
|
|
class ContextStore:
|
|
def __init__(self, dataset_path: str, refresh_interval: float = 5.0):
|
|
self.dataset_path = dataset_path
|
|
self.refresh_interval = refresh_interval
|
|
self._lock = threading.Lock()
|
|
self._dirty = False
|
|
self._conversations: Dict[str, ConversationState] = {}
|
|
self._stop_event = threading.Event()
|
|
self._load()
|
|
self._thread = threading.Thread(
|
|
target=self._maintenance_loop,
|
|
name="tcp-context-maintainer",
|
|
daemon=True,
|
|
)
|
|
self._thread.start()
|
|
|
|
def _load(self) -> None:
|
|
if not os.path.exists(self.dataset_path):
|
|
return
|
|
try:
|
|
with open(self.dataset_path, "r", encoding="utf-8") as handle:
|
|
data = json.load(handle)
|
|
except (OSError, json.JSONDecodeError):
|
|
return
|
|
conversations = data.get("conversations", {})
|
|
for conv_id, payload in conversations.items():
|
|
state = ConversationState.from_dict(payload)
|
|
if not state.conversation_id:
|
|
state.conversation_id = conv_id
|
|
self._conversations[conv_id] = state
|
|
|
|
def _maintenance_loop(self) -> None:
|
|
while not self._stop_event.wait(self.refresh_interval):
|
|
self._flush_if_dirty()
|
|
|
|
def _flush_if_dirty(self) -> None:
|
|
with self._lock:
|
|
if not self._dirty:
|
|
return
|
|
snapshot = self._snapshot_locked()
|
|
self._dirty = False
|
|
self._persist_snapshot(snapshot)
|
|
|
|
def _snapshot_locked(self) -> Dict[str, Any]:
|
|
for state in self._conversations.values():
|
|
self._refresh_context_locked(state)
|
|
return {
|
|
"updated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
"conversations": {
|
|
conv_id: state.to_dict()
|
|
for conv_id, state in self._conversations.items()
|
|
},
|
|
}
|
|
|
|
def _persist_snapshot(self, snapshot: Dict[str, Any]) -> None:
|
|
os.makedirs(os.path.dirname(self.dataset_path), exist_ok=True)
|
|
tmp_path = f"{self.dataset_path}.tmp"
|
|
with open(tmp_path, "w", encoding="utf-8") as handle:
|
|
json.dump(snapshot, handle, ensure_ascii=True, indent=2)
|
|
os.replace(tmp_path, self.dataset_path)
|
|
|
|
def stop(self) -> None:
|
|
self._stop_event.set()
|
|
self._thread.join(timeout=self.refresh_interval)
|
|
self._flush_if_dirty()
|
|
|
|
def get_or_create(
|
|
self,
|
|
conversation_id: Optional[str],
|
|
encoding: Optional[str],
|
|
language: Optional[str],
|
|
) -> ConversationState:
|
|
with self._lock:
|
|
if not conversation_id:
|
|
conversation_id = str(uuid.uuid4())
|
|
state = self._conversations.get(conversation_id)
|
|
if state is None:
|
|
state = ConversationState(conversation_id=conversation_id)
|
|
self._conversations[conversation_id] = state
|
|
if encoding:
|
|
state.encoding = encoding
|
|
if language:
|
|
state.language = language
|
|
return state
|
|
|
|
def list_contexts(self) -> Dict[str, Dict[str, Any]]:
|
|
with self._lock:
|
|
contexts = {}
|
|
for conv_id, state in self._conversations.items():
|
|
self._refresh_context_locked(state)
|
|
contexts[conv_id] = {
|
|
"context_b64": state.context_b64,
|
|
"encoding": state.encoding,
|
|
"language": state.language,
|
|
"context_tokens": state.context_tokens,
|
|
}
|
|
return contexts
|
|
|
|
def get_context(self, conversation_id: str) -> Optional[Dict[str, Any]]:
|
|
with self._lock:
|
|
state = self._conversations.get(conversation_id)
|
|
if not state:
|
|
return None
|
|
self._refresh_context_locked(state)
|
|
return {
|
|
"conversation_id": state.conversation_id,
|
|
"context_b64": state.context_b64,
|
|
"encoding": state.encoding,
|
|
"language": state.language,
|
|
"context_tokens": state.context_tokens,
|
|
}
|
|
|
|
def record_prompt(
|
|
self,
|
|
conversation_id: Optional[str],
|
|
text: str,
|
|
encoding: Optional[str],
|
|
language: Optional[str],
|
|
) -> Dict[str, Any]:
|
|
state = self.get_or_create(conversation_id, encoding, language)
|
|
clean_text, show_savings, show_total = _strip_control_tags(text)
|
|
encoded_prompt = _b64encode_text(clean_text, state.encoding)
|
|
prompt_stats = _token_stats(clean_text, encoded_prompt)
|
|
with self._lock:
|
|
state.messages.append({"role": "user", "text": clean_text})
|
|
state.prompt_tokens_raw += prompt_stats["raw"]
|
|
state.prompt_tokens_encoded += prompt_stats["encoded"]
|
|
state.last_prompt_stats = prompt_stats
|
|
state.pending_show_savings = show_savings or show_total
|
|
state.pending_show_total = show_total
|
|
self._refresh_context_locked(state)
|
|
self._dirty = True
|
|
response = {
|
|
"conversation_id": state.conversation_id,
|
|
"encoding": state.encoding,
|
|
"language": state.language,
|
|
"encoded_prompt_b64": encoded_prompt,
|
|
"context_b64": state.context_b64,
|
|
"context_tokens": state.context_tokens,
|
|
"prompt_tokens": prompt_stats,
|
|
"savings_request": {
|
|
"show_savings": state.pending_show_savings,
|
|
"show_total": state.pending_show_total,
|
|
},
|
|
}
|
|
return response
|
|
|
|
def record_response(
|
|
self,
|
|
conversation_id: str,
|
|
payload_b54: str,
|
|
) -> Dict[str, Any]:
|
|
with self._lock:
|
|
state = self._conversations.get(conversation_id)
|
|
if not state:
|
|
raise KeyError("Unknown conversation_id")
|
|
encoding = state.encoding
|
|
language = state.language
|
|
|
|
decoded_bytes = b54decode(payload_b54)
|
|
decoded_text = decoded_bytes.decode(encoding, errors="replace")
|
|
response_stats = _token_stats(decoded_text, payload_b54)
|
|
|
|
with self._lock:
|
|
state.messages.append({"role": "assistant", "text": decoded_text})
|
|
state.response_tokens_raw += response_stats["raw"]
|
|
state.response_tokens_encoded += response_stats["encoded"]
|
|
state.last_response_stats = response_stats
|
|
self._refresh_context_locked(state)
|
|
savings_payload = None
|
|
tagline = None
|
|
decoded_text_with_tagline = None
|
|
if state.pending_show_savings:
|
|
savings_payload = self._build_savings_payload(state)
|
|
tagline = self._format_tagline(
|
|
savings_payload,
|
|
include_total=state.pending_show_total,
|
|
)
|
|
decoded_text_with_tagline = (
|
|
decoded_text + "\n" + tagline if decoded_text else tagline
|
|
)
|
|
state.pending_show_savings = False
|
|
state.pending_show_total = False
|
|
self._dirty = True
|
|
response = {
|
|
"conversation_id": state.conversation_id,
|
|
"encoding": encoding,
|
|
"language": language,
|
|
"response_b54": payload_b54,
|
|
"decoded_text": decoded_text,
|
|
"response_tokens": response_stats,
|
|
"context_b64": state.context_b64,
|
|
"context_tokens": state.context_tokens,
|
|
}
|
|
if savings_payload:
|
|
response["savings"] = savings_payload
|
|
if tagline:
|
|
response["tagline"] = tagline
|
|
response["decoded_text_with_tagline"] = decoded_text_with_tagline
|
|
return response
|
|
|
|
def _refresh_context_locked(self, state: ConversationState) -> None:
|
|
context_text = _context_text(state.messages)
|
|
state.context_b64 = _b64encode_text(context_text, state.encoding)
|
|
state.context_tokens = _token_stats(context_text, state.context_b64)
|
|
|
|
def _build_savings_payload(self, state: ConversationState) -> Dict[str, Any]:
|
|
prompt_stats = state.last_prompt_stats or {"raw": 0, "encoded": 0, "saved": 0}
|
|
response_stats = state.last_response_stats or {
|
|
"raw": 0,
|
|
"encoded": 0,
|
|
"saved": 0,
|
|
}
|
|
context_stats = state.context_tokens or {"raw": 0, "encoded": 0, "saved": 0}
|
|
combined_saved = (
|
|
prompt_stats.get("saved", 0)
|
|
+ response_stats.get("saved", 0)
|
|
+ context_stats.get("saved", 0)
|
|
)
|
|
totals = {
|
|
"prompt": {
|
|
"raw": state.prompt_tokens_raw,
|
|
"encoded": state.prompt_tokens_encoded,
|
|
"saved": max(
|
|
0, state.prompt_tokens_raw - state.prompt_tokens_encoded
|
|
),
|
|
},
|
|
"response": {
|
|
"raw": state.response_tokens_raw,
|
|
"encoded": state.response_tokens_encoded,
|
|
"saved": max(
|
|
0, state.response_tokens_raw - state.response_tokens_encoded
|
|
),
|
|
},
|
|
"context": context_stats,
|
|
}
|
|
totals["combined_saved"] = (
|
|
totals["prompt"]["saved"]
|
|
+ totals["response"]["saved"]
|
|
+ totals["context"]["saved"]
|
|
)
|
|
return {
|
|
"prompt": prompt_stats,
|
|
"response": response_stats,
|
|
"context": context_stats,
|
|
"combined_saved": combined_saved,
|
|
"totals": totals,
|
|
}
|
|
|
|
def _format_tagline(self, savings: Dict[str, Any], include_total: bool) -> str:
|
|
prompt_saved = savings["prompt"]["saved"]
|
|
response_saved = savings["response"]["saved"]
|
|
context_saved = savings["context"]["saved"]
|
|
combined_saved = savings["combined_saved"]
|
|
tagline = (
|
|
"Token savings (prompt/response/context/combined): "
|
|
f"{prompt_saved}/{response_saved}/{context_saved}/{combined_saved}."
|
|
)
|
|
if include_total:
|
|
totals = savings.get("totals", {})
|
|
totals_prompt = totals.get("prompt", {}).get("saved", 0)
|
|
totals_response = totals.get("response", {}).get("saved", 0)
|
|
totals_context = totals.get("context", {}).get("saved", 0)
|
|
totals_combined = totals.get("combined_saved", 0)
|
|
tagline += (
|
|
" Total savings (prompt/response/context/combined): "
|
|
f"{totals_prompt}/{totals_response}/{totals_context}/{totals_combined}."
|
|
)
|
|
return tagline
|
|
|
|
|
|
class TokenCompressionProtocolProcessor:
|
|
def __init__(self, store: ContextStore):
|
|
self.store = store
|
|
|
|
def handle(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
action = payload.get("action")
|
|
if not action:
|
|
return {"ok": False, "error": "missing_action"}
|
|
|
|
if action == "prompt":
|
|
text = payload.get("text", "")
|
|
if not isinstance(text, str) or text == "":
|
|
return {"ok": False, "error": "missing_text"}
|
|
response = self.store.record_prompt(
|
|
conversation_id=payload.get("conversation_id"),
|
|
text=text,
|
|
encoding=payload.get("encoding"),
|
|
language=payload.get("language"),
|
|
)
|
|
return {"ok": True, "result": response}
|
|
|
|
if action == "response":
|
|
conversation_id = payload.get("conversation_id")
|
|
if not conversation_id:
|
|
return {"ok": False, "error": "missing_conversation_id"}
|
|
payload_b54 = payload.get("payload_b54", "")
|
|
if not isinstance(payload_b54, str) or payload_b54 == "":
|
|
return {"ok": False, "error": "missing_payload_b54"}
|
|
try:
|
|
response = self.store.record_response(
|
|
conversation_id=conversation_id,
|
|
payload_b54=payload_b54,
|
|
)
|
|
except KeyError:
|
|
return {"ok": False, "error": "unknown_conversation_id"}
|
|
except ValueError as exc:
|
|
return {"ok": False, "error": "invalid_base54", "detail": str(exc)}
|
|
return {"ok": True, "result": response}
|
|
|
|
if action == "context_get":
|
|
conversation_id = payload.get("conversation_id")
|
|
if conversation_id:
|
|
context = self.store.get_context(conversation_id)
|
|
if not context:
|
|
return {"ok": False, "error": "unknown_conversation_id"}
|
|
return {"ok": True, "result": context}
|
|
return {"ok": True, "result": {"contexts": self.store.list_contexts()}}
|
|
|
|
if action == "context_reset":
|
|
conversation_id = payload.get("conversation_id")
|
|
if not conversation_id:
|
|
return {"ok": False, "error": "missing_conversation_id"}
|
|
with self.store._lock:
|
|
if conversation_id in self.store._conversations:
|
|
del self.store._conversations[conversation_id]
|
|
self.store._dirty = True
|
|
return {"ok": True, "result": {"conversation_id": conversation_id}}
|
|
return {"ok": False, "error": "unknown_conversation_id"}
|
|
|
|
if action == "ping":
|
|
return {"ok": True, "result": {"message": "pong"}}
|
|
|
|
return {"ok": False, "error": "unknown_action"}
|
|
|
|
|
|
class TokenCompressionTCPServer(socketserver.ThreadingTCPServer):
|
|
allow_reuse_address = True
|
|
daemon_threads = True
|
|
|
|
def __init__(self, server_address, RequestHandlerClass, processor):
|
|
super().__init__(server_address, RequestHandlerClass)
|
|
self.processor = processor
|
|
|
|
|
|
class TokenCompressionRequestHandler(socketserver.StreamRequestHandler):
|
|
def handle(self) -> None:
|
|
while True:
|
|
raw_line = self.rfile.readline()
|
|
if not raw_line:
|
|
break
|
|
raw_line = raw_line.strip()
|
|
if not raw_line:
|
|
continue
|
|
try:
|
|
request = json.loads(raw_line.decode("utf-8"))
|
|
except json.JSONDecodeError as exc:
|
|
self._send({"ok": False, "error": "invalid_json", "detail": str(exc)})
|
|
continue
|
|
if not isinstance(request, dict):
|
|
self._send({"ok": False, "error": "invalid_payload"})
|
|
continue
|
|
response = self.server.processor.handle(request)
|
|
self._send(response)
|
|
|
|
def _send(self, payload: Dict[str, Any]) -> None:
|
|
encoded = json.dumps(payload, ensure_ascii=True).encode("utf-8") + b"\n"
|
|
self.wfile.write(encoded)
|
|
|
|
|
|
def run_tcp_server(
|
|
host: str = "127.0.0.1",
|
|
port: int = 7543,
|
|
dataset_path: Optional[str] = None,
|
|
refresh_interval: float = 5.0,
|
|
) -> None:
|
|
dataset_path = dataset_path or files.get_abs_path(
|
|
"memory", "token_compression_context.json"
|
|
)
|
|
store = ContextStore(dataset_path=dataset_path, refresh_interval=refresh_interval)
|
|
processor = TokenCompressionProtocolProcessor(store)
|
|
server = TokenCompressionTCPServer(
|
|
(host, port), TokenCompressionRequestHandler, processor
|
|
)
|
|
try:
|
|
server.serve_forever()
|
|
except KeyboardInterrupt:
|
|
pass
|
|
finally:
|
|
store.stop()
|
|
server.server_close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tcp_server()
|