feat: added configurable LLM's

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-06-09 15:50:15 -07:00
parent d0e9fdf810
commit a85f7920a9
36 changed files with 3415 additions and 293 deletions

View file

@ -54,7 +54,7 @@ Open source and easy to deploy locally.
- Support for multiple TTS providers (OpenAI, Azure, Google Vertex AI)
### 📊 **Advanced RAG Techniques**
- Supports 150+ LLM's
- Supports 100+ LLM's
- Supports 6000+ Embedding Models.
- Supports all major Rerankers (Pinecode, Cohere, Flashrank etc)
- Uses Hierarchical Indices (2 tiered RAG setup).

View file

@ -1,51 +1,55 @@
DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense"
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense
SECRET_KEY="SECRET"
NEXT_FRONTEND_URL="http://localhost:3000"
SECRET_KEY=SECRET
NEXT_FRONTEND_URL=http://localhost:3000
#Auth
AUTH_TYPE="GOOGLE" or "LOCAL"
AUTH_TYPE=GOOGLE or LOCAL
# For Google Auth Only
GOOGLE_OAUTH_CLIENT_ID="924507538m"
GOOGLE_OAUTH_CLIENT_SECRET="GOCSV"
GOOGLE_OAUTH_CLIENT_ID=924507538m
GOOGLE_OAUTH_CLIENT_SECRET=GOCSV
#Embedding Model
EMBEDDING_MODEL="mixedbread-ai/mxbai-embed-large-v1"
EMBEDDING_MODEL=mixedbread-ai/mxbai-embed-large-v1
RERANKERS_MODEL_NAME="ms-marco-MiniLM-L-12-v2"
RERANKERS_MODEL_TYPE="flashrank"
RERANKERS_MODEL_NAME=ms-marco-MiniLM-L-12-v2
RERANKERS_MODEL_TYPE=flashrank
# https://docs.litellm.ai/docs/providers
FAST_LLM="openai/gpt-4o-mini"
STRATEGIC_LLM="openai/gpt-4o"
LONG_CONTEXT_LLM="gemini/gemini-2.0-flash"
#LiteLLM TTS Provider: https://docs.litellm.ai/docs/text_to_speech#supported-providers
TTS_SERVICE="openai/tts-1"
TTS_SERVICE=openai/tts-1
#Respective TTS Service API
TTS_SERVICE_API_KEY=
#OPTIONAL: TTS Provider API Base
TTS_SERVICE_API_BASE=
#LiteLLM STT Provider: https://docs.litellm.ai/docs/audio_transcription#supported-providers
STT_SERVICE="openai/whisper-1"
STT_SERVICE=openai/whisper-1
#Respective STT Service API
STT_SERVICE_API_KEY=""
#OPTIONAL: STT Provider API Base
STT_SERVICE_API_BASE=
# Chosen LiteLLM Providers Keys
OPENAI_API_KEY="sk-proj-iA"
GEMINI_API_KEY="AIzaSyB6-1641124124124124124124124124124"
FIRECRAWL_API_KEY="fcr-01J0000000000000000000000"
FIRECRAWL_API_KEY=fcr-01J0000000000000000000000
#File Parser Service
ETL_SERVICE="UNSTRUCTURED" or "LLAMACLOUD"
UNSTRUCTURED_API_KEY="Tpu3P0U8iy"
LLAMA_CLOUD_API_KEY="llx-nnn"
ETL_SERVICE=UNSTRUCTURED or LLAMACLOUD
UNSTRUCTURED_API_KEY=Tpu3P0U8iy
LLAMA_CLOUD_API_KEY=llx-nnn
#OPTIONAL: Add these for LangSmith Observability
LANGSMITH_TRACING=true
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
LANGSMITH_API_KEY="lsv2_pt_....."
LANGSMITH_PROJECT="surfsense"
LANGSMITH_ENDPOINT=https://api.smith.langchain.com
LANGSMITH_API_KEY=lsv2_pt_.....
LANGSMITH_PROJECT=surfsense
# OPTIONAL: LiteLLM API Base
FAST_LLM_API_BASE=""
STRATEGIC_LLM_API_BASE=""
LONG_CONTEXT_LLM_API_BASE=""
TTS_SERVICE_API_BASE=""
STT_SERVICE_API_BASE=""
# FAST_LLM=openai/gpt-4o-mini
# STRATEGIC_LLM=openai/gpt-4o
# LONG_CONTEXT_LLM=gemini/gemini-2.0-flash
# FAST_LLM=ollama/gemma3:12b
# STRATEGIC_LLM=ollama/deepseek-r1:8b
# LONG_CONTEXT_LLM=ollama/deepseek-r1:8b

View file

@ -0,0 +1,86 @@
"""Add LLMConfig table and user LLM preferences
Revision ID: 11
Revises: 10
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID, JSON
# revision identifiers, used by Alembic.
revision: str = "11"
down_revision: Union[str, None] = "10"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema - add LiteLLMProvider enum, LLMConfig table and user LLM preferences."""
# Check if enum type exists and create if it doesn't
op.execute("""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'litellmprovider') THEN
CREATE TYPE litellmprovider AS ENUM ('OPENAI', 'ANTHROPIC', 'GROQ', 'COHERE', 'HUGGINGFACE', 'AZURE_OPENAI', 'GOOGLE', 'AWS_BEDROCK', 'OLLAMA', 'MISTRAL', 'TOGETHER_AI', 'REPLICATE', 'PALM', 'VERTEX_AI', 'ANYSCALE', 'PERPLEXITY', 'DEEPINFRA', 'AI21', 'NLPCLOUD', 'ALEPH_ALPHA', 'PETALS', 'CUSTOM');
END IF;
END$$;
""")
# Create llm_configs table using raw SQL to avoid enum creation conflicts
op.execute("""
CREATE TABLE llm_configs (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
name VARCHAR(100) NOT NULL,
provider litellmprovider NOT NULL,
custom_provider VARCHAR(100),
model_name VARCHAR(100) NOT NULL,
api_key TEXT NOT NULL,
api_base VARCHAR(500),
litellm_params JSONB,
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE
)
""")
# Create indexes
op.create_index(op.f('ix_llm_configs_id'), 'llm_configs', ['id'], unique=False)
op.create_index(op.f('ix_llm_configs_created_at'), 'llm_configs', ['created_at'], unique=False)
op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False)
# Add LLM preference columns to user table
op.add_column('user', sa.Column('long_context_llm_id', sa.Integer(), nullable=True))
op.add_column('user', sa.Column('fast_llm_id', sa.Integer(), nullable=True))
op.add_column('user', sa.Column('strategic_llm_id', sa.Integer(), nullable=True))
# Create foreign key constraints for LLM preferences
op.create_foreign_key(op.f('fk_user_long_context_llm_id_llm_configs'), 'user', 'llm_configs', ['long_context_llm_id'], ['id'], ondelete='SET NULL')
op.create_foreign_key(op.f('fk_user_fast_llm_id_llm_configs'), 'user', 'llm_configs', ['fast_llm_id'], ['id'], ondelete='SET NULL')
op.create_foreign_key(op.f('fk_user_strategic_llm_id_llm_configs'), 'user', 'llm_configs', ['strategic_llm_id'], ['id'], ondelete='SET NULL')
def downgrade() -> None:
"""Downgrade schema - remove LLMConfig table and user LLM preferences."""
# Drop foreign key constraints
op.drop_constraint(op.f('fk_user_strategic_llm_id_llm_configs'), 'user', type_='foreignkey')
op.drop_constraint(op.f('fk_user_fast_llm_id_llm_configs'), 'user', type_='foreignkey')
op.drop_constraint(op.f('fk_user_long_context_llm_id_llm_configs'), 'user', type_='foreignkey')
# Drop LLM preference columns from user table
op.drop_column('user', 'strategic_llm_id')
op.drop_column('user', 'fast_llm_id')
op.drop_column('user', 'long_context_llm_id')
# Drop indexes and table
op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs')
op.drop_index(op.f('ix_llm_configs_created_at'), table_name='llm_configs')
op.drop_index(op.f('ix_llm_configs_id'), table_name='llm_configs')
op.drop_table('llm_configs')
# Drop LiteLLMProvider enum
op.execute("DROP TYPE IF EXISTS litellmprovider")

View file

@ -16,7 +16,8 @@ class Configuration:
# these values can be pre-set when you
# create assistants (https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/)
# and when you invoke the graph
podcast_title: str
podcast_title: str
user_id: str
@classmethod
def from_runnable_config(

View file

@ -14,13 +14,22 @@ from .configuration import Configuration
from .state import PodcastTranscriptEntry, State, PodcastTranscripts
from .prompts import get_podcast_generation_prompt
from app.config import config as app_config
from app.utils.llm_service import get_user_long_context_llm
async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dict[str, Any]:
"""Each node does work."""
# Initialize LLM
llm = app_config.long_context_llm_instance
# Get configuration from runnable config
configuration = Configuration.from_runnable_config(config)
user_id = configuration.user_id
# Get user's long context LLM
llm = await get_user_long_context_llm(state.db_session, user_id)
if not llm:
error_message = f"No long context LLM configured for user {user_id}"
print(error_message)
raise RuntimeError(error_message)
# Get the prompt
prompt = get_podcast_generation_prompt()
@ -139,6 +148,7 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
response = await aspeech(
model=app_config.TTS_SERVICE,
api_base=app_config.TTS_SERVICE_API_BASE,
api_key=app_config.TTS_SERVICE_API_KEY,
voice=voice,
input=dialog,
max_retries=2,
@ -147,6 +157,7 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
else:
response = await aspeech(
model=app_config.TTS_SERVICE,
api_key=app_config.TTS_SERVICE_API_KEY,
voice=voice,
input=dialog,
max_retries=2,

View file

@ -5,7 +5,7 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
class PodcastTranscriptEntry(BaseModel):
"""
@ -32,7 +32,8 @@ class State:
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
for more information.
"""
# Runtime context
db_session: AsyncSession
source_content: str
podcast_transcript: Optional[List[PodcastTranscriptEntry]] = None
final_podcast_file_path: Optional[str] = None

View file

@ -2,7 +2,6 @@ import asyncio
import json
from typing import Any, Dict, List
from app.config import config as app_config
from app.db import async_session_maker
from app.utils.connector_service import ConnectorService
from langchain_core.messages import HumanMessage, SystemMessage
@ -274,6 +273,9 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
Returns:
Dict containing the answer outline in the "answer_outline" key for state update.
"""
from app.utils.llm_service import get_user_strategic_llm
from app.db import get_async_session
streaming_service = state.streaming_service
streaming_service.only_update_terminal("🔍 Generating answer outline...")
@ -283,12 +285,18 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
reformulated_query = state.reformulated_query
user_query = configuration.user_query
num_sections = configuration.num_sections
user_id = configuration.user_id
streaming_service.only_update_terminal(f"🤔 Planning research approach for: \"{user_query[:100]}...\"")
writer({"yeild_value": streaming_service._format_annotations()})
# Initialize LLM
llm = app_config.strategic_llm_instance
# Get user's strategic LLM
llm = await get_user_strategic_llm(state.db_session, user_id)
if not llm:
error_message = f"No strategic LLM configured for user {user_id}"
streaming_service.only_update_terminal(f"{error_message}", "error")
writer({"yeild_value": streaming_service._format_annotations()})
raise RuntimeError(error_message)
# Create the human message content
human_message_content = f"""
@ -828,48 +836,47 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
user_selected_documents = []
user_selected_sources = []
async with async_session_maker() as db_session:
try:
# First, fetch user-selected documents if any
if configuration.document_ids_to_add_in_context:
streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...")
writer({"yeild_value": streaming_service._format_annotations()})
user_selected_sources, user_selected_documents = await fetch_documents_by_ids(
document_ids=configuration.document_ids_to_add_in_context,
user_id=configuration.user_id,
db_session=db_session
)
if user_selected_documents:
streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context")
writer({"yeild_value": streaming_service._format_annotations()})
# Create connector service inside the db_session scope
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
relevant_documents = await fetch_relevant_documents(
research_questions=all_questions,
user_id=configuration.user_id,
search_space_id=configuration.search_space_id,
db_session=db_session,
connectors_to_search=configuration.connectors_to_search,
writer=writer,
state=state,
top_k=TOP_K,
connector_service=connector_service,
search_mode=configuration.search_mode,
user_selected_sources=user_selected_sources
)
except Exception as e:
error_message = f"Error fetching relevant documents: {str(e)}"
print(error_message)
streaming_service.only_update_terminal(f"{error_message}", "error")
try:
# First, fetch user-selected documents if any
if configuration.document_ids_to_add_in_context:
streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...")
writer({"yeild_value": streaming_service._format_annotations()})
# Log the error and continue with an empty list of documents
# This allows the process to continue, but the report might lack information
relevant_documents = []
user_selected_sources, user_selected_documents = await fetch_documents_by_ids(
document_ids=configuration.document_ids_to_add_in_context,
user_id=configuration.user_id,
db_session=state.db_session
)
if user_selected_documents:
streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context")
writer({"yeild_value": streaming_service._format_annotations()})
# Create connector service using state db_session
connector_service = ConnectorService(state.db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
relevant_documents = await fetch_relevant_documents(
research_questions=all_questions,
user_id=configuration.user_id,
search_space_id=configuration.search_space_id,
db_session=state.db_session,
connectors_to_search=configuration.connectors_to_search,
writer=writer,
state=state,
top_k=TOP_K,
connector_service=connector_service,
search_mode=configuration.search_mode,
user_selected_sources=user_selected_sources
)
except Exception as e:
error_message = f"Error fetching relevant documents: {str(e)}"
print(error_message)
streaming_service.only_update_terminal(f"{error_message}", "error")
writer({"yeild_value": streaming_service._format_annotations()})
# Log the error and continue with an empty list of documents
# This allows the process to continue, but the report might lack information
relevant_documents = []
# Combine user-selected documents with connector-fetched documents
all_documents = user_selected_documents + relevant_documents
@ -1014,82 +1021,80 @@ async def process_section_with_documents(
for question in section_questions
]
# Create a new database session for this section
async with async_session_maker() as db_session:
# Call the sub_section_writer graph with the appropriate config
config = {
"configurable": {
"sub_section_title": section_title,
"sub_section_questions": section_questions,
"sub_section_type": sub_section_type,
"user_query": user_query,
"relevant_documents": documents_to_use,
"user_id": user_id,
"search_space_id": search_space_id
}
# Call the sub_section_writer graph with the appropriate config
config = {
"configurable": {
"sub_section_title": section_title,
"sub_section_questions": section_questions,
"sub_section_type": sub_section_type,
"user_query": user_query,
"relevant_documents": documents_to_use,
"user_id": user_id,
"search_space_id": search_space_id
}
# Create the initial state with db_session and chat_history
sub_state = {
"db_session": db_session,
"chat_history": state.chat_history
}
# Invoke the sub-section writer graph with streaming
print(f"Invoking sub_section_writer for: {section_title}")
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"🧠 Analyzing information and drafting content for section: \"{section_title}\"")
writer({"yeild_value": state.streaming_service._format_annotations()})
# Variables to track streaming state
complete_content = "" # Tracks the complete content received so far
async for chunk_type, chunk in sub_section_writer_graph.astream(sub_state, config, stream_mode=["values"]):
if "final_answer" in chunk:
new_content = chunk["final_answer"]
if new_content and new_content != complete_content:
# Extract only the new content (delta)
delta = new_content[len(complete_content):]
}
# Create the initial state with db_session and chat_history
sub_state = {
"db_session": state.db_session,
"chat_history": state.chat_history
}
# Invoke the sub-section writer graph with streaming
print(f"Invoking sub_section_writer for: {section_title}")
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"🧠 Analyzing information and drafting content for section: \"{section_title}\"")
writer({"yeild_value": state.streaming_service._format_annotations()})
# Variables to track streaming state
complete_content = "" # Tracks the complete content received so far
async for chunk_type, chunk in sub_section_writer_graph.astream(sub_state, config, stream_mode=["values"]):
if "final_answer" in chunk:
new_content = chunk["final_answer"]
if new_content and new_content != complete_content:
# Extract only the new content (delta)
delta = new_content[len(complete_content):]
# Update what we've processed so far
complete_content = new_content
# Only stream if there's actual new content
if delta and state and state.streaming_service and writer:
# Update terminal with real-time progress indicator
state.streaming_service.only_update_terminal(f"✍️ Writing section {section_id+1}... ({len(complete_content.split())} words)")
# Update what we've processed so far
complete_content = new_content
# Update section_contents with just the new delta
section_contents[section_id]["content"] += delta
# Only stream if there's actual new content
if delta and state and state.streaming_service and writer:
# Update terminal with real-time progress indicator
state.streaming_service.only_update_terminal(f"✍️ Writing section {section_id+1}... ({len(complete_content.split())} words)")
# Update section_contents with just the new delta
section_contents[section_id]["content"] += delta
# Build UI-friendly content for all sections
complete_answer = []
for i in range(len(section_contents)):
if i in section_contents and section_contents[i]["content"]:
# Add section header
complete_answer.append(f"# {section_contents[i]['title']}")
complete_answer.append("") # Empty line after title
# Add section content
content_lines = section_contents[i]["content"].split("\n")
complete_answer.extend(content_lines)
complete_answer.append("") # Empty line after content
# Update answer in UI in real-time
state.streaming_service.only_update_answer(complete_answer)
writer({"yeild_value": state.streaming_service._format_annotations()})
# Set default if no content was received
if not complete_content:
complete_content = "No content was generated for this section."
section_contents[section_id]["content"] = complete_content
# Final terminal update
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"✅ Completed section: \"{section_title}\"")
writer({"yeild_value": state.streaming_service._format_annotations()})
return complete_content
# Build UI-friendly content for all sections
complete_answer = []
for i in range(len(section_contents)):
if i in section_contents and section_contents[i]["content"]:
# Add section header
complete_answer.append(f"# {section_contents[i]['title']}")
complete_answer.append("") # Empty line after title
# Add section content
content_lines = section_contents[i]["content"].split("\n")
complete_answer.extend(content_lines)
complete_answer.append("") # Empty line after content
# Update answer in UI in real-time
state.streaming_service.only_update_answer(complete_answer)
writer({"yeild_value": state.streaming_service._format_annotations()})
# Set default if no content was received
if not complete_content:
complete_content = "No content was generated for this section."
section_contents[section_id]["content"] = complete_content
# Final terminal update
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"✅ Completed section: \"{section_title}\"")
writer({"yeild_value": state.streaming_service._format_annotations()})
return complete_content
except Exception as e:
print(f"Error processing section '{section_title}': {str(e)}")
@ -1113,7 +1118,7 @@ async def reformulate_user_query(state: State, config: RunnableConfig, writer: S
if len(state.chat_history) == 0:
reformulated_query = user_query
else:
reformulated_query = await QueryService.reformulate_query_with_chat_history(user_query, chat_history_str)
reformulated_query = await QueryService.reformulate_query_with_chat_history(user_query=user_query, session=state.db_session, user_id=configuration.user_id, chat_history_str=chat_history_str)
return {
"reformulated_query": reformulated_query
@ -1152,50 +1157,49 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
user_selected_documents = []
user_selected_sources = []
async with async_session_maker() as db_session:
try:
# First, fetch user-selected documents if any
if configuration.document_ids_to_add_in_context:
streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...")
writer({"yeild_value": streaming_service._format_annotations()})
user_selected_sources, user_selected_documents = await fetch_documents_by_ids(
document_ids=configuration.document_ids_to_add_in_context,
user_id=configuration.user_id,
db_session=db_session
)
if user_selected_documents:
streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context")
writer({"yeild_value": streaming_service._format_annotations()})
# Create connector service inside the db_session scope
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
# Use the reformulated query as a single research question
research_questions = [reformulated_query, user_query]
relevant_documents = await fetch_relevant_documents(
research_questions=research_questions,
user_id=configuration.user_id,
search_space_id=configuration.search_space_id,
db_session=db_session,
connectors_to_search=configuration.connectors_to_search,
writer=writer,
state=state,
top_k=TOP_K,
connector_service=connector_service,
search_mode=configuration.search_mode,
user_selected_sources=user_selected_sources
)
except Exception as e:
error_message = f"Error fetching relevant documents for QNA: {str(e)}"
print(error_message)
streaming_service.only_update_terminal(f"{error_message}", "error")
try:
# First, fetch user-selected documents if any
if configuration.document_ids_to_add_in_context:
streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...")
writer({"yeild_value": streaming_service._format_annotations()})
# Continue with empty documents - the QNA agent will handle this gracefully
relevant_documents = []
user_selected_sources, user_selected_documents = await fetch_documents_by_ids(
document_ids=configuration.document_ids_to_add_in_context,
user_id=configuration.user_id,
db_session=state.db_session
)
if user_selected_documents:
streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context")
writer({"yeild_value": streaming_service._format_annotations()})
# Create connector service using state db_session
connector_service = ConnectorService(state.db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
# Use the reformulated query as a single research question
research_questions = [reformulated_query, user_query]
relevant_documents = await fetch_relevant_documents(
research_questions=research_questions,
user_id=configuration.user_id,
search_space_id=configuration.search_space_id,
db_session=state.db_session,
connectors_to_search=configuration.connectors_to_search,
writer=writer,
state=state,
top_k=TOP_K,
connector_service=connector_service,
search_mode=configuration.search_mode,
user_selected_sources=user_selected_sources
)
except Exception as e:
error_message = f"Error fetching relevant documents for QNA: {str(e)}"
print(error_message)
streaming_service.only_update_terminal(f"{error_message}", "error")
writer({"yeild_value": streaming_service._format_annotations()})
# Continue with empty documents - the QNA agent will handle this gracefully
relevant_documents = []
# Combine user-selected documents with connector-fetched documents
all_documents = user_selected_documents + relevant_documents

View file

@ -85,14 +85,20 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
Returns:
Dict containing the final answer in the "final_answer" key.
"""
from app.utils.llm_service import get_user_fast_llm
# Get configuration and relevant documents from configuration
configuration = Configuration.from_runnable_config(config)
documents = state.reranked_documents
user_query = configuration.user_query
user_id = configuration.user_id
# Initialize LLM
llm = app_config.fast_llm_instance
# Get user's fast LLM
llm = await get_user_fast_llm(state.db_session, user_id)
if not llm:
error_message = f"No fast LLM configured for user {user_id}"
print(error_message)
raise RuntimeError(error_message)
# Determine if we have documents and optimize for token limits
has_documents_initially = documents and len(documents) > 0
@ -118,7 +124,7 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
# Optimize documents to fit within token limits
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit(
documents, base_messages, app_config.FAST_LLM
documents, base_messages, llm.model
)
# Update state based on optimization result
@ -161,7 +167,7 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
]
# Log final token count
total_tokens = calculate_token_count(messages_with_chat_history, app_config.FAST_LLM)
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
print(f"Final token count: {total_tokens}")

View file

@ -91,13 +91,19 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
Returns:
Dict containing the final answer in the "final_answer" key.
"""
from app.utils.llm_service import get_user_fast_llm
# Get configuration and relevant documents from configuration
configuration = Configuration.from_runnable_config(config)
documents = state.reranked_documents
user_id = configuration.user_id
# Initialize LLM
llm = app_config.fast_llm_instance
# Get user's fast LLM
llm = await get_user_fast_llm(state.db_session, user_id)
if not llm:
error_message = f"No fast LLM configured for user {user_id}"
print(error_message)
raise RuntimeError(error_message)
# Extract configuration data
section_title = configuration.sub_section_title
@ -153,7 +159,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
# Optimize documents to fit within token limits
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit(
documents, base_messages, app_config.FAST_LLM
documents, base_messages, llm.model
)
# Update state based on optimization result
@ -206,7 +212,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
]
# Log final token count
total_tokens = calculate_token_count(messages_with_chat_history, app_config.FAST_LLM)
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
print(f"Final token count: {total_tokens}")
# Call the LLM and get the response

View file

@ -1,7 +1,6 @@
from typing import List, Dict, Any, Tuple, NamedTuple
from langchain_core.messages import BaseMessage
from litellm import token_counter, get_model_info
from app.config import config as app_config
class DocumentTokenInfo(NamedTuple):
@ -127,7 +126,7 @@ def get_model_context_window(model_name: str) -> int:
def optimize_documents_for_token_limit(
documents: List[Dict[str, Any]],
base_messages: List[BaseMessage],
model_name: str = None
model_name: str
) -> Tuple[List[Dict[str, Any]], bool]:
"""
Optimize documents to fit within token limits using binary search.
@ -135,7 +134,7 @@ def optimize_documents_for_token_limit(
Args:
documents: List of documents with content and metadata
base_messages: Base messages without documents (chat history + system + human message template)
model_name: Model name for token counting (defaults to app_config.FAST_LLM)
model_name: Model name for token counting (required)
output_token_buffer: Number of tokens to reserve for model output
Returns:
@ -144,7 +143,7 @@ def optimize_documents_for_token_limit(
if not documents:
return [], False
model = model_name or app_config.FAST_LLM
model = model_name
context_window = get_model_context_window(model)
# Calculate base token cost
@ -178,8 +177,8 @@ def optimize_documents_for_token_limit(
return optimized_documents, has_documents_remaining
def calculate_token_count(messages: List[BaseMessage], model_name: str = None) -> int:
def calculate_token_count(messages: List[BaseMessage], model_name: str) -> int:
"""Calculate token count for a list of LangChain messages."""
model = model_name or app_config.FAST_LLM
model = model_name
messages_dict = convert_langchain_messages_to_dict(messages)
return token_counter(messages=messages_dict, model=model)

View file

@ -4,7 +4,7 @@ import shutil
from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker
from dotenv import load_dotenv
from langchain_community.chat_models import ChatLiteLLM
from rerankers import Reranker
@ -49,31 +49,8 @@ class Config:
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
# LONG-CONTEXT LLMS
LONG_CONTEXT_LLM = os.getenv("LONG_CONTEXT_LLM")
LONG_CONTEXT_LLM_API_BASE = os.getenv("LONG_CONTEXT_LLM_API_BASE")
if LONG_CONTEXT_LLM_API_BASE:
long_context_llm_instance = ChatLiteLLM(model=LONG_CONTEXT_LLM, api_base=LONG_CONTEXT_LLM_API_BASE)
else:
long_context_llm_instance = ChatLiteLLM(model=LONG_CONTEXT_LLM)
# FAST LLM
FAST_LLM = os.getenv("FAST_LLM")
FAST_LLM_API_BASE = os.getenv("FAST_LLM_API_BASE")
if FAST_LLM_API_BASE:
fast_llm_instance = ChatLiteLLM(model=FAST_LLM, api_base=FAST_LLM_API_BASE)
else:
fast_llm_instance = ChatLiteLLM(model=FAST_LLM)
# STRATEGIC LLM
STRATEGIC_LLM = os.getenv("STRATEGIC_LLM")
STRATEGIC_LLM_API_BASE = os.getenv("STRATEGIC_LLM_API_BASE")
if STRATEGIC_LLM_API_BASE:
strategic_llm_instance = ChatLiteLLM(model=STRATEGIC_LLM, api_base=STRATEGIC_LLM_API_BASE)
else:
strategic_llm_instance = ChatLiteLLM(model=STRATEGIC_LLM)
# LLM instances are now managed per-user through the LLMConfig system
# Legacy environment variables removed in favor of user-specific configurations
# Chonkie Configuration | Edit this to your needs
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
@ -114,10 +91,12 @@ class Config:
# Litellm TTS Configuration
TTS_SERVICE = os.getenv("TTS_SERVICE")
TTS_SERVICE_API_BASE = os.getenv("TTS_SERVICE_API_BASE")
TTS_SERVICE_API_KEY = os.getenv("TTS_SERVICE_API_KEY")
# Litellm STT Configuration
STT_SERVICE = os.getenv("STT_SERVICE")
STT_SERVICE_API_BASE = os.getenv("STT_SERVICE_API_BASE")
STT_SERVICE_API_KEY = os.getenv("STT_SERVICE_API_KEY")
# Validation Checks

View file

@ -67,6 +67,30 @@ class ChatType(str, Enum):
REPORT_GENERAL = "REPORT_GENERAL"
REPORT_DEEP = "REPORT_DEEP"
REPORT_DEEPER = "REPORT_DEEPER"
class LiteLLMProvider(str, Enum):
OPENAI = "OPENAI"
ANTHROPIC = "ANTHROPIC"
GROQ = "GROQ"
COHERE = "COHERE"
HUGGINGFACE = "HUGGINGFACE"
AZURE_OPENAI = "AZURE_OPENAI"
GOOGLE = "GOOGLE"
AWS_BEDROCK = "AWS_BEDROCK"
OLLAMA = "OLLAMA"
MISTRAL = "MISTRAL"
TOGETHER_AI = "TOGETHER_AI"
REPLICATE = "REPLICATE"
PALM = "PALM"
VERTEX_AI = "VERTEX_AI"
ANYSCALE = "ANYSCALE"
PERPLEXITY = "PERPLEXITY"
DEEPINFRA = "DEEPINFRA"
AI21 = "AI21"
NLPCLOUD = "NLPCLOUD"
ALEPH_ALPHA = "ALEPH_ALPHA"
PETALS = "PETALS"
CUSTOM = "CUSTOM"
class Base(DeclarativeBase):
pass
@ -152,6 +176,26 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
user = relationship("User", back_populates="search_source_connectors")
class LLMConfig(BaseModel, TimestampMixin):
__tablename__ = "llm_configs"
name = Column(String(100), nullable=False, index=True)
# Provider from the enum
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
# Custom provider name when provider is CUSTOM
custom_provider = Column(String(100), nullable=True)
# Just the model name without provider prefix
model_name = Column(String(100), nullable=False)
# API Key should be encrypted before storing
api_key = Column(String, nullable=False)
api_base = Column(String(500), nullable=True)
# For any other parameters that litellm supports
litellm_params = Column(JSON, nullable=True, default={})
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
user = relationship("User", back_populates="llm_configs", foreign_keys=[user_id])
if config.AUTH_TYPE == "GOOGLE":
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
pass
@ -163,11 +207,29 @@ if config.AUTH_TYPE == "GOOGLE":
)
search_spaces = relationship("SearchSpace", back_populates="user")
search_source_connectors = relationship("SearchSourceConnector", back_populates="user")
llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan")
long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True)
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True)
else:
class User(SQLAlchemyBaseUserTableUUID, Base):
search_spaces = relationship("SearchSpace", back_populates="user")
search_source_connectors = relationship("SearchSourceConnector", back_populates="user")
llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan")
long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True)
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True)
engine = create_async_engine(DATABASE_URL)

View file

@ -4,6 +4,7 @@ from .documents_routes import router as documents_router
from .podcasts_routes import router as podcasts_router
from .chats_routes import router as chats_router
from .search_source_connectors_routes import router as search_source_connectors_router
from .llm_config_routes import router as llm_config_router
router = APIRouter()
@ -12,3 +13,4 @@ router.include_router(documents_router)
router.include_router(podcasts_router)
router.include_router(chats_router)
router.include_router(search_source_connectors_router)
router.include_router(llm_config_router)

View file

@ -38,21 +38,24 @@ async def create_documents(
fastapi_background_tasks.add_task(
process_extension_document_with_new_session,
individual_document,
request.search_space_id
request.search_space_id,
str(user.id)
)
elif request.document_type == DocumentType.CRAWLED_URL:
for url in request.content:
fastapi_background_tasks.add_task(
process_crawled_url_with_new_session,
url,
request.search_space_id
request.search_space_id,
str(user.id)
)
elif request.document_type == DocumentType.YOUTUBE_VIDEO:
for url in request.content:
fastapi_background_tasks.add_task(
process_youtube_video_with_new_session,
url,
request.search_space_id
request.search_space_id,
str(user.id)
)
else:
raise HTTPException(
@ -106,7 +109,8 @@ async def create_documents(
process_file_in_background_with_new_session,
temp_path,
file.filename,
search_space_id
search_space_id,
str(user.id)
)
except Exception as e:
raise HTTPException(
@ -130,6 +134,7 @@ async def process_file_in_background(
file_path: str,
filename: str,
search_space_id: int,
user_id: str,
session: AsyncSession
):
try:
@ -151,7 +156,8 @@ async def process_file_in_background(
session,
filename,
markdown_content,
search_space_id
search_space_id,
user_id
)
# Check if the file is an audio file
elif filename.lower().endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')):
@ -162,11 +168,13 @@ async def process_file_in_background(
transcription_response = await atranscription(
model=app_config.STT_SERVICE,
file=audio_file,
api_base=app_config.STT_SERVICE_API_BASE
api_base=app_config.STT_SERVICE_API_BASE,
api_key=app_config.STT_SERVICE_API_KEY
)
else:
transcription_response = await atranscription(
model=app_config.STT_SERVICE,
api_key=app_config.STT_SERVICE_API_KEY,
file=audio_file
)
@ -187,7 +195,8 @@ async def process_file_in_background(
session,
filename,
transcribed_text,
search_space_id
search_space_id,
user_id
)
else:
if app_config.ETL_SERVICE == "UNSTRUCTURED":
@ -218,7 +227,8 @@ async def process_file_in_background(
session,
filename,
docs,
search_space_id
search_space_id,
user_id
)
elif app_config.ETL_SERVICE == "LLAMACLOUD":
from llama_cloud_services import LlamaParse
@ -256,7 +266,8 @@ async def process_file_in_background(
session,
filename,
llamacloud_markdown_document=markdown_content,
search_space_id=search_space_id
search_space_id=search_space_id,
user_id=user_id
)
except Exception as e:
import logging
@ -426,14 +437,15 @@ async def delete_document(
async def process_extension_document_with_new_session(
individual_document,
search_space_id: int
search_space_id: int,
user_id: str
):
"""Create a new session and process extension document."""
from app.db import async_session_maker
async with async_session_maker() as session:
try:
await add_extension_received_document(session, individual_document, search_space_id)
await add_extension_received_document(session, individual_document, search_space_id, user_id)
except Exception as e:
import logging
logging.error(f"Error processing extension document: {str(e)}")
@ -441,14 +453,15 @@ async def process_extension_document_with_new_session(
async def process_crawled_url_with_new_session(
url: str,
search_space_id: int
search_space_id: int,
user_id: str
):
"""Create a new session and process crawled URL."""
from app.db import async_session_maker
async with async_session_maker() as session:
try:
await add_crawled_url_document(session, url, search_space_id)
await add_crawled_url_document(session, url, search_space_id, user_id)
except Exception as e:
import logging
logging.error(f"Error processing crawled URL: {str(e)}")
@ -457,25 +470,27 @@ async def process_crawled_url_with_new_session(
async def process_file_in_background_with_new_session(
file_path: str,
filename: str,
search_space_id: int
search_space_id: int,
user_id: str
):
"""Create a new session and process file."""
from app.db import async_session_maker
async with async_session_maker() as session:
await process_file_in_background(file_path, filename, search_space_id, session)
await process_file_in_background(file_path, filename, search_space_id, user_id, session)
async def process_youtube_video_with_new_session(
url: str,
search_space_id: int
search_space_id: int,
user_id: str
):
"""Create a new session and process YouTube video."""
from app.db import async_session_maker
async with async_session_maker() as session:
try:
await add_youtube_video_document(session, url, search_space_id)
await add_youtube_video_document(session, url, search_space_id, user_id)
except Exception as e:
import logging
logging.error(f"Error processing YouTube video: {str(e)}")

View file

@ -0,0 +1,243 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from typing import List, Optional
from pydantic import BaseModel
from app.db import get_async_session, User, LLMConfig
from app.schemas import LLMConfigCreate, LLMConfigUpdate, LLMConfigRead
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
router = APIRouter()
class LLMPreferencesUpdate(BaseModel):
"""Schema for updating user LLM preferences"""
long_context_llm_id: Optional[int] = None
fast_llm_id: Optional[int] = None
strategic_llm_id: Optional[int] = None
class LLMPreferencesRead(BaseModel):
"""Schema for reading user LLM preferences"""
long_context_llm_id: Optional[int] = None
fast_llm_id: Optional[int] = None
strategic_llm_id: Optional[int] = None
long_context_llm: Optional[LLMConfigRead] = None
fast_llm: Optional[LLMConfigRead] = None
strategic_llm: Optional[LLMConfigRead] = None
@router.post("/llm-configs/", response_model=LLMConfigRead)
async def create_llm_config(
llm_config: LLMConfigCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
"""Create a new LLM configuration for the authenticated user"""
try:
db_llm_config = LLMConfig(**llm_config.model_dump(), user_id=user.id)
session.add(db_llm_config)
await session.commit()
await session.refresh(db_llm_config)
return db_llm_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to create LLM configuration: {str(e)}"
)
@router.get("/llm-configs/", response_model=List[LLMConfigRead])
async def read_llm_configs(
skip: int = 0,
limit: int = 200,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
"""Get all LLM configurations for the authenticated user"""
try:
result = await session.execute(
select(LLMConfig)
.filter(LLMConfig.user_id == user.id)
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch LLM configurations: {str(e)}"
)
@router.get("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
async def read_llm_config(
llm_config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
"""Get a specific LLM configuration by ID"""
try:
llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
return llm_config
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch LLM configuration: {str(e)}"
)
@router.put("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
async def update_llm_config(
llm_config_id: int,
llm_config_update: LLMConfigUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
"""Update an existing LLM configuration"""
try:
db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
update_data = llm_config_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_llm_config, key, value)
await session.commit()
await session.refresh(db_llm_config)
return db_llm_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to update LLM configuration: {str(e)}"
)
@router.delete("/llm-configs/{llm_config_id}", response_model=dict)
async def delete_llm_config(
llm_config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
"""Delete an LLM configuration"""
try:
db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
await session.delete(db_llm_config)
await session.commit()
return {"message": "LLM configuration deleted successfully"}
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to delete LLM configuration: {str(e)}"
)
# User LLM Preferences endpoints
@router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead)
async def get_user_llm_preferences(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
"""Get the current user's LLM preferences"""
try:
# Refresh user to get latest relationships
await session.refresh(user)
result = {
"long_context_llm_id": user.long_context_llm_id,
"fast_llm_id": user.fast_llm_id,
"strategic_llm_id": user.strategic_llm_id,
"long_context_llm": None,
"fast_llm": None,
"strategic_llm": None,
}
# Fetch the actual LLM configs if they exist
if user.long_context_llm_id:
long_context_llm = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == user.long_context_llm_id,
LLMConfig.user_id == user.id
)
)
llm_config = long_context_llm.scalars().first()
if llm_config:
result["long_context_llm"] = llm_config
if user.fast_llm_id:
fast_llm = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == user.fast_llm_id,
LLMConfig.user_id == user.id
)
)
llm_config = fast_llm.scalars().first()
if llm_config:
result["fast_llm"] = llm_config
if user.strategic_llm_id:
strategic_llm = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == user.strategic_llm_id,
LLMConfig.user_id == user.id
)
)
llm_config = strategic_llm.scalars().first()
if llm_config:
result["strategic_llm"] = llm_config
return result
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to fetch LLM preferences: {str(e)}"
)
@router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead)
async def update_user_llm_preferences(
preferences: LLMPreferencesUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
"""Update the current user's LLM preferences"""
try:
# Validate that all provided LLM config IDs belong to the user
update_data = preferences.model_dump(exclude_unset=True)
for key, llm_config_id in update_data.items():
if llm_config_id is not None:
# Verify ownership of the LLM config
result = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == llm_config_id,
LLMConfig.user_id == user.id
)
)
llm_config = result.scalars().first()
if not llm_config:
raise HTTPException(
status_code=404,
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it"
)
# Update user preferences
for key, value in update_data.items():
setattr(user, key, value)
await session.commit()
await session.refresh(user)
# Return updated preferences
return await get_user_llm_preferences(session, user)
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"Failed to update LLM preferences: {str(e)}"
)

View file

@ -128,14 +128,15 @@ async def delete_podcast(
async def generate_chat_podcast_with_new_session(
chat_id: int,
search_space_id: int,
podcast_title: str = "SurfSense Podcast"
podcast_title: str,
user_id: int
):
"""Create a new session and process chat podcast generation."""
from app.db import async_session_maker
async with async_session_maker() as session:
try:
await generate_chat_podcast(session, chat_id, search_space_id, podcast_title)
await generate_chat_podcast(session, chat_id, search_space_id, podcast_title, user_id)
except Exception as e:
import logging
logging.error(f"Error generating podcast from chat: {str(e)}")
@ -175,7 +176,8 @@ async def generate_podcast(
generate_chat_podcast_with_new_session,
chat_id,
request.search_space_id,
request.podcast_title
request.podcast_title,
user.id
)
return {

View file

@ -328,25 +328,25 @@ async def index_connector_content(
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
# Run indexing in background
logger.info(f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
background_tasks.add_task(run_slack_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to)
background_tasks.add_task(run_slack_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
response_message = "Slack indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
# Run indexing in background
logger.info(f"Triggering Notion indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
background_tasks.add_task(run_notion_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to)
background_tasks.add_task(run_notion_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
response_message = "Notion indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.GITHUB_CONNECTOR:
# Run indexing in background
logger.info(f"Triggering GitHub indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
background_tasks.add_task(run_github_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to)
background_tasks.add_task(run_github_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
response_message = "GitHub indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR:
# Run indexing in background
logger.info(f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
background_tasks.add_task(run_linear_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to)
background_tasks.add_task(run_linear_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
response_message = "Linear indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
@ -355,7 +355,7 @@ async def index_connector_content(
f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
background_tasks.add_task(
run_discord_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to
run_discord_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Discord indexing started in the background."
@ -410,6 +410,7 @@ async def update_connector_last_indexed(
async def run_slack_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
):
@ -418,12 +419,13 @@ async def run_slack_indexing_with_new_session(
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_slack_indexing(session, connector_id, search_space_id, start_date, end_date)
await run_slack_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
async def run_slack_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
):
@ -434,6 +436,7 @@ async def run_slack_indexing(
session: Database session
connector_id: ID of the Slack connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
@ -443,6 +446,7 @@ async def run_slack_indexing(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
update_last_indexed=False # Don't update timestamp in the indexing function
@ -460,6 +464,7 @@ async def run_slack_indexing(
async def run_notion_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
):
@ -468,12 +473,13 @@ async def run_notion_indexing_with_new_session(
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_notion_indexing(session, connector_id, search_space_id, start_date, end_date)
await run_notion_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
async def run_notion_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
):
@ -484,6 +490,7 @@ async def run_notion_indexing(
session: Database session
connector_id: ID of the Notion connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
@ -493,6 +500,7 @@ async def run_notion_indexing(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
update_last_indexed=False # Don't update timestamp in the indexing function
@ -511,26 +519,28 @@ async def run_notion_indexing(
async def run_github_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
):
"""Wrapper to run GitHub indexing with its own database session."""
logger.info(f"Background task started: Indexing GitHub connector {connector_id} into space {search_space_id} from {start_date} to {end_date}")
async with async_session_maker() as session:
await run_github_indexing(session, connector_id, search_space_id, start_date, end_date)
await run_github_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
logger.info(f"Background task finished: Indexing GitHub connector {connector_id}")
async def run_github_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
):
"""Runs the GitHub indexing task and updates the timestamp."""
try:
indexed_count, error_message = await index_github_repos(
session, connector_id, search_space_id, start_date, end_date, update_last_indexed=False
session, connector_id, search_space_id, user_id, start_date, end_date, update_last_indexed=False
)
if error_message:
logger.error(f"GitHub indexing failed for connector {connector_id}: {error_message}")
@ -549,26 +559,28 @@ async def run_github_indexing(
async def run_linear_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
):
"""Wrapper to run Linear indexing with its own database session."""
logger.info(f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}")
async with async_session_maker() as session:
await run_linear_indexing(session, connector_id, search_space_id, start_date, end_date)
await run_linear_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
logger.info(f"Background task finished: Indexing Linear connector {connector_id}")
async def run_linear_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
):
"""Runs the Linear indexing task and updates the timestamp."""
try:
indexed_count, error_message = await index_linear_issues(
session, connector_id, search_space_id, start_date, end_date, update_last_indexed=False
session, connector_id, search_space_id, user_id, start_date, end_date, update_last_indexed=False
)
if error_message:
logger.error(f"Linear indexing failed for connector {connector_id}: {error_message}")
@ -587,6 +599,7 @@ async def run_linear_indexing(
async def run_discord_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
):
@ -595,12 +608,13 @@ async def run_discord_indexing_with_new_session(
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_discord_indexing(session, connector_id, search_space_id, start_date, end_date)
await run_discord_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
async def run_discord_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str
):
@ -610,6 +624,7 @@ async def run_discord_indexing(
session: Database session
connector_id: ID of the Discord connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
@ -619,6 +634,7 @@ async def run_discord_indexing(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
update_last_indexed=False # Don't update timestamp in the indexing function

View file

@ -13,6 +13,7 @@ from .chunks import ChunkBase, ChunkCreate, ChunkUpdate, ChunkRead
from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest
from .chats import ChatBase, ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
from .search_source_connector import SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigUpdate, LLMConfigRead
__all__ = [
"AISDKChatRequest",
@ -48,4 +49,8 @@ __all__ = [
"SearchSourceConnectorCreate",
"SearchSourceConnectorUpdate",
"SearchSourceConnectorRead",
"LLMConfigBase",
"LLMConfigCreate",
"LLMConfigUpdate",
"LLMConfigRead",
]

View file

@ -0,0 +1,34 @@
from datetime import datetime
import uuid
from typing import Optional, Dict, Any
from pydantic import BaseModel, ConfigDict, Field
from .base import IDModel, TimestampModel
from app.db import LiteLLMProvider
class LLMConfigBase(BaseModel):
name: str = Field(..., max_length=100, description="User-friendly name for the LLM configuration")
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM")
model_name: str = Field(..., max_length=100, description="Model name without provider prefix")
api_key: str = Field(..., description="API key for the provider")
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL")
litellm_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional LiteLLM parameters")
class LLMConfigCreate(LLMConfigBase):
pass
class LLMConfigUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=100, description="User-friendly name for the LLM configuration")
provider: Optional[LiteLLMProvider] = Field(None, description="LiteLLM provider type")
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM")
model_name: Optional[str] = Field(None, max_length=100, description="Model name without provider prefix")
api_key: Optional[str] = Field(None, description="API key for the provider")
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL")
litellm_params: Optional[Dict[str, Any]] = Field(None, description="Additional LiteLLM parameters")
class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel):
id: int
created_at: datetime
user_id: uuid.UUID
model_config = ConfigDict(from_attributes=True)

View file

@ -7,6 +7,7 @@ from app.schemas import ExtensionDocumentContent
from app.config import config
from app.prompts import SUMMARY_PROMPT_TEMPLATE
from app.utils.document_converters import convert_document_to_markdown, generate_content_hash
from app.utils.llm_service import get_user_long_context_llm
from langchain_core.documents import Document as LangChainDocument
from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader
from langchain_community.document_transformers import MarkdownifyTransformer
@ -18,9 +19,8 @@ import logging
md = MarkdownifyTransformer()
async def add_crawled_url_document(
session: AsyncSession, url: str, search_space_id: int
session: AsyncSession, url: str, search_space_id: int, user_id: str
) -> Optional[Document]:
try:
if not validators.url(url):
@ -84,8 +84,13 @@ async def add_crawled_url_document(
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
return existing_document
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}")
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
summary_result = await summary_chain.ainvoke(
{"document": combined_document_string}
)
@ -130,7 +135,7 @@ async def add_crawled_url_document(
async def add_extension_received_document(
session: AsyncSession, content: ExtensionDocumentContent, search_space_id: int
session: AsyncSession, content: ExtensionDocumentContent, search_space_id: int, user_id: str
) -> Optional[Document]:
"""
Process and store document content received from the SurfSense Extension.
@ -186,8 +191,13 @@ async def add_extension_received_document(
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
return existing_document
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}")
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
summary_result = await summary_chain.ainvoke(
{"document": combined_document_string}
)
@ -230,7 +240,7 @@ async def add_extension_received_document(
async def add_received_markdown_file_document(
session: AsyncSession, file_name: str, file_in_markdown: str, search_space_id: int
session: AsyncSession, file_name: str, file_in_markdown: str, search_space_id: int, user_id: str
) -> Optional[Document]:
try:
content_hash = generate_content_hash(file_in_markdown)
@ -245,8 +255,13 @@ async def add_received_markdown_file_document(
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
return existing_document
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}")
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
summary_result = await summary_chain.ainvoke({"document": file_in_markdown})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(summary_content)
@ -292,6 +307,7 @@ async def add_received_file_document_using_unstructured(
file_name: str,
unstructured_processed_elements: List[LangChainDocument],
search_space_id: int,
user_id: str,
) -> Optional[Document]:
try:
file_in_markdown = await convert_document_to_markdown(
@ -312,8 +328,13 @@ async def add_received_file_document_using_unstructured(
# TODO: Check if file_markdown exceeds token limit of embedding model
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}")
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
summary_result = await summary_chain.ainvoke({"document": file_in_markdown})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(summary_content)
@ -360,6 +381,7 @@ async def add_received_file_document_using_llamacloud(
file_name: str,
llamacloud_markdown_document: str,
search_space_id: int,
user_id: str,
) -> Optional[Document]:
"""
Process and store document content parsed by LlamaCloud.
@ -389,8 +411,13 @@ async def add_received_file_document_using_llamacloud(
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
return existing_document
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}")
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
summary_result = await summary_chain.ainvoke({"document": file_in_markdown})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(summary_content)
@ -433,7 +460,7 @@ async def add_received_file_document_using_llamacloud(
async def add_youtube_video_document(
session: AsyncSession, url: str, search_space_id: int
session: AsyncSession, url: str, search_space_id: int, user_id: str
):
"""
Process a YouTube video URL, extract transcripts, and store as a document.
@ -541,8 +568,13 @@ async def add_youtube_video_document(
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
return existing_document
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}")
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
summary_result = await summary_chain.ainvoke(
{"document": combined_document_string}
)

View file

@ -3,9 +3,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.future import select
from datetime import datetime, timedelta, timezone
from app.db import Document, DocumentType, Chunk, SearchSourceConnector, SearchSourceConnectorType
from app.db import Document, DocumentType, Chunk, SearchSourceConnector, SearchSourceConnectorType, SearchSpace
from app.config import config
from app.prompts import SUMMARY_PROMPT_TEMPLATE
from app.utils.llm_service import get_user_long_context_llm
from app.connectors.slack_history import SlackHistory
from app.connectors.notion_history import NotionHistoryConnector
from app.connectors.github_connector import GitHubConnector
@ -24,6 +25,7 @@ async def index_slack_messages(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str = None,
end_date: str = None,
update_last_indexed: bool = True
@ -211,8 +213,16 @@ async def index_slack_messages(
documents_skipped += 1
continue
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
if not user_llm:
logger.error(f"No long context LLM configured for user {user_id}")
skipped_channels.append(f"{channel_name} (no LLM configured)")
documents_skipped += 1
continue
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(summary_content)
@ -289,6 +299,7 @@ async def index_notion_pages(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str = None,
end_date: str = None,
update_last_indexed: bool = True
@ -476,9 +487,17 @@ async def index_notion_pages(
documents_skipped += 1
continue
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
if not user_llm:
logger.error(f"No long context LLM configured for user {user_id}")
skipped_pages.append(f"{page_title} (no LLM configured)")
documents_skipped += 1
continue
# Generate summary
logger.debug(f"Generating summary for page {page_title}")
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(summary_content)
@ -549,6 +568,7 @@ async def index_github_repos(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str = None,
end_date: str = None,
update_last_indexed: bool = True
@ -717,6 +737,7 @@ async def index_linear_issues(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str = None,
end_date: str = None,
update_last_indexed: bool = True
@ -955,6 +976,7 @@ async def index_discord_messages(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str = None,
end_date: str = None,
update_last_indexed: bool = True
@ -1142,8 +1164,16 @@ async def index_discord_messages(
documents_skipped += 1
continue
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
if not user_llm:
logger.error(f"No long context LLM configured for user {user_id}")
skipped_channels.append(f"{guild_name}#{channel_name} (no LLM configured)")
documents_skipped += 1
continue
# Generate summary using summary_chain
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
summary_content = summary_result.content
summary_embedding = await asyncio.to_thread(

View file

@ -21,7 +21,8 @@ async def generate_chat_podcast(
session: AsyncSession,
chat_id: int,
search_space_id: int,
podcast_title: str
podcast_title: str,
user_id: int
):
# Fetch the chat with the specified ID
query = select(Chat).filter(
@ -57,12 +58,14 @@ async def generate_chat_podcast(
# Pass it to the SurfSense Podcaster
config = {
"configurable": {
"podcast_title" : "Surfsense",
"podcast_title": "SurfSense",
"user_id": str(user_id),
}
}
# Initialize state with database session and streaming service
initial_state = State(
source_content=chat_history_str,
db_session=session
)
# Run the graph directly

View file

@ -0,0 +1,120 @@
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from langchain_community.chat_models import ChatLiteLLM
import logging
from app.db import User, LLMConfig
logger = logging.getLogger(__name__)
class LLMRole:
LONG_CONTEXT = "long_context"
FAST = "fast"
STRATEGIC = "strategic"
async def get_user_llm_instance(
session: AsyncSession,
user_id: str,
role: str
) -> Optional[ChatLiteLLM]:
"""
Get a ChatLiteLLM instance for a specific user and role.
Args:
session: Database session
user_id: User ID
role: LLM role ('long_context', 'fast', or 'strategic')
Returns:
ChatLiteLLM instance or None if not found
"""
try:
# Get user with their LLM preferences
result = await session.execute(
select(User).where(User.id == user_id)
)
user = result.scalars().first()
if not user:
logger.error(f"User {user_id} not found")
return None
# Get the appropriate LLM config ID based on role
llm_config_id = None
if role == LLMRole.LONG_CONTEXT:
llm_config_id = user.long_context_llm_id
elif role == LLMRole.FAST:
llm_config_id = user.fast_llm_id
elif role == LLMRole.STRATEGIC:
llm_config_id = user.strategic_llm_id
else:
logger.error(f"Invalid LLM role: {role}")
return None
if not llm_config_id:
logger.error(f"No {role} LLM configured for user {user_id}")
return None
# Get the LLM configuration
result = await session.execute(
select(LLMConfig).where(
LLMConfig.id == llm_config_id,
LLMConfig.user_id == user_id
)
)
llm_config = result.scalars().first()
if not llm_config:
logger.error(f"LLM config {llm_config_id} not found for user {user_id}")
return None
# Build the model string for litellm
if llm_config.custom_provider:
model_string = f"{llm_config.custom_provider}/{llm_config.model_name}"
else:
# Map provider enum to litellm format
provider_map = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama",
"MISTRAL": "mistral",
# Add more mappings as needed
}
provider_prefix = provider_map.get(llm_config.provider.value, llm_config.provider.value.lower())
model_string = f"{provider_prefix}/{llm_config.model_name}"
# Create ChatLiteLLM instance
litellm_kwargs = {
"model": model_string,
"api_key": llm_config.api_key,
}
# Add optional parameters
if llm_config.api_base:
litellm_kwargs["api_base"] = llm_config.api_base
# Add any additional litellm parameters
if llm_config.litellm_params:
litellm_kwargs.update(llm_config.litellm_params)
return ChatLiteLLM(**litellm_kwargs)
except Exception as e:
logger.error(f"Error getting LLM instance for user {user_id}, role {role}: {str(e)}")
return None
async def get_user_long_context_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
"""Get user's long context LLM instance."""
return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT)
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
"""Get user's fast LLM instance."""
return await get_user_llm_instance(session, user_id, LLMRole.FAST)
async def get_user_strategic_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
"""Get user's strategic LLM instance."""
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)

View file

@ -1,6 +1,8 @@
import datetime
from langchain.schema import HumanMessage, SystemMessage, AIMessage
from app.config import config
from app.utils.llm_service import get_user_strategic_llm
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Any, List, Optional
@ -10,14 +12,21 @@ class QueryService:
"""
@staticmethod
async def reformulate_query_with_chat_history(user_query: str, chat_history_str: Optional[str] = None) -> str:
async def reformulate_query_with_chat_history(
user_query: str,
session: AsyncSession,
user_id: str,
chat_history_str: Optional[str] = None
) -> str:
"""
Reformulate the user query using the STRATEGIC_LLM to make it more
Reformulate the user query using the user's strategic LLM to make it more
effective for information retrieval and research purposes.
Args:
user_query: The original user query
chat_history: Optional list of previous chat messages
session: Database session for accessing user LLM configs
user_id: User ID to get their specific LLM configuration
chat_history_str: Optional chat history string
Returns:
str: The reformulated query
@ -26,8 +35,11 @@ class QueryService:
return user_query
try:
# Get the strategic LLM instance from config
llm = config.strategic_llm_instance
# Get the user's strategic LLM instance
llm = await get_user_strategic_llm(session, user_id)
if not llm:
print(f"Warning: No strategic LLM configured for user {user_id}. Using original query.")
return user_query
# Create system message with instructions
system_message = SystemMessage(

View file

@ -19,7 +19,9 @@ import {
FolderOpen,
Upload,
ChevronDown,
Filter
Filter,
Brain,
Zap
} from 'lucide-react';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
@ -42,6 +44,13 @@ import {
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Badge } from "@/components/ui/badge";
import { Skeleton } from "@/components/ui/skeleton";
import {
@ -62,6 +71,7 @@ import { MarkdownViewer } from '@/components/markdown-viewer';
import { Logo } from '@/components/Logo';
import { useSearchSourceConnectors } from '@/hooks';
import { useDocuments } from '@/hooks/use-documents';
import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs';
interface SourceItem {
id: number;
@ -374,6 +384,8 @@ const ChatPage = () => {
const [currentDate, setCurrentDate] = useState<string>('');
const terminalMessagesRef = useRef<HTMLDivElement>(null);
const { connectorSourceItems, isLoading: isLoadingConnectors } = useSearchSourceConnectors();
const { llmConfigs } = useLLMConfigs();
const { preferences, updatePreferences } = useLLMPreferences();
const INITIAL_SOURCES_DISPLAY = 3;
@ -457,6 +469,8 @@ const ChatPage = () => {
setCurrentTime(new Date().toTimeString().split(' ')[0]);
}, []);
// Add this CSS to remove input shadow and improve the UI
useEffect(() => {
if (typeof document !== 'undefined') {
@ -710,6 +724,7 @@ const ChatPage = () => {
if (!input.trim() || status !== 'ready') return;
// Validation: require at least one connector OR at least one document
// Note: Fast LLM selection updates user preferences automatically
// if (selectedConnectors.length === 0 && selectedDocuments.length === 0) {
// alert("Please select at least one connector or document");
// return;
@ -1569,6 +1584,75 @@ const ChatPage = () => {
onChange={setResearchMode}
/>
</div>
{/* Fast LLM Selector */}
<div className="h-8 min-w-0">
<Select
value={preferences.fast_llm_id?.toString() || ""}
onValueChange={(value) => {
const llmId = value ? parseInt(value) : undefined;
updatePreferences({ fast_llm_id: llmId });
}}
>
<SelectTrigger className="h-8 w-auto min-w-[120px] px-3 text-xs border-border bg-background hover:bg-muted/50">
<div className="flex items-center gap-2">
<Zap className="h-3 w-3 text-primary" />
<SelectValue placeholder="Fast LLM">
{preferences.fast_llm_id && (() => {
const selectedConfig = llmConfigs.find(config => config.id === preferences.fast_llm_id);
return selectedConfig ? (
<div className="flex items-center gap-1">
<span className="font-medium">{selectedConfig.provider}</span>
<span className="text-muted-foreground"></span>
<span className="hidden sm:inline text-muted-foreground">{selectedConfig.name}</span>
</div>
) : "Select LLM";
})()}
</SelectValue>
</div>
</SelectTrigger>
<SelectContent align="end" className="w-[280px]">
<div className="px-2 py-1.5 text-xs font-medium text-muted-foreground border-b">
Answer LLM Selection
</div>
{llmConfigs.length === 0 ? (
<div className="px-2 py-3 text-center text-sm text-muted-foreground">
<Brain className="h-4 w-4 mx-auto mb-1 opacity-50" />
<p>No LLM configurations found</p>
<p className="text-xs">Configure models in Settings</p>
</div>
) : (
llmConfigs.map((config) => (
<SelectItem key={config.id} value={config.id.toString()}>
<div className="flex items-center justify-between w-full">
<div className="flex items-center gap-3">
<div className="flex h-8 w-8 items-center justify-center rounded-md bg-primary/10">
<Brain className="h-4 w-4 text-primary" />
</div>
<div className="space-y-1">
<div className="flex items-center gap-2">
<span className="font-medium text-sm">{config.name}</span>
<Badge variant="outline" className="text-xs">
{config.provider}
</Badge>
</div>
<p className="text-xs text-muted-foreground font-mono">
{config.model_name}
</p>
</div>
</div>
{preferences.fast_llm_id === config.id && (
<div className="flex h-4 w-4 items-center justify-center rounded-full bg-primary">
<div className="h-2 w-2 rounded-full bg-primary-foreground" />
</div>
)}
</div>
</SelectItem>
))
)}
</SelectContent>
</Select>
</div>
</div>
</div>
</div>

View file

@ -0,0 +1,90 @@
"use client";
import { useEffect, useState } from 'react';
import { useRouter } from 'next/navigation';
import { useLLMPreferences } from '@/hooks/use-llm-configs';
import { Loader2 } from 'lucide-react';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
interface DashboardLayoutProps {
children: React.ReactNode;
}
export default function DashboardLayout({ children }: DashboardLayoutProps) {
const router = useRouter();
const { loading, error, isOnboardingComplete } = useLLMPreferences();
const [isCheckingAuth, setIsCheckingAuth] = useState(true);
useEffect(() => {
// Check if user is authenticated
const token = localStorage.getItem('surfsense_bearer_token');
if (!token) {
router.push('/login');
return;
}
setIsCheckingAuth(false);
}, [router]);
useEffect(() => {
// Wait for preferences to load, then check if onboarding is complete
if (!loading && !error && !isCheckingAuth) {
if (!isOnboardingComplete()) {
router.push('/onboard');
}
}
}, [loading, error, isCheckingAuth, isOnboardingComplete, router]);
// Show loading screen while checking authentication or loading preferences
if (isCheckingAuth || loading) {
return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
<CardHeader className="pb-2">
<CardTitle className="text-xl font-medium">Loading Dashboard</CardTitle>
<CardDescription>Checking your configuration...</CardDescription>
</CardHeader>
<CardContent className="flex justify-center py-6">
<Loader2 className="h-12 w-12 text-primary animate-spin" />
</CardContent>
</Card>
</div>
);
}
// Show error screen if there's an error loading preferences
if (error) {
return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[400px] bg-background/60 backdrop-blur-sm border-destructive/20">
<CardHeader className="pb-2">
<CardTitle className="text-xl font-medium text-destructive">Configuration Error</CardTitle>
<CardDescription>Failed to load your LLM configuration</CardDescription>
</CardHeader>
<CardContent>
<p className="text-sm text-muted-foreground">{error}</p>
</CardContent>
</Card>
</div>
);
}
// Only render children if onboarding is complete
if (isOnboardingComplete()) {
return <>{children}</>;
}
// This should not be reached due to redirect, but just in case
return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
<CardHeader className="pb-2">
<CardTitle className="text-xl font-medium">Redirecting...</CardTitle>
<CardDescription>Taking you to complete your setup</CardDescription>
</CardHeader>
<CardContent className="flex justify-center py-6">
<Loader2 className="h-12 w-12 text-primary animate-spin" />
</CardContent>
</Card>
</div>
);
}

View file

@ -0,0 +1,227 @@
"use client";
import React, { useState, useEffect } from 'react';
import { useRouter } from 'next/navigation';
import { motion, AnimatePresence } from 'framer-motion';
import { Button } from '@/components/ui/button';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { Progress } from '@/components/ui/progress';
import { CheckCircle, ArrowRight, ArrowLeft, Bot, Sparkles, Zap, Brain } from 'lucide-react';
import { Logo } from '@/components/Logo';
import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs';
import { AddProviderStep } from '@/components/onboard/add-provider-step';
import { AssignRolesStep } from '@/components/onboard/assign-roles-step';
import { CompletionStep } from '@/components/onboard/completion-step';
const TOTAL_STEPS = 3;
const OnboardPage = () => {
const router = useRouter();
const { llmConfigs, loading: configsLoading } = useLLMConfigs();
const { preferences, loading: preferencesLoading, isOnboardingComplete, refreshPreferences } = useLLMPreferences();
const [currentStep, setCurrentStep] = useState(1);
const [hasUserProgressed, setHasUserProgressed] = useState(false);
// Check if user is authenticated
useEffect(() => {
const token = localStorage.getItem('surfsense_bearer_token');
if (!token) {
router.push('/login');
return;
}
}, [router]);
// Track if user has progressed beyond step 1
useEffect(() => {
if (currentStep > 1) {
setHasUserProgressed(true);
}
}, [currentStep]);
// Redirect to dashboard if onboarding is already complete and user hasn't progressed (fresh page load)
useEffect(() => {
if (!preferencesLoading && isOnboardingComplete() && !hasUserProgressed) {
router.push('/dashboard');
}
}, [preferencesLoading, isOnboardingComplete, hasUserProgressed, router]);
const progress = (currentStep / TOTAL_STEPS) * 100;
const stepTitles = [
"Add LLM Provider",
"Assign LLM Roles",
"Setup Complete"
];
const stepDescriptions = [
"Configure your first model provider",
"Assign specific roles to your LLM configurations",
"You're all set to start using SurfSense!"
];
const canProceedToStep2 = !configsLoading && llmConfigs.length > 0;
const canProceedToStep3 = !preferencesLoading && preferences.long_context_llm_id && preferences.fast_llm_id && preferences.strategic_llm_id;
const handleNext = () => {
if (currentStep < TOTAL_STEPS) {
setCurrentStep(currentStep + 1);
}
};
const handlePrevious = () => {
if (currentStep > 1) {
setCurrentStep(currentStep - 1);
}
};
const handleComplete = () => {
router.push('/dashboard');
};
if (configsLoading || preferencesLoading) {
return (
<div className="flex flex-col items-center justify-center min-h-screen">
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
<CardContent className="flex flex-col items-center justify-center py-12">
<Bot className="h-12 w-12 text-primary animate-pulse mb-4" />
<p className="text-sm text-muted-foreground">Loading your configuration...</p>
</CardContent>
</Card>
</div>
);
}
return (
<div className="min-h-screen bg-gradient-to-br from-background via-background to-muted/20 flex items-center justify-center p-4">
<motion.div
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
transition={{ duration: 0.5 }}
className="w-full max-w-4xl"
>
{/* Header */}
<div className="text-center mb-8">
<div className="flex items-center justify-center mb-4">
<Logo className="w-12 h-12 mr-3" />
<h1 className="text-3xl font-bold">Welcome to SurfSense</h1>
</div>
<p className="text-muted-foreground text-lg">Let's configure your SurfSense to get started</p>
</div>
{/* Progress */}
<Card className="mb-8 bg-background/60 backdrop-blur-sm">
<CardContent className="pt-6">
<div className="flex items-center justify-between mb-4">
<div className="text-sm font-medium">Step {currentStep} of {TOTAL_STEPS}</div>
<div className="text-sm text-muted-foreground">{Math.round(progress)}% Complete</div>
</div>
<Progress value={progress} className="mb-4" />
<div className="grid grid-cols-3 gap-4">
{Array.from({ length: TOTAL_STEPS }, (_, i) => {
const stepNum = i + 1;
const isCompleted = stepNum < currentStep;
const isCurrent = stepNum === currentStep;
return (
<div key={stepNum} className="flex items-center space-x-2">
<div className={`w-8 h-8 rounded-full flex items-center justify-center text-sm font-medium ${
isCompleted
? 'bg-primary text-primary-foreground'
: isCurrent
? 'bg-primary/20 text-primary border-2 border-primary'
: 'bg-muted text-muted-foreground'
}`}>
{isCompleted ? <CheckCircle className="w-4 h-4" /> : stepNum}
</div>
<div className="flex-1 min-w-0">
<p className={`text-sm font-medium truncate ${
isCurrent ? 'text-foreground' : 'text-muted-foreground'
}`}>
{stepTitles[i]}
</p>
</div>
</div>
);
})}
</div>
</CardContent>
</Card>
{/* Step Content */}
<Card className="min-h-[500px] bg-background/60 backdrop-blur-sm">
<CardHeader className="text-center">
<CardTitle className="text-2xl flex items-center justify-center gap-2">
{currentStep === 1 && <Bot className="w-6 h-6" />}
{currentStep === 2 && <Sparkles className="w-6 h-6" />}
{currentStep === 3 && <CheckCircle className="w-6 h-6" />}
{stepTitles[currentStep - 1]}
</CardTitle>
<CardDescription className="text-base">
{stepDescriptions[currentStep - 1]}
</CardDescription>
</CardHeader>
<CardContent>
<AnimatePresence mode="wait">
<motion.div
key={currentStep}
initial={{ opacity: 0, x: 20 }}
animate={{ opacity: 1, x: 0 }}
exit={{ opacity: 0, x: -20 }}
transition={{ duration: 0.3 }}
>
{currentStep === 1 && <AddProviderStep />}
{currentStep === 2 && <AssignRolesStep onPreferencesUpdated={refreshPreferences} />}
{currentStep === 3 && <CompletionStep />}
</motion.div>
</AnimatePresence>
</CardContent>
</Card>
{/* Navigation */}
<div className="flex justify-between mt-8">
<Button
variant="outline"
onClick={handlePrevious}
disabled={currentStep === 1}
className="flex items-center gap-2"
>
<ArrowLeft className="w-4 h-4" />
Previous
</Button>
<div className="flex gap-2">
{currentStep < TOTAL_STEPS && (
<Button
onClick={handleNext}
disabled={
(currentStep === 1 && !canProceedToStep2) ||
(currentStep === 2 && !canProceedToStep3)
}
className="flex items-center gap-2"
>
Next
<ArrowRight className="w-4 h-4" />
</Button>
)}
{currentStep === TOTAL_STEPS && (
<Button
onClick={handleComplete}
className="flex items-center gap-2"
>
Complete Setup
<CheckCircle className="w-4 h-4" />
</Button>
)}
</div>
</div>
</motion.div>
</div>
);
};
export default OnboardPage;

View file

@ -0,0 +1,60 @@
"use client";
import React from 'react';
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs';
import { Separator } from '@/components/ui/separator';
import { Bot, Settings, Brain } from 'lucide-react';
import { ModelConfigManager } from '@/components/settings/model-config-manager';
import { LLMRoleManager } from '@/components/settings/llm-role-manager';
export default function SettingsPage() {
return (
<div className="min-h-screen bg-background">
<div className="container max-w-7xl mx-auto p-6 lg:p-8">
<div className="space-y-8">
{/* Header Section */}
<div className="space-y-4">
<div className="flex items-center space-x-4">
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-primary/10">
<Settings className="h-6 w-6 text-primary" />
</div>
<div className="space-y-1">
<h1 className="text-3xl font-bold tracking-tight">Settings</h1>
<p className="text-lg text-muted-foreground">
Manage your LLM configurations and role assignments.
</p>
</div>
</div>
<Separator className="my-6" />
</div>
{/* Settings Content */}
<Tabs defaultValue="models" className="space-y-8">
<div className="overflow-x-auto">
<TabsList className="grid w-full min-w-fit grid-cols-2 lg:w-auto lg:inline-grid">
<TabsTrigger value="models" className="flex items-center gap-2 text-sm">
<Bot className="h-4 w-4" />
<span className="hidden sm:inline">Model Configs</span>
<span className="sm:hidden">Models</span>
</TabsTrigger>
<TabsTrigger value="roles" className="flex items-center gap-2 text-sm">
<Brain className="h-4 w-4" />
<span className="hidden sm:inline">LLM Roles</span>
<span className="sm:hidden">Roles</span>
</TabsTrigger>
</TabsList>
</div>
<TabsContent value="models" className="space-y-6">
<ModelConfigManager />
</TabsContent>
<TabsContent value="roles" className="space-y-6">
<LLMRoleManager />
</TabsContent>
</Tabs>
</div>
</div>
</div>
);
}

View file

@ -0,0 +1,255 @@
"use client";
import React, { useState } from 'react';
import { motion } from 'framer-motion';
import { Button } from '@/components/ui/button';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { Input } from '@/components/ui/input';
import { Label } from '@/components/ui/label';
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
import { Badge } from '@/components/ui/badge';
import { Plus, Trash2, Bot, AlertCircle } from 'lucide-react';
import { useLLMConfigs, CreateLLMConfig } from '@/hooks/use-llm-configs';
import { toast } from 'sonner';
import { Alert, AlertDescription } from '@/components/ui/alert';
const LLM_PROVIDERS = [
{ value: 'OPENAI', label: 'OpenAI', example: 'gpt-4o, gpt-4, gpt-3.5-turbo' },
{ value: 'ANTHROPIC', label: 'Anthropic', example: 'claude-3-5-sonnet-20241022, claude-3-opus-20240229' },
{ value: 'GROQ', label: 'Groq', example: 'llama3-70b-8192, mixtral-8x7b-32768' },
{ value: 'COHERE', label: 'Cohere', example: 'command-r-plus, command-r' },
{ value: 'HUGGINGFACE', label: 'HuggingFace', example: 'microsoft/DialoGPT-medium' },
{ value: 'AZURE_OPENAI', label: 'Azure OpenAI', example: 'gpt-4, gpt-35-turbo' },
{ value: 'GOOGLE', label: 'Google', example: 'gemini-pro, gemini-pro-vision' },
{ value: 'AWS_BEDROCK', label: 'AWS Bedrock', example: 'anthropic.claude-v2' },
{ value: 'OLLAMA', label: 'Ollama', example: 'llama2, codellama' },
{ value: 'MISTRAL', label: 'Mistral', example: 'mistral-large-latest, mistral-medium' },
{ value: 'TOGETHER_AI', label: 'Together AI', example: 'togethercomputer/llama-2-70b-chat' },
{ value: 'REPLICATE', label: 'Replicate', example: 'meta/llama-2-70b-chat' },
{ value: 'CUSTOM', label: 'Custom Provider', example: 'your-custom-model' },
];
export function AddProviderStep() {
const { llmConfigs, createLLMConfig, deleteLLMConfig } = useLLMConfigs();
const [isAddingNew, setIsAddingNew] = useState(false);
const [formData, setFormData] = useState<CreateLLMConfig>({
name: '',
provider: '',
custom_provider: '',
model_name: '',
api_key: '',
api_base: '',
litellm_params: {}
});
const [isSubmitting, setIsSubmitting] = useState(false);
const handleInputChange = (field: keyof CreateLLMConfig, value: string) => {
setFormData(prev => ({ ...prev, [field]: value }));
};
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
if (!formData.name || !formData.provider || !formData.model_name || !formData.api_key) {
toast.error('Please fill in all required fields');
return;
}
setIsSubmitting(true);
const result = await createLLMConfig(formData);
setIsSubmitting(false);
if (result) {
setFormData({
name: '',
provider: '',
custom_provider: '',
model_name: '',
api_key: '',
api_base: '',
litellm_params: {}
});
setIsAddingNew(false);
}
};
const selectedProvider = LLM_PROVIDERS.find(p => p.value === formData.provider);
return (
<div className="space-y-6">
{/* Info Alert */}
<Alert>
<AlertCircle className="h-4 w-4" />
<AlertDescription>
Add at least one LLM provider to continue. You can configure multiple providers and choose specific roles for each one in the next step.
</AlertDescription>
</Alert>
{/* Existing Configurations */}
{llmConfigs.length > 0 && (
<div className="space-y-4">
<h3 className="text-lg font-semibold">Your LLM Configurations</h3>
<div className="grid gap-4">
{llmConfigs.map((config) => (
<motion.div
key={config.id}
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -10 }}
>
<Card className="border-l-4 border-l-primary">
<CardContent className="pt-4">
<div className="flex items-center justify-between">
<div className="flex-1">
<div className="flex items-center gap-2 mb-2">
<Bot className="w-4 h-4" />
<h4 className="font-medium">{config.name}</h4>
<Badge variant="secondary">{config.provider}</Badge>
</div>
<p className="text-sm text-muted-foreground">
Model: {config.model_name}
{config.api_base && ` • Base: ${config.api_base}`}
</p>
</div>
<Button
variant="ghost"
size="sm"
onClick={() => deleteLLMConfig(config.id)}
className="text-destructive hover:text-destructive"
>
<Trash2 className="w-4 h-4" />
</Button>
</div>
</CardContent>
</Card>
</motion.div>
))}
</div>
</div>
)}
{/* Add New Provider */}
{!isAddingNew ? (
<Card className="border-dashed border-2 hover:border-primary/50 transition-colors">
<CardContent className="flex flex-col items-center justify-center py-12">
<Plus className="w-12 h-12 text-muted-foreground mb-4" />
<h3 className="text-lg font-semibold mb-2">Add LLM Provider</h3>
<p className="text-muted-foreground text-center mb-4">
Configure your first model provider to get started
</p>
<Button onClick={() => setIsAddingNew(true)}>
<Plus className="w-4 h-4 mr-2" />
Add Provider
</Button>
</CardContent>
</Card>
) : (
<Card>
<CardHeader>
<CardTitle>Add New LLM Provider</CardTitle>
<CardDescription>
Configure a new language model provider for your AI assistant
</CardDescription>
</CardHeader>
<CardContent>
<form onSubmit={handleSubmit} className="space-y-4">
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
<div className="space-y-2">
<Label htmlFor="name">Configuration Name *</Label>
<Input
id="name"
placeholder="e.g., My OpenAI GPT-4"
value={formData.name}
onChange={(e) => handleInputChange('name', e.target.value)}
required
/>
</div>
<div className="space-y-2">
<Label htmlFor="provider">Provider *</Label>
<Select value={formData.provider} onValueChange={(value) => handleInputChange('provider', value)}>
<SelectTrigger>
<SelectValue placeholder="Select a provider" />
</SelectTrigger>
<SelectContent>
{LLM_PROVIDERS.map((provider) => (
<SelectItem key={provider.value} value={provider.value}>
{provider.label}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
</div>
{formData.provider === 'CUSTOM' && (
<div className="space-y-2">
<Label htmlFor="custom_provider">Custom Provider Name *</Label>
<Input
id="custom_provider"
placeholder="e.g., my-custom-provider"
value={formData.custom_provider}
onChange={(e) => handleInputChange('custom_provider', e.target.value)}
required
/>
</div>
)}
<div className="space-y-2">
<Label htmlFor="model_name">Model Name *</Label>
<Input
id="model_name"
placeholder={selectedProvider?.example || "e.g., gpt-4"}
value={formData.model_name}
onChange={(e) => handleInputChange('model_name', e.target.value)}
required
/>
{selectedProvider && (
<p className="text-xs text-muted-foreground">
Examples: {selectedProvider.example}
</p>
)}
</div>
<div className="space-y-2">
<Label htmlFor="api_key">API Key *</Label>
<Input
id="api_key"
type="password"
placeholder="Your API key"
value={formData.api_key}
onChange={(e) => handleInputChange('api_key', e.target.value)}
required
/>
</div>
<div className="space-y-2">
<Label htmlFor="api_base">API Base URL (Optional)</Label>
<Input
id="api_base"
placeholder="e.g., https://api.openai.com/v1"
value={formData.api_base}
onChange={(e) => handleInputChange('api_base', e.target.value)}
/>
</div>
<div className="flex gap-2 pt-4">
<Button type="submit" disabled={isSubmitting}>
{isSubmitting ? 'Adding...' : 'Add Provider'}
</Button>
<Button
type="button"
variant="outline"
onClick={() => setIsAddingNew(false)}
disabled={isSubmitting}
>
Cancel
</Button>
</div>
</form>
</CardContent>
</Card>
)}
</div>
);
}

View file

@ -0,0 +1,232 @@
"use client";
import React, { useState, useEffect } from 'react';
import { motion } from 'framer-motion';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
import { Badge } from '@/components/ui/badge';
import { Brain, Zap, Bot, AlertCircle, CheckCircle } from 'lucide-react';
import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs';
import { Alert, AlertDescription } from '@/components/ui/alert';
const ROLE_DESCRIPTIONS = {
long_context: {
icon: Brain,
title: 'Long Context LLM',
description: 'Handles complex tasks requiring extensive context understanding and reasoning',
color: 'bg-blue-100 text-blue-800 border-blue-200',
examples: 'Document analysis, research synthesis, complex Q&A'
},
fast: {
icon: Zap,
title: 'Fast LLM',
description: 'Optimized for quick responses and real-time interactions',
color: 'bg-green-100 text-green-800 border-green-200',
examples: 'Quick searches, simple questions, instant responses'
},
strategic: {
icon: Bot,
title: 'Strategic LLM',
description: 'Advanced reasoning for planning and strategic decision making',
color: 'bg-purple-100 text-purple-800 border-purple-200',
examples: 'Planning workflows, strategic analysis, complex problem solving'
}
};
interface AssignRolesStepProps {
onPreferencesUpdated?: () => Promise<void>;
}
export function AssignRolesStep({ onPreferencesUpdated }: AssignRolesStepProps) {
const { llmConfigs } = useLLMConfigs();
const { preferences, updatePreferences } = useLLMPreferences();
const [assignments, setAssignments] = useState({
long_context_llm_id: preferences.long_context_llm_id || '',
fast_llm_id: preferences.fast_llm_id || '',
strategic_llm_id: preferences.strategic_llm_id || ''
});
useEffect(() => {
setAssignments({
long_context_llm_id: preferences.long_context_llm_id || '',
fast_llm_id: preferences.fast_llm_id || '',
strategic_llm_id: preferences.strategic_llm_id || ''
});
}, [preferences]);
const handleRoleAssignment = async (role: string, configId: string) => {
const newAssignments = {
...assignments,
[role]: configId === '' ? '' : parseInt(configId)
};
setAssignments(newAssignments);
// Auto-save if this assignment completes all roles
const hasAllAssignments = newAssignments.long_context_llm_id && newAssignments.fast_llm_id && newAssignments.strategic_llm_id;
if (hasAllAssignments) {
const numericAssignments = {
long_context_llm_id: typeof newAssignments.long_context_llm_id === 'string' ? parseInt(newAssignments.long_context_llm_id) : newAssignments.long_context_llm_id,
fast_llm_id: typeof newAssignments.fast_llm_id === 'string' ? parseInt(newAssignments.fast_llm_id) : newAssignments.fast_llm_id,
strategic_llm_id: typeof newAssignments.strategic_llm_id === 'string' ? parseInt(newAssignments.strategic_llm_id) : newAssignments.strategic_llm_id,
};
const success = await updatePreferences(numericAssignments);
// Refresh parent preferences state
if (success && onPreferencesUpdated) {
await onPreferencesUpdated();
}
}
};
const isAssignmentComplete = assignments.long_context_llm_id && assignments.fast_llm_id && assignments.strategic_llm_id;
if (llmConfigs.length === 0) {
return (
<div className="flex flex-col items-center justify-center py-12">
<AlertCircle className="w-16 h-16 text-muted-foreground mb-4" />
<h3 className="text-lg font-semibold mb-2">No LLM Configurations Found</h3>
<p className="text-muted-foreground text-center">
Please add at least one LLM provider in the previous step before assigning roles.
</p>
</div>
);
}
return (
<div className="space-y-6">
{/* Info Alert */}
<Alert>
<AlertCircle className="h-4 w-4" />
<AlertDescription>
Assign your LLM configurations to specific roles. Each role serves different purposes in your workflow.
</AlertDescription>
</Alert>
{/* Role Assignment Cards */}
<div className="grid gap-6">
{Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => {
const IconComponent = role.icon;
const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments];
const assignedConfig = llmConfigs.find(config => config.id === currentAssignment);
return (
<motion.div
key={key}
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: Object.keys(ROLE_DESCRIPTIONS).indexOf(key) * 0.1 }}
>
<Card className={`border-l-4 ${currentAssignment ? 'border-l-primary' : 'border-l-muted'}`}>
<CardHeader className="pb-3">
<div className="flex items-center justify-between">
<div className="flex items-center gap-3">
<div className={`p-2 rounded-lg ${role.color}`}>
<IconComponent className="w-5 h-5" />
</div>
<div>
<CardTitle className="text-lg">{role.title}</CardTitle>
<CardDescription className="mt-1">{role.description}</CardDescription>
</div>
</div>
{currentAssignment && (
<CheckCircle className="w-5 h-5 text-green-500" />
)}
</div>
</CardHeader>
<CardContent className="space-y-4">
<div className="text-sm text-muted-foreground">
<strong>Use cases:</strong> {role.examples}
</div>
<div className="space-y-2">
<label className="text-sm font-medium">Assign LLM Configuration:</label>
<Select
value={currentAssignment?.toString() || ''}
onValueChange={(value) => handleRoleAssignment(`${key}_llm_id`, value)}
>
<SelectTrigger>
<SelectValue placeholder="Select an LLM configuration" />
</SelectTrigger>
<SelectContent>
{llmConfigs
.filter(config => config.id && config.id.toString().trim() !== '')
.map((config) => (
<SelectItem key={config.id} value={config.id.toString()}>
<div className="flex items-center gap-2">
<Badge variant="outline" className="text-xs">
{config.provider}
</Badge>
<span>{config.name}</span>
<span className="text-muted-foreground">({config.model_name})</span>
</div>
</SelectItem>
))}
</SelectContent>
</Select>
</div>
{assignedConfig && (
<div className="mt-3 p-3 bg-muted/50 rounded-lg">
<div className="flex items-center gap-2 text-sm">
<Bot className="w-4 h-4" />
<span className="font-medium">Assigned:</span>
<Badge variant="secondary">{assignedConfig.provider}</Badge>
<span>{assignedConfig.name}</span>
</div>
<div className="text-xs text-muted-foreground mt-1">
Model: {assignedConfig.model_name}
</div>
</div>
)}
</CardContent>
</Card>
</motion.div>
);
})}
</div>
{/* Status Indicator */}
{isAssignmentComplete && (
<div className="flex justify-center pt-4">
<div className="flex items-center gap-2 px-4 py-2 bg-green-50 text-green-700 rounded-lg border border-green-200">
<CheckCircle className="w-4 h-4" />
<span className="text-sm font-medium">All roles assigned and saved!</span>
</div>
</div>
)}
{/* Progress Indicator */}
<div className="flex justify-center">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<span>Progress:</span>
<div className="flex gap-1">
{Object.keys(ROLE_DESCRIPTIONS).map((key, index) => (
<div
key={key}
className={`w-2 h-2 rounded-full ${
assignments[`${key}_llm_id` as keyof typeof assignments]
? 'bg-primary'
: 'bg-muted'
}`}
/>
))}
</div>
<span>
{Object.values(assignments).filter(Boolean).length} of {Object.keys(ROLE_DESCRIPTIONS).length} roles assigned
</span>
</div>
</div>
</div>
);
}

View file

@ -0,0 +1,125 @@
"use client";
import React from 'react';
import { motion } from 'framer-motion';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { Badge } from '@/components/ui/badge';
import { CheckCircle, Bot, Brain, Zap, Sparkles, ArrowRight } from 'lucide-react';
import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs';
const ROLE_ICONS = {
long_context: Brain,
fast: Zap,
strategic: Bot
};
export function CompletionStep() {
const { llmConfigs } = useLLMConfigs();
const { preferences } = useLLMPreferences();
const assignedConfigs = {
long_context: llmConfigs.find(c => c.id === preferences.long_context_llm_id),
fast: llmConfigs.find(c => c.id === preferences.fast_llm_id),
strategic: llmConfigs.find(c => c.id === preferences.strategic_llm_id)
};
return (
<div className="space-y-8">
{/* Success Message */}
<motion.div
initial={{ opacity: 0, scale: 0.95 }}
animate={{ opacity: 1, scale: 1 }}
transition={{ duration: 0.5 }}
className="text-center"
>
<div className="w-20 h-20 mx-auto mb-6 bg-green-100 rounded-full flex items-center justify-center">
<CheckCircle className="w-10 h-10 text-green-600" />
</div>
<h2 className="text-2xl font-bold mb-2">Setup Complete!</h2>
</motion.div>
{/* Configuration Summary */}
<motion.div
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.2 }}
>
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2">
<Sparkles className="w-5 h-5" />
Your LLM Configuration
</CardTitle>
<CardDescription>
Here's a summary of your setup
</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
{Object.entries(assignedConfigs).map(([role, config]) => {
if (!config) return null;
const IconComponent = ROLE_ICONS[role as keyof typeof ROLE_ICONS];
const roleDisplayNames = {
long_context: 'Long Context LLM',
fast: 'Fast LLM',
strategic: 'Strategic LLM'
};
return (
<motion.div
key={role}
initial={{ opacity: 0, x: -10 }}
animate={{ opacity: 1, x: 0 }}
transition={{ delay: 0.3 + Object.keys(assignedConfigs).indexOf(role) * 0.1 }}
className="flex items-center justify-between p-3 bg-muted/50 rounded-lg"
>
<div className="flex items-center gap-3">
<div className="p-2 bg-background rounded-md">
<IconComponent className="w-4 h-4" />
</div>
<div>
<p className="font-medium">{roleDisplayNames[role as keyof typeof roleDisplayNames]}</p>
<p className="text-sm text-muted-foreground">{config.name}</p>
</div>
</div>
<div className="flex items-center gap-2">
<Badge variant="outline">{config.provider}</Badge>
<span className="text-sm text-muted-foreground">{config.model_name}</span>
</div>
</motion.div>
);
})}
</CardContent>
</Card>
</motion.div>
{/* Next Steps */}
<motion.div
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.6 }}
>
<Card className="border-primary/20 bg-primary/5">
<CardContent className="pt-6">
<div className="flex items-center gap-3 mb-4">
<div className="p-2 bg-primary rounded-md">
<ArrowRight className="w-4 h-4 text-primary-foreground" />
</div>
<h3 className="text-lg font-semibold">Ready to Get Started?</h3>
</div>
<p className="text-muted-foreground mb-4">
Click "Complete Setup" to enter your dashboard and start exploring!
</p>
<div className="flex flex-wrap gap-2 text-sm">
<Badge variant="secondary"> {llmConfigs.length} LLM provider{llmConfigs.length > 1 ? 's' : ''} configured</Badge>
<Badge variant="secondary"> All roles assigned</Badge>
<Badge variant="secondary"> Ready to use</Badge>
</div>
</CardContent>
</Card>
</motion.div>
</div>
);
}

View file

@ -0,0 +1,465 @@
"use client";
import React, { useState, useEffect } from 'react';
import { motion } from 'framer-motion';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
import { Badge } from '@/components/ui/badge';
import { Button } from '@/components/ui/button';
import {
Brain,
Zap,
Bot,
AlertCircle,
CheckCircle,
Settings2,
RefreshCw,
Save,
RotateCcw,
Loader2
} from 'lucide-react';
import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs';
import { Alert, AlertDescription } from '@/components/ui/alert';
import { toast } from 'sonner';
const ROLE_DESCRIPTIONS = {
long_context: {
icon: Brain,
title: 'Long Context LLM',
description: 'Handles complex tasks requiring extensive context understanding and reasoning',
color: 'bg-blue-100 text-blue-800 border-blue-200',
examples: 'Document analysis, research synthesis, complex Q&A',
characteristics: ['Large context window', 'Deep reasoning', 'Complex analysis']
},
fast: {
icon: Zap,
title: 'Fast LLM',
description: 'Optimized for quick responses and real-time interactions',
color: 'bg-green-100 text-green-800 border-green-200',
examples: 'Quick searches, simple questions, instant responses',
characteristics: ['Low latency', 'Quick responses', 'Real-time chat']
},
strategic: {
icon: Bot,
title: 'Strategic LLM',
description: 'Advanced reasoning for planning and strategic decision making',
color: 'bg-purple-100 text-purple-800 border-purple-200',
examples: 'Planning workflows, strategic analysis, complex problem solving',
characteristics: ['Strategic thinking', 'Long-term planning', 'Complex reasoning']
}
};
export function LLMRoleManager() {
const { llmConfigs, loading: configsLoading, error: configsError, refreshConfigs } = useLLMConfigs();
const { preferences, loading: preferencesLoading, error: preferencesError, updatePreferences, refreshPreferences } = useLLMPreferences();
const [assignments, setAssignments] = useState({
long_context_llm_id: preferences.long_context_llm_id || '',
fast_llm_id: preferences.fast_llm_id || '',
strategic_llm_id: preferences.strategic_llm_id || ''
});
const [hasChanges, setHasChanges] = useState(false);
const [isSaving, setIsSaving] = useState(false);
useEffect(() => {
const newAssignments = {
long_context_llm_id: preferences.long_context_llm_id || '',
fast_llm_id: preferences.fast_llm_id || '',
strategic_llm_id: preferences.strategic_llm_id || ''
};
setAssignments(newAssignments);
setHasChanges(false);
}, [preferences]);
const handleRoleAssignment = (role: string, configId: string) => {
const newAssignments = {
...assignments,
[role]: configId === 'unassigned' ? '' : parseInt(configId)
};
setAssignments(newAssignments);
// Check if there are changes compared to current preferences
const currentPrefs = {
long_context_llm_id: preferences.long_context_llm_id || '',
fast_llm_id: preferences.fast_llm_id || '',
strategic_llm_id: preferences.strategic_llm_id || ''
};
const hasChangesNow = Object.keys(newAssignments).some(
key => newAssignments[key as keyof typeof newAssignments] !== currentPrefs[key as keyof typeof currentPrefs]
);
setHasChanges(hasChangesNow);
};
const handleSave = async () => {
setIsSaving(true);
const numericAssignments = {
long_context_llm_id: typeof assignments.long_context_llm_id === 'string'
? (assignments.long_context_llm_id ? parseInt(assignments.long_context_llm_id) : undefined)
: assignments.long_context_llm_id,
fast_llm_id: typeof assignments.fast_llm_id === 'string'
? (assignments.fast_llm_id ? parseInt(assignments.fast_llm_id) : undefined)
: assignments.fast_llm_id,
strategic_llm_id: typeof assignments.strategic_llm_id === 'string'
? (assignments.strategic_llm_id ? parseInt(assignments.strategic_llm_id) : undefined)
: assignments.strategic_llm_id,
};
const success = await updatePreferences(numericAssignments);
if (success) {
setHasChanges(false);
toast.success('LLM role assignments saved successfully!');
}
setIsSaving(false);
};
const handleReset = () => {
setAssignments({
long_context_llm_id: preferences.long_context_llm_id || '',
fast_llm_id: preferences.fast_llm_id || '',
strategic_llm_id: preferences.strategic_llm_id || ''
});
setHasChanges(false);
};
const isAssignmentComplete = assignments.long_context_llm_id && assignments.fast_llm_id && assignments.strategic_llm_id;
const assignedConfigIds = Object.values(assignments).filter(id => id !== '');
const availableConfigs = llmConfigs.filter(config => config.id && config.id.toString().trim() !== '');
const isLoading = configsLoading || preferencesLoading;
const hasError = configsError || preferencesError;
return (
<div className="space-y-6">
{/* Header */}
<div className="flex flex-col space-y-4 lg:flex-row lg:items-center lg:justify-between lg:space-y-0">
<div className="space-y-1">
<div className="flex items-center space-x-3">
<div className="flex h-10 w-10 items-center justify-center rounded-lg bg-purple-500/10">
<Settings2 className="h-5 w-5 text-purple-600" />
</div>
<div>
<h2 className="text-2xl font-bold tracking-tight">LLM Role Management</h2>
<p className="text-muted-foreground">
Assign your LLM configurations to specific roles for different purposes.
</p>
</div>
</div>
</div>
<div className="flex flex-wrap gap-2">
<Button
variant="outline"
size="sm"
onClick={refreshConfigs}
disabled={isLoading}
className="flex items-center gap-2"
>
<RefreshCw className={`h-4 w-4 ${configsLoading ? 'animate-spin' : ''}`} />
<span className="hidden sm:inline">Refresh Configs</span>
<span className="sm:hidden">Configs</span>
</Button>
<Button
variant="outline"
size="sm"
onClick={refreshPreferences}
disabled={isLoading}
className="flex items-center gap-2"
>
<RefreshCw className={`h-4 w-4 ${preferencesLoading ? 'animate-spin' : ''}`} />
<span className="hidden sm:inline">Refresh Preferences</span>
<span className="sm:hidden">Prefs</span>
</Button>
</div>
</div>
{/* Error Alert */}
{hasError && (
<Alert variant="destructive">
<AlertCircle className="h-4 w-4" />
<AlertDescription>
{configsError || preferencesError}
</AlertDescription>
</Alert>
)}
{/* Loading State */}
{isLoading && (
<Card>
<CardContent className="flex items-center justify-center py-12">
<div className="flex items-center gap-2 text-muted-foreground">
<Loader2 className="w-5 h-5 animate-spin" />
<span>
{configsLoading && preferencesLoading ? 'Loading configurations and preferences...' :
configsLoading ? 'Loading configurations...' :
'Loading preferences...'}
</span>
</div>
</CardContent>
</Card>
)}
{/* Stats Overview */}
{!isLoading && !hasError && (
<div className="grid gap-4 sm:grid-cols-2 lg:grid-cols-4">
<Card className="border-l-4 border-l-blue-500">
<CardContent className="p-6">
<div className="flex items-center justify-between space-x-4">
<div className="space-y-1">
<p className="text-3xl font-bold tracking-tight">{availableConfigs.length}</p>
<p className="text-sm font-medium text-muted-foreground">Available Models</p>
</div>
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-blue-500/10">
<Bot className="h-6 w-6 text-blue-600" />
</div>
</div>
</CardContent>
</Card>
<Card className="border-l-4 border-l-purple-500">
<CardContent className="p-6">
<div className="flex items-center justify-between space-x-4">
<div className="space-y-1">
<p className="text-3xl font-bold tracking-tight">{assignedConfigIds.length}</p>
<p className="text-sm font-medium text-muted-foreground">Assigned Roles</p>
</div>
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-purple-500/10">
<CheckCircle className="h-6 w-6 text-purple-600" />
</div>
</div>
</CardContent>
</Card>
<Card className={`border-l-4 ${
isAssignmentComplete ? 'border-l-green-500' : 'border-l-yellow-500'
}`}>
<CardContent className="p-6">
<div className="flex items-center justify-between space-x-4">
<div className="space-y-1">
<p className="text-3xl font-bold tracking-tight">
{Math.round((assignedConfigIds.length / 3) * 100)}%
</p>
<p className="text-sm font-medium text-muted-foreground">Completion</p>
</div>
<div className={`flex h-12 w-12 items-center justify-center rounded-lg ${
isAssignmentComplete ? 'bg-green-500/10' : 'bg-yellow-500/10'
}`}>
{isAssignmentComplete ? (
<CheckCircle className="h-6 w-6 text-green-600" />
) : (
<AlertCircle className="h-6 w-6 text-yellow-600" />
)}
</div>
</div>
</CardContent>
</Card>
<Card className={`border-l-4 ${
isAssignmentComplete ? 'border-l-emerald-500' : 'border-l-orange-500'
}`}>
<CardContent className="p-6">
<div className="flex items-center justify-between space-x-4">
<div className="space-y-1">
<p className={`text-3xl font-bold tracking-tight ${
isAssignmentComplete ? 'text-emerald-600' : 'text-orange-600'
}`}>
{isAssignmentComplete ? 'Ready' : 'Setup'}
</p>
<p className="text-sm font-medium text-muted-foreground">Status</p>
</div>
<div className={`flex h-12 w-12 items-center justify-center rounded-lg ${
isAssignmentComplete ? 'bg-emerald-500/10' : 'bg-orange-500/10'
}`}>
{isAssignmentComplete ? (
<CheckCircle className="h-6 w-6 text-emerald-600" />
) : (
<RefreshCw className="h-6 w-6 text-orange-600" />
)}
</div>
</div>
</CardContent>
</Card>
</div>
)}
{/* Info Alert */}
{!isLoading && !hasError && (
<div className="space-y-6">
{availableConfigs.length === 0 ? (
<Alert variant="destructive">
<AlertCircle className="h-4 w-4" />
<AlertDescription>
No LLM configurations found. Please add at least one LLM provider in the Model Configs tab before assigning roles.
</AlertDescription>
</Alert>
) : !isAssignmentComplete ? (
<Alert>
<AlertCircle className="h-4 w-4" />
<AlertDescription>
Complete all role assignments to enable full functionality. Each role serves different purposes in your workflow.
</AlertDescription>
</Alert>
) : (
<Alert>
<CheckCircle className="h-4 w-4" />
<AlertDescription>
All roles are assigned and ready to use! Your LLM configuration is complete.
</AlertDescription>
</Alert>
)}
{/* Role Assignment Cards */}
{availableConfigs.length > 0 && (
<div className="grid gap-6">
{Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => {
const IconComponent = role.icon;
const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments];
const assignedConfig = availableConfigs.find(config => config.id === currentAssignment);
return (
<motion.div
key={key}
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: Object.keys(ROLE_DESCRIPTIONS).indexOf(key) * 0.1 }}
>
<Card className={`border-l-4 ${currentAssignment ? 'border-l-primary' : 'border-l-muted'} hover:shadow-md transition-shadow`}>
<CardHeader className="pb-3">
<div className="flex items-center justify-between">
<div className="flex items-center gap-3">
<div className={`p-2 rounded-lg ${role.color}`}>
<IconComponent className="w-5 h-5" />
</div>
<div>
<CardTitle className="text-lg">{role.title}</CardTitle>
<CardDescription className="mt-1">{role.description}</CardDescription>
</div>
</div>
{currentAssignment && (
<CheckCircle className="w-5 h-5 text-green-500" />
)}
</div>
</CardHeader>
<CardContent className="space-y-4">
<div className="space-y-2">
<div className="text-sm text-muted-foreground">
<strong>Use cases:</strong> {role.examples}
</div>
<div className="flex flex-wrap gap-1">
{role.characteristics.map((char, idx) => (
<Badge key={idx} variant="outline" className="text-xs">
{char}
</Badge>
))}
</div>
</div>
<div className="space-y-2">
<label className="text-sm font-medium">Assign LLM Configuration:</label>
<Select
value={currentAssignment?.toString() || 'unassigned'}
onValueChange={(value) => handleRoleAssignment(`${key}_llm_id`, value)}
>
<SelectTrigger>
<SelectValue placeholder="Select an LLM configuration" />
</SelectTrigger>
<SelectContent>
<SelectItem value="unassigned">
<span className="text-muted-foreground">Unassigned</span>
</SelectItem>
{availableConfigs.map((config) => (
<SelectItem key={config.id} value={config.id.toString()}>
<div className="flex items-center gap-2">
<Badge variant="outline" className="text-xs">
{config.provider}
</Badge>
<span>{config.name}</span>
<span className="text-muted-foreground">({config.model_name})</span>
</div>
</SelectItem>
))}
</SelectContent>
</Select>
</div>
{assignedConfig && (
<div className="mt-3 p-3 bg-muted/50 rounded-lg">
<div className="flex items-center gap-2 text-sm">
<Bot className="w-4 h-4" />
<span className="font-medium">Assigned:</span>
<Badge variant="secondary">{assignedConfig.provider}</Badge>
<span>{assignedConfig.name}</span>
</div>
<div className="text-xs text-muted-foreground mt-1">
Model: {assignedConfig.model_name}
</div>
{assignedConfig.api_base && (
<div className="text-xs text-muted-foreground">
Base: {assignedConfig.api_base}
</div>
)}
</div>
)}
</CardContent>
</Card>
</motion.div>
);
})}
</div>
)}
{/* Action Buttons */}
{hasChanges && (
<div className="flex justify-center gap-3 pt-4">
<Button onClick={handleSave} disabled={isSaving} className="flex items-center gap-2">
<Save className="w-4 h-4" />
{isSaving ? 'Saving...' : 'Save Changes'}
</Button>
<Button variant="outline" onClick={handleReset} disabled={isSaving} className="flex items-center gap-2">
<RotateCcw className="w-4 h-4" />
Reset
</Button>
</div>
)}
{/* Status Indicator */}
{isAssignmentComplete && !hasChanges && (
<div className="flex justify-center pt-4">
<div className="flex items-center gap-2 px-4 py-2 bg-green-50 text-green-700 rounded-lg border border-green-200">
<CheckCircle className="w-4 h-4" />
<span className="text-sm font-medium">All roles assigned and saved!</span>
</div>
</div>
)}
{/* Progress Indicator */}
<div className="flex justify-center">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<span>Progress:</span>
<div className="flex gap-1">
{Object.keys(ROLE_DESCRIPTIONS).map((key, index) => (
<div
key={key}
className={`w-2 h-2 rounded-full ${
assignments[`${key}_llm_id` as keyof typeof assignments]
? 'bg-primary'
: 'bg-muted'
}`}
/>
))}
</div>
<span>
{assignedConfigIds.length} of {Object.keys(ROLE_DESCRIPTIONS).length} roles assigned
</span>
</div>
</div>
</div>
)}
</div>
);
}

View file

@ -0,0 +1,631 @@
"use client";
import React, { useState, useEffect } from 'react';
import { motion, AnimatePresence } from 'framer-motion';
import { Button } from '@/components/ui/button';
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
import { Input } from '@/components/ui/input';
import { Label } from '@/components/ui/label';
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
import { Badge } from '@/components/ui/badge';
import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle } from '@/components/ui/dialog';
import {
Plus,
Trash2,
Bot,
AlertCircle,
Edit3,
Settings2,
Eye,
EyeOff,
CheckCircle,
Clock,
AlertTriangle,
RefreshCw,
Loader2
} from 'lucide-react';
import { useLLMConfigs, CreateLLMConfig, UpdateLLMConfig, LLMConfig } from '@/hooks/use-llm-configs';
import { toast } from 'sonner';
import { Alert, AlertDescription } from '@/components/ui/alert';
const LLM_PROVIDERS = [
{
value: 'OPENAI',
label: 'OpenAI',
example: 'gpt-4o, gpt-4, gpt-3.5-turbo',
description: 'Most popular and versatile AI models'
},
{
value: 'ANTHROPIC',
label: 'Anthropic',
example: 'claude-3-5-sonnet-20241022, claude-3-opus-20240229',
description: 'Constitutional AI with strong reasoning'
},
{
value: 'GROQ',
label: 'Groq',
example: 'llama3-70b-8192, mixtral-8x7b-32768',
description: 'Ultra-fast inference speeds'
},
{
value: 'COHERE',
label: 'Cohere',
example: 'command-r-plus, command-r',
description: 'Enterprise-focused language models'
},
{
value: 'HUGGINGFACE',
label: 'HuggingFace',
example: 'microsoft/DialoGPT-medium',
description: 'Open source model hub'
},
{
value: 'AZURE_OPENAI',
label: 'Azure OpenAI',
example: 'gpt-4, gpt-35-turbo',
description: 'Enterprise OpenAI through Azure'
},
{
value: 'GOOGLE',
label: 'Google',
example: 'gemini-pro, gemini-pro-vision',
description: 'Google\'s Gemini AI models'
},
{
value: 'AWS_BEDROCK',
label: 'AWS Bedrock',
example: 'anthropic.claude-v2',
description: 'AWS managed AI service'
},
{
value: 'OLLAMA',
label: 'Ollama',
example: 'llama2, codellama',
description: 'Run models locally'
},
{
value: 'MISTRAL',
label: 'Mistral',
example: 'mistral-large-latest, mistral-medium',
description: 'European AI excellence'
},
{
value: 'TOGETHER_AI',
label: 'Together AI',
example: 'togethercomputer/llama-2-70b-chat',
description: 'Decentralized AI platform'
},
{
value: 'REPLICATE',
label: 'Replicate',
example: 'meta/llama-2-70b-chat',
description: 'Run models via API'
},
{
value: 'CUSTOM',
label: 'Custom Provider',
example: 'your-custom-model',
description: 'Your own model endpoint'
},
];
export function ModelConfigManager() {
const { llmConfigs, loading, error, createLLMConfig, updateLLMConfig, deleteLLMConfig, refreshConfigs } = useLLMConfigs();
const [isAddingNew, setIsAddingNew] = useState(false);
const [editingConfig, setEditingConfig] = useState<LLMConfig | null>(null);
const [showApiKey, setShowApiKey] = useState<Record<number, boolean>>({});
const [formData, setFormData] = useState<CreateLLMConfig>({
name: '',
provider: '',
custom_provider: '',
model_name: '',
api_key: '',
api_base: '',
litellm_params: {}
});
const [isSubmitting, setIsSubmitting] = useState(false);
// Populate form when editing
useEffect(() => {
if (editingConfig) {
setFormData({
name: editingConfig.name,
provider: editingConfig.provider,
custom_provider: editingConfig.custom_provider || '',
model_name: editingConfig.model_name,
api_key: editingConfig.api_key,
api_base: editingConfig.api_base || '',
litellm_params: editingConfig.litellm_params || {}
});
}
}, [editingConfig]);
const handleInputChange = (field: keyof CreateLLMConfig, value: string) => {
setFormData(prev => ({ ...prev, [field]: value }));
};
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
if (!formData.name || !formData.provider || !formData.model_name || !formData.api_key) {
toast.error('Please fill in all required fields');
return;
}
setIsSubmitting(true);
let result;
if (editingConfig) {
// Update existing config
result = await updateLLMConfig(editingConfig.id, formData);
} else {
// Create new config
result = await createLLMConfig(formData);
}
setIsSubmitting(false);
if (result) {
setFormData({
name: '',
provider: '',
custom_provider: '',
model_name: '',
api_key: '',
api_base: '',
litellm_params: {}
});
setIsAddingNew(false);
setEditingConfig(null);
}
};
const handleDelete = async (id: number) => {
if (confirm('Are you sure you want to delete this configuration? This action cannot be undone.')) {
await deleteLLMConfig(id);
}
};
const toggleApiKeyVisibility = (configId: number) => {
setShowApiKey(prev => ({
...prev,
[configId]: !prev[configId]
}));
};
const selectedProvider = LLM_PROVIDERS.find(p => p.value === formData.provider);
const getProviderInfo = (providerValue: string) => {
return LLM_PROVIDERS.find(p => p.value === providerValue);
};
const maskApiKey = (apiKey: string) => {
if (apiKey.length <= 8) return '*'.repeat(apiKey.length);
return apiKey.substring(0, 4) + '*'.repeat(apiKey.length - 8) + apiKey.substring(apiKey.length - 4);
};
return (
<div className="space-y-6">
{/* Header */}
<div className="flex flex-col space-y-4 sm:flex-row sm:items-center sm:justify-between sm:space-y-0">
<div className="space-y-1">
<div className="flex items-center space-x-3">
<div className="flex h-10 w-10 items-center justify-center rounded-lg bg-blue-500/10">
<Settings2 className="h-5 w-5 text-blue-600" />
</div>
<div>
<h2 className="text-2xl font-bold tracking-tight">Model Configurations</h2>
<p className="text-muted-foreground">
Manage your LLM provider configurations and API settings.
</p>
</div>
</div>
</div>
<div className="flex items-center space-x-2">
<Button
variant="outline"
size="sm"
onClick={refreshConfigs}
disabled={loading}
className="flex items-center gap-2"
>
<RefreshCw className={`h-4 w-4 ${loading ? 'animate-spin' : ''}`} />
Refresh
</Button>
</div>
</div>
{/* Error Alert */}
{error && (
<Alert variant="destructive">
<AlertCircle className="h-4 w-4" />
<AlertDescription>
{error}
</AlertDescription>
</Alert>
)}
{/* Loading State */}
{loading && (
<Card>
<CardContent className="flex items-center justify-center py-12">
<div className="flex items-center gap-2 text-muted-foreground">
<Loader2 className="w-5 h-5 animate-spin" />
<span>Loading configurations...</span>
</div>
</CardContent>
</Card>
)}
{/* Stats Overview */}
{!loading && !error && (
<div className="grid gap-4 md:grid-cols-3">
<Card className="border-l-4 border-l-blue-500">
<CardContent className="p-6">
<div className="flex items-center justify-between space-x-4">
<div className="space-y-1">
<p className="text-3xl font-bold tracking-tight">{llmConfigs.length}</p>
<p className="text-sm font-medium text-muted-foreground">Total Configurations</p>
</div>
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-blue-500/10">
<Bot className="h-6 w-6 text-blue-600" />
</div>
</div>
</CardContent>
</Card>
<Card className="border-l-4 border-l-green-500">
<CardContent className="p-6">
<div className="flex items-center justify-between space-x-4">
<div className="space-y-1">
<p className="text-3xl font-bold tracking-tight">
{new Set(llmConfigs.map(c => c.provider)).size}
</p>
<p className="text-sm font-medium text-muted-foreground">Unique Providers</p>
</div>
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-green-500/10">
<CheckCircle className="h-6 w-6 text-green-600" />
</div>
</div>
</CardContent>
</Card>
<Card className="border-l-4 border-l-emerald-500">
<CardContent className="p-6">
<div className="flex items-center justify-between space-x-4">
<div className="space-y-1">
<p className="text-3xl font-bold tracking-tight text-emerald-600">Active</p>
<p className="text-sm font-medium text-muted-foreground">System Status</p>
</div>
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-emerald-500/10">
<CheckCircle className="h-6 w-6 text-emerald-600" />
</div>
</div>
</CardContent>
</Card>
</div>
)}
{/* Configuration Management */}
{!loading && !error && (
<div className="space-y-6">
<div className="flex flex-col space-y-4 sm:flex-row sm:items-center sm:justify-between sm:space-y-0">
<div>
<h3 className="text-xl font-semibold tracking-tight">Your Configurations</h3>
<p className="text-sm text-muted-foreground">
Manage and configure your LLM providers
</p>
</div>
<Button onClick={() => setIsAddingNew(true)} className="flex items-center gap-2">
<Plus className="h-4 w-4" />
Add Configuration
</Button>
</div>
{llmConfigs.length === 0 ? (
<Card className="border-dashed border-2 border-muted-foreground/25">
<CardContent className="flex flex-col items-center justify-center py-16 text-center">
<div className="rounded-full bg-muted p-4 mb-6">
<Bot className="h-10 w-10 text-muted-foreground" />
</div>
<div className="space-y-2 mb-6">
<h3 className="text-xl font-semibold">No Configurations Yet</h3>
<p className="text-muted-foreground max-w-sm">
Get started by adding your first LLM provider configuration to begin using the system.
</p>
</div>
<Button onClick={() => setIsAddingNew(true)} size="lg">
<Plus className="h-4 w-4 mr-2" />
Add First Configuration
</Button>
</CardContent>
</Card>
) : (
<div className="grid gap-4">
<AnimatePresence>
{llmConfigs.map((config) => {
const providerInfo = getProviderInfo(config.provider);
return (
<motion.div
key={config.id}
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -10 }}
transition={{ duration: 0.2 }}
>
<Card className="group border-l-4 border-l-primary/50 hover:border-l-primary hover:shadow-md transition-all duration-200">
<CardContent className="p-6">
<div className="flex items-start justify-between">
<div className="flex-1 space-y-4">
{/* Header */}
<div className="flex items-start gap-4">
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-primary/10 group-hover:bg-primary/20 transition-colors">
<Bot className="h-6 w-6 text-primary" />
</div>
<div className="flex-1 space-y-2">
<div className="flex items-center gap-3">
<h4 className="text-lg font-semibold tracking-tight">{config.name}</h4>
<Badge variant="secondary" className="text-xs font-medium">
{config.provider}
</Badge>
</div>
<p className="text-sm text-muted-foreground font-mono">
{config.model_name}
</p>
</div>
</div>
{/* Provider Description */}
{providerInfo && (
<p className="text-sm text-muted-foreground">
{providerInfo.description}
</p>
)}
{/* Configuration Details */}
<div className="grid gap-4 sm:grid-cols-2">
<div className="space-y-2">
<Label className="text-xs font-medium uppercase tracking-wide text-muted-foreground">
API Key
</Label>
<div className="flex items-center space-x-2">
<code className="flex-1 rounded-md bg-muted px-3 py-2 text-xs font-mono">
{showApiKey[config.id]
? config.api_key
: maskApiKey(config.api_key)
}
</code>
<Button
variant="ghost"
size="sm"
onClick={() => toggleApiKeyVisibility(config.id)}
className="h-8 w-8 p-0"
>
{showApiKey[config.id] ? (
<EyeOff className="h-3 w-3" />
) : (
<Eye className="h-3 w-3" />
)}
</Button>
</div>
</div>
{config.api_base && (
<div className="space-y-2">
<Label className="text-xs font-medium uppercase tracking-wide text-muted-foreground">
API Base URL
</Label>
<code className="block rounded-md bg-muted px-3 py-2 text-xs font-mono break-all">
{config.api_base}
</code>
</div>
)}
</div>
{/* Metadata */}
<div className="flex flex-wrap items-center gap-4 pt-4 border-t border-border/50">
<div className="flex items-center gap-2 text-xs text-muted-foreground">
<Clock className="h-3 w-3" />
<span>Created {new Date(config.created_at).toLocaleDateString()}</span>
</div>
<div className="flex items-center gap-2 text-xs">
<div className="h-2 w-2 rounded-full bg-green-500"></div>
<span className="text-green-600 font-medium">Active</span>
</div>
</div>
</div>
{/* Actions */}
<div className="flex flex-col gap-2 ml-6">
<Button
variant="outline"
size="sm"
onClick={() => setEditingConfig(config)}
className="h-8 w-8 p-0"
>
<Edit3 className="h-4 w-4" />
</Button>
<Button
variant="outline"
size="sm"
onClick={() => handleDelete(config.id)}
className="h-8 w-8 p-0 border-destructive/20 text-destructive hover:bg-destructive hover:text-destructive-foreground"
>
<Trash2 className="h-4 w-4" />
</Button>
</div>
</div>
</CardContent>
</Card>
</motion.div>
);
})}
</AnimatePresence>
</div>
)}
</div>
)}
{/* Add/Edit Configuration Dialog */}
<Dialog open={isAddingNew || !!editingConfig} onOpenChange={(open) => {
if (!open) {
setIsAddingNew(false);
setEditingConfig(null);
setFormData({
name: '',
provider: '',
custom_provider: '',
model_name: '',
api_key: '',
api_base: '',
litellm_params: {}
});
}
}}>
<DialogContent className="max-w-2xl max-h-[90vh] overflow-y-auto">
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
{editingConfig ? <Edit3 className="w-5 h-5" /> : <Plus className="w-5 h-5" />}
{editingConfig ? 'Edit LLM Configuration' : 'Add New LLM Configuration'}
</DialogTitle>
<DialogDescription>
{editingConfig
? 'Update your language model provider configuration'
: 'Configure a new language model provider for your AI assistant'
}
</DialogDescription>
</DialogHeader>
<form onSubmit={handleSubmit} className="space-y-4">
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
<div className="space-y-2">
<Label htmlFor="name">Configuration Name *</Label>
<Input
id="name"
placeholder="e.g., My OpenAI GPT-4"
value={formData.name}
onChange={(e) => handleInputChange('name', e.target.value)}
required
/>
</div>
<div className="space-y-2">
<Label htmlFor="provider">Provider *</Label>
<Select value={formData.provider} onValueChange={(value) => handleInputChange('provider', value)}>
<SelectTrigger className="h-auto min-h-[2.5rem] py-2">
<SelectValue placeholder="Select a provider">
{formData.provider && (
<div className="flex items-center space-x-2 py-1">
<div className="font-medium">
{LLM_PROVIDERS.find(p => p.value === formData.provider)?.label}
</div>
<div className="text-xs text-muted-foreground">
</div>
<div className="text-xs text-muted-foreground">
{LLM_PROVIDERS.find(p => p.value === formData.provider)?.description}
</div>
</div>
)}
</SelectValue>
</SelectTrigger>
<SelectContent>
{LLM_PROVIDERS.map((provider) => (
<SelectItem key={provider.value} value={provider.value}>
<div className="space-y-1 py-1">
<div className="font-medium">{provider.label}</div>
<div className="text-xs text-muted-foreground">
{provider.description}
</div>
</div>
</SelectItem>
))}
</SelectContent>
</Select>
</div>
</div>
{formData.provider === 'CUSTOM' && (
<div className="space-y-2">
<Label htmlFor="custom_provider">Custom Provider Name *</Label>
<Input
id="custom_provider"
placeholder="e.g., my-custom-provider"
value={formData.custom_provider}
onChange={(e) => handleInputChange('custom_provider', e.target.value)}
required
/>
</div>
)}
<div className="space-y-2">
<Label htmlFor="model_name">Model Name *</Label>
<Input
id="model_name"
placeholder={selectedProvider?.example || "e.g., gpt-4"}
value={formData.model_name}
onChange={(e) => handleInputChange('model_name', e.target.value)}
required
/>
{selectedProvider && (
<p className="text-xs text-muted-foreground">
Examples: {selectedProvider.example}
</p>
)}
</div>
<div className="space-y-2">
<Label htmlFor="api_key">API Key *</Label>
<Input
id="api_key"
type="password"
placeholder="Your API key"
value={formData.api_key}
onChange={(e) => handleInputChange('api_key', e.target.value)}
required
/>
</div>
<div className="space-y-2">
<Label htmlFor="api_base">API Base URL (Optional)</Label>
<Input
id="api_base"
placeholder="e.g., https://api.openai.com/v1"
value={formData.api_base}
onChange={(e) => handleInputChange('api_base', e.target.value)}
/>
</div>
<div className="flex gap-2 pt-4">
<Button type="submit" disabled={isSubmitting}>
{isSubmitting
? (editingConfig ? 'Updating...' : 'Adding...')
: (editingConfig ? 'Update Configuration' : 'Add Configuration')
}
</Button>
<Button
type="button"
variant="outline"
onClick={() => {
setIsAddingNew(false);
setEditingConfig(null);
setFormData({
name: '',
provider: '',
custom_provider: '',
model_name: '',
api_key: '',
api_base: '',
litellm_params: {}
});
}}
disabled={isSubmitting}
>
Cancel
</Button>
</div>
</form>
</DialogContent>
</Dialog>
</div>
);
}

View file

@ -4,6 +4,7 @@ import {
BadgeCheck,
ChevronsUpDown,
LogOut,
Settings,
} from "lucide-react"
import {
@ -93,6 +94,10 @@ export function NavUser({
</DropdownMenuItem>
</DropdownMenuGroup>
<DropdownMenuSeparator />
<DropdownMenuItem onClick={() => router.push(`/settings`)}>
<Settings />
Settings
</DropdownMenuItem>
<DropdownMenuItem onClick={handleLogout}>
<LogOut />
Log out

View file

@ -0,0 +1,29 @@
"use client"
import * as React from "react"
import { cn } from "@/lib/utils"
interface ProgressProps extends React.HTMLAttributes<HTMLDivElement> {
value?: number
}
const Progress = React.forwardRef<HTMLDivElement, ProgressProps>(
({ className, value = 0, ...props }, ref) => (
<div
ref={ref}
className={cn(
"relative h-4 w-full overflow-hidden rounded-full bg-secondary",
className
)}
{...props}
>
<div
className="h-full bg-primary transition-all duration-300 ease-in-out"
style={{ width: `${Math.min(100, Math.max(0, value))}%` }}
/>
</div>
)
)
Progress.displayName = "Progress"
export { Progress }

View file

@ -0,0 +1,246 @@
"use client"
import { useState, useEffect } from 'react';
import { toast } from 'sonner';
export interface LLMConfig {
id: number;
name: string;
provider: string;
custom_provider?: string;
model_name: string;
api_key: string;
api_base?: string;
litellm_params?: Record<string, any>;
created_at: string;
user_id: string;
}
export interface LLMPreferences {
long_context_llm_id?: number;
fast_llm_id?: number;
strategic_llm_id?: number;
long_context_llm?: LLMConfig;
fast_llm?: LLMConfig;
strategic_llm?: LLMConfig;
}
export interface CreateLLMConfig {
name: string;
provider: string;
custom_provider?: string;
model_name: string;
api_key: string;
api_base?: string;
litellm_params?: Record<string, any>;
}
export interface UpdateLLMConfig {
name?: string;
provider?: string;
custom_provider?: string;
model_name?: string;
api_key?: string;
api_base?: string;
litellm_params?: Record<string, any>;
}
export function useLLMConfigs() {
const [llmConfigs, setLlmConfigs] = useState<LLMConfig[]>([]);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const fetchLLMConfigs = async () => {
try {
setLoading(true);
const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/`, {
headers: {
Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`,
},
method: "GET",
});
if (!response.ok) {
throw new Error("Failed to fetch LLM configurations");
}
const data = await response.json();
setLlmConfigs(data);
setError(null);
} catch (err: any) {
setError(err.message || 'Failed to fetch LLM configurations');
console.error('Error fetching LLM configurations:', err);
} finally {
setLoading(false);
}
};
useEffect(() => {
fetchLLMConfigs();
}, []);
const createLLMConfig = async (config: CreateLLMConfig): Promise<LLMConfig | null> => {
try {
const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`,
},
body: JSON.stringify(config),
});
if (!response.ok) {
const errorData = await response.json();
throw new Error(errorData.detail || 'Failed to create LLM configuration');
}
const newConfig = await response.json();
setLlmConfigs(prev => [...prev, newConfig]);
toast.success('LLM configuration created successfully');
return newConfig;
} catch (err: any) {
toast.error(err.message || 'Failed to create LLM configuration');
console.error('Error creating LLM configuration:', err);
return null;
}
};
const deleteLLMConfig = async (id: number): Promise<boolean> => {
try {
const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/${id}`, {
method: 'DELETE',
headers: {
Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`,
},
});
if (!response.ok) {
throw new Error('Failed to delete LLM configuration');
}
setLlmConfigs(prev => prev.filter(config => config.id !== id));
toast.success('LLM configuration deleted successfully');
return true;
} catch (err: any) {
toast.error(err.message || 'Failed to delete LLM configuration');
console.error('Error deleting LLM configuration:', err);
return false;
}
};
const updateLLMConfig = async (id: number, config: UpdateLLMConfig): Promise<LLMConfig | null> => {
try {
const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/${id}`, {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`,
},
body: JSON.stringify(config),
});
if (!response.ok) {
const errorData = await response.json();
throw new Error(errorData.detail || 'Failed to update LLM configuration');
}
const updatedConfig = await response.json();
setLlmConfigs(prev => prev.map(c => c.id === id ? updatedConfig : c));
toast.success('LLM configuration updated successfully');
return updatedConfig;
} catch (err: any) {
toast.error(err.message || 'Failed to update LLM configuration');
console.error('Error updating LLM configuration:', err);
return null;
}
};
return {
llmConfigs,
loading,
error,
createLLMConfig,
updateLLMConfig,
deleteLLMConfig,
refreshConfigs: fetchLLMConfigs
};
}
export function useLLMPreferences() {
const [preferences, setPreferences] = useState<LLMPreferences>({});
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const fetchPreferences = async () => {
try {
setLoading(true);
const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`, {
headers: {
Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`,
},
method: "GET",
});
if (!response.ok) {
throw new Error("Failed to fetch LLM preferences");
}
const data = await response.json();
setPreferences(data);
setError(null);
} catch (err: any) {
setError(err.message || 'Failed to fetch LLM preferences');
console.error('Error fetching LLM preferences:', err);
} finally {
setLoading(false);
}
};
useEffect(() => {
fetchPreferences();
}, []);
const updatePreferences = async (newPreferences: Partial<LLMPreferences>): Promise<boolean> => {
try {
const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`, {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`,
},
body: JSON.stringify(newPreferences),
});
if (!response.ok) {
const errorData = await response.json();
throw new Error(errorData.detail || 'Failed to update LLM preferences');
}
const updatedPreferences = await response.json();
setPreferences(updatedPreferences);
toast.success('LLM preferences updated successfully');
return true;
} catch (err: any) {
toast.error(err.message || 'Failed to update LLM preferences');
console.error('Error updating LLM preferences:', err);
return false;
}
};
const isOnboardingComplete = (): boolean => {
return !!(
preferences.long_context_llm_id &&
preferences.fast_llm_id &&
preferences.strategic_llm_id
);
};
return {
preferences,
loading,
error,
updatePreferences,
refreshPreferences: fetchPreferences,
isOnboardingComplete
};
}