mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 10:09:08 +00:00
feat: added configurable LLM's
This commit is contained in:
parent
d0e9fdf810
commit
a85f7920a9
36 changed files with 3415 additions and 293 deletions
|
@ -54,7 +54,7 @@ Open source and easy to deploy locally.
|
|||
- Support for multiple TTS providers (OpenAI, Azure, Google Vertex AI)
|
||||
|
||||
### 📊 **Advanced RAG Techniques**
|
||||
- Supports 150+ LLM's
|
||||
- Supports 100+ LLM's
|
||||
- Supports 6000+ Embedding Models.
|
||||
- Supports all major Rerankers (Pinecode, Cohere, Flashrank etc)
|
||||
- Uses Hierarchical Indices (2 tiered RAG setup).
|
||||
|
|
|
@ -1,51 +1,55 @@
|
|||
DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense"
|
||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense
|
||||
|
||||
SECRET_KEY="SECRET"
|
||||
NEXT_FRONTEND_URL="http://localhost:3000"
|
||||
SECRET_KEY=SECRET
|
||||
NEXT_FRONTEND_URL=http://localhost:3000
|
||||
|
||||
#Auth
|
||||
AUTH_TYPE="GOOGLE" or "LOCAL"
|
||||
AUTH_TYPE=GOOGLE or LOCAL
|
||||
# For Google Auth Only
|
||||
GOOGLE_OAUTH_CLIENT_ID="924507538m"
|
||||
GOOGLE_OAUTH_CLIENT_SECRET="GOCSV"
|
||||
GOOGLE_OAUTH_CLIENT_ID=924507538m
|
||||
GOOGLE_OAUTH_CLIENT_SECRET=GOCSV
|
||||
|
||||
#Embedding Model
|
||||
EMBEDDING_MODEL="mixedbread-ai/mxbai-embed-large-v1"
|
||||
EMBEDDING_MODEL=mixedbread-ai/mxbai-embed-large-v1
|
||||
|
||||
RERANKERS_MODEL_NAME="ms-marco-MiniLM-L-12-v2"
|
||||
RERANKERS_MODEL_TYPE="flashrank"
|
||||
RERANKERS_MODEL_NAME=ms-marco-MiniLM-L-12-v2
|
||||
RERANKERS_MODEL_TYPE=flashrank
|
||||
|
||||
# https://docs.litellm.ai/docs/providers
|
||||
FAST_LLM="openai/gpt-4o-mini"
|
||||
STRATEGIC_LLM="openai/gpt-4o"
|
||||
LONG_CONTEXT_LLM="gemini/gemini-2.0-flash"
|
||||
|
||||
#LiteLLM TTS Provider: https://docs.litellm.ai/docs/text_to_speech#supported-providers
|
||||
TTS_SERVICE="openai/tts-1"
|
||||
TTS_SERVICE=openai/tts-1
|
||||
#Respective TTS Service API
|
||||
TTS_SERVICE_API_KEY=
|
||||
#OPTIONAL: TTS Provider API Base
|
||||
TTS_SERVICE_API_BASE=
|
||||
|
||||
#LiteLLM STT Provider: https://docs.litellm.ai/docs/audio_transcription#supported-providers
|
||||
STT_SERVICE="openai/whisper-1"
|
||||
STT_SERVICE=openai/whisper-1
|
||||
#Respective STT Service API
|
||||
STT_SERVICE_API_KEY=""
|
||||
#OPTIONAL: STT Provider API Base
|
||||
STT_SERVICE_API_BASE=
|
||||
|
||||
# Chosen LiteLLM Providers Keys
|
||||
OPENAI_API_KEY="sk-proj-iA"
|
||||
GEMINI_API_KEY="AIzaSyB6-1641124124124124124124124124124"
|
||||
|
||||
FIRECRAWL_API_KEY="fcr-01J0000000000000000000000"
|
||||
FIRECRAWL_API_KEY=fcr-01J0000000000000000000000
|
||||
|
||||
#File Parser Service
|
||||
ETL_SERVICE="UNSTRUCTURED" or "LLAMACLOUD"
|
||||
UNSTRUCTURED_API_KEY="Tpu3P0U8iy"
|
||||
LLAMA_CLOUD_API_KEY="llx-nnn"
|
||||
ETL_SERVICE=UNSTRUCTURED or LLAMACLOUD
|
||||
UNSTRUCTURED_API_KEY=Tpu3P0U8iy
|
||||
LLAMA_CLOUD_API_KEY=llx-nnn
|
||||
|
||||
#OPTIONAL: Add these for LangSmith Observability
|
||||
LANGSMITH_TRACING=true
|
||||
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
|
||||
LANGSMITH_API_KEY="lsv2_pt_....."
|
||||
LANGSMITH_PROJECT="surfsense"
|
||||
LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
||||
LANGSMITH_API_KEY=lsv2_pt_.....
|
||||
LANGSMITH_PROJECT=surfsense
|
||||
|
||||
# OPTIONAL: LiteLLM API Base
|
||||
FAST_LLM_API_BASE=""
|
||||
STRATEGIC_LLM_API_BASE=""
|
||||
LONG_CONTEXT_LLM_API_BASE=""
|
||||
TTS_SERVICE_API_BASE=""
|
||||
STT_SERVICE_API_BASE=""
|
||||
|
||||
|
||||
# FAST_LLM=openai/gpt-4o-mini
|
||||
# STRATEGIC_LLM=openai/gpt-4o
|
||||
# LONG_CONTEXT_LLM=gemini/gemini-2.0-flash
|
||||
|
||||
# FAST_LLM=ollama/gemma3:12b
|
||||
# STRATEGIC_LLM=ollama/deepseek-r1:8b
|
||||
# LONG_CONTEXT_LLM=ollama/deepseek-r1:8b
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
"""Add LLMConfig table and user LLM preferences
|
||||
|
||||
Revision ID: 11
|
||||
Revises: 10
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "11"
|
||||
down_revision: Union[str, None] = "10"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema - add LiteLLMProvider enum, LLMConfig table and user LLM preferences."""
|
||||
|
||||
# Check if enum type exists and create if it doesn't
|
||||
op.execute("""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'litellmprovider') THEN
|
||||
CREATE TYPE litellmprovider AS ENUM ('OPENAI', 'ANTHROPIC', 'GROQ', 'COHERE', 'HUGGINGFACE', 'AZURE_OPENAI', 'GOOGLE', 'AWS_BEDROCK', 'OLLAMA', 'MISTRAL', 'TOGETHER_AI', 'REPLICATE', 'PALM', 'VERTEX_AI', 'ANYSCALE', 'PERPLEXITY', 'DEEPINFRA', 'AI21', 'NLPCLOUD', 'ALEPH_ALPHA', 'PETALS', 'CUSTOM');
|
||||
END IF;
|
||||
END$$;
|
||||
""")
|
||||
|
||||
# Create llm_configs table using raw SQL to avoid enum creation conflicts
|
||||
op.execute("""
|
||||
CREATE TABLE llm_configs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
name VARCHAR(100) NOT NULL,
|
||||
provider litellmprovider NOT NULL,
|
||||
custom_provider VARCHAR(100),
|
||||
model_name VARCHAR(100) NOT NULL,
|
||||
api_key TEXT NOT NULL,
|
||||
api_base VARCHAR(500),
|
||||
litellm_params JSONB,
|
||||
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes
|
||||
op.create_index(op.f('ix_llm_configs_id'), 'llm_configs', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_llm_configs_created_at'), 'llm_configs', ['created_at'], unique=False)
|
||||
op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False)
|
||||
|
||||
# Add LLM preference columns to user table
|
||||
op.add_column('user', sa.Column('long_context_llm_id', sa.Integer(), nullable=True))
|
||||
op.add_column('user', sa.Column('fast_llm_id', sa.Integer(), nullable=True))
|
||||
op.add_column('user', sa.Column('strategic_llm_id', sa.Integer(), nullable=True))
|
||||
|
||||
# Create foreign key constraints for LLM preferences
|
||||
op.create_foreign_key(op.f('fk_user_long_context_llm_id_llm_configs'), 'user', 'llm_configs', ['long_context_llm_id'], ['id'], ondelete='SET NULL')
|
||||
op.create_foreign_key(op.f('fk_user_fast_llm_id_llm_configs'), 'user', 'llm_configs', ['fast_llm_id'], ['id'], ondelete='SET NULL')
|
||||
op.create_foreign_key(op.f('fk_user_strategic_llm_id_llm_configs'), 'user', 'llm_configs', ['strategic_llm_id'], ['id'], ondelete='SET NULL')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema - remove LLMConfig table and user LLM preferences."""
|
||||
|
||||
# Drop foreign key constraints
|
||||
op.drop_constraint(op.f('fk_user_strategic_llm_id_llm_configs'), 'user', type_='foreignkey')
|
||||
op.drop_constraint(op.f('fk_user_fast_llm_id_llm_configs'), 'user', type_='foreignkey')
|
||||
op.drop_constraint(op.f('fk_user_long_context_llm_id_llm_configs'), 'user', type_='foreignkey')
|
||||
|
||||
# Drop LLM preference columns from user table
|
||||
op.drop_column('user', 'strategic_llm_id')
|
||||
op.drop_column('user', 'fast_llm_id')
|
||||
op.drop_column('user', 'long_context_llm_id')
|
||||
|
||||
# Drop indexes and table
|
||||
op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs')
|
||||
op.drop_index(op.f('ix_llm_configs_created_at'), table_name='llm_configs')
|
||||
op.drop_index(op.f('ix_llm_configs_id'), table_name='llm_configs')
|
||||
op.drop_table('llm_configs')
|
||||
|
||||
# Drop LiteLLMProvider enum
|
||||
op.execute("DROP TYPE IF EXISTS litellmprovider")
|
|
@ -16,7 +16,8 @@ class Configuration:
|
|||
# these values can be pre-set when you
|
||||
# create assistants (https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/)
|
||||
# and when you invoke the graph
|
||||
podcast_title: str
|
||||
podcast_title: str
|
||||
user_id: str
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
|
|
|
@ -14,13 +14,22 @@ from .configuration import Configuration
|
|||
from .state import PodcastTranscriptEntry, State, PodcastTranscripts
|
||||
from .prompts import get_podcast_generation_prompt
|
||||
from app.config import config as app_config
|
||||
from app.utils.llm_service import get_user_long_context_llm
|
||||
|
||||
|
||||
async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""Each node does work."""
|
||||
|
||||
# Initialize LLM
|
||||
llm = app_config.long_context_llm_instance
|
||||
# Get configuration from runnable config
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
user_id = configuration.user_id
|
||||
|
||||
# Get user's long context LLM
|
||||
llm = await get_user_long_context_llm(state.db_session, user_id)
|
||||
if not llm:
|
||||
error_message = f"No long context LLM configured for user {user_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
# Get the prompt
|
||||
prompt = get_podcast_generation_prompt()
|
||||
|
@ -139,6 +148,7 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
|
|||
response = await aspeech(
|
||||
model=app_config.TTS_SERVICE,
|
||||
api_base=app_config.TTS_SERVICE_API_BASE,
|
||||
api_key=app_config.TTS_SERVICE_API_KEY,
|
||||
voice=voice,
|
||||
input=dialog,
|
||||
max_retries=2,
|
||||
|
@ -147,6 +157,7 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D
|
|||
else:
|
||||
response = await aspeech(
|
||||
model=app_config.TTS_SERVICE,
|
||||
api_key=app_config.TTS_SERVICE_API_KEY,
|
||||
voice=voice,
|
||||
input=dialog,
|
||||
max_retries=2,
|
||||
|
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
class PodcastTranscriptEntry(BaseModel):
|
||||
"""
|
||||
|
@ -32,7 +32,8 @@ class State:
|
|||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
"""
|
||||
|
||||
# Runtime context
|
||||
db_session: AsyncSession
|
||||
source_content: str
|
||||
podcast_transcript: Optional[List[PodcastTranscriptEntry]] = None
|
||||
final_podcast_file_path: Optional[str] = None
|
||||
|
|
|
@ -2,7 +2,6 @@ import asyncio
|
|||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import async_session_maker
|
||||
from app.utils.connector_service import ConnectorService
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
@ -274,6 +273,9 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
|
|||
Returns:
|
||||
Dict containing the answer outline in the "answer_outline" key for state update.
|
||||
"""
|
||||
from app.utils.llm_service import get_user_strategic_llm
|
||||
from app.db import get_async_session
|
||||
|
||||
streaming_service = state.streaming_service
|
||||
|
||||
streaming_service.only_update_terminal("🔍 Generating answer outline...")
|
||||
|
@ -283,12 +285,18 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
|
|||
reformulated_query = state.reformulated_query
|
||||
user_query = configuration.user_query
|
||||
num_sections = configuration.num_sections
|
||||
user_id = configuration.user_id
|
||||
|
||||
streaming_service.only_update_terminal(f"🤔 Planning research approach for: \"{user_query[:100]}...\"")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Initialize LLM
|
||||
llm = app_config.strategic_llm_instance
|
||||
# Get user's strategic LLM
|
||||
llm = await get_user_strategic_llm(state.db_session, user_id)
|
||||
if not llm:
|
||||
error_message = f"No strategic LLM configured for user {user_id}"
|
||||
streaming_service.only_update_terminal(f"❌ {error_message}", "error")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
# Create the human message content
|
||||
human_message_content = f"""
|
||||
|
@ -828,48 +836,47 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
|
|||
user_selected_documents = []
|
||||
user_selected_sources = []
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
try:
|
||||
# First, fetch user-selected documents if any
|
||||
if configuration.document_ids_to_add_in_context:
|
||||
streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
user_selected_sources, user_selected_documents = await fetch_documents_by_ids(
|
||||
document_ids=configuration.document_ids_to_add_in_context,
|
||||
user_id=configuration.user_id,
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
if user_selected_documents:
|
||||
streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Create connector service inside the db_session scope
|
||||
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
|
||||
await connector_service.initialize_counter()
|
||||
|
||||
relevant_documents = await fetch_relevant_documents(
|
||||
research_questions=all_questions,
|
||||
user_id=configuration.user_id,
|
||||
search_space_id=configuration.search_space_id,
|
||||
db_session=db_session,
|
||||
connectors_to_search=configuration.connectors_to_search,
|
||||
writer=writer,
|
||||
state=state,
|
||||
top_k=TOP_K,
|
||||
connector_service=connector_service,
|
||||
search_mode=configuration.search_mode,
|
||||
user_selected_sources=user_selected_sources
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = f"Error fetching relevant documents: {str(e)}"
|
||||
print(error_message)
|
||||
streaming_service.only_update_terminal(f"❌ {error_message}", "error")
|
||||
try:
|
||||
# First, fetch user-selected documents if any
|
||||
if configuration.document_ids_to_add_in_context:
|
||||
streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
# Log the error and continue with an empty list of documents
|
||||
# This allows the process to continue, but the report might lack information
|
||||
relevant_documents = []
|
||||
|
||||
user_selected_sources, user_selected_documents = await fetch_documents_by_ids(
|
||||
document_ids=configuration.document_ids_to_add_in_context,
|
||||
user_id=configuration.user_id,
|
||||
db_session=state.db_session
|
||||
)
|
||||
|
||||
if user_selected_documents:
|
||||
streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Create connector service using state db_session
|
||||
connector_service = ConnectorService(state.db_session, user_id=configuration.user_id)
|
||||
await connector_service.initialize_counter()
|
||||
|
||||
relevant_documents = await fetch_relevant_documents(
|
||||
research_questions=all_questions,
|
||||
user_id=configuration.user_id,
|
||||
search_space_id=configuration.search_space_id,
|
||||
db_session=state.db_session,
|
||||
connectors_to_search=configuration.connectors_to_search,
|
||||
writer=writer,
|
||||
state=state,
|
||||
top_k=TOP_K,
|
||||
connector_service=connector_service,
|
||||
search_mode=configuration.search_mode,
|
||||
user_selected_sources=user_selected_sources
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = f"Error fetching relevant documents: {str(e)}"
|
||||
print(error_message)
|
||||
streaming_service.only_update_terminal(f"❌ {error_message}", "error")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
# Log the error and continue with an empty list of documents
|
||||
# This allows the process to continue, but the report might lack information
|
||||
relevant_documents = []
|
||||
|
||||
# Combine user-selected documents with connector-fetched documents
|
||||
all_documents = user_selected_documents + relevant_documents
|
||||
|
@ -1014,82 +1021,80 @@ async def process_section_with_documents(
|
|||
for question in section_questions
|
||||
]
|
||||
|
||||
# Create a new database session for this section
|
||||
async with async_session_maker() as db_session:
|
||||
# Call the sub_section_writer graph with the appropriate config
|
||||
config = {
|
||||
"configurable": {
|
||||
"sub_section_title": section_title,
|
||||
"sub_section_questions": section_questions,
|
||||
"sub_section_type": sub_section_type,
|
||||
"user_query": user_query,
|
||||
"relevant_documents": documents_to_use,
|
||||
"user_id": user_id,
|
||||
"search_space_id": search_space_id
|
||||
}
|
||||
# Call the sub_section_writer graph with the appropriate config
|
||||
config = {
|
||||
"configurable": {
|
||||
"sub_section_title": section_title,
|
||||
"sub_section_questions": section_questions,
|
||||
"sub_section_type": sub_section_type,
|
||||
"user_query": user_query,
|
||||
"relevant_documents": documents_to_use,
|
||||
"user_id": user_id,
|
||||
"search_space_id": search_space_id
|
||||
}
|
||||
|
||||
# Create the initial state with db_session and chat_history
|
||||
sub_state = {
|
||||
"db_session": db_session,
|
||||
"chat_history": state.chat_history
|
||||
}
|
||||
|
||||
# Invoke the sub-section writer graph with streaming
|
||||
print(f"Invoking sub_section_writer for: {section_title}")
|
||||
if state and state.streaming_service and writer:
|
||||
state.streaming_service.only_update_terminal(f"🧠 Analyzing information and drafting content for section: \"{section_title}\"")
|
||||
writer({"yeild_value": state.streaming_service._format_annotations()})
|
||||
|
||||
# Variables to track streaming state
|
||||
complete_content = "" # Tracks the complete content received so far
|
||||
|
||||
async for chunk_type, chunk in sub_section_writer_graph.astream(sub_state, config, stream_mode=["values"]):
|
||||
if "final_answer" in chunk:
|
||||
new_content = chunk["final_answer"]
|
||||
if new_content and new_content != complete_content:
|
||||
# Extract only the new content (delta)
|
||||
delta = new_content[len(complete_content):]
|
||||
}
|
||||
|
||||
# Create the initial state with db_session and chat_history
|
||||
sub_state = {
|
||||
"db_session": state.db_session,
|
||||
"chat_history": state.chat_history
|
||||
}
|
||||
|
||||
# Invoke the sub-section writer graph with streaming
|
||||
print(f"Invoking sub_section_writer for: {section_title}")
|
||||
if state and state.streaming_service and writer:
|
||||
state.streaming_service.only_update_terminal(f"🧠 Analyzing information and drafting content for section: \"{section_title}\"")
|
||||
writer({"yeild_value": state.streaming_service._format_annotations()})
|
||||
|
||||
# Variables to track streaming state
|
||||
complete_content = "" # Tracks the complete content received so far
|
||||
|
||||
async for chunk_type, chunk in sub_section_writer_graph.astream(sub_state, config, stream_mode=["values"]):
|
||||
if "final_answer" in chunk:
|
||||
new_content = chunk["final_answer"]
|
||||
if new_content and new_content != complete_content:
|
||||
# Extract only the new content (delta)
|
||||
delta = new_content[len(complete_content):]
|
||||
|
||||
# Update what we've processed so far
|
||||
complete_content = new_content
|
||||
|
||||
# Only stream if there's actual new content
|
||||
if delta and state and state.streaming_service and writer:
|
||||
# Update terminal with real-time progress indicator
|
||||
state.streaming_service.only_update_terminal(f"✍️ Writing section {section_id+1}... ({len(complete_content.split())} words)")
|
||||
|
||||
# Update what we've processed so far
|
||||
complete_content = new_content
|
||||
# Update section_contents with just the new delta
|
||||
section_contents[section_id]["content"] += delta
|
||||
|
||||
# Only stream if there's actual new content
|
||||
if delta and state and state.streaming_service and writer:
|
||||
# Update terminal with real-time progress indicator
|
||||
state.streaming_service.only_update_terminal(f"✍️ Writing section {section_id+1}... ({len(complete_content.split())} words)")
|
||||
|
||||
# Update section_contents with just the new delta
|
||||
section_contents[section_id]["content"] += delta
|
||||
|
||||
# Build UI-friendly content for all sections
|
||||
complete_answer = []
|
||||
for i in range(len(section_contents)):
|
||||
if i in section_contents and section_contents[i]["content"]:
|
||||
# Add section header
|
||||
complete_answer.append(f"# {section_contents[i]['title']}")
|
||||
complete_answer.append("") # Empty line after title
|
||||
|
||||
# Add section content
|
||||
content_lines = section_contents[i]["content"].split("\n")
|
||||
complete_answer.extend(content_lines)
|
||||
complete_answer.append("") # Empty line after content
|
||||
|
||||
# Update answer in UI in real-time
|
||||
state.streaming_service.only_update_answer(complete_answer)
|
||||
writer({"yeild_value": state.streaming_service._format_annotations()})
|
||||
|
||||
# Set default if no content was received
|
||||
if not complete_content:
|
||||
complete_content = "No content was generated for this section."
|
||||
section_contents[section_id]["content"] = complete_content
|
||||
|
||||
# Final terminal update
|
||||
if state and state.streaming_service and writer:
|
||||
state.streaming_service.only_update_terminal(f"✅ Completed section: \"{section_title}\"")
|
||||
writer({"yeild_value": state.streaming_service._format_annotations()})
|
||||
|
||||
return complete_content
|
||||
# Build UI-friendly content for all sections
|
||||
complete_answer = []
|
||||
for i in range(len(section_contents)):
|
||||
if i in section_contents and section_contents[i]["content"]:
|
||||
# Add section header
|
||||
complete_answer.append(f"# {section_contents[i]['title']}")
|
||||
complete_answer.append("") # Empty line after title
|
||||
|
||||
# Add section content
|
||||
content_lines = section_contents[i]["content"].split("\n")
|
||||
complete_answer.extend(content_lines)
|
||||
complete_answer.append("") # Empty line after content
|
||||
|
||||
# Update answer in UI in real-time
|
||||
state.streaming_service.only_update_answer(complete_answer)
|
||||
writer({"yeild_value": state.streaming_service._format_annotations()})
|
||||
|
||||
# Set default if no content was received
|
||||
if not complete_content:
|
||||
complete_content = "No content was generated for this section."
|
||||
section_contents[section_id]["content"] = complete_content
|
||||
|
||||
# Final terminal update
|
||||
if state and state.streaming_service and writer:
|
||||
state.streaming_service.only_update_terminal(f"✅ Completed section: \"{section_title}\"")
|
||||
writer({"yeild_value": state.streaming_service._format_annotations()})
|
||||
|
||||
return complete_content
|
||||
except Exception as e:
|
||||
print(f"Error processing section '{section_title}': {str(e)}")
|
||||
|
||||
|
@ -1113,7 +1118,7 @@ async def reformulate_user_query(state: State, config: RunnableConfig, writer: S
|
|||
if len(state.chat_history) == 0:
|
||||
reformulated_query = user_query
|
||||
else:
|
||||
reformulated_query = await QueryService.reformulate_query_with_chat_history(user_query, chat_history_str)
|
||||
reformulated_query = await QueryService.reformulate_query_with_chat_history(user_query=user_query, session=state.db_session, user_id=configuration.user_id, chat_history_str=chat_history_str)
|
||||
|
||||
return {
|
||||
"reformulated_query": reformulated_query
|
||||
|
@ -1152,50 +1157,49 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
|
|||
user_selected_documents = []
|
||||
user_selected_sources = []
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
try:
|
||||
# First, fetch user-selected documents if any
|
||||
if configuration.document_ids_to_add_in_context:
|
||||
streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
user_selected_sources, user_selected_documents = await fetch_documents_by_ids(
|
||||
document_ids=configuration.document_ids_to_add_in_context,
|
||||
user_id=configuration.user_id,
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
if user_selected_documents:
|
||||
streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Create connector service inside the db_session scope
|
||||
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
|
||||
await connector_service.initialize_counter()
|
||||
|
||||
# Use the reformulated query as a single research question
|
||||
research_questions = [reformulated_query, user_query]
|
||||
|
||||
relevant_documents = await fetch_relevant_documents(
|
||||
research_questions=research_questions,
|
||||
user_id=configuration.user_id,
|
||||
search_space_id=configuration.search_space_id,
|
||||
db_session=db_session,
|
||||
connectors_to_search=configuration.connectors_to_search,
|
||||
writer=writer,
|
||||
state=state,
|
||||
top_k=TOP_K,
|
||||
connector_service=connector_service,
|
||||
search_mode=configuration.search_mode,
|
||||
user_selected_sources=user_selected_sources
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = f"Error fetching relevant documents for QNA: {str(e)}"
|
||||
print(error_message)
|
||||
streaming_service.only_update_terminal(f"❌ {error_message}", "error")
|
||||
try:
|
||||
# First, fetch user-selected documents if any
|
||||
if configuration.document_ids_to_add_in_context:
|
||||
streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
# Continue with empty documents - the QNA agent will handle this gracefully
|
||||
relevant_documents = []
|
||||
|
||||
user_selected_sources, user_selected_documents = await fetch_documents_by_ids(
|
||||
document_ids=configuration.document_ids_to_add_in_context,
|
||||
user_id=configuration.user_id,
|
||||
db_session=state.db_session
|
||||
)
|
||||
|
||||
if user_selected_documents:
|
||||
streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Create connector service using state db_session
|
||||
connector_service = ConnectorService(state.db_session, user_id=configuration.user_id)
|
||||
await connector_service.initialize_counter()
|
||||
|
||||
# Use the reformulated query as a single research question
|
||||
research_questions = [reformulated_query, user_query]
|
||||
|
||||
relevant_documents = await fetch_relevant_documents(
|
||||
research_questions=research_questions,
|
||||
user_id=configuration.user_id,
|
||||
search_space_id=configuration.search_space_id,
|
||||
db_session=state.db_session,
|
||||
connectors_to_search=configuration.connectors_to_search,
|
||||
writer=writer,
|
||||
state=state,
|
||||
top_k=TOP_K,
|
||||
connector_service=connector_service,
|
||||
search_mode=configuration.search_mode,
|
||||
user_selected_sources=user_selected_sources
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = f"Error fetching relevant documents for QNA: {str(e)}"
|
||||
print(error_message)
|
||||
streaming_service.only_update_terminal(f"❌ {error_message}", "error")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
# Continue with empty documents - the QNA agent will handle this gracefully
|
||||
relevant_documents = []
|
||||
|
||||
# Combine user-selected documents with connector-fetched documents
|
||||
all_documents = user_selected_documents + relevant_documents
|
||||
|
|
|
@ -85,14 +85,20 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
|
|||
Returns:
|
||||
Dict containing the final answer in the "final_answer" key.
|
||||
"""
|
||||
from app.utils.llm_service import get_user_fast_llm
|
||||
|
||||
# Get configuration and relevant documents from configuration
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
documents = state.reranked_documents
|
||||
user_query = configuration.user_query
|
||||
user_id = configuration.user_id
|
||||
|
||||
# Initialize LLM
|
||||
llm = app_config.fast_llm_instance
|
||||
# Get user's fast LLM
|
||||
llm = await get_user_fast_llm(state.db_session, user_id)
|
||||
if not llm:
|
||||
error_message = f"No fast LLM configured for user {user_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
# Determine if we have documents and optimize for token limits
|
||||
has_documents_initially = documents and len(documents) > 0
|
||||
|
@ -118,7 +124,7 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
|
|||
|
||||
# Optimize documents to fit within token limits
|
||||
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit(
|
||||
documents, base_messages, app_config.FAST_LLM
|
||||
documents, base_messages, llm.model
|
||||
)
|
||||
|
||||
# Update state based on optimization result
|
||||
|
@ -161,7 +167,7 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any
|
|||
]
|
||||
|
||||
# Log final token count
|
||||
total_tokens = calculate_token_count(messages_with_chat_history, app_config.FAST_LLM)
|
||||
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
|
||||
print(f"Final token count: {total_tokens}")
|
||||
|
||||
|
||||
|
|
|
@ -91,13 +91,19 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
Returns:
|
||||
Dict containing the final answer in the "final_answer" key.
|
||||
"""
|
||||
from app.utils.llm_service import get_user_fast_llm
|
||||
|
||||
# Get configuration and relevant documents from configuration
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
documents = state.reranked_documents
|
||||
user_id = configuration.user_id
|
||||
|
||||
# Initialize LLM
|
||||
llm = app_config.fast_llm_instance
|
||||
# Get user's fast LLM
|
||||
llm = await get_user_fast_llm(state.db_session, user_id)
|
||||
if not llm:
|
||||
error_message = f"No fast LLM configured for user {user_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
# Extract configuration data
|
||||
section_title = configuration.sub_section_title
|
||||
|
@ -153,7 +159,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
|
||||
# Optimize documents to fit within token limits
|
||||
optimized_documents, has_optimized_documents = optimize_documents_for_token_limit(
|
||||
documents, base_messages, app_config.FAST_LLM
|
||||
documents, base_messages, llm.model
|
||||
)
|
||||
|
||||
# Update state based on optimization result
|
||||
|
@ -206,7 +212,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A
|
|||
]
|
||||
|
||||
# Log final token count
|
||||
total_tokens = calculate_token_count(messages_with_chat_history, app_config.FAST_LLM)
|
||||
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
|
||||
print(f"Final token count: {total_tokens}")
|
||||
|
||||
# Call the LLM and get the response
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from typing import List, Dict, Any, Tuple, NamedTuple
|
||||
from langchain_core.messages import BaseMessage
|
||||
from litellm import token_counter, get_model_info
|
||||
from app.config import config as app_config
|
||||
|
||||
|
||||
class DocumentTokenInfo(NamedTuple):
|
||||
|
@ -127,7 +126,7 @@ def get_model_context_window(model_name: str) -> int:
|
|||
def optimize_documents_for_token_limit(
|
||||
documents: List[Dict[str, Any]],
|
||||
base_messages: List[BaseMessage],
|
||||
model_name: str = None
|
||||
model_name: str
|
||||
) -> Tuple[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
Optimize documents to fit within token limits using binary search.
|
||||
|
@ -135,7 +134,7 @@ def optimize_documents_for_token_limit(
|
|||
Args:
|
||||
documents: List of documents with content and metadata
|
||||
base_messages: Base messages without documents (chat history + system + human message template)
|
||||
model_name: Model name for token counting (defaults to app_config.FAST_LLM)
|
||||
model_name: Model name for token counting (required)
|
||||
output_token_buffer: Number of tokens to reserve for model output
|
||||
|
||||
Returns:
|
||||
|
@ -144,7 +143,7 @@ def optimize_documents_for_token_limit(
|
|||
if not documents:
|
||||
return [], False
|
||||
|
||||
model = model_name or app_config.FAST_LLM
|
||||
model = model_name
|
||||
context_window = get_model_context_window(model)
|
||||
|
||||
# Calculate base token cost
|
||||
|
@ -178,8 +177,8 @@ def optimize_documents_for_token_limit(
|
|||
return optimized_documents, has_documents_remaining
|
||||
|
||||
|
||||
def calculate_token_count(messages: List[BaseMessage], model_name: str = None) -> int:
|
||||
def calculate_token_count(messages: List[BaseMessage], model_name: str) -> int:
|
||||
"""Calculate token count for a list of LangChain messages."""
|
||||
model = model_name or app_config.FAST_LLM
|
||||
model = model_name
|
||||
messages_dict = convert_langchain_messages_to_dict(messages)
|
||||
return token_counter(messages=messages_dict, model=model)
|
||||
|
|
|
@ -4,7 +4,7 @@ import shutil
|
|||
|
||||
from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker
|
||||
from dotenv import load_dotenv
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
|
||||
from rerankers import Reranker
|
||||
|
||||
|
||||
|
@ -49,31 +49,8 @@ class Config:
|
|||
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||
|
||||
|
||||
# LONG-CONTEXT LLMS
|
||||
LONG_CONTEXT_LLM = os.getenv("LONG_CONTEXT_LLM")
|
||||
LONG_CONTEXT_LLM_API_BASE = os.getenv("LONG_CONTEXT_LLM_API_BASE")
|
||||
if LONG_CONTEXT_LLM_API_BASE:
|
||||
long_context_llm_instance = ChatLiteLLM(model=LONG_CONTEXT_LLM, api_base=LONG_CONTEXT_LLM_API_BASE)
|
||||
else:
|
||||
long_context_llm_instance = ChatLiteLLM(model=LONG_CONTEXT_LLM)
|
||||
|
||||
# FAST LLM
|
||||
FAST_LLM = os.getenv("FAST_LLM")
|
||||
FAST_LLM_API_BASE = os.getenv("FAST_LLM_API_BASE")
|
||||
if FAST_LLM_API_BASE:
|
||||
fast_llm_instance = ChatLiteLLM(model=FAST_LLM, api_base=FAST_LLM_API_BASE)
|
||||
else:
|
||||
fast_llm_instance = ChatLiteLLM(model=FAST_LLM)
|
||||
|
||||
|
||||
|
||||
# STRATEGIC LLM
|
||||
STRATEGIC_LLM = os.getenv("STRATEGIC_LLM")
|
||||
STRATEGIC_LLM_API_BASE = os.getenv("STRATEGIC_LLM_API_BASE")
|
||||
if STRATEGIC_LLM_API_BASE:
|
||||
strategic_llm_instance = ChatLiteLLM(model=STRATEGIC_LLM, api_base=STRATEGIC_LLM_API_BASE)
|
||||
else:
|
||||
strategic_llm_instance = ChatLiteLLM(model=STRATEGIC_LLM)
|
||||
# LLM instances are now managed per-user through the LLMConfig system
|
||||
# Legacy environment variables removed in favor of user-specific configurations
|
||||
|
||||
# Chonkie Configuration | Edit this to your needs
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||
|
@ -114,10 +91,12 @@ class Config:
|
|||
# Litellm TTS Configuration
|
||||
TTS_SERVICE = os.getenv("TTS_SERVICE")
|
||||
TTS_SERVICE_API_BASE = os.getenv("TTS_SERVICE_API_BASE")
|
||||
TTS_SERVICE_API_KEY = os.getenv("TTS_SERVICE_API_KEY")
|
||||
|
||||
# Litellm STT Configuration
|
||||
STT_SERVICE = os.getenv("STT_SERVICE")
|
||||
STT_SERVICE_API_BASE = os.getenv("STT_SERVICE_API_BASE")
|
||||
STT_SERVICE_API_KEY = os.getenv("STT_SERVICE_API_KEY")
|
||||
|
||||
|
||||
# Validation Checks
|
||||
|
|
|
@ -67,6 +67,30 @@ class ChatType(str, Enum):
|
|||
REPORT_GENERAL = "REPORT_GENERAL"
|
||||
REPORT_DEEP = "REPORT_DEEP"
|
||||
REPORT_DEEPER = "REPORT_DEEPER"
|
||||
|
||||
class LiteLLMProvider(str, Enum):
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
GROQ = "GROQ"
|
||||
COHERE = "COHERE"
|
||||
HUGGINGFACE = "HUGGINGFACE"
|
||||
AZURE_OPENAI = "AZURE_OPENAI"
|
||||
GOOGLE = "GOOGLE"
|
||||
AWS_BEDROCK = "AWS_BEDROCK"
|
||||
OLLAMA = "OLLAMA"
|
||||
MISTRAL = "MISTRAL"
|
||||
TOGETHER_AI = "TOGETHER_AI"
|
||||
REPLICATE = "REPLICATE"
|
||||
PALM = "PALM"
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
ANYSCALE = "ANYSCALE"
|
||||
PERPLEXITY = "PERPLEXITY"
|
||||
DEEPINFRA = "DEEPINFRA"
|
||||
AI21 = "AI21"
|
||||
NLPCLOUD = "NLPCLOUD"
|
||||
ALEPH_ALPHA = "ALEPH_ALPHA"
|
||||
PETALS = "PETALS"
|
||||
CUSTOM = "CUSTOM"
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
@ -152,6 +176,26 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
|
|||
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
|
||||
user = relationship("User", back_populates="search_source_connectors")
|
||||
|
||||
class LLMConfig(BaseModel, TimestampMixin):
|
||||
__tablename__ = "llm_configs"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
# Provider from the enum
|
||||
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
|
||||
# Custom provider name when provider is CUSTOM
|
||||
custom_provider = Column(String(100), nullable=True)
|
||||
# Just the model name without provider prefix
|
||||
model_name = Column(String(100), nullable=False)
|
||||
# API Key should be encrypted before storing
|
||||
api_key = Column(String, nullable=False)
|
||||
api_base = Column(String(500), nullable=True)
|
||||
|
||||
# For any other parameters that litellm supports
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
|
||||
user = relationship("User", back_populates="llm_configs", foreign_keys=[user_id])
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
pass
|
||||
|
@ -163,11 +207,29 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
)
|
||||
search_spaces = relationship("SearchSpace", back_populates="user")
|
||||
search_source_connectors = relationship("SearchSourceConnector", back_populates="user")
|
||||
llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan")
|
||||
|
||||
long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
||||
fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
||||
strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
||||
|
||||
long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True)
|
||||
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
|
||||
strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True)
|
||||
else:
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
||||
search_spaces = relationship("SearchSpace", back_populates="user")
|
||||
search_source_connectors = relationship("SearchSourceConnector", back_populates="user")
|
||||
llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan")
|
||||
|
||||
long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
||||
fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
||||
strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
||||
|
||||
long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True)
|
||||
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
|
||||
strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True)
|
||||
|
||||
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
|
|
|
@ -4,6 +4,7 @@ from .documents_routes import router as documents_router
|
|||
from .podcasts_routes import router as podcasts_router
|
||||
from .chats_routes import router as chats_router
|
||||
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||
from .llm_config_routes import router as llm_config_router
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -12,3 +13,4 @@ router.include_router(documents_router)
|
|||
router.include_router(podcasts_router)
|
||||
router.include_router(chats_router)
|
||||
router.include_router(search_source_connectors_router)
|
||||
router.include_router(llm_config_router)
|
||||
|
|
|
@ -38,21 +38,24 @@ async def create_documents(
|
|||
fastapi_background_tasks.add_task(
|
||||
process_extension_document_with_new_session,
|
||||
individual_document,
|
||||
request.search_space_id
|
||||
request.search_space_id,
|
||||
str(user.id)
|
||||
)
|
||||
elif request.document_type == DocumentType.CRAWLED_URL:
|
||||
for url in request.content:
|
||||
fastapi_background_tasks.add_task(
|
||||
process_crawled_url_with_new_session,
|
||||
url,
|
||||
request.search_space_id
|
||||
request.search_space_id,
|
||||
str(user.id)
|
||||
)
|
||||
elif request.document_type == DocumentType.YOUTUBE_VIDEO:
|
||||
for url in request.content:
|
||||
fastapi_background_tasks.add_task(
|
||||
process_youtube_video_with_new_session,
|
||||
url,
|
||||
request.search_space_id
|
||||
request.search_space_id,
|
||||
str(user.id)
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
@ -106,7 +109,8 @@ async def create_documents(
|
|||
process_file_in_background_with_new_session,
|
||||
temp_path,
|
||||
file.filename,
|
||||
search_space_id
|
||||
search_space_id,
|
||||
str(user.id)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
@ -130,6 +134,7 @@ async def process_file_in_background(
|
|||
file_path: str,
|
||||
filename: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
session: AsyncSession
|
||||
):
|
||||
try:
|
||||
|
@ -151,7 +156,8 @@ async def process_file_in_background(
|
|||
session,
|
||||
filename,
|
||||
markdown_content,
|
||||
search_space_id
|
||||
search_space_id,
|
||||
user_id
|
||||
)
|
||||
# Check if the file is an audio file
|
||||
elif filename.lower().endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')):
|
||||
|
@ -162,11 +168,13 @@ async def process_file_in_background(
|
|||
transcription_response = await atranscription(
|
||||
model=app_config.STT_SERVICE,
|
||||
file=audio_file,
|
||||
api_base=app_config.STT_SERVICE_API_BASE
|
||||
api_base=app_config.STT_SERVICE_API_BASE,
|
||||
api_key=app_config.STT_SERVICE_API_KEY
|
||||
)
|
||||
else:
|
||||
transcription_response = await atranscription(
|
||||
model=app_config.STT_SERVICE,
|
||||
api_key=app_config.STT_SERVICE_API_KEY,
|
||||
file=audio_file
|
||||
)
|
||||
|
||||
|
@ -187,7 +195,8 @@ async def process_file_in_background(
|
|||
session,
|
||||
filename,
|
||||
transcribed_text,
|
||||
search_space_id
|
||||
search_space_id,
|
||||
user_id
|
||||
)
|
||||
else:
|
||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||
|
@ -218,7 +227,8 @@ async def process_file_in_background(
|
|||
session,
|
||||
filename,
|
||||
docs,
|
||||
search_space_id
|
||||
search_space_id,
|
||||
user_id
|
||||
)
|
||||
elif app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||
from llama_cloud_services import LlamaParse
|
||||
|
@ -256,7 +266,8 @@ async def process_file_in_background(
|
|||
session,
|
||||
filename,
|
||||
llamacloud_markdown_document=markdown_content,
|
||||
search_space_id=search_space_id
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
@ -426,14 +437,15 @@ async def delete_document(
|
|||
|
||||
async def process_extension_document_with_new_session(
|
||||
individual_document,
|
||||
search_space_id: int
|
||||
search_space_id: int,
|
||||
user_id: str
|
||||
):
|
||||
"""Create a new session and process extension document."""
|
||||
from app.db import async_session_maker
|
||||
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
await add_extension_received_document(session, individual_document, search_space_id)
|
||||
await add_extension_received_document(session, individual_document, search_space_id, user_id)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.error(f"Error processing extension document: {str(e)}")
|
||||
|
@ -441,14 +453,15 @@ async def process_extension_document_with_new_session(
|
|||
|
||||
async def process_crawled_url_with_new_session(
|
||||
url: str,
|
||||
search_space_id: int
|
||||
search_space_id: int,
|
||||
user_id: str
|
||||
):
|
||||
"""Create a new session and process crawled URL."""
|
||||
from app.db import async_session_maker
|
||||
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
await add_crawled_url_document(session, url, search_space_id)
|
||||
await add_crawled_url_document(session, url, search_space_id, user_id)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.error(f"Error processing crawled URL: {str(e)}")
|
||||
|
@ -457,25 +470,27 @@ async def process_crawled_url_with_new_session(
|
|||
async def process_file_in_background_with_new_session(
|
||||
file_path: str,
|
||||
filename: str,
|
||||
search_space_id: int
|
||||
search_space_id: int,
|
||||
user_id: str
|
||||
):
|
||||
"""Create a new session and process file."""
|
||||
from app.db import async_session_maker
|
||||
|
||||
async with async_session_maker() as session:
|
||||
await process_file_in_background(file_path, filename, search_space_id, session)
|
||||
await process_file_in_background(file_path, filename, search_space_id, user_id, session)
|
||||
|
||||
|
||||
async def process_youtube_video_with_new_session(
|
||||
url: str,
|
||||
search_space_id: int
|
||||
search_space_id: int,
|
||||
user_id: str
|
||||
):
|
||||
"""Create a new session and process YouTube video."""
|
||||
from app.db import async_session_maker
|
||||
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
await add_youtube_video_document(session, url, search_space_id)
|
||||
await add_youtube_video_document(session, url, search_space_id, user_id)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.error(f"Error processing YouTube video: {str(e)}")
|
||||
|
|
243
surfsense_backend/app/routes/llm_config_routes.py
Normal file
243
surfsense_backend/app/routes/llm_config_routes.py
Normal file
|
@ -0,0 +1,243 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
from app.db import get_async_session, User, LLMConfig
|
||||
from app.schemas import LLMConfigCreate, LLMConfigUpdate, LLMConfigRead
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class LLMPreferencesUpdate(BaseModel):
|
||||
"""Schema for updating user LLM preferences"""
|
||||
long_context_llm_id: Optional[int] = None
|
||||
fast_llm_id: Optional[int] = None
|
||||
strategic_llm_id: Optional[int] = None
|
||||
|
||||
class LLMPreferencesRead(BaseModel):
|
||||
"""Schema for reading user LLM preferences"""
|
||||
long_context_llm_id: Optional[int] = None
|
||||
fast_llm_id: Optional[int] = None
|
||||
strategic_llm_id: Optional[int] = None
|
||||
long_context_llm: Optional[LLMConfigRead] = None
|
||||
fast_llm: Optional[LLMConfigRead] = None
|
||||
strategic_llm: Optional[LLMConfigRead] = None
|
||||
|
||||
@router.post("/llm-configs/", response_model=LLMConfigRead)
|
||||
async def create_llm_config(
|
||||
llm_config: LLMConfigCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""Create a new LLM configuration for the authenticated user"""
|
||||
try:
|
||||
db_llm_config = LLMConfig(**llm_config.model_dump(), user_id=user.id)
|
||||
session.add(db_llm_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_llm_config)
|
||||
return db_llm_config
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create LLM configuration: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/llm-configs/", response_model=List[LLMConfigRead])
|
||||
async def read_llm_configs(
|
||||
skip: int = 0,
|
||||
limit: int = 200,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""Get all LLM configurations for the authenticated user"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(LLMConfig)
|
||||
.filter(LLMConfig.user_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch LLM configurations: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
|
||||
async def read_llm_config(
|
||||
llm_config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""Get a specific LLM configuration by ID"""
|
||||
try:
|
||||
llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
|
||||
return llm_config
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch LLM configuration: {str(e)}"
|
||||
)
|
||||
|
||||
@router.put("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
|
||||
async def update_llm_config(
|
||||
llm_config_id: int,
|
||||
llm_config_update: LLMConfigUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""Update an existing LLM configuration"""
|
||||
try:
|
||||
db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
|
||||
update_data = llm_config_update.model_dump(exclude_unset=True)
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(db_llm_config, key, value)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_llm_config)
|
||||
return db_llm_config
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update LLM configuration: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/llm-configs/{llm_config_id}", response_model=dict)
|
||||
async def delete_llm_config(
|
||||
llm_config_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""Delete an LLM configuration"""
|
||||
try:
|
||||
db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
|
||||
await session.delete(db_llm_config)
|
||||
await session.commit()
|
||||
return {"message": "LLM configuration deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete LLM configuration: {str(e)}"
|
||||
)
|
||||
|
||||
# User LLM Preferences endpoints
|
||||
|
||||
@router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead)
|
||||
async def get_user_llm_preferences(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""Get the current user's LLM preferences"""
|
||||
try:
|
||||
# Refresh user to get latest relationships
|
||||
await session.refresh(user)
|
||||
|
||||
result = {
|
||||
"long_context_llm_id": user.long_context_llm_id,
|
||||
"fast_llm_id": user.fast_llm_id,
|
||||
"strategic_llm_id": user.strategic_llm_id,
|
||||
"long_context_llm": None,
|
||||
"fast_llm": None,
|
||||
"strategic_llm": None,
|
||||
}
|
||||
|
||||
# Fetch the actual LLM configs if they exist
|
||||
if user.long_context_llm_id:
|
||||
long_context_llm = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == user.long_context_llm_id,
|
||||
LLMConfig.user_id == user.id
|
||||
)
|
||||
)
|
||||
llm_config = long_context_llm.scalars().first()
|
||||
if llm_config:
|
||||
result["long_context_llm"] = llm_config
|
||||
|
||||
if user.fast_llm_id:
|
||||
fast_llm = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == user.fast_llm_id,
|
||||
LLMConfig.user_id == user.id
|
||||
)
|
||||
)
|
||||
llm_config = fast_llm.scalars().first()
|
||||
if llm_config:
|
||||
result["fast_llm"] = llm_config
|
||||
|
||||
if user.strategic_llm_id:
|
||||
strategic_llm = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == user.strategic_llm_id,
|
||||
LLMConfig.user_id == user.id
|
||||
)
|
||||
)
|
||||
llm_config = strategic_llm.scalars().first()
|
||||
if llm_config:
|
||||
result["strategic_llm"] = llm_config
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch LLM preferences: {str(e)}"
|
||||
)
|
||||
|
||||
@router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead)
|
||||
async def update_user_llm_preferences(
|
||||
preferences: LLMPreferencesUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""Update the current user's LLM preferences"""
|
||||
try:
|
||||
# Validate that all provided LLM config IDs belong to the user
|
||||
update_data = preferences.model_dump(exclude_unset=True)
|
||||
|
||||
for key, llm_config_id in update_data.items():
|
||||
if llm_config_id is not None:
|
||||
# Verify ownership of the LLM config
|
||||
result = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == llm_config_id,
|
||||
LLMConfig.user_id == user.id
|
||||
)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
if not llm_config:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it"
|
||||
)
|
||||
|
||||
# Update user preferences
|
||||
for key, value in update_data.items():
|
||||
setattr(user, key, value)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
# Return updated preferences
|
||||
return await get_user_llm_preferences(session, user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update LLM preferences: {str(e)}"
|
||||
)
|
|
@ -128,14 +128,15 @@ async def delete_podcast(
|
|||
async def generate_chat_podcast_with_new_session(
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
podcast_title: str = "SurfSense Podcast"
|
||||
podcast_title: str,
|
||||
user_id: int
|
||||
):
|
||||
"""Create a new session and process chat podcast generation."""
|
||||
from app.db import async_session_maker
|
||||
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
await generate_chat_podcast(session, chat_id, search_space_id, podcast_title)
|
||||
await generate_chat_podcast(session, chat_id, search_space_id, podcast_title, user_id)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.error(f"Error generating podcast from chat: {str(e)}")
|
||||
|
@ -175,7 +176,8 @@ async def generate_podcast(
|
|||
generate_chat_podcast_with_new_session,
|
||||
chat_id,
|
||||
request.search_space_id,
|
||||
request.podcast_title
|
||||
request.podcast_title,
|
||||
user.id
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
|
@ -328,25 +328,25 @@ async def index_connector_content(
|
|||
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
||||
# Run indexing in background
|
||||
logger.info(f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
|
||||
background_tasks.add_task(run_slack_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to)
|
||||
background_tasks.add_task(run_slack_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
|
||||
response_message = "Slack indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
# Run indexing in background
|
||||
logger.info(f"Triggering Notion indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
|
||||
background_tasks.add_task(run_notion_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to)
|
||||
background_tasks.add_task(run_notion_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
|
||||
response_message = "Notion indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.GITHUB_CONNECTOR:
|
||||
# Run indexing in background
|
||||
logger.info(f"Triggering GitHub indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
|
||||
background_tasks.add_task(run_github_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to)
|
||||
background_tasks.add_task(run_github_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
|
||||
response_message = "GitHub indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR:
|
||||
# Run indexing in background
|
||||
logger.info(f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}")
|
||||
background_tasks.add_task(run_linear_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to)
|
||||
background_tasks.add_task(run_linear_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to)
|
||||
response_message = "Linear indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
|
||||
|
@ -355,7 +355,7 @@ async def index_connector_content(
|
|||
f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
background_tasks.add_task(
|
||||
run_discord_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to
|
||||
run_discord_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Discord indexing started in the background."
|
||||
|
||||
|
@ -410,6 +410,7 @@ async def update_connector_last_indexed(
|
|||
async def run_slack_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
):
|
||||
|
@ -418,12 +419,13 @@ async def run_slack_indexing_with_new_session(
|
|||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_slack_indexing(session, connector_id, search_space_id, start_date, end_date)
|
||||
await run_slack_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
|
||||
|
||||
async def run_slack_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
):
|
||||
|
@ -434,6 +436,7 @@ async def run_slack_indexing(
|
|||
session: Database session
|
||||
connector_id: ID of the Slack connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
|
@ -443,6 +446,7 @@ async def run_slack_indexing(
|
|||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
update_last_indexed=False # Don't update timestamp in the indexing function
|
||||
|
@ -460,6 +464,7 @@ async def run_slack_indexing(
|
|||
async def run_notion_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
):
|
||||
|
@ -468,12 +473,13 @@ async def run_notion_indexing_with_new_session(
|
|||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_notion_indexing(session, connector_id, search_space_id, start_date, end_date)
|
||||
await run_notion_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
|
||||
|
||||
async def run_notion_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
):
|
||||
|
@ -484,6 +490,7 @@ async def run_notion_indexing(
|
|||
session: Database session
|
||||
connector_id: ID of the Notion connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
|
@ -493,6 +500,7 @@ async def run_notion_indexing(
|
|||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
update_last_indexed=False # Don't update timestamp in the indexing function
|
||||
|
@ -511,26 +519,28 @@ async def run_notion_indexing(
|
|||
async def run_github_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
):
|
||||
"""Wrapper to run GitHub indexing with its own database session."""
|
||||
logger.info(f"Background task started: Indexing GitHub connector {connector_id} into space {search_space_id} from {start_date} to {end_date}")
|
||||
async with async_session_maker() as session:
|
||||
await run_github_indexing(session, connector_id, search_space_id, start_date, end_date)
|
||||
await run_github_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
|
||||
logger.info(f"Background task finished: Indexing GitHub connector {connector_id}")
|
||||
|
||||
async def run_github_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
):
|
||||
"""Runs the GitHub indexing task and updates the timestamp."""
|
||||
try:
|
||||
indexed_count, error_message = await index_github_repos(
|
||||
session, connector_id, search_space_id, start_date, end_date, update_last_indexed=False
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date, update_last_indexed=False
|
||||
)
|
||||
if error_message:
|
||||
logger.error(f"GitHub indexing failed for connector {connector_id}: {error_message}")
|
||||
|
@ -549,26 +559,28 @@ async def run_github_indexing(
|
|||
async def run_linear_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
):
|
||||
"""Wrapper to run Linear indexing with its own database session."""
|
||||
logger.info(f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}")
|
||||
async with async_session_maker() as session:
|
||||
await run_linear_indexing(session, connector_id, search_space_id, start_date, end_date)
|
||||
await run_linear_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
|
||||
logger.info(f"Background task finished: Indexing Linear connector {connector_id}")
|
||||
|
||||
async def run_linear_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
):
|
||||
"""Runs the Linear indexing task and updates the timestamp."""
|
||||
try:
|
||||
indexed_count, error_message = await index_linear_issues(
|
||||
session, connector_id, search_space_id, start_date, end_date, update_last_indexed=False
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date, update_last_indexed=False
|
||||
)
|
||||
if error_message:
|
||||
logger.error(f"Linear indexing failed for connector {connector_id}: {error_message}")
|
||||
|
@ -587,6 +599,7 @@ async def run_linear_indexing(
|
|||
async def run_discord_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
):
|
||||
|
@ -595,12 +608,13 @@ async def run_discord_indexing_with_new_session(
|
|||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_discord_indexing(session, connector_id, search_space_id, start_date, end_date)
|
||||
await run_discord_indexing(session, connector_id, search_space_id, user_id, start_date, end_date)
|
||||
|
||||
async def run_discord_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
):
|
||||
|
@ -610,6 +624,7 @@ async def run_discord_indexing(
|
|||
session: Database session
|
||||
connector_id: ID of the Discord connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
|
@ -619,6 +634,7 @@ async def run_discord_indexing(
|
|||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
update_last_indexed=False # Don't update timestamp in the indexing function
|
||||
|
|
|
@ -13,6 +13,7 @@ from .chunks import ChunkBase, ChunkCreate, ChunkUpdate, ChunkRead
|
|||
from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest
|
||||
from .chats import ChatBase, ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
|
||||
from .search_source_connector import SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead
|
||||
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigUpdate, LLMConfigRead
|
||||
|
||||
__all__ = [
|
||||
"AISDKChatRequest",
|
||||
|
@ -48,4 +49,8 @@ __all__ = [
|
|||
"SearchSourceConnectorCreate",
|
||||
"SearchSourceConnectorUpdate",
|
||||
"SearchSourceConnectorRead",
|
||||
"LLMConfigBase",
|
||||
"LLMConfigCreate",
|
||||
"LLMConfigUpdate",
|
||||
"LLMConfigRead",
|
||||
]
|
34
surfsense_backend/app/schemas/llm_config.py
Normal file
34
surfsense_backend/app/schemas/llm_config.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from .base import IDModel, TimestampModel
|
||||
from app.db import LiteLLMProvider
|
||||
|
||||
class LLMConfigBase(BaseModel):
|
||||
name: str = Field(..., max_length=100, description="User-friendly name for the LLM configuration")
|
||||
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
|
||||
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM")
|
||||
model_name: str = Field(..., max_length=100, description="Model name without provider prefix")
|
||||
api_key: str = Field(..., description="API key for the provider")
|
||||
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL")
|
||||
litellm_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional LiteLLM parameters")
|
||||
|
||||
class LLMConfigCreate(LLMConfigBase):
|
||||
pass
|
||||
|
||||
class LLMConfigUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, max_length=100, description="User-friendly name for the LLM configuration")
|
||||
provider: Optional[LiteLLMProvider] = Field(None, description="LiteLLM provider type")
|
||||
custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM")
|
||||
model_name: Optional[str] = Field(None, max_length=100, description="Model name without provider prefix")
|
||||
api_key: Optional[str] = Field(None, description="API key for the provider")
|
||||
api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL")
|
||||
litellm_params: Optional[Dict[str, Any]] = Field(None, description="Additional LiteLLM parameters")
|
||||
|
||||
class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel):
|
||||
id: int
|
||||
created_at: datetime
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
|
@ -7,6 +7,7 @@ from app.schemas import ExtensionDocumentContent
|
|||
from app.config import config
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from app.utils.document_converters import convert_document_to_markdown, generate_content_hash
|
||||
from app.utils.llm_service import get_user_long_context_llm
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader
|
||||
from langchain_community.document_transformers import MarkdownifyTransformer
|
||||
|
@ -18,9 +19,8 @@ import logging
|
|||
|
||||
md = MarkdownifyTransformer()
|
||||
|
||||
|
||||
async def add_crawled_url_document(
|
||||
session: AsyncSession, url: str, search_space_id: int
|
||||
session: AsyncSession, url: str, search_space_id: int, user_id: str
|
||||
) -> Optional[Document]:
|
||||
try:
|
||||
if not validators.url(url):
|
||||
|
@ -84,8 +84,13 @@ async def add_crawled_url_document(
|
|||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(f"No long context LLM configured for user {user_id}")
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||
summary_result = await summary_chain.ainvoke(
|
||||
{"document": combined_document_string}
|
||||
)
|
||||
|
@ -130,7 +135,7 @@ async def add_crawled_url_document(
|
|||
|
||||
|
||||
async def add_extension_received_document(
|
||||
session: AsyncSession, content: ExtensionDocumentContent, search_space_id: int
|
||||
session: AsyncSession, content: ExtensionDocumentContent, search_space_id: int, user_id: str
|
||||
) -> Optional[Document]:
|
||||
"""
|
||||
Process and store document content received from the SurfSense Extension.
|
||||
|
@ -186,8 +191,13 @@ async def add_extension_received_document(
|
|||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(f"No long context LLM configured for user {user_id}")
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||
summary_result = await summary_chain.ainvoke(
|
||||
{"document": combined_document_string}
|
||||
)
|
||||
|
@ -230,7 +240,7 @@ async def add_extension_received_document(
|
|||
|
||||
|
||||
async def add_received_markdown_file_document(
|
||||
session: AsyncSession, file_name: str, file_in_markdown: str, search_space_id: int
|
||||
session: AsyncSession, file_name: str, file_in_markdown: str, search_space_id: int, user_id: str
|
||||
) -> Optional[Document]:
|
||||
try:
|
||||
content_hash = generate_content_hash(file_in_markdown)
|
||||
|
@ -245,8 +255,13 @@ async def add_received_markdown_file_document(
|
|||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(f"No long context LLM configured for user {user_id}")
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||
summary_result = await summary_chain.ainvoke({"document": file_in_markdown})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
@ -292,6 +307,7 @@ async def add_received_file_document_using_unstructured(
|
|||
file_name: str,
|
||||
unstructured_processed_elements: List[LangChainDocument],
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> Optional[Document]:
|
||||
try:
|
||||
file_in_markdown = await convert_document_to_markdown(
|
||||
|
@ -312,8 +328,13 @@ async def add_received_file_document_using_unstructured(
|
|||
|
||||
# TODO: Check if file_markdown exceeds token limit of embedding model
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(f"No long context LLM configured for user {user_id}")
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||
summary_result = await summary_chain.ainvoke({"document": file_in_markdown})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
@ -360,6 +381,7 @@ async def add_received_file_document_using_llamacloud(
|
|||
file_name: str,
|
||||
llamacloud_markdown_document: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> Optional[Document]:
|
||||
"""
|
||||
Process and store document content parsed by LlamaCloud.
|
||||
|
@ -389,8 +411,13 @@ async def add_received_file_document_using_llamacloud(
|
|||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(f"No long context LLM configured for user {user_id}")
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||
summary_result = await summary_chain.ainvoke({"document": file_in_markdown})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
@ -433,7 +460,7 @@ async def add_received_file_document_using_llamacloud(
|
|||
|
||||
|
||||
async def add_youtube_video_document(
|
||||
session: AsyncSession, url: str, search_space_id: int
|
||||
session: AsyncSession, url: str, search_space_id: int, user_id: str
|
||||
):
|
||||
"""
|
||||
Process a YouTube video URL, extract transcripts, and store as a document.
|
||||
|
@ -541,8 +568,13 @@ async def add_youtube_video_document(
|
|||
logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.")
|
||||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(f"No long context LLM configured for user {user_id}")
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||
summary_result = await summary_chain.ainvoke(
|
||||
{"document": combined_document_string}
|
||||
)
|
||||
|
|
|
@ -3,9 +3,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.future import select
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from app.db import Document, DocumentType, Chunk, SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.db import Document, DocumentType, Chunk, SearchSourceConnector, SearchSourceConnectorType, SearchSpace
|
||||
from app.config import config
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from app.utils.llm_service import get_user_long_context_llm
|
||||
from app.connectors.slack_history import SlackHistory
|
||||
from app.connectors.notion_history import NotionHistoryConnector
|
||||
from app.connectors.github_connector import GitHubConnector
|
||||
|
@ -24,6 +25,7 @@ async def index_slack_messages(
|
|||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
update_last_indexed: bool = True
|
||||
|
@ -211,8 +213,16 @@ async def index_slack_messages(
|
|||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
if not user_llm:
|
||||
logger.error(f"No long context LLM configured for user {user_id}")
|
||||
skipped_channels.append(f"{channel_name} (no LLM configured)")
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
@ -289,6 +299,7 @@ async def index_notion_pages(
|
|||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
update_last_indexed: bool = True
|
||||
|
@ -476,9 +487,17 @@ async def index_notion_pages(
|
|||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
if not user_llm:
|
||||
logger.error(f"No long context LLM configured for user {user_id}")
|
||||
skipped_pages.append(f"{page_title} (no LLM configured)")
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Generate summary
|
||||
logger.debug(f"Generating summary for page {page_title}")
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
@ -549,6 +568,7 @@ async def index_github_repos(
|
|||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
update_last_indexed: bool = True
|
||||
|
@ -717,6 +737,7 @@ async def index_linear_issues(
|
|||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
update_last_indexed: bool = True
|
||||
|
@ -955,6 +976,7 @@ async def index_discord_messages(
|
|||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
update_last_indexed: bool = True
|
||||
|
@ -1142,8 +1164,16 @@ async def index_discord_messages(
|
|||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
if not user_llm:
|
||||
logger.error(f"No long context LLM configured for user {user_id}")
|
||||
skipped_channels.append(f"{guild_name}#{channel_name} (no LLM configured)")
|
||||
documents_skipped += 1
|
||||
continue
|
||||
|
||||
# Generate summary using summary_chain
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = await asyncio.to_thread(
|
||||
|
|
|
@ -21,7 +21,8 @@ async def generate_chat_podcast(
|
|||
session: AsyncSession,
|
||||
chat_id: int,
|
||||
search_space_id: int,
|
||||
podcast_title: str
|
||||
podcast_title: str,
|
||||
user_id: int
|
||||
):
|
||||
# Fetch the chat with the specified ID
|
||||
query = select(Chat).filter(
|
||||
|
@ -57,12 +58,14 @@ async def generate_chat_podcast(
|
|||
# Pass it to the SurfSense Podcaster
|
||||
config = {
|
||||
"configurable": {
|
||||
"podcast_title" : "Surfsense",
|
||||
"podcast_title": "SurfSense",
|
||||
"user_id": str(user_id),
|
||||
}
|
||||
}
|
||||
# Initialize state with database session and streaming service
|
||||
initial_state = State(
|
||||
source_content=chat_history_str,
|
||||
db_session=session
|
||||
)
|
||||
|
||||
# Run the graph directly
|
||||
|
|
120
surfsense_backend/app/utils/llm_service.py
Normal file
120
surfsense_backend/app/utils/llm_service.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
import logging
|
||||
|
||||
from app.db import User, LLMConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LLMRole:
|
||||
LONG_CONTEXT = "long_context"
|
||||
FAST = "fast"
|
||||
STRATEGIC = "strategic"
|
||||
|
||||
async def get_user_llm_instance(
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
role: str
|
||||
) -> Optional[ChatLiteLLM]:
|
||||
"""
|
||||
Get a ChatLiteLLM instance for a specific user and role.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
role: LLM role ('long_context', 'fast', or 'strategic')
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM instance or None if not found
|
||||
"""
|
||||
try:
|
||||
# Get user with their LLM preferences
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == user_id)
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
if not user:
|
||||
logger.error(f"User {user_id} not found")
|
||||
return None
|
||||
|
||||
# Get the appropriate LLM config ID based on role
|
||||
llm_config_id = None
|
||||
if role == LLMRole.LONG_CONTEXT:
|
||||
llm_config_id = user.long_context_llm_id
|
||||
elif role == LLMRole.FAST:
|
||||
llm_config_id = user.fast_llm_id
|
||||
elif role == LLMRole.STRATEGIC:
|
||||
llm_config_id = user.strategic_llm_id
|
||||
else:
|
||||
logger.error(f"Invalid LLM role: {role}")
|
||||
return None
|
||||
|
||||
if not llm_config_id:
|
||||
logger.error(f"No {role} LLM configured for user {user_id}")
|
||||
return None
|
||||
|
||||
# Get the LLM configuration
|
||||
result = await session.execute(
|
||||
select(LLMConfig).where(
|
||||
LLMConfig.id == llm_config_id,
|
||||
LLMConfig.user_id == user_id
|
||||
)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
|
||||
if not llm_config:
|
||||
logger.error(f"LLM config {llm_config_id} not found for user {user_id}")
|
||||
return None
|
||||
|
||||
# Build the model string for litellm
|
||||
if llm_config.custom_provider:
|
||||
model_string = f"{llm_config.custom_provider}/{llm_config.model_name}"
|
||||
else:
|
||||
# Map provider enum to litellm format
|
||||
provider_map = {
|
||||
"OPENAI": "openai",
|
||||
"ANTHROPIC": "anthropic",
|
||||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama",
|
||||
"MISTRAL": "mistral",
|
||||
# Add more mappings as needed
|
||||
}
|
||||
provider_prefix = provider_map.get(llm_config.provider.value, llm_config.provider.value.lower())
|
||||
model_string = f"{provider_prefix}/{llm_config.model_name}"
|
||||
|
||||
# Create ChatLiteLLM instance
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": llm_config.api_key,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if llm_config.api_base:
|
||||
litellm_kwargs["api_base"] = llm_config.api_base
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if llm_config.litellm_params:
|
||||
litellm_kwargs.update(llm_config.litellm_params)
|
||||
|
||||
return ChatLiteLLM(**litellm_kwargs)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LLM instance for user {user_id}, role {role}: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_user_long_context_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
||||
"""Get user's long context LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT)
|
||||
|
||||
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
||||
"""Get user's fast LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.FAST)
|
||||
|
||||
async def get_user_strategic_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]:
|
||||
"""Get user's strategic LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)
|
|
@ -1,6 +1,8 @@
|
|||
import datetime
|
||||
from langchain.schema import HumanMessage, SystemMessage, AIMessage
|
||||
from app.config import config
|
||||
from app.utils.llm_service import get_user_strategic_llm
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import Any, List, Optional
|
||||
|
||||
|
||||
|
@ -10,14 +12,21 @@ class QueryService:
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
async def reformulate_query_with_chat_history(user_query: str, chat_history_str: Optional[str] = None) -> str:
|
||||
async def reformulate_query_with_chat_history(
|
||||
user_query: str,
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
chat_history_str: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Reformulate the user query using the STRATEGIC_LLM to make it more
|
||||
Reformulate the user query using the user's strategic LLM to make it more
|
||||
effective for information retrieval and research purposes.
|
||||
|
||||
Args:
|
||||
user_query: The original user query
|
||||
chat_history: Optional list of previous chat messages
|
||||
session: Database session for accessing user LLM configs
|
||||
user_id: User ID to get their specific LLM configuration
|
||||
chat_history_str: Optional chat history string
|
||||
|
||||
Returns:
|
||||
str: The reformulated query
|
||||
|
@ -26,8 +35,11 @@ class QueryService:
|
|||
return user_query
|
||||
|
||||
try:
|
||||
# Get the strategic LLM instance from config
|
||||
llm = config.strategic_llm_instance
|
||||
# Get the user's strategic LLM instance
|
||||
llm = await get_user_strategic_llm(session, user_id)
|
||||
if not llm:
|
||||
print(f"Warning: No strategic LLM configured for user {user_id}. Using original query.")
|
||||
return user_query
|
||||
|
||||
# Create system message with instructions
|
||||
system_message = SystemMessage(
|
||||
|
|
|
@ -19,7 +19,9 @@ import {
|
|||
FolderOpen,
|
||||
Upload,
|
||||
ChevronDown,
|
||||
Filter
|
||||
Filter,
|
||||
Brain,
|
||||
Zap
|
||||
} from 'lucide-react';
|
||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { Button } from '@/components/ui/button';
|
||||
|
@ -42,6 +44,13 @@ import {
|
|||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import {
|
||||
|
@ -62,6 +71,7 @@ import { MarkdownViewer } from '@/components/markdown-viewer';
|
|||
import { Logo } from '@/components/Logo';
|
||||
import { useSearchSourceConnectors } from '@/hooks';
|
||||
import { useDocuments } from '@/hooks/use-documents';
|
||||
import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs';
|
||||
|
||||
interface SourceItem {
|
||||
id: number;
|
||||
|
@ -374,6 +384,8 @@ const ChatPage = () => {
|
|||
const [currentDate, setCurrentDate] = useState<string>('');
|
||||
const terminalMessagesRef = useRef<HTMLDivElement>(null);
|
||||
const { connectorSourceItems, isLoading: isLoadingConnectors } = useSearchSourceConnectors();
|
||||
const { llmConfigs } = useLLMConfigs();
|
||||
const { preferences, updatePreferences } = useLLMPreferences();
|
||||
|
||||
const INITIAL_SOURCES_DISPLAY = 3;
|
||||
|
||||
|
@ -457,6 +469,8 @@ const ChatPage = () => {
|
|||
setCurrentTime(new Date().toTimeString().split(' ')[0]);
|
||||
}, []);
|
||||
|
||||
|
||||
|
||||
// Add this CSS to remove input shadow and improve the UI
|
||||
useEffect(() => {
|
||||
if (typeof document !== 'undefined') {
|
||||
|
@ -710,6 +724,7 @@ const ChatPage = () => {
|
|||
if (!input.trim() || status !== 'ready') return;
|
||||
|
||||
// Validation: require at least one connector OR at least one document
|
||||
// Note: Fast LLM selection updates user preferences automatically
|
||||
// if (selectedConnectors.length === 0 && selectedDocuments.length === 0) {
|
||||
// alert("Please select at least one connector or document");
|
||||
// return;
|
||||
|
@ -1569,6 +1584,75 @@ const ChatPage = () => {
|
|||
onChange={setResearchMode}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Fast LLM Selector */}
|
||||
<div className="h-8 min-w-0">
|
||||
<Select
|
||||
value={preferences.fast_llm_id?.toString() || ""}
|
||||
onValueChange={(value) => {
|
||||
const llmId = value ? parseInt(value) : undefined;
|
||||
updatePreferences({ fast_llm_id: llmId });
|
||||
}}
|
||||
>
|
||||
<SelectTrigger className="h-8 w-auto min-w-[120px] px-3 text-xs border-border bg-background hover:bg-muted/50">
|
||||
<div className="flex items-center gap-2">
|
||||
<Zap className="h-3 w-3 text-primary" />
|
||||
<SelectValue placeholder="Fast LLM">
|
||||
{preferences.fast_llm_id && (() => {
|
||||
const selectedConfig = llmConfigs.find(config => config.id === preferences.fast_llm_id);
|
||||
return selectedConfig ? (
|
||||
<div className="flex items-center gap-1">
|
||||
<span className="font-medium">{selectedConfig.provider}</span>
|
||||
<span className="text-muted-foreground">•</span>
|
||||
<span className="hidden sm:inline text-muted-foreground">{selectedConfig.name}</span>
|
||||
</div>
|
||||
) : "Select LLM";
|
||||
})()}
|
||||
</SelectValue>
|
||||
</div>
|
||||
</SelectTrigger>
|
||||
<SelectContent align="end" className="w-[280px]">
|
||||
<div className="px-2 py-1.5 text-xs font-medium text-muted-foreground border-b">
|
||||
Answer LLM Selection
|
||||
</div>
|
||||
{llmConfigs.length === 0 ? (
|
||||
<div className="px-2 py-3 text-center text-sm text-muted-foreground">
|
||||
<Brain className="h-4 w-4 mx-auto mb-1 opacity-50" />
|
||||
<p>No LLM configurations found</p>
|
||||
<p className="text-xs">Configure models in Settings</p>
|
||||
</div>
|
||||
) : (
|
||||
llmConfigs.map((config) => (
|
||||
<SelectItem key={config.id} value={config.id.toString()}>
|
||||
<div className="flex items-center justify-between w-full">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="flex h-8 w-8 items-center justify-center rounded-md bg-primary/10">
|
||||
<Brain className="h-4 w-4 text-primary" />
|
||||
</div>
|
||||
<div className="space-y-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-medium text-sm">{config.name}</span>
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{config.provider}
|
||||
</Badge>
|
||||
</div>
|
||||
<p className="text-xs text-muted-foreground font-mono">
|
||||
{config.model_name}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{preferences.fast_llm_id === config.id && (
|
||||
<div className="flex h-4 w-4 items-center justify-center rounded-full bg-primary">
|
||||
<div className="h-2 w-2 rounded-full bg-primary-foreground" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</SelectItem>
|
||||
))
|
||||
)}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
90
surfsense_web/app/dashboard/layout.tsx
Normal file
90
surfsense_web/app/dashboard/layout.tsx
Normal file
|
@ -0,0 +1,90 @@
|
|||
"use client";
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useRouter } from 'next/navigation';
|
||||
import { useLLMPreferences } from '@/hooks/use-llm-configs';
|
||||
import { Loader2 } from 'lucide-react';
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
|
||||
interface DashboardLayoutProps {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export default function DashboardLayout({ children }: DashboardLayoutProps) {
|
||||
const router = useRouter();
|
||||
const { loading, error, isOnboardingComplete } = useLLMPreferences();
|
||||
const [isCheckingAuth, setIsCheckingAuth] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
// Check if user is authenticated
|
||||
const token = localStorage.getItem('surfsense_bearer_token');
|
||||
if (!token) {
|
||||
router.push('/login');
|
||||
return;
|
||||
}
|
||||
setIsCheckingAuth(false);
|
||||
}, [router]);
|
||||
|
||||
useEffect(() => {
|
||||
// Wait for preferences to load, then check if onboarding is complete
|
||||
if (!loading && !error && !isCheckingAuth) {
|
||||
if (!isOnboardingComplete()) {
|
||||
router.push('/onboard');
|
||||
}
|
||||
}
|
||||
}, [loading, error, isCheckingAuth, isOnboardingComplete, router]);
|
||||
|
||||
// Show loading screen while checking authentication or loading preferences
|
||||
if (isCheckingAuth || loading) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
|
||||
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
|
||||
<CardHeader className="pb-2">
|
||||
<CardTitle className="text-xl font-medium">Loading Dashboard</CardTitle>
|
||||
<CardDescription>Checking your configuration...</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="flex justify-center py-6">
|
||||
<Loader2 className="h-12 w-12 text-primary animate-spin" />
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show error screen if there's an error loading preferences
|
||||
if (error) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
|
||||
<Card className="w-[400px] bg-background/60 backdrop-blur-sm border-destructive/20">
|
||||
<CardHeader className="pb-2">
|
||||
<CardTitle className="text-xl font-medium text-destructive">Configuration Error</CardTitle>
|
||||
<CardDescription>Failed to load your LLM configuration</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<p className="text-sm text-muted-foreground">{error}</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Only render children if onboarding is complete
|
||||
if (isOnboardingComplete()) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
// This should not be reached due to redirect, but just in case
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
|
||||
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
|
||||
<CardHeader className="pb-2">
|
||||
<CardTitle className="text-xl font-medium">Redirecting...</CardTitle>
|
||||
<CardDescription>Taking you to complete your setup</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="flex justify-center py-6">
|
||||
<Loader2 className="h-12 w-12 text-primary animate-spin" />
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
227
surfsense_web/app/onboard/page.tsx
Normal file
227
surfsense_web/app/onboard/page.tsx
Normal file
|
@ -0,0 +1,227 @@
|
|||
"use client";
|
||||
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { useRouter } from 'next/navigation';
|
||||
import { motion, AnimatePresence } from 'framer-motion';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { Progress } from '@/components/ui/progress';
|
||||
import { CheckCircle, ArrowRight, ArrowLeft, Bot, Sparkles, Zap, Brain } from 'lucide-react';
|
||||
import { Logo } from '@/components/Logo';
|
||||
import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs';
|
||||
import { AddProviderStep } from '@/components/onboard/add-provider-step';
|
||||
import { AssignRolesStep } from '@/components/onboard/assign-roles-step';
|
||||
import { CompletionStep } from '@/components/onboard/completion-step';
|
||||
|
||||
const TOTAL_STEPS = 3;
|
||||
|
||||
const OnboardPage = () => {
|
||||
const router = useRouter();
|
||||
const { llmConfigs, loading: configsLoading } = useLLMConfigs();
|
||||
const { preferences, loading: preferencesLoading, isOnboardingComplete, refreshPreferences } = useLLMPreferences();
|
||||
const [currentStep, setCurrentStep] = useState(1);
|
||||
const [hasUserProgressed, setHasUserProgressed] = useState(false);
|
||||
|
||||
// Check if user is authenticated
|
||||
useEffect(() => {
|
||||
const token = localStorage.getItem('surfsense_bearer_token');
|
||||
if (!token) {
|
||||
router.push('/login');
|
||||
return;
|
||||
}
|
||||
}, [router]);
|
||||
|
||||
// Track if user has progressed beyond step 1
|
||||
useEffect(() => {
|
||||
if (currentStep > 1) {
|
||||
setHasUserProgressed(true);
|
||||
}
|
||||
}, [currentStep]);
|
||||
|
||||
// Redirect to dashboard if onboarding is already complete and user hasn't progressed (fresh page load)
|
||||
useEffect(() => {
|
||||
if (!preferencesLoading && isOnboardingComplete() && !hasUserProgressed) {
|
||||
router.push('/dashboard');
|
||||
}
|
||||
}, [preferencesLoading, isOnboardingComplete, hasUserProgressed, router]);
|
||||
|
||||
|
||||
|
||||
const progress = (currentStep / TOTAL_STEPS) * 100;
|
||||
|
||||
const stepTitles = [
|
||||
"Add LLM Provider",
|
||||
"Assign LLM Roles",
|
||||
"Setup Complete"
|
||||
];
|
||||
|
||||
const stepDescriptions = [
|
||||
"Configure your first model provider",
|
||||
"Assign specific roles to your LLM configurations",
|
||||
"You're all set to start using SurfSense!"
|
||||
];
|
||||
|
||||
const canProceedToStep2 = !configsLoading && llmConfigs.length > 0;
|
||||
const canProceedToStep3 = !preferencesLoading && preferences.long_context_llm_id && preferences.fast_llm_id && preferences.strategic_llm_id;
|
||||
|
||||
|
||||
|
||||
const handleNext = () => {
|
||||
if (currentStep < TOTAL_STEPS) {
|
||||
setCurrentStep(currentStep + 1);
|
||||
}
|
||||
};
|
||||
|
||||
const handlePrevious = () => {
|
||||
if (currentStep > 1) {
|
||||
setCurrentStep(currentStep - 1);
|
||||
}
|
||||
};
|
||||
|
||||
const handleComplete = () => {
|
||||
router.push('/dashboard');
|
||||
};
|
||||
|
||||
if (configsLoading || preferencesLoading) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center min-h-screen">
|
||||
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
|
||||
<CardContent className="flex flex-col items-center justify-center py-12">
|
||||
<Bot className="h-12 w-12 text-primary animate-pulse mb-4" />
|
||||
<p className="text-sm text-muted-foreground">Loading your configuration...</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-gradient-to-br from-background via-background to-muted/20 flex items-center justify-center p-4">
|
||||
<motion.div
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ duration: 0.5 }}
|
||||
className="w-full max-w-4xl"
|
||||
>
|
||||
{/* Header */}
|
||||
<div className="text-center mb-8">
|
||||
<div className="flex items-center justify-center mb-4">
|
||||
<Logo className="w-12 h-12 mr-3" />
|
||||
<h1 className="text-3xl font-bold">Welcome to SurfSense</h1>
|
||||
</div>
|
||||
<p className="text-muted-foreground text-lg">Let's configure your SurfSense to get started</p>
|
||||
</div>
|
||||
|
||||
{/* Progress */}
|
||||
<Card className="mb-8 bg-background/60 backdrop-blur-sm">
|
||||
<CardContent className="pt-6">
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<div className="text-sm font-medium">Step {currentStep} of {TOTAL_STEPS}</div>
|
||||
<div className="text-sm text-muted-foreground">{Math.round(progress)}% Complete</div>
|
||||
</div>
|
||||
<Progress value={progress} className="mb-4" />
|
||||
<div className="grid grid-cols-3 gap-4">
|
||||
{Array.from({ length: TOTAL_STEPS }, (_, i) => {
|
||||
const stepNum = i + 1;
|
||||
const isCompleted = stepNum < currentStep;
|
||||
const isCurrent = stepNum === currentStep;
|
||||
|
||||
return (
|
||||
<div key={stepNum} className="flex items-center space-x-2">
|
||||
<div className={`w-8 h-8 rounded-full flex items-center justify-center text-sm font-medium ${
|
||||
isCompleted
|
||||
? 'bg-primary text-primary-foreground'
|
||||
: isCurrent
|
||||
? 'bg-primary/20 text-primary border-2 border-primary'
|
||||
: 'bg-muted text-muted-foreground'
|
||||
}`}>
|
||||
{isCompleted ? <CheckCircle className="w-4 h-4" /> : stepNum}
|
||||
</div>
|
||||
<div className="flex-1 min-w-0">
|
||||
<p className={`text-sm font-medium truncate ${
|
||||
isCurrent ? 'text-foreground' : 'text-muted-foreground'
|
||||
}`}>
|
||||
{stepTitles[i]}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Step Content */}
|
||||
<Card className="min-h-[500px] bg-background/60 backdrop-blur-sm">
|
||||
<CardHeader className="text-center">
|
||||
<CardTitle className="text-2xl flex items-center justify-center gap-2">
|
||||
{currentStep === 1 && <Bot className="w-6 h-6" />}
|
||||
{currentStep === 2 && <Sparkles className="w-6 h-6" />}
|
||||
{currentStep === 3 && <CheckCircle className="w-6 h-6" />}
|
||||
{stepTitles[currentStep - 1]}
|
||||
</CardTitle>
|
||||
<CardDescription className="text-base">
|
||||
{stepDescriptions[currentStep - 1]}
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<AnimatePresence mode="wait">
|
||||
<motion.div
|
||||
key={currentStep}
|
||||
initial={{ opacity: 0, x: 20 }}
|
||||
animate={{ opacity: 1, x: 0 }}
|
||||
exit={{ opacity: 0, x: -20 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
>
|
||||
{currentStep === 1 && <AddProviderStep />}
|
||||
{currentStep === 2 && <AssignRolesStep onPreferencesUpdated={refreshPreferences} />}
|
||||
{currentStep === 3 && <CompletionStep />}
|
||||
</motion.div>
|
||||
</AnimatePresence>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Navigation */}
|
||||
<div className="flex justify-between mt-8">
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={handlePrevious}
|
||||
disabled={currentStep === 1}
|
||||
className="flex items-center gap-2"
|
||||
>
|
||||
<ArrowLeft className="w-4 h-4" />
|
||||
Previous
|
||||
</Button>
|
||||
|
||||
<div className="flex gap-2">
|
||||
{currentStep < TOTAL_STEPS && (
|
||||
<Button
|
||||
onClick={handleNext}
|
||||
disabled={
|
||||
(currentStep === 1 && !canProceedToStep2) ||
|
||||
(currentStep === 2 && !canProceedToStep3)
|
||||
}
|
||||
className="flex items-center gap-2"
|
||||
>
|
||||
Next
|
||||
<ArrowRight className="w-4 h-4" />
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{currentStep === TOTAL_STEPS && (
|
||||
<Button
|
||||
onClick={handleComplete}
|
||||
className="flex items-center gap-2"
|
||||
>
|
||||
Complete Setup
|
||||
<CheckCircle className="w-4 h-4" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default OnboardPage;
|
60
surfsense_web/app/settings/page.tsx
Normal file
60
surfsense_web/app/settings/page.tsx
Normal file
|
@ -0,0 +1,60 @@
|
|||
"use client";
|
||||
|
||||
import React from 'react';
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs';
|
||||
import { Separator } from '@/components/ui/separator';
|
||||
import { Bot, Settings, Brain } from 'lucide-react';
|
||||
import { ModelConfigManager } from '@/components/settings/model-config-manager';
|
||||
import { LLMRoleManager } from '@/components/settings/llm-role-manager';
|
||||
|
||||
export default function SettingsPage() {
|
||||
return (
|
||||
<div className="min-h-screen bg-background">
|
||||
<div className="container max-w-7xl mx-auto p-6 lg:p-8">
|
||||
<div className="space-y-8">
|
||||
{/* Header Section */}
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center space-x-4">
|
||||
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-primary/10">
|
||||
<Settings className="h-6 w-6 text-primary" />
|
||||
</div>
|
||||
<div className="space-y-1">
|
||||
<h1 className="text-3xl font-bold tracking-tight">Settings</h1>
|
||||
<p className="text-lg text-muted-foreground">
|
||||
Manage your LLM configurations and role assignments.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<Separator className="my-6" />
|
||||
</div>
|
||||
|
||||
{/* Settings Content */}
|
||||
<Tabs defaultValue="models" className="space-y-8">
|
||||
<div className="overflow-x-auto">
|
||||
<TabsList className="grid w-full min-w-fit grid-cols-2 lg:w-auto lg:inline-grid">
|
||||
<TabsTrigger value="models" className="flex items-center gap-2 text-sm">
|
||||
<Bot className="h-4 w-4" />
|
||||
<span className="hidden sm:inline">Model Configs</span>
|
||||
<span className="sm:hidden">Models</span>
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="roles" className="flex items-center gap-2 text-sm">
|
||||
<Brain className="h-4 w-4" />
|
||||
<span className="hidden sm:inline">LLM Roles</span>
|
||||
<span className="sm:hidden">Roles</span>
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
</div>
|
||||
|
||||
<TabsContent value="models" className="space-y-6">
|
||||
<ModelConfigManager />
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="roles" className="space-y-6">
|
||||
<LLMRoleManager />
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
255
surfsense_web/components/onboard/add-provider-step.tsx
Normal file
255
surfsense_web/components/onboard/add-provider-step.tsx
Normal file
|
@ -0,0 +1,255 @@
|
|||
"use client";
|
||||
|
||||
import React, { useState } from 'react';
|
||||
import { motion } from 'framer-motion';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { Label } from '@/components/ui/label';
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
|
||||
|
||||
import { Badge } from '@/components/ui/badge';
|
||||
import { Plus, Trash2, Bot, AlertCircle } from 'lucide-react';
|
||||
import { useLLMConfigs, CreateLLMConfig } from '@/hooks/use-llm-configs';
|
||||
import { toast } from 'sonner';
|
||||
import { Alert, AlertDescription } from '@/components/ui/alert';
|
||||
|
||||
const LLM_PROVIDERS = [
|
||||
{ value: 'OPENAI', label: 'OpenAI', example: 'gpt-4o, gpt-4, gpt-3.5-turbo' },
|
||||
{ value: 'ANTHROPIC', label: 'Anthropic', example: 'claude-3-5-sonnet-20241022, claude-3-opus-20240229' },
|
||||
{ value: 'GROQ', label: 'Groq', example: 'llama3-70b-8192, mixtral-8x7b-32768' },
|
||||
{ value: 'COHERE', label: 'Cohere', example: 'command-r-plus, command-r' },
|
||||
{ value: 'HUGGINGFACE', label: 'HuggingFace', example: 'microsoft/DialoGPT-medium' },
|
||||
{ value: 'AZURE_OPENAI', label: 'Azure OpenAI', example: 'gpt-4, gpt-35-turbo' },
|
||||
{ value: 'GOOGLE', label: 'Google', example: 'gemini-pro, gemini-pro-vision' },
|
||||
{ value: 'AWS_BEDROCK', label: 'AWS Bedrock', example: 'anthropic.claude-v2' },
|
||||
{ value: 'OLLAMA', label: 'Ollama', example: 'llama2, codellama' },
|
||||
{ value: 'MISTRAL', label: 'Mistral', example: 'mistral-large-latest, mistral-medium' },
|
||||
{ value: 'TOGETHER_AI', label: 'Together AI', example: 'togethercomputer/llama-2-70b-chat' },
|
||||
{ value: 'REPLICATE', label: 'Replicate', example: 'meta/llama-2-70b-chat' },
|
||||
{ value: 'CUSTOM', label: 'Custom Provider', example: 'your-custom-model' },
|
||||
];
|
||||
|
||||
export function AddProviderStep() {
|
||||
const { llmConfigs, createLLMConfig, deleteLLMConfig } = useLLMConfigs();
|
||||
const [isAddingNew, setIsAddingNew] = useState(false);
|
||||
const [formData, setFormData] = useState<CreateLLMConfig>({
|
||||
name: '',
|
||||
provider: '',
|
||||
custom_provider: '',
|
||||
model_name: '',
|
||||
api_key: '',
|
||||
api_base: '',
|
||||
litellm_params: {}
|
||||
});
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
|
||||
const handleInputChange = (field: keyof CreateLLMConfig, value: string) => {
|
||||
setFormData(prev => ({ ...prev, [field]: value }));
|
||||
};
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if (!formData.name || !formData.provider || !formData.model_name || !formData.api_key) {
|
||||
toast.error('Please fill in all required fields');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsSubmitting(true);
|
||||
const result = await createLLMConfig(formData);
|
||||
setIsSubmitting(false);
|
||||
|
||||
if (result) {
|
||||
setFormData({
|
||||
name: '',
|
||||
provider: '',
|
||||
custom_provider: '',
|
||||
model_name: '',
|
||||
api_key: '',
|
||||
api_base: '',
|
||||
litellm_params: {}
|
||||
});
|
||||
setIsAddingNew(false);
|
||||
}
|
||||
};
|
||||
|
||||
const selectedProvider = LLM_PROVIDERS.find(p => p.value === formData.provider);
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{/* Info Alert */}
|
||||
<Alert>
|
||||
<AlertCircle className="h-4 w-4" />
|
||||
<AlertDescription>
|
||||
Add at least one LLM provider to continue. You can configure multiple providers and choose specific roles for each one in the next step.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
{/* Existing Configurations */}
|
||||
{llmConfigs.length > 0 && (
|
||||
<div className="space-y-4">
|
||||
<h3 className="text-lg font-semibold">Your LLM Configurations</h3>
|
||||
<div className="grid gap-4">
|
||||
{llmConfigs.map((config) => (
|
||||
<motion.div
|
||||
key={config.id}
|
||||
initial={{ opacity: 0, y: 10 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -10 }}
|
||||
>
|
||||
<Card className="border-l-4 border-l-primary">
|
||||
<CardContent className="pt-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex-1">
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<Bot className="w-4 h-4" />
|
||||
<h4 className="font-medium">{config.name}</h4>
|
||||
<Badge variant="secondary">{config.provider}</Badge>
|
||||
</div>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Model: {config.model_name}
|
||||
{config.api_base && ` • Base: ${config.api_base}`}
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => deleteLLMConfig(config.id)}
|
||||
className="text-destructive hover:text-destructive"
|
||||
>
|
||||
<Trash2 className="w-4 h-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</motion.div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Add New Provider */}
|
||||
{!isAddingNew ? (
|
||||
<Card className="border-dashed border-2 hover:border-primary/50 transition-colors">
|
||||
<CardContent className="flex flex-col items-center justify-center py-12">
|
||||
<Plus className="w-12 h-12 text-muted-foreground mb-4" />
|
||||
<h3 className="text-lg font-semibold mb-2">Add LLM Provider</h3>
|
||||
<p className="text-muted-foreground text-center mb-4">
|
||||
Configure your first model provider to get started
|
||||
</p>
|
||||
<Button onClick={() => setIsAddingNew(true)}>
|
||||
<Plus className="w-4 h-4 mr-2" />
|
||||
Add Provider
|
||||
</Button>
|
||||
</CardContent>
|
||||
</Card>
|
||||
) : (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Add New LLM Provider</CardTitle>
|
||||
<CardDescription>
|
||||
Configure a new language model provider for your AI assistant
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<form onSubmit={handleSubmit} className="space-y-4">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="name">Configuration Name *</Label>
|
||||
<Input
|
||||
id="name"
|
||||
placeholder="e.g., My OpenAI GPT-4"
|
||||
value={formData.name}
|
||||
onChange={(e) => handleInputChange('name', e.target.value)}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="provider">Provider *</Label>
|
||||
<Select value={formData.provider} onValueChange={(value) => handleInputChange('provider', value)}>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select a provider" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{LLM_PROVIDERS.map((provider) => (
|
||||
<SelectItem key={provider.value} value={provider.value}>
|
||||
{provider.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{formData.provider === 'CUSTOM' && (
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="custom_provider">Custom Provider Name *</Label>
|
||||
<Input
|
||||
id="custom_provider"
|
||||
placeholder="e.g., my-custom-provider"
|
||||
value={formData.custom_provider}
|
||||
onChange={(e) => handleInputChange('custom_provider', e.target.value)}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="model_name">Model Name *</Label>
|
||||
<Input
|
||||
id="model_name"
|
||||
placeholder={selectedProvider?.example || "e.g., gpt-4"}
|
||||
value={formData.model_name}
|
||||
onChange={(e) => handleInputChange('model_name', e.target.value)}
|
||||
required
|
||||
/>
|
||||
{selectedProvider && (
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Examples: {selectedProvider.example}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="api_key">API Key *</Label>
|
||||
<Input
|
||||
id="api_key"
|
||||
type="password"
|
||||
placeholder="Your API key"
|
||||
value={formData.api_key}
|
||||
onChange={(e) => handleInputChange('api_key', e.target.value)}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="api_base">API Base URL (Optional)</Label>
|
||||
<Input
|
||||
id="api_base"
|
||||
placeholder="e.g., https://api.openai.com/v1"
|
||||
value={formData.api_base}
|
||||
onChange={(e) => handleInputChange('api_base', e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex gap-2 pt-4">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
{isSubmitting ? 'Adding...' : 'Add Provider'}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={() => setIsAddingNew(false)}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
</form>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
232
surfsense_web/components/onboard/assign-roles-step.tsx
Normal file
232
surfsense_web/components/onboard/assign-roles-step.tsx
Normal file
|
@ -0,0 +1,232 @@
|
|||
"use client";
|
||||
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { motion } from 'framer-motion';
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
|
||||
import { Badge } from '@/components/ui/badge';
|
||||
import { Brain, Zap, Bot, AlertCircle, CheckCircle } from 'lucide-react';
|
||||
import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs';
|
||||
import { Alert, AlertDescription } from '@/components/ui/alert';
|
||||
|
||||
const ROLE_DESCRIPTIONS = {
|
||||
long_context: {
|
||||
icon: Brain,
|
||||
title: 'Long Context LLM',
|
||||
description: 'Handles complex tasks requiring extensive context understanding and reasoning',
|
||||
color: 'bg-blue-100 text-blue-800 border-blue-200',
|
||||
examples: 'Document analysis, research synthesis, complex Q&A'
|
||||
},
|
||||
fast: {
|
||||
icon: Zap,
|
||||
title: 'Fast LLM',
|
||||
description: 'Optimized for quick responses and real-time interactions',
|
||||
color: 'bg-green-100 text-green-800 border-green-200',
|
||||
examples: 'Quick searches, simple questions, instant responses'
|
||||
},
|
||||
strategic: {
|
||||
icon: Bot,
|
||||
title: 'Strategic LLM',
|
||||
description: 'Advanced reasoning for planning and strategic decision making',
|
||||
color: 'bg-purple-100 text-purple-800 border-purple-200',
|
||||
examples: 'Planning workflows, strategic analysis, complex problem solving'
|
||||
}
|
||||
};
|
||||
|
||||
interface AssignRolesStepProps {
|
||||
onPreferencesUpdated?: () => Promise<void>;
|
||||
}
|
||||
|
||||
export function AssignRolesStep({ onPreferencesUpdated }: AssignRolesStepProps) {
|
||||
const { llmConfigs } = useLLMConfigs();
|
||||
const { preferences, updatePreferences } = useLLMPreferences();
|
||||
|
||||
|
||||
const [assignments, setAssignments] = useState({
|
||||
long_context_llm_id: preferences.long_context_llm_id || '',
|
||||
fast_llm_id: preferences.fast_llm_id || '',
|
||||
strategic_llm_id: preferences.strategic_llm_id || ''
|
||||
});
|
||||
|
||||
|
||||
useEffect(() => {
|
||||
setAssignments({
|
||||
long_context_llm_id: preferences.long_context_llm_id || '',
|
||||
fast_llm_id: preferences.fast_llm_id || '',
|
||||
strategic_llm_id: preferences.strategic_llm_id || ''
|
||||
});
|
||||
}, [preferences]);
|
||||
|
||||
const handleRoleAssignment = async (role: string, configId: string) => {
|
||||
const newAssignments = {
|
||||
...assignments,
|
||||
[role]: configId === '' ? '' : parseInt(configId)
|
||||
};
|
||||
|
||||
setAssignments(newAssignments);
|
||||
|
||||
// Auto-save if this assignment completes all roles
|
||||
const hasAllAssignments = newAssignments.long_context_llm_id && newAssignments.fast_llm_id && newAssignments.strategic_llm_id;
|
||||
|
||||
if (hasAllAssignments) {
|
||||
const numericAssignments = {
|
||||
long_context_llm_id: typeof newAssignments.long_context_llm_id === 'string' ? parseInt(newAssignments.long_context_llm_id) : newAssignments.long_context_llm_id,
|
||||
fast_llm_id: typeof newAssignments.fast_llm_id === 'string' ? parseInt(newAssignments.fast_llm_id) : newAssignments.fast_llm_id,
|
||||
strategic_llm_id: typeof newAssignments.strategic_llm_id === 'string' ? parseInt(newAssignments.strategic_llm_id) : newAssignments.strategic_llm_id,
|
||||
};
|
||||
|
||||
const success = await updatePreferences(numericAssignments);
|
||||
|
||||
// Refresh parent preferences state
|
||||
if (success && onPreferencesUpdated) {
|
||||
await onPreferencesUpdated();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
const isAssignmentComplete = assignments.long_context_llm_id && assignments.fast_llm_id && assignments.strategic_llm_id;
|
||||
|
||||
if (llmConfigs.length === 0) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-12">
|
||||
<AlertCircle className="w-16 h-16 text-muted-foreground mb-4" />
|
||||
<h3 className="text-lg font-semibold mb-2">No LLM Configurations Found</h3>
|
||||
<p className="text-muted-foreground text-center">
|
||||
Please add at least one LLM provider in the previous step before assigning roles.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{/* Info Alert */}
|
||||
<Alert>
|
||||
<AlertCircle className="h-4 w-4" />
|
||||
<AlertDescription>
|
||||
Assign your LLM configurations to specific roles. Each role serves different purposes in your workflow.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
{/* Role Assignment Cards */}
|
||||
<div className="grid gap-6">
|
||||
{Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => {
|
||||
const IconComponent = role.icon;
|
||||
const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments];
|
||||
const assignedConfig = llmConfigs.find(config => config.id === currentAssignment);
|
||||
|
||||
return (
|
||||
<motion.div
|
||||
key={key}
|
||||
initial={{ opacity: 0, y: 10 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ delay: Object.keys(ROLE_DESCRIPTIONS).indexOf(key) * 0.1 }}
|
||||
>
|
||||
<Card className={`border-l-4 ${currentAssignment ? 'border-l-primary' : 'border-l-muted'}`}>
|
||||
<CardHeader className="pb-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className={`p-2 rounded-lg ${role.color}`}>
|
||||
<IconComponent className="w-5 h-5" />
|
||||
</div>
|
||||
<div>
|
||||
<CardTitle className="text-lg">{role.title}</CardTitle>
|
||||
<CardDescription className="mt-1">{role.description}</CardDescription>
|
||||
</div>
|
||||
</div>
|
||||
{currentAssignment && (
|
||||
<CheckCircle className="w-5 h-5 text-green-500" />
|
||||
)}
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
<div className="text-sm text-muted-foreground">
|
||||
<strong>Use cases:</strong> {role.examples}
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Assign LLM Configuration:</label>
|
||||
<Select
|
||||
value={currentAssignment?.toString() || ''}
|
||||
onValueChange={(value) => handleRoleAssignment(`${key}_llm_id`, value)}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select an LLM configuration" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{llmConfigs
|
||||
.filter(config => config.id && config.id.toString().trim() !== '')
|
||||
.map((config) => (
|
||||
<SelectItem key={config.id} value={config.id.toString()}>
|
||||
<div className="flex items-center gap-2">
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{config.provider}
|
||||
</Badge>
|
||||
<span>{config.name}</span>
|
||||
<span className="text-muted-foreground">({config.model_name})</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
{assignedConfig && (
|
||||
<div className="mt-3 p-3 bg-muted/50 rounded-lg">
|
||||
<div className="flex items-center gap-2 text-sm">
|
||||
<Bot className="w-4 h-4" />
|
||||
<span className="font-medium">Assigned:</span>
|
||||
<Badge variant="secondary">{assignedConfig.provider}</Badge>
|
||||
<span>{assignedConfig.name}</span>
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground mt-1">
|
||||
Model: {assignedConfig.model_name}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</motion.div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
{/* Status Indicator */}
|
||||
{isAssignmentComplete && (
|
||||
<div className="flex justify-center pt-4">
|
||||
<div className="flex items-center gap-2 px-4 py-2 bg-green-50 text-green-700 rounded-lg border border-green-200">
|
||||
<CheckCircle className="w-4 h-4" />
|
||||
<span className="text-sm font-medium">All roles assigned and saved!</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Progress Indicator */}
|
||||
<div className="flex justify-center">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<span>Progress:</span>
|
||||
<div className="flex gap-1">
|
||||
{Object.keys(ROLE_DESCRIPTIONS).map((key, index) => (
|
||||
<div
|
||||
key={key}
|
||||
className={`w-2 h-2 rounded-full ${
|
||||
assignments[`${key}_llm_id` as keyof typeof assignments]
|
||||
? 'bg-primary'
|
||||
: 'bg-muted'
|
||||
}`}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
<span>
|
||||
{Object.values(assignments).filter(Boolean).length} of {Object.keys(ROLE_DESCRIPTIONS).length} roles assigned
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
125
surfsense_web/components/onboard/completion-step.tsx
Normal file
125
surfsense_web/components/onboard/completion-step.tsx
Normal file
|
@ -0,0 +1,125 @@
|
|||
"use client";
|
||||
|
||||
import React from 'react';
|
||||
import { motion } from 'framer-motion';
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { Badge } from '@/components/ui/badge';
|
||||
import { CheckCircle, Bot, Brain, Zap, Sparkles, ArrowRight } from 'lucide-react';
|
||||
import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs';
|
||||
|
||||
|
||||
const ROLE_ICONS = {
|
||||
long_context: Brain,
|
||||
fast: Zap,
|
||||
strategic: Bot
|
||||
};
|
||||
|
||||
export function CompletionStep() {
|
||||
const { llmConfigs } = useLLMConfigs();
|
||||
const { preferences } = useLLMPreferences();
|
||||
|
||||
const assignedConfigs = {
|
||||
long_context: llmConfigs.find(c => c.id === preferences.long_context_llm_id),
|
||||
fast: llmConfigs.find(c => c.id === preferences.fast_llm_id),
|
||||
strategic: llmConfigs.find(c => c.id === preferences.strategic_llm_id)
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-8">
|
||||
{/* Success Message */}
|
||||
<motion.div
|
||||
initial={{ opacity: 0, scale: 0.95 }}
|
||||
animate={{ opacity: 1, scale: 1 }}
|
||||
transition={{ duration: 0.5 }}
|
||||
className="text-center"
|
||||
>
|
||||
<div className="w-20 h-20 mx-auto mb-6 bg-green-100 rounded-full flex items-center justify-center">
|
||||
<CheckCircle className="w-10 h-10 text-green-600" />
|
||||
</div>
|
||||
<h2 className="text-2xl font-bold mb-2">Setup Complete!</h2>
|
||||
</motion.div>
|
||||
|
||||
{/* Configuration Summary */}
|
||||
<motion.div
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ delay: 0.2 }}
|
||||
>
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2">
|
||||
<Sparkles className="w-5 h-5" />
|
||||
Your LLM Configuration
|
||||
</CardTitle>
|
||||
<CardDescription>
|
||||
Here's a summary of your setup
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
{Object.entries(assignedConfigs).map(([role, config]) => {
|
||||
if (!config) return null;
|
||||
|
||||
const IconComponent = ROLE_ICONS[role as keyof typeof ROLE_ICONS];
|
||||
const roleDisplayNames = {
|
||||
long_context: 'Long Context LLM',
|
||||
fast: 'Fast LLM',
|
||||
strategic: 'Strategic LLM'
|
||||
};
|
||||
|
||||
return (
|
||||
<motion.div
|
||||
key={role}
|
||||
initial={{ opacity: 0, x: -10 }}
|
||||
animate={{ opacity: 1, x: 0 }}
|
||||
transition={{ delay: 0.3 + Object.keys(assignedConfigs).indexOf(role) * 0.1 }}
|
||||
className="flex items-center justify-between p-3 bg-muted/50 rounded-lg"
|
||||
>
|
||||
<div className="flex items-center gap-3">
|
||||
<div className="p-2 bg-background rounded-md">
|
||||
<IconComponent className="w-4 h-4" />
|
||||
</div>
|
||||
<div>
|
||||
<p className="font-medium">{roleDisplayNames[role as keyof typeof roleDisplayNames]}</p>
|
||||
<p className="text-sm text-muted-foreground">{config.name}</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<Badge variant="outline">{config.provider}</Badge>
|
||||
<span className="text-sm text-muted-foreground">{config.model_name}</span>
|
||||
</div>
|
||||
</motion.div>
|
||||
);
|
||||
})}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</motion.div>
|
||||
|
||||
|
||||
{/* Next Steps */}
|
||||
<motion.div
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ delay: 0.6 }}
|
||||
>
|
||||
<Card className="border-primary/20 bg-primary/5">
|
||||
<CardContent className="pt-6">
|
||||
<div className="flex items-center gap-3 mb-4">
|
||||
<div className="p-2 bg-primary rounded-md">
|
||||
<ArrowRight className="w-4 h-4 text-primary-foreground" />
|
||||
</div>
|
||||
<h3 className="text-lg font-semibold">Ready to Get Started?</h3>
|
||||
</div>
|
||||
<p className="text-muted-foreground mb-4">
|
||||
Click "Complete Setup" to enter your dashboard and start exploring!
|
||||
</p>
|
||||
<div className="flex flex-wrap gap-2 text-sm">
|
||||
<Badge variant="secondary">✓ {llmConfigs.length} LLM provider{llmConfigs.length > 1 ? 's' : ''} configured</Badge>
|
||||
<Badge variant="secondary">✓ All roles assigned</Badge>
|
||||
<Badge variant="secondary">✓ Ready to use</Badge>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</motion.div>
|
||||
</div>
|
||||
);
|
||||
}
|
465
surfsense_web/components/settings/llm-role-manager.tsx
Normal file
465
surfsense_web/components/settings/llm-role-manager.tsx
Normal file
|
@ -0,0 +1,465 @@
|
|||
"use client";
|
||||
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { motion } from 'framer-motion';
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
|
||||
import { Badge } from '@/components/ui/badge';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import {
|
||||
Brain,
|
||||
Zap,
|
||||
Bot,
|
||||
AlertCircle,
|
||||
CheckCircle,
|
||||
Settings2,
|
||||
RefreshCw,
|
||||
Save,
|
||||
RotateCcw,
|
||||
Loader2
|
||||
} from 'lucide-react';
|
||||
import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs';
|
||||
import { Alert, AlertDescription } from '@/components/ui/alert';
|
||||
import { toast } from 'sonner';
|
||||
|
||||
const ROLE_DESCRIPTIONS = {
|
||||
long_context: {
|
||||
icon: Brain,
|
||||
title: 'Long Context LLM',
|
||||
description: 'Handles complex tasks requiring extensive context understanding and reasoning',
|
||||
color: 'bg-blue-100 text-blue-800 border-blue-200',
|
||||
examples: 'Document analysis, research synthesis, complex Q&A',
|
||||
characteristics: ['Large context window', 'Deep reasoning', 'Complex analysis']
|
||||
},
|
||||
fast: {
|
||||
icon: Zap,
|
||||
title: 'Fast LLM',
|
||||
description: 'Optimized for quick responses and real-time interactions',
|
||||
color: 'bg-green-100 text-green-800 border-green-200',
|
||||
examples: 'Quick searches, simple questions, instant responses',
|
||||
characteristics: ['Low latency', 'Quick responses', 'Real-time chat']
|
||||
},
|
||||
strategic: {
|
||||
icon: Bot,
|
||||
title: 'Strategic LLM',
|
||||
description: 'Advanced reasoning for planning and strategic decision making',
|
||||
color: 'bg-purple-100 text-purple-800 border-purple-200',
|
||||
examples: 'Planning workflows, strategic analysis, complex problem solving',
|
||||
characteristics: ['Strategic thinking', 'Long-term planning', 'Complex reasoning']
|
||||
}
|
||||
};
|
||||
|
||||
export function LLMRoleManager() {
|
||||
const { llmConfigs, loading: configsLoading, error: configsError, refreshConfigs } = useLLMConfigs();
|
||||
const { preferences, loading: preferencesLoading, error: preferencesError, updatePreferences, refreshPreferences } = useLLMPreferences();
|
||||
|
||||
const [assignments, setAssignments] = useState({
|
||||
long_context_llm_id: preferences.long_context_llm_id || '',
|
||||
fast_llm_id: preferences.fast_llm_id || '',
|
||||
strategic_llm_id: preferences.strategic_llm_id || ''
|
||||
});
|
||||
|
||||
const [hasChanges, setHasChanges] = useState(false);
|
||||
const [isSaving, setIsSaving] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const newAssignments = {
|
||||
long_context_llm_id: preferences.long_context_llm_id || '',
|
||||
fast_llm_id: preferences.fast_llm_id || '',
|
||||
strategic_llm_id: preferences.strategic_llm_id || ''
|
||||
};
|
||||
setAssignments(newAssignments);
|
||||
setHasChanges(false);
|
||||
}, [preferences]);
|
||||
|
||||
const handleRoleAssignment = (role: string, configId: string) => {
|
||||
const newAssignments = {
|
||||
...assignments,
|
||||
[role]: configId === 'unassigned' ? '' : parseInt(configId)
|
||||
};
|
||||
|
||||
setAssignments(newAssignments);
|
||||
|
||||
// Check if there are changes compared to current preferences
|
||||
const currentPrefs = {
|
||||
long_context_llm_id: preferences.long_context_llm_id || '',
|
||||
fast_llm_id: preferences.fast_llm_id || '',
|
||||
strategic_llm_id: preferences.strategic_llm_id || ''
|
||||
};
|
||||
|
||||
const hasChangesNow = Object.keys(newAssignments).some(
|
||||
key => newAssignments[key as keyof typeof newAssignments] !== currentPrefs[key as keyof typeof currentPrefs]
|
||||
);
|
||||
|
||||
setHasChanges(hasChangesNow);
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
setIsSaving(true);
|
||||
|
||||
const numericAssignments = {
|
||||
long_context_llm_id: typeof assignments.long_context_llm_id === 'string'
|
||||
? (assignments.long_context_llm_id ? parseInt(assignments.long_context_llm_id) : undefined)
|
||||
: assignments.long_context_llm_id,
|
||||
fast_llm_id: typeof assignments.fast_llm_id === 'string'
|
||||
? (assignments.fast_llm_id ? parseInt(assignments.fast_llm_id) : undefined)
|
||||
: assignments.fast_llm_id,
|
||||
strategic_llm_id: typeof assignments.strategic_llm_id === 'string'
|
||||
? (assignments.strategic_llm_id ? parseInt(assignments.strategic_llm_id) : undefined)
|
||||
: assignments.strategic_llm_id,
|
||||
};
|
||||
|
||||
const success = await updatePreferences(numericAssignments);
|
||||
|
||||
if (success) {
|
||||
setHasChanges(false);
|
||||
toast.success('LLM role assignments saved successfully!');
|
||||
}
|
||||
|
||||
setIsSaving(false);
|
||||
};
|
||||
|
||||
const handleReset = () => {
|
||||
setAssignments({
|
||||
long_context_llm_id: preferences.long_context_llm_id || '',
|
||||
fast_llm_id: preferences.fast_llm_id || '',
|
||||
strategic_llm_id: preferences.strategic_llm_id || ''
|
||||
});
|
||||
setHasChanges(false);
|
||||
};
|
||||
|
||||
const isAssignmentComplete = assignments.long_context_llm_id && assignments.fast_llm_id && assignments.strategic_llm_id;
|
||||
const assignedConfigIds = Object.values(assignments).filter(id => id !== '');
|
||||
const availableConfigs = llmConfigs.filter(config => config.id && config.id.toString().trim() !== '');
|
||||
|
||||
const isLoading = configsLoading || preferencesLoading;
|
||||
const hasError = configsError || preferencesError;
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{/* Header */}
|
||||
<div className="flex flex-col space-y-4 lg:flex-row lg:items-center lg:justify-between lg:space-y-0">
|
||||
<div className="space-y-1">
|
||||
<div className="flex items-center space-x-3">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-lg bg-purple-500/10">
|
||||
<Settings2 className="h-5 w-5 text-purple-600" />
|
||||
</div>
|
||||
<div>
|
||||
<h2 className="text-2xl font-bold tracking-tight">LLM Role Management</h2>
|
||||
<p className="text-muted-foreground">
|
||||
Assign your LLM configurations to specific roles for different purposes.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={refreshConfigs}
|
||||
disabled={isLoading}
|
||||
className="flex items-center gap-2"
|
||||
>
|
||||
<RefreshCw className={`h-4 w-4 ${configsLoading ? 'animate-spin' : ''}`} />
|
||||
<span className="hidden sm:inline">Refresh Configs</span>
|
||||
<span className="sm:hidden">Configs</span>
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={refreshPreferences}
|
||||
disabled={isLoading}
|
||||
className="flex items-center gap-2"
|
||||
>
|
||||
<RefreshCw className={`h-4 w-4 ${preferencesLoading ? 'animate-spin' : ''}`} />
|
||||
<span className="hidden sm:inline">Refresh Preferences</span>
|
||||
<span className="sm:hidden">Prefs</span>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Error Alert */}
|
||||
{hasError && (
|
||||
<Alert variant="destructive">
|
||||
<AlertCircle className="h-4 w-4" />
|
||||
<AlertDescription>
|
||||
{configsError || preferencesError}
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{/* Loading State */}
|
||||
{isLoading && (
|
||||
<Card>
|
||||
<CardContent className="flex items-center justify-center py-12">
|
||||
<div className="flex items-center gap-2 text-muted-foreground">
|
||||
<Loader2 className="w-5 h-5 animate-spin" />
|
||||
<span>
|
||||
{configsLoading && preferencesLoading ? 'Loading configurations and preferences...' :
|
||||
configsLoading ? 'Loading configurations...' :
|
||||
'Loading preferences...'}
|
||||
</span>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
|
||||
{/* Stats Overview */}
|
||||
{!isLoading && !hasError && (
|
||||
<div className="grid gap-4 sm:grid-cols-2 lg:grid-cols-4">
|
||||
<Card className="border-l-4 border-l-blue-500">
|
||||
<CardContent className="p-6">
|
||||
<div className="flex items-center justify-between space-x-4">
|
||||
<div className="space-y-1">
|
||||
<p className="text-3xl font-bold tracking-tight">{availableConfigs.length}</p>
|
||||
<p className="text-sm font-medium text-muted-foreground">Available Models</p>
|
||||
</div>
|
||||
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-blue-500/10">
|
||||
<Bot className="h-6 w-6 text-blue-600" />
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<Card className="border-l-4 border-l-purple-500">
|
||||
<CardContent className="p-6">
|
||||
<div className="flex items-center justify-between space-x-4">
|
||||
<div className="space-y-1">
|
||||
<p className="text-3xl font-bold tracking-tight">{assignedConfigIds.length}</p>
|
||||
<p className="text-sm font-medium text-muted-foreground">Assigned Roles</p>
|
||||
</div>
|
||||
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-purple-500/10">
|
||||
<CheckCircle className="h-6 w-6 text-purple-600" />
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<Card className={`border-l-4 ${
|
||||
isAssignmentComplete ? 'border-l-green-500' : 'border-l-yellow-500'
|
||||
}`}>
|
||||
<CardContent className="p-6">
|
||||
<div className="flex items-center justify-between space-x-4">
|
||||
<div className="space-y-1">
|
||||
<p className="text-3xl font-bold tracking-tight">
|
||||
{Math.round((assignedConfigIds.length / 3) * 100)}%
|
||||
</p>
|
||||
<p className="text-sm font-medium text-muted-foreground">Completion</p>
|
||||
</div>
|
||||
<div className={`flex h-12 w-12 items-center justify-center rounded-lg ${
|
||||
isAssignmentComplete ? 'bg-green-500/10' : 'bg-yellow-500/10'
|
||||
}`}>
|
||||
{isAssignmentComplete ? (
|
||||
<CheckCircle className="h-6 w-6 text-green-600" />
|
||||
) : (
|
||||
<AlertCircle className="h-6 w-6 text-yellow-600" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<Card className={`border-l-4 ${
|
||||
isAssignmentComplete ? 'border-l-emerald-500' : 'border-l-orange-500'
|
||||
}`}>
|
||||
<CardContent className="p-6">
|
||||
<div className="flex items-center justify-between space-x-4">
|
||||
<div className="space-y-1">
|
||||
<p className={`text-3xl font-bold tracking-tight ${
|
||||
isAssignmentComplete ? 'text-emerald-600' : 'text-orange-600'
|
||||
}`}>
|
||||
{isAssignmentComplete ? 'Ready' : 'Setup'}
|
||||
</p>
|
||||
<p className="text-sm font-medium text-muted-foreground">Status</p>
|
||||
</div>
|
||||
<div className={`flex h-12 w-12 items-center justify-center rounded-lg ${
|
||||
isAssignmentComplete ? 'bg-emerald-500/10' : 'bg-orange-500/10'
|
||||
}`}>
|
||||
{isAssignmentComplete ? (
|
||||
<CheckCircle className="h-6 w-6 text-emerald-600" />
|
||||
) : (
|
||||
<RefreshCw className="h-6 w-6 text-orange-600" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Info Alert */}
|
||||
{!isLoading && !hasError && (
|
||||
<div className="space-y-6">
|
||||
{availableConfigs.length === 0 ? (
|
||||
<Alert variant="destructive">
|
||||
<AlertCircle className="h-4 w-4" />
|
||||
<AlertDescription>
|
||||
No LLM configurations found. Please add at least one LLM provider in the Model Configs tab before assigning roles.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
) : !isAssignmentComplete ? (
|
||||
<Alert>
|
||||
<AlertCircle className="h-4 w-4" />
|
||||
<AlertDescription>
|
||||
Complete all role assignments to enable full functionality. Each role serves different purposes in your workflow.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
) : (
|
||||
<Alert>
|
||||
<CheckCircle className="h-4 w-4" />
|
||||
<AlertDescription>
|
||||
All roles are assigned and ready to use! Your LLM configuration is complete.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{/* Role Assignment Cards */}
|
||||
{availableConfigs.length > 0 && (
|
||||
<div className="grid gap-6">
|
||||
{Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => {
|
||||
const IconComponent = role.icon;
|
||||
const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments];
|
||||
const assignedConfig = availableConfigs.find(config => config.id === currentAssignment);
|
||||
|
||||
return (
|
||||
<motion.div
|
||||
key={key}
|
||||
initial={{ opacity: 0, y: 10 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ delay: Object.keys(ROLE_DESCRIPTIONS).indexOf(key) * 0.1 }}
|
||||
>
|
||||
<Card className={`border-l-4 ${currentAssignment ? 'border-l-primary' : 'border-l-muted'} hover:shadow-md transition-shadow`}>
|
||||
<CardHeader className="pb-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
<div className={`p-2 rounded-lg ${role.color}`}>
|
||||
<IconComponent className="w-5 h-5" />
|
||||
</div>
|
||||
<div>
|
||||
<CardTitle className="text-lg">{role.title}</CardTitle>
|
||||
<CardDescription className="mt-1">{role.description}</CardDescription>
|
||||
</div>
|
||||
</div>
|
||||
{currentAssignment && (
|
||||
<CheckCircle className="w-5 h-5 text-green-500" />
|
||||
)}
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
<div className="text-sm text-muted-foreground">
|
||||
<strong>Use cases:</strong> {role.examples}
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{role.characteristics.map((char, idx) => (
|
||||
<Badge key={idx} variant="outline" className="text-xs">
|
||||
{char}
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Assign LLM Configuration:</label>
|
||||
<Select
|
||||
value={currentAssignment?.toString() || 'unassigned'}
|
||||
onValueChange={(value) => handleRoleAssignment(`${key}_llm_id`, value)}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select an LLM configuration" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="unassigned">
|
||||
<span className="text-muted-foreground">Unassigned</span>
|
||||
</SelectItem>
|
||||
{availableConfigs.map((config) => (
|
||||
<SelectItem key={config.id} value={config.id.toString()}>
|
||||
<div className="flex items-center gap-2">
|
||||
<Badge variant="outline" className="text-xs">
|
||||
{config.provider}
|
||||
</Badge>
|
||||
<span>{config.name}</span>
|
||||
<span className="text-muted-foreground">({config.model_name})</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
{assignedConfig && (
|
||||
<div className="mt-3 p-3 bg-muted/50 rounded-lg">
|
||||
<div className="flex items-center gap-2 text-sm">
|
||||
<Bot className="w-4 h-4" />
|
||||
<span className="font-medium">Assigned:</span>
|
||||
<Badge variant="secondary">{assignedConfig.provider}</Badge>
|
||||
<span>{assignedConfig.name}</span>
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground mt-1">
|
||||
Model: {assignedConfig.model_name}
|
||||
</div>
|
||||
{assignedConfig.api_base && (
|
||||
<div className="text-xs text-muted-foreground">
|
||||
Base: {assignedConfig.api_base}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</motion.div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Action Buttons */}
|
||||
{hasChanges && (
|
||||
<div className="flex justify-center gap-3 pt-4">
|
||||
<Button onClick={handleSave} disabled={isSaving} className="flex items-center gap-2">
|
||||
<Save className="w-4 h-4" />
|
||||
{isSaving ? 'Saving...' : 'Save Changes'}
|
||||
</Button>
|
||||
<Button variant="outline" onClick={handleReset} disabled={isSaving} className="flex items-center gap-2">
|
||||
<RotateCcw className="w-4 h-4" />
|
||||
Reset
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Status Indicator */}
|
||||
{isAssignmentComplete && !hasChanges && (
|
||||
<div className="flex justify-center pt-4">
|
||||
<div className="flex items-center gap-2 px-4 py-2 bg-green-50 text-green-700 rounded-lg border border-green-200">
|
||||
<CheckCircle className="w-4 h-4" />
|
||||
<span className="text-sm font-medium">All roles assigned and saved!</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Progress Indicator */}
|
||||
<div className="flex justify-center">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<span>Progress:</span>
|
||||
<div className="flex gap-1">
|
||||
{Object.keys(ROLE_DESCRIPTIONS).map((key, index) => (
|
||||
<div
|
||||
key={key}
|
||||
className={`w-2 h-2 rounded-full ${
|
||||
assignments[`${key}_llm_id` as keyof typeof assignments]
|
||||
? 'bg-primary'
|
||||
: 'bg-muted'
|
||||
}`}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
<span>
|
||||
{assignedConfigIds.length} of {Object.keys(ROLE_DESCRIPTIONS).length} roles assigned
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
631
surfsense_web/components/settings/model-config-manager.tsx
Normal file
631
surfsense_web/components/settings/model-config-manager.tsx
Normal file
|
@ -0,0 +1,631 @@
|
|||
"use client";
|
||||
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { motion, AnimatePresence } from 'framer-motion';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { Label } from '@/components/ui/label';
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select';
|
||||
import { Badge } from '@/components/ui/badge';
|
||||
import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle } from '@/components/ui/dialog';
|
||||
import {
|
||||
Plus,
|
||||
Trash2,
|
||||
Bot,
|
||||
AlertCircle,
|
||||
Edit3,
|
||||
Settings2,
|
||||
Eye,
|
||||
EyeOff,
|
||||
CheckCircle,
|
||||
Clock,
|
||||
AlertTriangle,
|
||||
RefreshCw,
|
||||
Loader2
|
||||
} from 'lucide-react';
|
||||
import { useLLMConfigs, CreateLLMConfig, UpdateLLMConfig, LLMConfig } from '@/hooks/use-llm-configs';
|
||||
import { toast } from 'sonner';
|
||||
import { Alert, AlertDescription } from '@/components/ui/alert';
|
||||
|
||||
const LLM_PROVIDERS = [
|
||||
{
|
||||
value: 'OPENAI',
|
||||
label: 'OpenAI',
|
||||
example: 'gpt-4o, gpt-4, gpt-3.5-turbo',
|
||||
description: 'Most popular and versatile AI models'
|
||||
},
|
||||
{
|
||||
value: 'ANTHROPIC',
|
||||
label: 'Anthropic',
|
||||
example: 'claude-3-5-sonnet-20241022, claude-3-opus-20240229',
|
||||
description: 'Constitutional AI with strong reasoning'
|
||||
},
|
||||
{
|
||||
value: 'GROQ',
|
||||
label: 'Groq',
|
||||
example: 'llama3-70b-8192, mixtral-8x7b-32768',
|
||||
description: 'Ultra-fast inference speeds'
|
||||
},
|
||||
{
|
||||
value: 'COHERE',
|
||||
label: 'Cohere',
|
||||
example: 'command-r-plus, command-r',
|
||||
description: 'Enterprise-focused language models'
|
||||
},
|
||||
{
|
||||
value: 'HUGGINGFACE',
|
||||
label: 'HuggingFace',
|
||||
example: 'microsoft/DialoGPT-medium',
|
||||
description: 'Open source model hub'
|
||||
},
|
||||
{
|
||||
value: 'AZURE_OPENAI',
|
||||
label: 'Azure OpenAI',
|
||||
example: 'gpt-4, gpt-35-turbo',
|
||||
description: 'Enterprise OpenAI through Azure'
|
||||
},
|
||||
{
|
||||
value: 'GOOGLE',
|
||||
label: 'Google',
|
||||
example: 'gemini-pro, gemini-pro-vision',
|
||||
description: 'Google\'s Gemini AI models'
|
||||
},
|
||||
{
|
||||
value: 'AWS_BEDROCK',
|
||||
label: 'AWS Bedrock',
|
||||
example: 'anthropic.claude-v2',
|
||||
description: 'AWS managed AI service'
|
||||
},
|
||||
{
|
||||
value: 'OLLAMA',
|
||||
label: 'Ollama',
|
||||
example: 'llama2, codellama',
|
||||
description: 'Run models locally'
|
||||
},
|
||||
{
|
||||
value: 'MISTRAL',
|
||||
label: 'Mistral',
|
||||
example: 'mistral-large-latest, mistral-medium',
|
||||
description: 'European AI excellence'
|
||||
},
|
||||
{
|
||||
value: 'TOGETHER_AI',
|
||||
label: 'Together AI',
|
||||
example: 'togethercomputer/llama-2-70b-chat',
|
||||
description: 'Decentralized AI platform'
|
||||
},
|
||||
{
|
||||
value: 'REPLICATE',
|
||||
label: 'Replicate',
|
||||
example: 'meta/llama-2-70b-chat',
|
||||
description: 'Run models via API'
|
||||
},
|
||||
{
|
||||
value: 'CUSTOM',
|
||||
label: 'Custom Provider',
|
||||
example: 'your-custom-model',
|
||||
description: 'Your own model endpoint'
|
||||
},
|
||||
];
|
||||
|
||||
export function ModelConfigManager() {
|
||||
const { llmConfigs, loading, error, createLLMConfig, updateLLMConfig, deleteLLMConfig, refreshConfigs } = useLLMConfigs();
|
||||
const [isAddingNew, setIsAddingNew] = useState(false);
|
||||
const [editingConfig, setEditingConfig] = useState<LLMConfig | null>(null);
|
||||
const [showApiKey, setShowApiKey] = useState<Record<number, boolean>>({});
|
||||
const [formData, setFormData] = useState<CreateLLMConfig>({
|
||||
name: '',
|
||||
provider: '',
|
||||
custom_provider: '',
|
||||
model_name: '',
|
||||
api_key: '',
|
||||
api_base: '',
|
||||
litellm_params: {}
|
||||
});
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
|
||||
// Populate form when editing
|
||||
useEffect(() => {
|
||||
if (editingConfig) {
|
||||
setFormData({
|
||||
name: editingConfig.name,
|
||||
provider: editingConfig.provider,
|
||||
custom_provider: editingConfig.custom_provider || '',
|
||||
model_name: editingConfig.model_name,
|
||||
api_key: editingConfig.api_key,
|
||||
api_base: editingConfig.api_base || '',
|
||||
litellm_params: editingConfig.litellm_params || {}
|
||||
});
|
||||
}
|
||||
}, [editingConfig]);
|
||||
|
||||
const handleInputChange = (field: keyof CreateLLMConfig, value: string) => {
|
||||
setFormData(prev => ({ ...prev, [field]: value }));
|
||||
};
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if (!formData.name || !formData.provider || !formData.model_name || !formData.api_key) {
|
||||
toast.error('Please fill in all required fields');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsSubmitting(true);
|
||||
|
||||
let result;
|
||||
if (editingConfig) {
|
||||
// Update existing config
|
||||
result = await updateLLMConfig(editingConfig.id, formData);
|
||||
} else {
|
||||
// Create new config
|
||||
result = await createLLMConfig(formData);
|
||||
}
|
||||
|
||||
setIsSubmitting(false);
|
||||
|
||||
if (result) {
|
||||
setFormData({
|
||||
name: '',
|
||||
provider: '',
|
||||
custom_provider: '',
|
||||
model_name: '',
|
||||
api_key: '',
|
||||
api_base: '',
|
||||
litellm_params: {}
|
||||
});
|
||||
setIsAddingNew(false);
|
||||
setEditingConfig(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDelete = async (id: number) => {
|
||||
if (confirm('Are you sure you want to delete this configuration? This action cannot be undone.')) {
|
||||
await deleteLLMConfig(id);
|
||||
}
|
||||
};
|
||||
|
||||
const toggleApiKeyVisibility = (configId: number) => {
|
||||
setShowApiKey(prev => ({
|
||||
...prev,
|
||||
[configId]: !prev[configId]
|
||||
}));
|
||||
};
|
||||
|
||||
const selectedProvider = LLM_PROVIDERS.find(p => p.value === formData.provider);
|
||||
|
||||
const getProviderInfo = (providerValue: string) => {
|
||||
return LLM_PROVIDERS.find(p => p.value === providerValue);
|
||||
};
|
||||
|
||||
const maskApiKey = (apiKey: string) => {
|
||||
if (apiKey.length <= 8) return '*'.repeat(apiKey.length);
|
||||
return apiKey.substring(0, 4) + '*'.repeat(apiKey.length - 8) + apiKey.substring(apiKey.length - 4);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
{/* Header */}
|
||||
<div className="flex flex-col space-y-4 sm:flex-row sm:items-center sm:justify-between sm:space-y-0">
|
||||
<div className="space-y-1">
|
||||
<div className="flex items-center space-x-3">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-lg bg-blue-500/10">
|
||||
<Settings2 className="h-5 w-5 text-blue-600" />
|
||||
</div>
|
||||
<div>
|
||||
<h2 className="text-2xl font-bold tracking-tight">Model Configurations</h2>
|
||||
<p className="text-muted-foreground">
|
||||
Manage your LLM provider configurations and API settings.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={refreshConfigs}
|
||||
disabled={loading}
|
||||
className="flex items-center gap-2"
|
||||
>
|
||||
<RefreshCw className={`h-4 w-4 ${loading ? 'animate-spin' : ''}`} />
|
||||
Refresh
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Error Alert */}
|
||||
{error && (
|
||||
<Alert variant="destructive">
|
||||
<AlertCircle className="h-4 w-4" />
|
||||
<AlertDescription>
|
||||
{error}
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{/* Loading State */}
|
||||
{loading && (
|
||||
<Card>
|
||||
<CardContent className="flex items-center justify-center py-12">
|
||||
<div className="flex items-center gap-2 text-muted-foreground">
|
||||
<Loader2 className="w-5 h-5 animate-spin" />
|
||||
<span>Loading configurations...</span>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
|
||||
{/* Stats Overview */}
|
||||
{!loading && !error && (
|
||||
<div className="grid gap-4 md:grid-cols-3">
|
||||
<Card className="border-l-4 border-l-blue-500">
|
||||
<CardContent className="p-6">
|
||||
<div className="flex items-center justify-between space-x-4">
|
||||
<div className="space-y-1">
|
||||
<p className="text-3xl font-bold tracking-tight">{llmConfigs.length}</p>
|
||||
<p className="text-sm font-medium text-muted-foreground">Total Configurations</p>
|
||||
</div>
|
||||
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-blue-500/10">
|
||||
<Bot className="h-6 w-6 text-blue-600" />
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<Card className="border-l-4 border-l-green-500">
|
||||
<CardContent className="p-6">
|
||||
<div className="flex items-center justify-between space-x-4">
|
||||
<div className="space-y-1">
|
||||
<p className="text-3xl font-bold tracking-tight">
|
||||
{new Set(llmConfigs.map(c => c.provider)).size}
|
||||
</p>
|
||||
<p className="text-sm font-medium text-muted-foreground">Unique Providers</p>
|
||||
</div>
|
||||
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-green-500/10">
|
||||
<CheckCircle className="h-6 w-6 text-green-600" />
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<Card className="border-l-4 border-l-emerald-500">
|
||||
<CardContent className="p-6">
|
||||
<div className="flex items-center justify-between space-x-4">
|
||||
<div className="space-y-1">
|
||||
<p className="text-3xl font-bold tracking-tight text-emerald-600">Active</p>
|
||||
<p className="text-sm font-medium text-muted-foreground">System Status</p>
|
||||
</div>
|
||||
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-emerald-500/10">
|
||||
<CheckCircle className="h-6 w-6 text-emerald-600" />
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Configuration Management */}
|
||||
{!loading && !error && (
|
||||
<div className="space-y-6">
|
||||
<div className="flex flex-col space-y-4 sm:flex-row sm:items-center sm:justify-between sm:space-y-0">
|
||||
<div>
|
||||
<h3 className="text-xl font-semibold tracking-tight">Your Configurations</h3>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Manage and configure your LLM providers
|
||||
</p>
|
||||
</div>
|
||||
<Button onClick={() => setIsAddingNew(true)} className="flex items-center gap-2">
|
||||
<Plus className="h-4 w-4" />
|
||||
Add Configuration
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{llmConfigs.length === 0 ? (
|
||||
<Card className="border-dashed border-2 border-muted-foreground/25">
|
||||
<CardContent className="flex flex-col items-center justify-center py-16 text-center">
|
||||
<div className="rounded-full bg-muted p-4 mb-6">
|
||||
<Bot className="h-10 w-10 text-muted-foreground" />
|
||||
</div>
|
||||
<div className="space-y-2 mb-6">
|
||||
<h3 className="text-xl font-semibold">No Configurations Yet</h3>
|
||||
<p className="text-muted-foreground max-w-sm">
|
||||
Get started by adding your first LLM provider configuration to begin using the system.
|
||||
</p>
|
||||
</div>
|
||||
<Button onClick={() => setIsAddingNew(true)} size="lg">
|
||||
<Plus className="h-4 w-4 mr-2" />
|
||||
Add First Configuration
|
||||
</Button>
|
||||
</CardContent>
|
||||
</Card>
|
||||
) : (
|
||||
<div className="grid gap-4">
|
||||
<AnimatePresence>
|
||||
{llmConfigs.map((config) => {
|
||||
const providerInfo = getProviderInfo(config.provider);
|
||||
return (
|
||||
<motion.div
|
||||
key={config.id}
|
||||
initial={{ opacity: 0, y: 10 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -10 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
>
|
||||
<Card className="group border-l-4 border-l-primary/50 hover:border-l-primary hover:shadow-md transition-all duration-200">
|
||||
<CardContent className="p-6">
|
||||
<div className="flex items-start justify-between">
|
||||
<div className="flex-1 space-y-4">
|
||||
{/* Header */}
|
||||
<div className="flex items-start gap-4">
|
||||
<div className="flex h-12 w-12 items-center justify-center rounded-lg bg-primary/10 group-hover:bg-primary/20 transition-colors">
|
||||
<Bot className="h-6 w-6 text-primary" />
|
||||
</div>
|
||||
<div className="flex-1 space-y-2">
|
||||
<div className="flex items-center gap-3">
|
||||
<h4 className="text-lg font-semibold tracking-tight">{config.name}</h4>
|
||||
<Badge variant="secondary" className="text-xs font-medium">
|
||||
{config.provider}
|
||||
</Badge>
|
||||
</div>
|
||||
<p className="text-sm text-muted-foreground font-mono">
|
||||
{config.model_name}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Provider Description */}
|
||||
{providerInfo && (
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{providerInfo.description}
|
||||
</p>
|
||||
)}
|
||||
|
||||
{/* Configuration Details */}
|
||||
<div className="grid gap-4 sm:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<Label className="text-xs font-medium uppercase tracking-wide text-muted-foreground">
|
||||
API Key
|
||||
</Label>
|
||||
<div className="flex items-center space-x-2">
|
||||
<code className="flex-1 rounded-md bg-muted px-3 py-2 text-xs font-mono">
|
||||
{showApiKey[config.id]
|
||||
? config.api_key
|
||||
: maskApiKey(config.api_key)
|
||||
}
|
||||
</code>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => toggleApiKeyVisibility(config.id)}
|
||||
className="h-8 w-8 p-0"
|
||||
>
|
||||
{showApiKey[config.id] ? (
|
||||
<EyeOff className="h-3 w-3" />
|
||||
) : (
|
||||
<Eye className="h-3 w-3" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{config.api_base && (
|
||||
<div className="space-y-2">
|
||||
<Label className="text-xs font-medium uppercase tracking-wide text-muted-foreground">
|
||||
API Base URL
|
||||
</Label>
|
||||
<code className="block rounded-md bg-muted px-3 py-2 text-xs font-mono break-all">
|
||||
{config.api_base}
|
||||
</code>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Metadata */}
|
||||
<div className="flex flex-wrap items-center gap-4 pt-4 border-t border-border/50">
|
||||
<div className="flex items-center gap-2 text-xs text-muted-foreground">
|
||||
<Clock className="h-3 w-3" />
|
||||
<span>Created {new Date(config.created_at).toLocaleDateString()}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2 text-xs">
|
||||
<div className="h-2 w-2 rounded-full bg-green-500"></div>
|
||||
<span className="text-green-600 font-medium">Active</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Actions */}
|
||||
<div className="flex flex-col gap-2 ml-6">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setEditingConfig(config)}
|
||||
className="h-8 w-8 p-0"
|
||||
>
|
||||
<Edit3 className="h-4 w-4" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => handleDelete(config.id)}
|
||||
className="h-8 w-8 p-0 border-destructive/20 text-destructive hover:bg-destructive hover:text-destructive-foreground"
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</motion.div>
|
||||
);
|
||||
})}
|
||||
</AnimatePresence>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Add/Edit Configuration Dialog */}
|
||||
<Dialog open={isAddingNew || !!editingConfig} onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
setIsAddingNew(false);
|
||||
setEditingConfig(null);
|
||||
setFormData({
|
||||
name: '',
|
||||
provider: '',
|
||||
custom_provider: '',
|
||||
model_name: '',
|
||||
api_key: '',
|
||||
api_base: '',
|
||||
litellm_params: {}
|
||||
});
|
||||
}
|
||||
}}>
|
||||
<DialogContent className="max-w-2xl max-h-[90vh] overflow-y-auto">
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
{editingConfig ? <Edit3 className="w-5 h-5" /> : <Plus className="w-5 h-5" />}
|
||||
{editingConfig ? 'Edit LLM Configuration' : 'Add New LLM Configuration'}
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
{editingConfig
|
||||
? 'Update your language model provider configuration'
|
||||
: 'Configure a new language model provider for your AI assistant'
|
||||
}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<form onSubmit={handleSubmit} className="space-y-4">
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="name">Configuration Name *</Label>
|
||||
<Input
|
||||
id="name"
|
||||
placeholder="e.g., My OpenAI GPT-4"
|
||||
value={formData.name}
|
||||
onChange={(e) => handleInputChange('name', e.target.value)}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="provider">Provider *</Label>
|
||||
<Select value={formData.provider} onValueChange={(value) => handleInputChange('provider', value)}>
|
||||
<SelectTrigger className="h-auto min-h-[2.5rem] py-2">
|
||||
<SelectValue placeholder="Select a provider">
|
||||
{formData.provider && (
|
||||
<div className="flex items-center space-x-2 py-1">
|
||||
<div className="font-medium">
|
||||
{LLM_PROVIDERS.find(p => p.value === formData.provider)?.label}
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
•
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{LLM_PROVIDERS.find(p => p.value === formData.provider)?.description}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</SelectValue>
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{LLM_PROVIDERS.map((provider) => (
|
||||
<SelectItem key={provider.value} value={provider.value}>
|
||||
<div className="space-y-1 py-1">
|
||||
<div className="font-medium">{provider.label}</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{provider.description}
|
||||
</div>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{formData.provider === 'CUSTOM' && (
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="custom_provider">Custom Provider Name *</Label>
|
||||
<Input
|
||||
id="custom_provider"
|
||||
placeholder="e.g., my-custom-provider"
|
||||
value={formData.custom_provider}
|
||||
onChange={(e) => handleInputChange('custom_provider', e.target.value)}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="model_name">Model Name *</Label>
|
||||
<Input
|
||||
id="model_name"
|
||||
placeholder={selectedProvider?.example || "e.g., gpt-4"}
|
||||
value={formData.model_name}
|
||||
onChange={(e) => handleInputChange('model_name', e.target.value)}
|
||||
required
|
||||
/>
|
||||
{selectedProvider && (
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Examples: {selectedProvider.example}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="api_key">API Key *</Label>
|
||||
<Input
|
||||
id="api_key"
|
||||
type="password"
|
||||
placeholder="Your API key"
|
||||
value={formData.api_key}
|
||||
onChange={(e) => handleInputChange('api_key', e.target.value)}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="api_base">API Base URL (Optional)</Label>
|
||||
<Input
|
||||
id="api_base"
|
||||
placeholder="e.g., https://api.openai.com/v1"
|
||||
value={formData.api_base}
|
||||
onChange={(e) => handleInputChange('api_base', e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex gap-2 pt-4">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
{isSubmitting
|
||||
? (editingConfig ? 'Updating...' : 'Adding...')
|
||||
: (editingConfig ? 'Update Configuration' : 'Add Configuration')
|
||||
}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={() => {
|
||||
setIsAddingNew(false);
|
||||
setEditingConfig(null);
|
||||
setFormData({
|
||||
name: '',
|
||||
provider: '',
|
||||
custom_provider: '',
|
||||
model_name: '',
|
||||
api_key: '',
|
||||
api_base: '',
|
||||
litellm_params: {}
|
||||
});
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
</form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
}
|
|
@ -4,6 +4,7 @@ import {
|
|||
BadgeCheck,
|
||||
ChevronsUpDown,
|
||||
LogOut,
|
||||
Settings,
|
||||
} from "lucide-react"
|
||||
|
||||
import {
|
||||
|
@ -93,6 +94,10 @@ export function NavUser({
|
|||
</DropdownMenuItem>
|
||||
</DropdownMenuGroup>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuItem onClick={() => router.push(`/settings`)}>
|
||||
<Settings />
|
||||
Settings
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={handleLogout}>
|
||||
<LogOut />
|
||||
Log out
|
||||
|
|
29
surfsense_web/components/ui/progress.tsx
Normal file
29
surfsense_web/components/ui/progress.tsx
Normal file
|
@ -0,0 +1,29 @@
|
|||
"use client"
|
||||
|
||||
import * as React from "react"
|
||||
import { cn } from "@/lib/utils"
|
||||
|
||||
interface ProgressProps extends React.HTMLAttributes<HTMLDivElement> {
|
||||
value?: number
|
||||
}
|
||||
|
||||
const Progress = React.forwardRef<HTMLDivElement, ProgressProps>(
|
||||
({ className, value = 0, ...props }, ref) => (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"relative h-4 w-full overflow-hidden rounded-full bg-secondary",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<div
|
||||
className="h-full bg-primary transition-all duration-300 ease-in-out"
|
||||
style={{ width: `${Math.min(100, Math.max(0, value))}%` }}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
)
|
||||
Progress.displayName = "Progress"
|
||||
|
||||
export { Progress }
|
246
surfsense_web/hooks/use-llm-configs.ts
Normal file
246
surfsense_web/hooks/use-llm-configs.ts
Normal file
|
@ -0,0 +1,246 @@
|
|||
"use client"
|
||||
import { useState, useEffect } from 'react';
|
||||
import { toast } from 'sonner';
|
||||
|
||||
export interface LLMConfig {
|
||||
id: number;
|
||||
name: string;
|
||||
provider: string;
|
||||
custom_provider?: string;
|
||||
model_name: string;
|
||||
api_key: string;
|
||||
api_base?: string;
|
||||
litellm_params?: Record<string, any>;
|
||||
created_at: string;
|
||||
user_id: string;
|
||||
}
|
||||
|
||||
export interface LLMPreferences {
|
||||
long_context_llm_id?: number;
|
||||
fast_llm_id?: number;
|
||||
strategic_llm_id?: number;
|
||||
long_context_llm?: LLMConfig;
|
||||
fast_llm?: LLMConfig;
|
||||
strategic_llm?: LLMConfig;
|
||||
}
|
||||
|
||||
export interface CreateLLMConfig {
|
||||
name: string;
|
||||
provider: string;
|
||||
custom_provider?: string;
|
||||
model_name: string;
|
||||
api_key: string;
|
||||
api_base?: string;
|
||||
litellm_params?: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface UpdateLLMConfig {
|
||||
name?: string;
|
||||
provider?: string;
|
||||
custom_provider?: string;
|
||||
model_name?: string;
|
||||
api_key?: string;
|
||||
api_base?: string;
|
||||
litellm_params?: Record<string, any>;
|
||||
}
|
||||
|
||||
export function useLLMConfigs() {
|
||||
const [llmConfigs, setLlmConfigs] = useState<LLMConfig[]>([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const fetchLLMConfigs = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`,
|
||||
},
|
||||
method: "GET",
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to fetch LLM configurations");
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
setLlmConfigs(data);
|
||||
setError(null);
|
||||
} catch (err: any) {
|
||||
setError(err.message || 'Failed to fetch LLM configurations');
|
||||
console.error('Error fetching LLM configurations:', err);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchLLMConfigs();
|
||||
}, []);
|
||||
|
||||
const createLLMConfig = async (config: CreateLLMConfig): Promise<LLMConfig | null> => {
|
||||
try {
|
||||
const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`,
|
||||
},
|
||||
body: JSON.stringify(config),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json();
|
||||
throw new Error(errorData.detail || 'Failed to create LLM configuration');
|
||||
}
|
||||
|
||||
const newConfig = await response.json();
|
||||
setLlmConfigs(prev => [...prev, newConfig]);
|
||||
toast.success('LLM configuration created successfully');
|
||||
return newConfig;
|
||||
} catch (err: any) {
|
||||
toast.error(err.message || 'Failed to create LLM configuration');
|
||||
console.error('Error creating LLM configuration:', err);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
const deleteLLMConfig = async (id: number): Promise<boolean> => {
|
||||
try {
|
||||
const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/${id}`, {
|
||||
method: 'DELETE',
|
||||
headers: {
|
||||
Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to delete LLM configuration');
|
||||
}
|
||||
|
||||
setLlmConfigs(prev => prev.filter(config => config.id !== id));
|
||||
toast.success('LLM configuration deleted successfully');
|
||||
return true;
|
||||
} catch (err: any) {
|
||||
toast.error(err.message || 'Failed to delete LLM configuration');
|
||||
console.error('Error deleting LLM configuration:', err);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
const updateLLMConfig = async (id: number, config: UpdateLLMConfig): Promise<LLMConfig | null> => {
|
||||
try {
|
||||
const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/${id}`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`,
|
||||
},
|
||||
body: JSON.stringify(config),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json();
|
||||
throw new Error(errorData.detail || 'Failed to update LLM configuration');
|
||||
}
|
||||
|
||||
const updatedConfig = await response.json();
|
||||
setLlmConfigs(prev => prev.map(c => c.id === id ? updatedConfig : c));
|
||||
toast.success('LLM configuration updated successfully');
|
||||
return updatedConfig;
|
||||
} catch (err: any) {
|
||||
toast.error(err.message || 'Failed to update LLM configuration');
|
||||
console.error('Error updating LLM configuration:', err);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
llmConfigs,
|
||||
loading,
|
||||
error,
|
||||
createLLMConfig,
|
||||
updateLLMConfig,
|
||||
deleteLLMConfig,
|
||||
refreshConfigs: fetchLLMConfigs
|
||||
};
|
||||
}
|
||||
|
||||
export function useLLMPreferences() {
|
||||
const [preferences, setPreferences] = useState<LLMPreferences>({});
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const fetchPreferences = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`,
|
||||
},
|
||||
method: "GET",
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to fetch LLM preferences");
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
setPreferences(data);
|
||||
setError(null);
|
||||
} catch (err: any) {
|
||||
setError(err.message || 'Failed to fetch LLM preferences');
|
||||
console.error('Error fetching LLM preferences:', err);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchPreferences();
|
||||
}, []);
|
||||
|
||||
const updatePreferences = async (newPreferences: Partial<LLMPreferences>): Promise<boolean> => {
|
||||
try {
|
||||
const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`,
|
||||
},
|
||||
body: JSON.stringify(newPreferences),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json();
|
||||
throw new Error(errorData.detail || 'Failed to update LLM preferences');
|
||||
}
|
||||
|
||||
const updatedPreferences = await response.json();
|
||||
setPreferences(updatedPreferences);
|
||||
toast.success('LLM preferences updated successfully');
|
||||
return true;
|
||||
} catch (err: any) {
|
||||
toast.error(err.message || 'Failed to update LLM preferences');
|
||||
console.error('Error updating LLM preferences:', err);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
const isOnboardingComplete = (): boolean => {
|
||||
return !!(
|
||||
preferences.long_context_llm_id &&
|
||||
preferences.fast_llm_id &&
|
||||
preferences.strategic_llm_id
|
||||
);
|
||||
};
|
||||
|
||||
return {
|
||||
preferences,
|
||||
loading,
|
||||
error,
|
||||
updatePreferences,
|
||||
refreshPreferences: fetchPreferences,
|
||||
isOnboardingComplete
|
||||
};
|
||||
}
|
Loading…
Add table
Reference in a new issue