agent-zero/python/helpers/token_compression_protocol.py
Cursor Agent 5152e4d424 Add token compression TCP protocol server
Co-authored-by: nicsins <nicsins@gmail.com>
2026-01-20 09:36:16 +00:00

554 lines
20 KiB
Python

import base64
import json
import os
import re
import socketserver
import threading
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from python.helpers import files, tokens
BASE54_ALPHABET = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstyz"
BASE54_INDEX = {char: idx for idx, char in enumerate(BASE54_ALPHABET)}
BASE54_BASE = len(BASE54_ALPHABET)
CONTROL_TAG_SHOW_SAVINGS = "**show savings**"
CONTROL_TAG_SHOW_TOTAL = "**show total**"
def _safe_count_tokens(text: str) -> int:
if not text:
return 0
try:
return tokens.count_tokens(text)
except Exception:
return max(1, len(text.split()))
def _token_stats(raw_text: str, encoded_text: str) -> Dict[str, int]:
raw_tokens = _safe_count_tokens(raw_text)
encoded_tokens = _safe_count_tokens(encoded_text)
saved = raw_tokens - encoded_tokens
if saved < 0:
saved = 0
return {
"raw": raw_tokens,
"encoded": encoded_tokens,
"saved": saved,
}
def _strip_control_tags(text: str) -> Tuple[str, bool, bool]:
show_savings = False
show_total = False
if not text:
return text, show_savings, show_total
if re.search(re.escape(CONTROL_TAG_SHOW_SAVINGS), text, flags=re.IGNORECASE):
show_savings = True
text = re.sub(
re.escape(CONTROL_TAG_SHOW_SAVINGS), "", text, flags=re.IGNORECASE
)
if re.search(re.escape(CONTROL_TAG_SHOW_TOTAL), text, flags=re.IGNORECASE):
show_total = True
show_savings = True
text = re.sub(
re.escape(CONTROL_TAG_SHOW_TOTAL), "", text, flags=re.IGNORECASE
)
return text.strip(), show_savings, show_total
def b54encode(payload: bytes) -> str:
if not payload:
return ""
num = int.from_bytes(payload, "big")
encoded: List[str] = []
while num > 0:
num, rem = divmod(num, BASE54_BASE)
encoded.append(BASE54_ALPHABET[rem])
pad = 0
for byte in payload:
if byte == 0:
pad += 1
else:
break
encoded_str = "".join(reversed(encoded)) if encoded else ""
return (BASE54_ALPHABET[0] * pad) + encoded_str
def b54decode(payload: str) -> bytes:
if payload == "":
return b""
num = 0
for char in payload:
if char not in BASE54_INDEX:
raise ValueError(f"Invalid base54 character: {char!r}")
num = num * BASE54_BASE + BASE54_INDEX[char]
pad = 0
for char in payload:
if char == BASE54_ALPHABET[0]:
pad += 1
else:
break
decoded = b""
if num > 0:
byte_len = (num.bit_length() + 7) // 8
decoded = num.to_bytes(byte_len, "big")
return (b"\x00" * pad) + decoded
def _b64encode_text(text: str, encoding: str) -> str:
return base64.b64encode(text.encode(encoding, errors="replace")).decode("ascii")
def _context_text(messages: List[Dict[str, str]]) -> str:
return "\n".join(f"{entry['role']}: {entry['text']}" for entry in messages).strip()
@dataclass
class ConversationState:
conversation_id: str
encoding: str = "utf-8"
language: str = "unknown"
messages: List[Dict[str, str]] = field(default_factory=list)
context_b64: str = ""
context_tokens: Dict[str, int] = field(default_factory=dict)
prompt_tokens_raw: int = 0
prompt_tokens_encoded: int = 0
response_tokens_raw: int = 0
response_tokens_encoded: int = 0
last_prompt_stats: Dict[str, int] = field(default_factory=dict)
last_response_stats: Dict[str, int] = field(default_factory=dict)
pending_show_savings: bool = False
pending_show_total: bool = False
def to_dict(self) -> Dict[str, Any]:
return {
"conversation_id": self.conversation_id,
"encoding": self.encoding,
"language": self.language,
"messages": self.messages,
"context_b64": self.context_b64,
"context_tokens": self.context_tokens,
"prompt_tokens_raw": self.prompt_tokens_raw,
"prompt_tokens_encoded": self.prompt_tokens_encoded,
"response_tokens_raw": self.response_tokens_raw,
"response_tokens_encoded": self.response_tokens_encoded,
}
@classmethod
def from_dict(cls, payload: Dict[str, Any]) -> "ConversationState":
return cls(
conversation_id=payload.get("conversation_id", ""),
encoding=payload.get("encoding", "utf-8"),
language=payload.get("language", "unknown"),
messages=payload.get("messages", []),
context_b64=payload.get("context_b64", ""),
context_tokens=payload.get("context_tokens", {}),
prompt_tokens_raw=payload.get("prompt_tokens_raw", 0),
prompt_tokens_encoded=payload.get("prompt_tokens_encoded", 0),
response_tokens_raw=payload.get("response_tokens_raw", 0),
response_tokens_encoded=payload.get("response_tokens_encoded", 0),
)
class ContextStore:
def __init__(self, dataset_path: str, refresh_interval: float = 5.0):
self.dataset_path = dataset_path
self.refresh_interval = refresh_interval
self._lock = threading.Lock()
self._dirty = False
self._conversations: Dict[str, ConversationState] = {}
self._stop_event = threading.Event()
self._load()
self._thread = threading.Thread(
target=self._maintenance_loop,
name="tcp-context-maintainer",
daemon=True,
)
self._thread.start()
def _load(self) -> None:
if not os.path.exists(self.dataset_path):
return
try:
with open(self.dataset_path, "r", encoding="utf-8") as handle:
data = json.load(handle)
except (OSError, json.JSONDecodeError):
return
conversations = data.get("conversations", {})
for conv_id, payload in conversations.items():
state = ConversationState.from_dict(payload)
if not state.conversation_id:
state.conversation_id = conv_id
self._conversations[conv_id] = state
def _maintenance_loop(self) -> None:
while not self._stop_event.wait(self.refresh_interval):
self._flush_if_dirty()
def _flush_if_dirty(self) -> None:
with self._lock:
if not self._dirty:
return
snapshot = self._snapshot_locked()
self._dirty = False
self._persist_snapshot(snapshot)
def _snapshot_locked(self) -> Dict[str, Any]:
for state in self._conversations.values():
self._refresh_context_locked(state)
return {
"updated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"conversations": {
conv_id: state.to_dict()
for conv_id, state in self._conversations.items()
},
}
def _persist_snapshot(self, snapshot: Dict[str, Any]) -> None:
os.makedirs(os.path.dirname(self.dataset_path), exist_ok=True)
tmp_path = f"{self.dataset_path}.tmp"
with open(tmp_path, "w", encoding="utf-8") as handle:
json.dump(snapshot, handle, ensure_ascii=True, indent=2)
os.replace(tmp_path, self.dataset_path)
def stop(self) -> None:
self._stop_event.set()
self._thread.join(timeout=self.refresh_interval)
self._flush_if_dirty()
def get_or_create(
self,
conversation_id: Optional[str],
encoding: Optional[str],
language: Optional[str],
) -> ConversationState:
with self._lock:
if not conversation_id:
conversation_id = str(uuid.uuid4())
state = self._conversations.get(conversation_id)
if state is None:
state = ConversationState(conversation_id=conversation_id)
self._conversations[conversation_id] = state
if encoding:
state.encoding = encoding
if language:
state.language = language
return state
def list_contexts(self) -> Dict[str, Dict[str, Any]]:
with self._lock:
contexts = {}
for conv_id, state in self._conversations.items():
self._refresh_context_locked(state)
contexts[conv_id] = {
"context_b64": state.context_b64,
"encoding": state.encoding,
"language": state.language,
"context_tokens": state.context_tokens,
}
return contexts
def get_context(self, conversation_id: str) -> Optional[Dict[str, Any]]:
with self._lock:
state = self._conversations.get(conversation_id)
if not state:
return None
self._refresh_context_locked(state)
return {
"conversation_id": state.conversation_id,
"context_b64": state.context_b64,
"encoding": state.encoding,
"language": state.language,
"context_tokens": state.context_tokens,
}
def record_prompt(
self,
conversation_id: Optional[str],
text: str,
encoding: Optional[str],
language: Optional[str],
) -> Dict[str, Any]:
state = self.get_or_create(conversation_id, encoding, language)
clean_text, show_savings, show_total = _strip_control_tags(text)
encoded_prompt = _b64encode_text(clean_text, state.encoding)
prompt_stats = _token_stats(clean_text, encoded_prompt)
with self._lock:
state.messages.append({"role": "user", "text": clean_text})
state.prompt_tokens_raw += prompt_stats["raw"]
state.prompt_tokens_encoded += prompt_stats["encoded"]
state.last_prompt_stats = prompt_stats
state.pending_show_savings = show_savings or show_total
state.pending_show_total = show_total
self._refresh_context_locked(state)
self._dirty = True
response = {
"conversation_id": state.conversation_id,
"encoding": state.encoding,
"language": state.language,
"encoded_prompt_b64": encoded_prompt,
"context_b64": state.context_b64,
"context_tokens": state.context_tokens,
"prompt_tokens": prompt_stats,
"savings_request": {
"show_savings": state.pending_show_savings,
"show_total": state.pending_show_total,
},
}
return response
def record_response(
self,
conversation_id: str,
payload_b54: str,
) -> Dict[str, Any]:
with self._lock:
state = self._conversations.get(conversation_id)
if not state:
raise KeyError("Unknown conversation_id")
encoding = state.encoding
language = state.language
decoded_bytes = b54decode(payload_b54)
decoded_text = decoded_bytes.decode(encoding, errors="replace")
response_stats = _token_stats(decoded_text, payload_b54)
with self._lock:
state.messages.append({"role": "assistant", "text": decoded_text})
state.response_tokens_raw += response_stats["raw"]
state.response_tokens_encoded += response_stats["encoded"]
state.last_response_stats = response_stats
self._refresh_context_locked(state)
savings_payload = None
tagline = None
decoded_text_with_tagline = None
if state.pending_show_savings:
savings_payload = self._build_savings_payload(state)
tagline = self._format_tagline(
savings_payload,
include_total=state.pending_show_total,
)
decoded_text_with_tagline = (
decoded_text + "\n" + tagline if decoded_text else tagline
)
state.pending_show_savings = False
state.pending_show_total = False
self._dirty = True
response = {
"conversation_id": state.conversation_id,
"encoding": encoding,
"language": language,
"response_b54": payload_b54,
"decoded_text": decoded_text,
"response_tokens": response_stats,
"context_b64": state.context_b64,
"context_tokens": state.context_tokens,
}
if savings_payload:
response["savings"] = savings_payload
if tagline:
response["tagline"] = tagline
response["decoded_text_with_tagline"] = decoded_text_with_tagline
return response
def _refresh_context_locked(self, state: ConversationState) -> None:
context_text = _context_text(state.messages)
state.context_b64 = _b64encode_text(context_text, state.encoding)
state.context_tokens = _token_stats(context_text, state.context_b64)
def _build_savings_payload(self, state: ConversationState) -> Dict[str, Any]:
prompt_stats = state.last_prompt_stats or {"raw": 0, "encoded": 0, "saved": 0}
response_stats = state.last_response_stats or {
"raw": 0,
"encoded": 0,
"saved": 0,
}
context_stats = state.context_tokens or {"raw": 0, "encoded": 0, "saved": 0}
combined_saved = (
prompt_stats.get("saved", 0)
+ response_stats.get("saved", 0)
+ context_stats.get("saved", 0)
)
totals = {
"prompt": {
"raw": state.prompt_tokens_raw,
"encoded": state.prompt_tokens_encoded,
"saved": max(
0, state.prompt_tokens_raw - state.prompt_tokens_encoded
),
},
"response": {
"raw": state.response_tokens_raw,
"encoded": state.response_tokens_encoded,
"saved": max(
0, state.response_tokens_raw - state.response_tokens_encoded
),
},
"context": context_stats,
}
totals["combined_saved"] = (
totals["prompt"]["saved"]
+ totals["response"]["saved"]
+ totals["context"]["saved"]
)
return {
"prompt": prompt_stats,
"response": response_stats,
"context": context_stats,
"combined_saved": combined_saved,
"totals": totals,
}
def _format_tagline(self, savings: Dict[str, Any], include_total: bool) -> str:
prompt_saved = savings["prompt"]["saved"]
response_saved = savings["response"]["saved"]
context_saved = savings["context"]["saved"]
combined_saved = savings["combined_saved"]
tagline = (
"Token savings (prompt/response/context/combined): "
f"{prompt_saved}/{response_saved}/{context_saved}/{combined_saved}."
)
if include_total:
totals = savings.get("totals", {})
totals_prompt = totals.get("prompt", {}).get("saved", 0)
totals_response = totals.get("response", {}).get("saved", 0)
totals_context = totals.get("context", {}).get("saved", 0)
totals_combined = totals.get("combined_saved", 0)
tagline += (
" Total savings (prompt/response/context/combined): "
f"{totals_prompt}/{totals_response}/{totals_context}/{totals_combined}."
)
return tagline
class TokenCompressionProtocolProcessor:
def __init__(self, store: ContextStore):
self.store = store
def handle(self, payload: Dict[str, Any]) -> Dict[str, Any]:
action = payload.get("action")
if not action:
return {"ok": False, "error": "missing_action"}
if action == "prompt":
text = payload.get("text", "")
if not isinstance(text, str) or text == "":
return {"ok": False, "error": "missing_text"}
response = self.store.record_prompt(
conversation_id=payload.get("conversation_id"),
text=text,
encoding=payload.get("encoding"),
language=payload.get("language"),
)
return {"ok": True, "result": response}
if action == "response":
conversation_id = payload.get("conversation_id")
if not conversation_id:
return {"ok": False, "error": "missing_conversation_id"}
payload_b54 = payload.get("payload_b54", "")
if not isinstance(payload_b54, str) or payload_b54 == "":
return {"ok": False, "error": "missing_payload_b54"}
try:
response = self.store.record_response(
conversation_id=conversation_id,
payload_b54=payload_b54,
)
except KeyError:
return {"ok": False, "error": "unknown_conversation_id"}
except ValueError as exc:
return {"ok": False, "error": "invalid_base54", "detail": str(exc)}
return {"ok": True, "result": response}
if action == "context_get":
conversation_id = payload.get("conversation_id")
if conversation_id:
context = self.store.get_context(conversation_id)
if not context:
return {"ok": False, "error": "unknown_conversation_id"}
return {"ok": True, "result": context}
return {"ok": True, "result": {"contexts": self.store.list_contexts()}}
if action == "context_reset":
conversation_id = payload.get("conversation_id")
if not conversation_id:
return {"ok": False, "error": "missing_conversation_id"}
with self.store._lock:
if conversation_id in self.store._conversations:
del self.store._conversations[conversation_id]
self.store._dirty = True
return {"ok": True, "result": {"conversation_id": conversation_id}}
return {"ok": False, "error": "unknown_conversation_id"}
if action == "ping":
return {"ok": True, "result": {"message": "pong"}}
return {"ok": False, "error": "unknown_action"}
class TokenCompressionTCPServer(socketserver.ThreadingTCPServer):
allow_reuse_address = True
daemon_threads = True
def __init__(self, server_address, RequestHandlerClass, processor):
super().__init__(server_address, RequestHandlerClass)
self.processor = processor
class TokenCompressionRequestHandler(socketserver.StreamRequestHandler):
def handle(self) -> None:
while True:
raw_line = self.rfile.readline()
if not raw_line:
break
raw_line = raw_line.strip()
if not raw_line:
continue
try:
request = json.loads(raw_line.decode("utf-8"))
except json.JSONDecodeError as exc:
self._send({"ok": False, "error": "invalid_json", "detail": str(exc)})
continue
if not isinstance(request, dict):
self._send({"ok": False, "error": "invalid_payload"})
continue
response = self.server.processor.handle(request)
self._send(response)
def _send(self, payload: Dict[str, Any]) -> None:
encoded = json.dumps(payload, ensure_ascii=True).encode("utf-8") + b"\n"
self.wfile.write(encoded)
def run_tcp_server(
host: str = "127.0.0.1",
port: int = 7543,
dataset_path: Optional[str] = None,
refresh_interval: float = 5.0,
) -> None:
dataset_path = dataset_path or files.get_abs_path(
"memory", "token_compression_context.json"
)
store = ContextStore(dataset_path=dataset_path, refresh_interval=refresh_interval)
processor = TokenCompressionProtocolProcessor(store)
server = TokenCompressionTCPServer(
(host, port), TokenCompressionRequestHandler, processor
)
try:
server.serve_forever()
except KeyboardInterrupt:
pass
finally:
store.stop()
server.server_close()
if __name__ == "__main__":
run_tcp_server()