mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-05-24 05:26:39 +00:00
Add token compression TCP protocol server
Co-authored-by: nicsins <nicsins@gmail.com>
This commit is contained in:
parent
f3c41bca08
commit
5152e4d424
1 changed files with 554 additions and 0 deletions
554
python/helpers/token_compression_protocol.py
Normal file
554
python/helpers/token_compression_protocol.py
Normal file
|
|
@ -0,0 +1,554 @@
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import socketserver
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from python.helpers import files, tokens
|
||||
|
||||
|
||||
BASE54_ALPHABET = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstyz"
|
||||
BASE54_INDEX = {char: idx for idx, char in enumerate(BASE54_ALPHABET)}
|
||||
BASE54_BASE = len(BASE54_ALPHABET)
|
||||
|
||||
CONTROL_TAG_SHOW_SAVINGS = "**show savings**"
|
||||
CONTROL_TAG_SHOW_TOTAL = "**show total**"
|
||||
|
||||
|
||||
def _safe_count_tokens(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
try:
|
||||
return tokens.count_tokens(text)
|
||||
except Exception:
|
||||
return max(1, len(text.split()))
|
||||
|
||||
|
||||
def _token_stats(raw_text: str, encoded_text: str) -> Dict[str, int]:
|
||||
raw_tokens = _safe_count_tokens(raw_text)
|
||||
encoded_tokens = _safe_count_tokens(encoded_text)
|
||||
saved = raw_tokens - encoded_tokens
|
||||
if saved < 0:
|
||||
saved = 0
|
||||
return {
|
||||
"raw": raw_tokens,
|
||||
"encoded": encoded_tokens,
|
||||
"saved": saved,
|
||||
}
|
||||
|
||||
|
||||
def _strip_control_tags(text: str) -> Tuple[str, bool, bool]:
|
||||
show_savings = False
|
||||
show_total = False
|
||||
if not text:
|
||||
return text, show_savings, show_total
|
||||
|
||||
if re.search(re.escape(CONTROL_TAG_SHOW_SAVINGS), text, flags=re.IGNORECASE):
|
||||
show_savings = True
|
||||
text = re.sub(
|
||||
re.escape(CONTROL_TAG_SHOW_SAVINGS), "", text, flags=re.IGNORECASE
|
||||
)
|
||||
if re.search(re.escape(CONTROL_TAG_SHOW_TOTAL), text, flags=re.IGNORECASE):
|
||||
show_total = True
|
||||
show_savings = True
|
||||
text = re.sub(
|
||||
re.escape(CONTROL_TAG_SHOW_TOTAL), "", text, flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
return text.strip(), show_savings, show_total
|
||||
|
||||
|
||||
def b54encode(payload: bytes) -> str:
|
||||
if not payload:
|
||||
return ""
|
||||
num = int.from_bytes(payload, "big")
|
||||
encoded: List[str] = []
|
||||
while num > 0:
|
||||
num, rem = divmod(num, BASE54_BASE)
|
||||
encoded.append(BASE54_ALPHABET[rem])
|
||||
pad = 0
|
||||
for byte in payload:
|
||||
if byte == 0:
|
||||
pad += 1
|
||||
else:
|
||||
break
|
||||
encoded_str = "".join(reversed(encoded)) if encoded else ""
|
||||
return (BASE54_ALPHABET[0] * pad) + encoded_str
|
||||
|
||||
|
||||
def b54decode(payload: str) -> bytes:
|
||||
if payload == "":
|
||||
return b""
|
||||
num = 0
|
||||
for char in payload:
|
||||
if char not in BASE54_INDEX:
|
||||
raise ValueError(f"Invalid base54 character: {char!r}")
|
||||
num = num * BASE54_BASE + BASE54_INDEX[char]
|
||||
pad = 0
|
||||
for char in payload:
|
||||
if char == BASE54_ALPHABET[0]:
|
||||
pad += 1
|
||||
else:
|
||||
break
|
||||
decoded = b""
|
||||
if num > 0:
|
||||
byte_len = (num.bit_length() + 7) // 8
|
||||
decoded = num.to_bytes(byte_len, "big")
|
||||
return (b"\x00" * pad) + decoded
|
||||
|
||||
|
||||
def _b64encode_text(text: str, encoding: str) -> str:
|
||||
return base64.b64encode(text.encode(encoding, errors="replace")).decode("ascii")
|
||||
|
||||
|
||||
def _context_text(messages: List[Dict[str, str]]) -> str:
|
||||
return "\n".join(f"{entry['role']}: {entry['text']}" for entry in messages).strip()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationState:
|
||||
conversation_id: str
|
||||
encoding: str = "utf-8"
|
||||
language: str = "unknown"
|
||||
messages: List[Dict[str, str]] = field(default_factory=list)
|
||||
context_b64: str = ""
|
||||
context_tokens: Dict[str, int] = field(default_factory=dict)
|
||||
prompt_tokens_raw: int = 0
|
||||
prompt_tokens_encoded: int = 0
|
||||
response_tokens_raw: int = 0
|
||||
response_tokens_encoded: int = 0
|
||||
last_prompt_stats: Dict[str, int] = field(default_factory=dict)
|
||||
last_response_stats: Dict[str, int] = field(default_factory=dict)
|
||||
pending_show_savings: bool = False
|
||||
pending_show_total: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"conversation_id": self.conversation_id,
|
||||
"encoding": self.encoding,
|
||||
"language": self.language,
|
||||
"messages": self.messages,
|
||||
"context_b64": self.context_b64,
|
||||
"context_tokens": self.context_tokens,
|
||||
"prompt_tokens_raw": self.prompt_tokens_raw,
|
||||
"prompt_tokens_encoded": self.prompt_tokens_encoded,
|
||||
"response_tokens_raw": self.response_tokens_raw,
|
||||
"response_tokens_encoded": self.response_tokens_encoded,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, payload: Dict[str, Any]) -> "ConversationState":
|
||||
return cls(
|
||||
conversation_id=payload.get("conversation_id", ""),
|
||||
encoding=payload.get("encoding", "utf-8"),
|
||||
language=payload.get("language", "unknown"),
|
||||
messages=payload.get("messages", []),
|
||||
context_b64=payload.get("context_b64", ""),
|
||||
context_tokens=payload.get("context_tokens", {}),
|
||||
prompt_tokens_raw=payload.get("prompt_tokens_raw", 0),
|
||||
prompt_tokens_encoded=payload.get("prompt_tokens_encoded", 0),
|
||||
response_tokens_raw=payload.get("response_tokens_raw", 0),
|
||||
response_tokens_encoded=payload.get("response_tokens_encoded", 0),
|
||||
)
|
||||
|
||||
|
||||
class ContextStore:
|
||||
def __init__(self, dataset_path: str, refresh_interval: float = 5.0):
|
||||
self.dataset_path = dataset_path
|
||||
self.refresh_interval = refresh_interval
|
||||
self._lock = threading.Lock()
|
||||
self._dirty = False
|
||||
self._conversations: Dict[str, ConversationState] = {}
|
||||
self._stop_event = threading.Event()
|
||||
self._load()
|
||||
self._thread = threading.Thread(
|
||||
target=self._maintenance_loop,
|
||||
name="tcp-context-maintainer",
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def _load(self) -> None:
|
||||
if not os.path.exists(self.dataset_path):
|
||||
return
|
||||
try:
|
||||
with open(self.dataset_path, "r", encoding="utf-8") as handle:
|
||||
data = json.load(handle)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return
|
||||
conversations = data.get("conversations", {})
|
||||
for conv_id, payload in conversations.items():
|
||||
state = ConversationState.from_dict(payload)
|
||||
if not state.conversation_id:
|
||||
state.conversation_id = conv_id
|
||||
self._conversations[conv_id] = state
|
||||
|
||||
def _maintenance_loop(self) -> None:
|
||||
while not self._stop_event.wait(self.refresh_interval):
|
||||
self._flush_if_dirty()
|
||||
|
||||
def _flush_if_dirty(self) -> None:
|
||||
with self._lock:
|
||||
if not self._dirty:
|
||||
return
|
||||
snapshot = self._snapshot_locked()
|
||||
self._dirty = False
|
||||
self._persist_snapshot(snapshot)
|
||||
|
||||
def _snapshot_locked(self) -> Dict[str, Any]:
|
||||
for state in self._conversations.values():
|
||||
self._refresh_context_locked(state)
|
||||
return {
|
||||
"updated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"conversations": {
|
||||
conv_id: state.to_dict()
|
||||
for conv_id, state in self._conversations.items()
|
||||
},
|
||||
}
|
||||
|
||||
def _persist_snapshot(self, snapshot: Dict[str, Any]) -> None:
|
||||
os.makedirs(os.path.dirname(self.dataset_path), exist_ok=True)
|
||||
tmp_path = f"{self.dataset_path}.tmp"
|
||||
with open(tmp_path, "w", encoding="utf-8") as handle:
|
||||
json.dump(snapshot, handle, ensure_ascii=True, indent=2)
|
||||
os.replace(tmp_path, self.dataset_path)
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop_event.set()
|
||||
self._thread.join(timeout=self.refresh_interval)
|
||||
self._flush_if_dirty()
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
conversation_id: Optional[str],
|
||||
encoding: Optional[str],
|
||||
language: Optional[str],
|
||||
) -> ConversationState:
|
||||
with self._lock:
|
||||
if not conversation_id:
|
||||
conversation_id = str(uuid.uuid4())
|
||||
state = self._conversations.get(conversation_id)
|
||||
if state is None:
|
||||
state = ConversationState(conversation_id=conversation_id)
|
||||
self._conversations[conversation_id] = state
|
||||
if encoding:
|
||||
state.encoding = encoding
|
||||
if language:
|
||||
state.language = language
|
||||
return state
|
||||
|
||||
def list_contexts(self) -> Dict[str, Dict[str, Any]]:
|
||||
with self._lock:
|
||||
contexts = {}
|
||||
for conv_id, state in self._conversations.items():
|
||||
self._refresh_context_locked(state)
|
||||
contexts[conv_id] = {
|
||||
"context_b64": state.context_b64,
|
||||
"encoding": state.encoding,
|
||||
"language": state.language,
|
||||
"context_tokens": state.context_tokens,
|
||||
}
|
||||
return contexts
|
||||
|
||||
def get_context(self, conversation_id: str) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
state = self._conversations.get(conversation_id)
|
||||
if not state:
|
||||
return None
|
||||
self._refresh_context_locked(state)
|
||||
return {
|
||||
"conversation_id": state.conversation_id,
|
||||
"context_b64": state.context_b64,
|
||||
"encoding": state.encoding,
|
||||
"language": state.language,
|
||||
"context_tokens": state.context_tokens,
|
||||
}
|
||||
|
||||
def record_prompt(
|
||||
self,
|
||||
conversation_id: Optional[str],
|
||||
text: str,
|
||||
encoding: Optional[str],
|
||||
language: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
state = self.get_or_create(conversation_id, encoding, language)
|
||||
clean_text, show_savings, show_total = _strip_control_tags(text)
|
||||
encoded_prompt = _b64encode_text(clean_text, state.encoding)
|
||||
prompt_stats = _token_stats(clean_text, encoded_prompt)
|
||||
with self._lock:
|
||||
state.messages.append({"role": "user", "text": clean_text})
|
||||
state.prompt_tokens_raw += prompt_stats["raw"]
|
||||
state.prompt_tokens_encoded += prompt_stats["encoded"]
|
||||
state.last_prompt_stats = prompt_stats
|
||||
state.pending_show_savings = show_savings or show_total
|
||||
state.pending_show_total = show_total
|
||||
self._refresh_context_locked(state)
|
||||
self._dirty = True
|
||||
response = {
|
||||
"conversation_id": state.conversation_id,
|
||||
"encoding": state.encoding,
|
||||
"language": state.language,
|
||||
"encoded_prompt_b64": encoded_prompt,
|
||||
"context_b64": state.context_b64,
|
||||
"context_tokens": state.context_tokens,
|
||||
"prompt_tokens": prompt_stats,
|
||||
"savings_request": {
|
||||
"show_savings": state.pending_show_savings,
|
||||
"show_total": state.pending_show_total,
|
||||
},
|
||||
}
|
||||
return response
|
||||
|
||||
def record_response(
|
||||
self,
|
||||
conversation_id: str,
|
||||
payload_b54: str,
|
||||
) -> Dict[str, Any]:
|
||||
with self._lock:
|
||||
state = self._conversations.get(conversation_id)
|
||||
if not state:
|
||||
raise KeyError("Unknown conversation_id")
|
||||
encoding = state.encoding
|
||||
language = state.language
|
||||
|
||||
decoded_bytes = b54decode(payload_b54)
|
||||
decoded_text = decoded_bytes.decode(encoding, errors="replace")
|
||||
response_stats = _token_stats(decoded_text, payload_b54)
|
||||
|
||||
with self._lock:
|
||||
state.messages.append({"role": "assistant", "text": decoded_text})
|
||||
state.response_tokens_raw += response_stats["raw"]
|
||||
state.response_tokens_encoded += response_stats["encoded"]
|
||||
state.last_response_stats = response_stats
|
||||
self._refresh_context_locked(state)
|
||||
savings_payload = None
|
||||
tagline = None
|
||||
decoded_text_with_tagline = None
|
||||
if state.pending_show_savings:
|
||||
savings_payload = self._build_savings_payload(state)
|
||||
tagline = self._format_tagline(
|
||||
savings_payload,
|
||||
include_total=state.pending_show_total,
|
||||
)
|
||||
decoded_text_with_tagline = (
|
||||
decoded_text + "\n" + tagline if decoded_text else tagline
|
||||
)
|
||||
state.pending_show_savings = False
|
||||
state.pending_show_total = False
|
||||
self._dirty = True
|
||||
response = {
|
||||
"conversation_id": state.conversation_id,
|
||||
"encoding": encoding,
|
||||
"language": language,
|
||||
"response_b54": payload_b54,
|
||||
"decoded_text": decoded_text,
|
||||
"response_tokens": response_stats,
|
||||
"context_b64": state.context_b64,
|
||||
"context_tokens": state.context_tokens,
|
||||
}
|
||||
if savings_payload:
|
||||
response["savings"] = savings_payload
|
||||
if tagline:
|
||||
response["tagline"] = tagline
|
||||
response["decoded_text_with_tagline"] = decoded_text_with_tagline
|
||||
return response
|
||||
|
||||
def _refresh_context_locked(self, state: ConversationState) -> None:
|
||||
context_text = _context_text(state.messages)
|
||||
state.context_b64 = _b64encode_text(context_text, state.encoding)
|
||||
state.context_tokens = _token_stats(context_text, state.context_b64)
|
||||
|
||||
def _build_savings_payload(self, state: ConversationState) -> Dict[str, Any]:
|
||||
prompt_stats = state.last_prompt_stats or {"raw": 0, "encoded": 0, "saved": 0}
|
||||
response_stats = state.last_response_stats or {
|
||||
"raw": 0,
|
||||
"encoded": 0,
|
||||
"saved": 0,
|
||||
}
|
||||
context_stats = state.context_tokens or {"raw": 0, "encoded": 0, "saved": 0}
|
||||
combined_saved = (
|
||||
prompt_stats.get("saved", 0)
|
||||
+ response_stats.get("saved", 0)
|
||||
+ context_stats.get("saved", 0)
|
||||
)
|
||||
totals = {
|
||||
"prompt": {
|
||||
"raw": state.prompt_tokens_raw,
|
||||
"encoded": state.prompt_tokens_encoded,
|
||||
"saved": max(
|
||||
0, state.prompt_tokens_raw - state.prompt_tokens_encoded
|
||||
),
|
||||
},
|
||||
"response": {
|
||||
"raw": state.response_tokens_raw,
|
||||
"encoded": state.response_tokens_encoded,
|
||||
"saved": max(
|
||||
0, state.response_tokens_raw - state.response_tokens_encoded
|
||||
),
|
||||
},
|
||||
"context": context_stats,
|
||||
}
|
||||
totals["combined_saved"] = (
|
||||
totals["prompt"]["saved"]
|
||||
+ totals["response"]["saved"]
|
||||
+ totals["context"]["saved"]
|
||||
)
|
||||
return {
|
||||
"prompt": prompt_stats,
|
||||
"response": response_stats,
|
||||
"context": context_stats,
|
||||
"combined_saved": combined_saved,
|
||||
"totals": totals,
|
||||
}
|
||||
|
||||
def _format_tagline(self, savings: Dict[str, Any], include_total: bool) -> str:
|
||||
prompt_saved = savings["prompt"]["saved"]
|
||||
response_saved = savings["response"]["saved"]
|
||||
context_saved = savings["context"]["saved"]
|
||||
combined_saved = savings["combined_saved"]
|
||||
tagline = (
|
||||
"Token savings (prompt/response/context/combined): "
|
||||
f"{prompt_saved}/{response_saved}/{context_saved}/{combined_saved}."
|
||||
)
|
||||
if include_total:
|
||||
totals = savings.get("totals", {})
|
||||
totals_prompt = totals.get("prompt", {}).get("saved", 0)
|
||||
totals_response = totals.get("response", {}).get("saved", 0)
|
||||
totals_context = totals.get("context", {}).get("saved", 0)
|
||||
totals_combined = totals.get("combined_saved", 0)
|
||||
tagline += (
|
||||
" Total savings (prompt/response/context/combined): "
|
||||
f"{totals_prompt}/{totals_response}/{totals_context}/{totals_combined}."
|
||||
)
|
||||
return tagline
|
||||
|
||||
|
||||
class TokenCompressionProtocolProcessor:
|
||||
def __init__(self, store: ContextStore):
|
||||
self.store = store
|
||||
|
||||
def handle(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
action = payload.get("action")
|
||||
if not action:
|
||||
return {"ok": False, "error": "missing_action"}
|
||||
|
||||
if action == "prompt":
|
||||
text = payload.get("text", "")
|
||||
if not isinstance(text, str) or text == "":
|
||||
return {"ok": False, "error": "missing_text"}
|
||||
response = self.store.record_prompt(
|
||||
conversation_id=payload.get("conversation_id"),
|
||||
text=text,
|
||||
encoding=payload.get("encoding"),
|
||||
language=payload.get("language"),
|
||||
)
|
||||
return {"ok": True, "result": response}
|
||||
|
||||
if action == "response":
|
||||
conversation_id = payload.get("conversation_id")
|
||||
if not conversation_id:
|
||||
return {"ok": False, "error": "missing_conversation_id"}
|
||||
payload_b54 = payload.get("payload_b54", "")
|
||||
if not isinstance(payload_b54, str) or payload_b54 == "":
|
||||
return {"ok": False, "error": "missing_payload_b54"}
|
||||
try:
|
||||
response = self.store.record_response(
|
||||
conversation_id=conversation_id,
|
||||
payload_b54=payload_b54,
|
||||
)
|
||||
except KeyError:
|
||||
return {"ok": False, "error": "unknown_conversation_id"}
|
||||
except ValueError as exc:
|
||||
return {"ok": False, "error": "invalid_base54", "detail": str(exc)}
|
||||
return {"ok": True, "result": response}
|
||||
|
||||
if action == "context_get":
|
||||
conversation_id = payload.get("conversation_id")
|
||||
if conversation_id:
|
||||
context = self.store.get_context(conversation_id)
|
||||
if not context:
|
||||
return {"ok": False, "error": "unknown_conversation_id"}
|
||||
return {"ok": True, "result": context}
|
||||
return {"ok": True, "result": {"contexts": self.store.list_contexts()}}
|
||||
|
||||
if action == "context_reset":
|
||||
conversation_id = payload.get("conversation_id")
|
||||
if not conversation_id:
|
||||
return {"ok": False, "error": "missing_conversation_id"}
|
||||
with self.store._lock:
|
||||
if conversation_id in self.store._conversations:
|
||||
del self.store._conversations[conversation_id]
|
||||
self.store._dirty = True
|
||||
return {"ok": True, "result": {"conversation_id": conversation_id}}
|
||||
return {"ok": False, "error": "unknown_conversation_id"}
|
||||
|
||||
if action == "ping":
|
||||
return {"ok": True, "result": {"message": "pong"}}
|
||||
|
||||
return {"ok": False, "error": "unknown_action"}
|
||||
|
||||
|
||||
class TokenCompressionTCPServer(socketserver.ThreadingTCPServer):
|
||||
allow_reuse_address = True
|
||||
daemon_threads = True
|
||||
|
||||
def __init__(self, server_address, RequestHandlerClass, processor):
|
||||
super().__init__(server_address, RequestHandlerClass)
|
||||
self.processor = processor
|
||||
|
||||
|
||||
class TokenCompressionRequestHandler(socketserver.StreamRequestHandler):
|
||||
def handle(self) -> None:
|
||||
while True:
|
||||
raw_line = self.rfile.readline()
|
||||
if not raw_line:
|
||||
break
|
||||
raw_line = raw_line.strip()
|
||||
if not raw_line:
|
||||
continue
|
||||
try:
|
||||
request = json.loads(raw_line.decode("utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
self._send({"ok": False, "error": "invalid_json", "detail": str(exc)})
|
||||
continue
|
||||
if not isinstance(request, dict):
|
||||
self._send({"ok": False, "error": "invalid_payload"})
|
||||
continue
|
||||
response = self.server.processor.handle(request)
|
||||
self._send(response)
|
||||
|
||||
def _send(self, payload: Dict[str, Any]) -> None:
|
||||
encoded = json.dumps(payload, ensure_ascii=True).encode("utf-8") + b"\n"
|
||||
self.wfile.write(encoded)
|
||||
|
||||
|
||||
def run_tcp_server(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 7543,
|
||||
dataset_path: Optional[str] = None,
|
||||
refresh_interval: float = 5.0,
|
||||
) -> None:
|
||||
dataset_path = dataset_path or files.get_abs_path(
|
||||
"memory", "token_compression_context.json"
|
||||
)
|
||||
store = ContextStore(dataset_path=dataset_path, refresh_interval=refresh_interval)
|
||||
processor = TokenCompressionProtocolProcessor(store)
|
||||
server = TokenCompressionTCPServer(
|
||||
(host, port), TokenCompressionRequestHandler, processor
|
||||
)
|
||||
try:
|
||||
server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
store.stop()
|
||||
server.server_close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tcp_server()
|
||||
Loading…
Add table
Add a link
Reference in a new issue