Fixes for issue 113 and 116

This commit is contained in:
Alishahryar1 2026-04-18 16:32:31 -07:00
parent 7468f53ab7
commit 835d0454e8
28 changed files with 807 additions and 83 deletions

View file

@ -39,6 +39,7 @@ def _create_provider_for_type(provider_type: str, settings: Settings) -> BasePro
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_thinking,
)
return NvidiaNimProvider(config, nim_settings=settings.nim)
if provider_type == "open_router":
@ -56,6 +57,7 @@ def _create_provider_for_type(provider_type: str, settings: Settings) -> BasePro
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_thinking,
)
return OpenRouterProvider(config)
if provider_type == "lmstudio":
@ -68,6 +70,7 @@ def _create_provider_for_type(provider_type: str, settings: Settings) -> BasePro
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_thinking,
)
return LMStudioProvider(config)
if provider_type == "llamacpp":
@ -80,6 +83,7 @@ def _create_provider_for_type(provider_type: str, settings: Settings) -> BasePro
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_thinking,
)
return LlamaCppProvider(config)
logger.error(

View file

@ -14,7 +14,13 @@ from .anthropic import (
TokenCountRequest,
Tool,
)
from .responses import MessagesResponse, TokenCountResponse, Usage
from .responses import (
MessagesResponse,
ModelResponse,
ModelsListResponse,
TokenCountResponse,
Usage,
)
__all__ = [
"ContentBlockImage",
@ -25,6 +31,8 @@ __all__ = [
"Message",
"MessagesRequest",
"MessagesResponse",
"ModelResponse",
"ModelsListResponse",
"Role",
"SystemContent",
"ThinkingConfig",

View file

@ -11,6 +11,20 @@ class TokenCountResponse(BaseModel):
input_tokens: int
class ModelResponse(BaseModel):
created_at: str
display_name: str
id: str
type: Literal["model"] = "model"
class ModelsListResponse(BaseModel):
data: list[ModelResponse]
first_id: str | None
has_more: bool
last_id: str | None
class Usage(BaseModel):
input_tokens: int
output_tokens: int

View file

@ -3,7 +3,7 @@
import traceback
import uuid
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from loguru import logger
@ -13,13 +13,57 @@ from providers.exceptions import InvalidRequestError, ProviderError
from .dependencies import get_provider_for_type, get_settings, require_api_key
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import TokenCountResponse
from .models.responses import ModelResponse, ModelsListResponse, TokenCountResponse
from .optimization_handlers import try_optimizations
from .request_utils import get_token_count
router = APIRouter()
SUPPORTED_CLAUDE_MODELS = [
ModelResponse(
id="claude-opus-4-20250514",
display_name="Claude Opus 4",
created_at="2025-05-14T00:00:00Z",
),
ModelResponse(
id="claude-sonnet-4-20250514",
display_name="Claude Sonnet 4",
created_at="2025-05-14T00:00:00Z",
),
ModelResponse(
id="claude-haiku-4-20250514",
display_name="Claude Haiku 4",
created_at="2025-05-14T00:00:00Z",
),
ModelResponse(
id="claude-3-opus-20240229",
display_name="Claude 3 Opus",
created_at="2024-02-29T00:00:00Z",
),
ModelResponse(
id="claude-3-5-sonnet-20241022",
display_name="Claude 3.5 Sonnet",
created_at="2024-10-22T00:00:00Z",
),
ModelResponse(
id="claude-3-haiku-20240307",
display_name="Claude 3 Haiku",
created_at="2024-03-07T00:00:00Z",
),
ModelResponse(
id="claude-3-5-haiku-20241022",
display_name="Claude 3.5 Haiku",
created_at="2024-10-22T00:00:00Z",
),
]
def _probe_response(allow: str) -> Response:
"""Return an empty success response for compatibility probes."""
return Response(status_code=204, headers={"Allow": allow})
# =============================================================================
# Routes
# =============================================================================
@ -83,6 +127,12 @@ async def create_message(
) from e
@router.api_route("/v1/messages", methods=["HEAD", "OPTIONS"])
async def probe_messages(_auth=Depends(require_api_key)):
"""Respond to Claude compatibility probes for the messages endpoint."""
return _probe_response("POST, HEAD, OPTIONS")
@router.post("/v1/messages/count_tokens")
async def count_tokens(request_data: TokenCountRequest, _auth=Depends(require_api_key)):
"""Count tokens for a request."""
@ -112,6 +162,12 @@ async def count_tokens(request_data: TokenCountRequest, _auth=Depends(require_ap
) from e
@router.api_route("/v1/messages/count_tokens", methods=["HEAD", "OPTIONS"])
async def probe_count_tokens(_auth=Depends(require_api_key)):
"""Respond to Claude compatibility probes for the token count endpoint."""
return _probe_response("POST, HEAD, OPTIONS")
@router.get("/")
async def root(
settings: Settings = Depends(get_settings), _auth=Depends(require_api_key)
@ -124,12 +180,35 @@ async def root(
}
@router.api_route("/", methods=["HEAD", "OPTIONS"])
async def probe_root(_auth=Depends(require_api_key)):
"""Respond to compatibility probes for the root endpoint."""
return _probe_response("GET, HEAD, OPTIONS")
@router.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "healthy"}
@router.api_route("/health", methods=["HEAD", "OPTIONS"])
async def probe_health():
"""Respond to compatibility probes for the health endpoint."""
return _probe_response("GET, HEAD, OPTIONS")
@router.get("/v1/models", response_model=ModelsListResponse)
async def list_models(_auth=Depends(require_api_key)):
"""List the Claude model ids this proxy advertises for compatibility."""
return ModelsListResponse(
data=SUPPORTED_CLAUDE_MODELS,
first_id=SUPPORTED_CLAUDE_MODELS[0].id if SUPPORTED_CLAUDE_MODELS else None,
has_more=False,
last_id=SUPPORTED_CLAUDE_MODELS[-1].id if SUPPORTED_CLAUDE_MODELS else None,
)
@router.post("/stop")
async def stop_cli(request: Request, _auth=Depends(require_api_key)):
"""Stop all CLI sessions and pending tasks."""