mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 11:30:03 +00:00
384 lines
12 KiB
Python
384 lines
12 KiB
Python
"""Request utility functions for API route handlers.
|
|
|
|
This module contains optimization functions, quota detection, title generation detection,
|
|
prefix detection, and token counting utilities.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import tiktoken
|
|
|
|
from .models.anthropic import MessagesRequest
|
|
from utils.text import extract_text_from_content
|
|
|
|
logger = logging.getLogger(__name__)
|
|
ENCODER = tiktoken.get_encoding("cl100k_base")
|
|
|
|
|
|
def is_quota_check_request(request_data: MessagesRequest) -> bool:
|
|
"""Check if this is a quota probe request.
|
|
|
|
Quota checks are typically simple requests with max_tokens=1
|
|
and a single message containing the word "quota".
|
|
|
|
Args:
|
|
request_data: The incoming request data
|
|
|
|
Returns:
|
|
True if this is a quota probe request
|
|
"""
|
|
if (
|
|
request_data.max_tokens == 1
|
|
and len(request_data.messages) == 1
|
|
and request_data.messages[0].role == "user"
|
|
):
|
|
text = extract_text_from_content(request_data.messages[0].content)
|
|
if "quota" in text.lower():
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_title_generation_request(request_data: MessagesRequest) -> bool:
|
|
"""Check if this is a conversation title generation request.
|
|
|
|
Title generation requests typically contain the phrase
|
|
"write a 5-10 word title" in the user's message.
|
|
|
|
Args:
|
|
request_data: The incoming request data
|
|
|
|
Returns:
|
|
True if this is a title generation request
|
|
"""
|
|
if len(request_data.messages) > 0 and request_data.messages[-1].role == "user":
|
|
text = extract_text_from_content(request_data.messages[-1].content)
|
|
if "write a 5-10 word title" in text.lower():
|
|
return True
|
|
return False
|
|
|
|
|
|
def extract_command_prefix(command: str) -> str:
|
|
"""Extract the command prefix for fast prefix detection.
|
|
|
|
Parses a shell command safely, handling environment variables and
|
|
command injection attempts. Returns the command prefix suitable
|
|
for quick identification.
|
|
|
|
Args:
|
|
command: The command string to analyze
|
|
|
|
Returns:
|
|
Command prefix (e.g., "git", "git commit", "npm install")
|
|
or "none" if no valid command found
|
|
"""
|
|
import shlex
|
|
|
|
# Quick check for command injection patterns
|
|
if "`" in command or "$(" in command:
|
|
return "command_injection_detected"
|
|
|
|
try:
|
|
parts = shlex.split(command)
|
|
if not parts:
|
|
return "none"
|
|
|
|
# Handle environment variable prefixes (e.g., KEY=value command)
|
|
env_prefix = []
|
|
cmd_start = 0
|
|
for i, part in enumerate(parts):
|
|
if "=" in part and not part.startswith("-"):
|
|
env_prefix.append(part)
|
|
cmd_start = i + 1
|
|
else:
|
|
break
|
|
|
|
if cmd_start >= len(parts):
|
|
return "none"
|
|
|
|
cmd_parts = parts[cmd_start:]
|
|
if not cmd_parts:
|
|
return "none"
|
|
|
|
first_word = cmd_parts[0]
|
|
two_word_commands = {
|
|
"git",
|
|
"npm",
|
|
"docker",
|
|
"kubectl",
|
|
"cargo",
|
|
"go",
|
|
"pip",
|
|
"yarn",
|
|
}
|
|
|
|
# For compound commands, include the subcommand (e.g., "git commit")
|
|
if first_word in two_word_commands and len(cmd_parts) > 1:
|
|
second_word = cmd_parts[1]
|
|
if not second_word.startswith("-"):
|
|
return f"{first_word} {second_word}"
|
|
return first_word
|
|
return first_word if not env_prefix else " ".join(env_prefix) + " " + first_word
|
|
|
|
except ValueError:
|
|
# Fall back to simple split if shlex fails
|
|
return command.split()[0] if command.split() else "none"
|
|
|
|
|
|
def is_prefix_detection_request(request_data: MessagesRequest) -> Tuple[bool, str]:
|
|
"""Check if this is a fast prefix detection request.
|
|
|
|
Prefix detection requests contain a policy_spec block and
|
|
a Command: section for extracting shell command prefixes.
|
|
|
|
Args:
|
|
request_data: The incoming request data
|
|
|
|
Returns:
|
|
Tuple of (is_prefix_request, command_string)
|
|
"""
|
|
if len(request_data.messages) != 1 or request_data.messages[0].role != "user":
|
|
return False, ""
|
|
|
|
content = extract_text_from_content(request_data.messages[0].content)
|
|
|
|
if "<policy_spec>" in content and "Command:" in content:
|
|
try:
|
|
cmd_start = content.rfind("Command:") + len("Command:")
|
|
return True, content[cmd_start:].strip()
|
|
except Exception:
|
|
pass
|
|
|
|
return False, ""
|
|
|
|
|
|
def is_suggestion_mode_request(request_data: MessagesRequest) -> bool:
|
|
"""Check if this is a suggestion mode request.
|
|
|
|
Suggestion mode requests contain "[SUGGESTION MODE:" in the user's message,
|
|
used for auto-suggesting what the user might type next.
|
|
|
|
Args:
|
|
request_data: The incoming request data
|
|
|
|
Returns:
|
|
True if this is a suggestion mode request
|
|
"""
|
|
for msg in request_data.messages:
|
|
if msg.role == "user":
|
|
text = extract_text_from_content(msg.content)
|
|
if "[SUGGESTION MODE:" in text:
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_filepath_extraction_request(
|
|
request_data: MessagesRequest,
|
|
) -> Tuple[bool, str, str]:
|
|
"""Check if this is a filepath extraction request.
|
|
|
|
Filepath extraction requests have a single user message with
|
|
"Command:" and "Output:" sections, asking to extract file paths
|
|
from command output.
|
|
|
|
Args:
|
|
request_data: The incoming request data
|
|
|
|
Returns:
|
|
Tuple of (is_filepath_request, command, output)
|
|
"""
|
|
# Must be single message, no tools
|
|
if len(request_data.messages) != 1 or request_data.messages[0].role != "user":
|
|
return False, "", ""
|
|
if request_data.tools:
|
|
return False, "", ""
|
|
|
|
content = extract_text_from_content(request_data.messages[0].content)
|
|
|
|
# Must have Command: and Output: markers
|
|
if "Command:" not in content or "Output:" not in content:
|
|
return False, "", ""
|
|
|
|
# Must ask for filepath extraction
|
|
if "filepaths" not in content.lower() and "<filepaths>" not in content.lower():
|
|
return False, "", ""
|
|
|
|
try:
|
|
# Extract command and output
|
|
cmd_start = content.find("Command:") + len("Command:")
|
|
output_marker = content.find("Output:", cmd_start)
|
|
if output_marker == -1:
|
|
return False, "", ""
|
|
|
|
command = content[cmd_start:output_marker].strip()
|
|
output = content[output_marker + len("Output:") :].strip()
|
|
|
|
# Clean up output - stop at next section marker if present
|
|
for marker in ["<", "\n\n"]:
|
|
if marker in output:
|
|
output = output.split(marker)[0].strip()
|
|
|
|
return True, command, output
|
|
except Exception:
|
|
return False, "", ""
|
|
|
|
|
|
def extract_filepaths_from_command(command: str, output: str) -> str:
|
|
"""Extract file paths from a command locally without API call.
|
|
|
|
Determines if the command reads file contents and extracts paths accordingly.
|
|
Commands like ls/dir/find just list files, so return empty.
|
|
Commands like cat/head/tail actually read contents, so extract the file path.
|
|
|
|
Args:
|
|
command: The shell command that was executed
|
|
output: The command's output
|
|
|
|
Returns:
|
|
Filepath extraction result in <filepaths> format
|
|
"""
|
|
import shlex
|
|
|
|
# Commands that just list files (don't read contents)
|
|
listing_commands = {
|
|
"ls",
|
|
"dir",
|
|
"find",
|
|
"tree",
|
|
"pwd",
|
|
"cd",
|
|
"mkdir",
|
|
"rmdir",
|
|
"rm",
|
|
}
|
|
|
|
# Commands that read file contents
|
|
reading_commands = {"cat", "head", "tail", "less", "more", "bat", "type"}
|
|
|
|
try:
|
|
parts = shlex.split(command)
|
|
if not parts:
|
|
return "<filepaths>\n</filepaths>"
|
|
|
|
# Get base command (handle paths like /bin/cat)
|
|
base_cmd = parts[0].split("/")[-1].split("\\")[-1].lower()
|
|
|
|
# Listing commands - return empty
|
|
if base_cmd in listing_commands:
|
|
return "<filepaths>\n</filepaths>"
|
|
|
|
# Reading commands - extract file arguments
|
|
if base_cmd in reading_commands:
|
|
filepaths = []
|
|
for part in parts[1:]:
|
|
# Skip flags
|
|
if part.startswith("-"):
|
|
continue
|
|
# This is likely a file path
|
|
filepaths.append(part)
|
|
|
|
if filepaths:
|
|
paths_str = "\n".join(filepaths)
|
|
return f"<filepaths>\n{paths_str}\n</filepaths>"
|
|
return "<filepaths>\n</filepaths>"
|
|
|
|
# grep with file argument
|
|
if base_cmd == "grep":
|
|
filepaths = []
|
|
skip_next = False
|
|
for i, part in enumerate(parts[1:], 1):
|
|
if skip_next:
|
|
skip_next = False
|
|
continue
|
|
# Skip flags and their arguments
|
|
if part.startswith("-"):
|
|
# Flags that take an argument
|
|
if part in {"-e", "-f", "-m", "-A", "-B", "-C"}:
|
|
skip_next = True
|
|
continue
|
|
# First non-flag is pattern, rest are files
|
|
if i > 1: # Skip the pattern
|
|
filepaths.append(part)
|
|
|
|
if filepaths:
|
|
paths_str = "\n".join(filepaths)
|
|
return f"<filepaths>\n{paths_str}\n</filepaths>"
|
|
return "<filepaths>\n</filepaths>"
|
|
|
|
# Default - return empty for unknown commands
|
|
return "<filepaths>\n</filepaths>"
|
|
|
|
except Exception:
|
|
return "<filepaths>\n</filepaths>"
|
|
|
|
|
|
def get_token_count(
|
|
messages: List,
|
|
system: Optional[Union[str, List]] = None,
|
|
tools: Optional[List] = None,
|
|
) -> int:
|
|
"""Estimate token count for a request.
|
|
|
|
Uses tiktoken cl100k_base encoding to estimate token usage.
|
|
Includes system prompt, messages, tools, and per-message overhead.
|
|
|
|
Args:
|
|
messages: List of message objects with content
|
|
system: Optional system prompt (str or list of blocks)
|
|
tools: Optional list of tool definitions
|
|
|
|
Returns:
|
|
Estimated total token count
|
|
"""
|
|
total_tokens = 0
|
|
|
|
# Count system prompt tokens
|
|
if system:
|
|
if isinstance(system, str):
|
|
total_tokens += len(ENCODER.encode(system))
|
|
elif isinstance(system, list):
|
|
for block in system:
|
|
if hasattr(block, "text"):
|
|
total_tokens += len(ENCODER.encode(block.text))
|
|
|
|
# Count message tokens
|
|
for msg in messages:
|
|
if isinstance(msg.content, str):
|
|
total_tokens += len(ENCODER.encode(msg.content))
|
|
elif isinstance(msg.content, list):
|
|
for block in msg.content:
|
|
b_type = getattr(block, "type", None)
|
|
|
|
if b_type == "text":
|
|
total_tokens += len(ENCODER.encode(getattr(block, "text", "")))
|
|
elif b_type == "thinking":
|
|
total_tokens += len(ENCODER.encode(getattr(block, "thinking", "")))
|
|
elif b_type == "tool_use":
|
|
name = getattr(block, "name", "")
|
|
inp = getattr(block, "input", {})
|
|
total_tokens += len(ENCODER.encode(name))
|
|
total_tokens += len(ENCODER.encode(json.dumps(inp)))
|
|
total_tokens += 10 # Tool use overhead
|
|
elif b_type == "tool_result":
|
|
content = getattr(block, "content", "")
|
|
if isinstance(content, str):
|
|
total_tokens += len(ENCODER.encode(content))
|
|
else:
|
|
total_tokens += len(ENCODER.encode(json.dumps(content)))
|
|
total_tokens += 5 # Tool result overhead
|
|
|
|
# Count tool definition tokens
|
|
if tools:
|
|
for tool in tools:
|
|
tool_str = (
|
|
tool.name + (tool.description or "") + json.dumps(tool.input_schema)
|
|
)
|
|
total_tokens += len(ENCODER.encode(tool_str))
|
|
|
|
# Add per-message overhead
|
|
total_tokens += len(messages) * 3
|
|
if tools:
|
|
total_tokens += len(tools) * 5
|
|
|
|
return max(1, total_tokens)
|