mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-09 13:54:40 +00:00
Fixed all ruff lint and formatting errors
This commit is contained in:
parent
0a03c42cc5
commit
d359a59f6d
85 changed files with 5520 additions and 3870 deletions
|
@ -1,45 +1,43 @@
|
|||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
import logging
|
||||
|
||||
from app.db import User, LLMConfig
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import LLMConfig, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMRole:
|
||||
LONG_CONTEXT = "long_context"
|
||||
FAST = "fast"
|
||||
STRATEGIC = "strategic"
|
||||
|
||||
|
||||
async def get_user_llm_instance(
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
role: str
|
||||
) -> Optional[ChatLiteLLM]:
|
||||
session: AsyncSession, user_id: str, role: str
|
||||
) -> ChatLiteLLM | None:
|
||||
"""
|
||||
Get a ChatLiteLLM instance for a specific user and role.
|
||||
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
role: LLM role ('long_context', 'fast', or 'strategic')
|
||||
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM instance or None if not found
|
||||
"""
|
||||
try:
|
||||
# Get user with their LLM preferences
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalars().first()
|
||||
|
||||
|
||||
if not user:
|
||||
logger.error(f"User {user_id} not found")
|
||||
return None
|
||||
|
||||
|
||||
# Get the appropriate LLM config ID based on role
|
||||
llm_config_id = None
|
||||
if role == LLMRole.LONG_CONTEXT:
|
||||
|
@ -51,24 +49,23 @@ async def get_user_llm_instance(
|
|||
else:
|
||||
logger.error(f"Invalid LLM role: {role}")
|
||||
return None
|
||||
|
||||
|
||||
if not llm_config_id:
|
||||
logger.error(f"No {role} LLM configured for user {user_id}")
|
||||
return None
|
||||
|
||||
|
||||
# Get the LLM configuration
|
||||
result = await session.execute(
|
||||
select(LLMConfig).where(
|
||||
LLMConfig.id == llm_config_id,
|
||||
LLMConfig.user_id == user_id
|
||||
LLMConfig.id == llm_config_id, LLMConfig.user_id == user_id
|
||||
)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
|
||||
|
||||
if not llm_config:
|
||||
logger.error(f"LLM config {llm_config_id} not found for user {user_id}")
|
||||
return None
|
||||
|
||||
|
||||
# Build the model string for litellm
|
||||
if llm_config.custom_provider:
|
||||
model_string = f"{llm_config.custom_provider}/{llm_config.model_name}"
|
||||
|
@ -76,7 +73,7 @@ async def get_user_llm_instance(
|
|||
# Map provider enum to litellm format
|
||||
provider_map = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
|
@ -84,37 +81,48 @@ async def get_user_llm_instance(
|
|||
"MISTRAL": "mistral",
|
||||
# Add more mappings as needed
|
||||
}
|
||||
provider_prefix = provider_map.get(llm_config.provider.value, llm_config.provider.value.lower())
|
||||
provider_prefix = provider_map.get(
|
||||
llm_config.provider.value, llm_config.provider.value.lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{llm_config.model_name}"
|
||||
|
||||
|
||||
# Create ChatLiteLLM instance
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": llm_config.api_key,
|
||||
}
|
||||
|
||||
|
||||
# Add optional parameters
|
||||
if llm_config.api_base:
|
||||
litellm_kwargs["api_base"] = llm_config.api_base
|
||||
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if llm_config.litellm_params:
|
||||
litellm_kwargs.update(llm_config.litellm_params)
|
||||
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LLM instance for user {user_id}, role {role}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error getting LLM instance for user {user_id}, role {role}: {e!s}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_user_long_context_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
||||
|
||||
async def get_user_long_context_llm(
|
||||
session: AsyncSession, user_id: str
|
||||
) -> ChatLiteLLM | None:
|
||||
"""Get user's long context LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT)
|
||||
|
||||
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
||||
|
||||
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> ChatLiteLLM | None:
|
||||
"""Get user's fast LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.FAST)
|
||||
|
||||
async def get_user_strategic_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
||||
|
||||
async def get_user_strategic_llm(
|
||||
session: AsyncSession, user_id: str
|
||||
) -> ChatLiteLLM | None:
|
||||
"""Get user's strategic LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue