From a85f7920a9d2200bb594a0c4c56e0fa1591ac225 Mon Sep 17 00:00:00 2001 From: "DESKTOP-RTLN3BA\\$punk" Date: Mon, 9 Jun 2025 15:50:15 -0700 Subject: [PATCH] feat: added configurable LLM's --- README.md | 2 +- surfsense_backend/.env.example | 66 +- ..._add_llm_config_table_and_relationships.py | 86 +++ .../app/agents/podcaster/configuration.py | 3 +- .../app/agents/podcaster/nodes.py | 15 +- .../app/agents/podcaster/state.py | 5 +- .../app/agents/researcher/nodes.py | 326 ++++----- .../app/agents/researcher/qna_agent/nodes.py | 14 +- .../researcher/sub_section_writer/nodes.py | 14 +- .../app/agents/researcher/utils.py | 11 +- surfsense_backend/app/config/__init__.py | 31 +- surfsense_backend/app/db.py | 62 ++ surfsense_backend/app/routes/__init__.py | 2 + .../app/routes/documents_routes.py | 49 +- .../app/routes/llm_config_routes.py | 243 +++++++ .../app/routes/podcasts_routes.py | 8 +- .../routes/search_source_connectors_routes.py | 40 +- surfsense_backend/app/schemas/__init__.py | 5 + surfsense_backend/app/schemas/llm_config.py | 34 + .../app/tasks/background_tasks.py | 54 +- .../app/tasks/connectors_indexing_tasks.py | 38 +- surfsense_backend/app/tasks/podcast_tasks.py | 7 +- surfsense_backend/app/utils/llm_service.py | 120 ++++ surfsense_backend/app/utils/query_service.py | 22 +- .../researcher/[chat_id]/page.tsx | 86 ++- surfsense_web/app/dashboard/layout.tsx | 90 +++ surfsense_web/app/onboard/page.tsx | 227 +++++++ surfsense_web/app/settings/page.tsx | 60 ++ .../components/onboard/add-provider-step.tsx | 255 +++++++ .../components/onboard/assign-roles-step.tsx | 232 +++++++ .../components/onboard/completion-step.tsx | 125 ++++ .../components/settings/llm-role-manager.tsx | 465 +++++++++++++ .../settings/model-config-manager.tsx | 631 ++++++++++++++++++ surfsense_web/components/sidebar/nav-user.tsx | 5 + surfsense_web/components/ui/progress.tsx | 29 + surfsense_web/hooks/use-llm-configs.ts | 246 +++++++ 36 files changed, 3415 insertions(+), 293 deletions(-) create mode 100644 surfsense_backend/alembic/versions/11_add_llm_config_table_and_relationships.py create mode 100644 surfsense_backend/app/routes/llm_config_routes.py create mode 100644 surfsense_backend/app/schemas/llm_config.py create mode 100644 surfsense_backend/app/utils/llm_service.py create mode 100644 surfsense_web/app/dashboard/layout.tsx create mode 100644 surfsense_web/app/onboard/page.tsx create mode 100644 surfsense_web/app/settings/page.tsx create mode 100644 surfsense_web/components/onboard/add-provider-step.tsx create mode 100644 surfsense_web/components/onboard/assign-roles-step.tsx create mode 100644 surfsense_web/components/onboard/completion-step.tsx create mode 100644 surfsense_web/components/settings/llm-role-manager.tsx create mode 100644 surfsense_web/components/settings/model-config-manager.tsx create mode 100644 surfsense_web/components/ui/progress.tsx create mode 100644 surfsense_web/hooks/use-llm-configs.ts diff --git a/README.md b/README.md index 7d53ff6..4ad0d86 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Open source and easy to deploy locally. - Support for multiple TTS providers (OpenAI, Azure, Google Vertex AI) ### 📊 **Advanced RAG Techniques** -- Supports 150+ LLM's +- Supports 100+ LLM's - Supports 6000+ Embedding Models. - Supports all major Rerankers (Pinecode, Cohere, Flashrank etc) - Uses Hierarchical Indices (2 tiered RAG setup). diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index c0032a9..0915906 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -1,51 +1,55 @@ -DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense" +DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense -SECRET_KEY="SECRET" -NEXT_FRONTEND_URL="http://localhost:3000" +SECRET_KEY=SECRET +NEXT_FRONTEND_URL=http://localhost:3000 #Auth -AUTH_TYPE="GOOGLE" or "LOCAL" +AUTH_TYPE=GOOGLE or LOCAL # For Google Auth Only -GOOGLE_OAUTH_CLIENT_ID="924507538m" -GOOGLE_OAUTH_CLIENT_SECRET="GOCSV" +GOOGLE_OAUTH_CLIENT_ID=924507538m +GOOGLE_OAUTH_CLIENT_SECRET=GOCSV #Embedding Model -EMBEDDING_MODEL="mixedbread-ai/mxbai-embed-large-v1" +EMBEDDING_MODEL=mixedbread-ai/mxbai-embed-large-v1 -RERANKERS_MODEL_NAME="ms-marco-MiniLM-L-12-v2" -RERANKERS_MODEL_TYPE="flashrank" +RERANKERS_MODEL_NAME=ms-marco-MiniLM-L-12-v2 +RERANKERS_MODEL_TYPE=flashrank -# https://docs.litellm.ai/docs/providers -FAST_LLM="openai/gpt-4o-mini" -STRATEGIC_LLM="openai/gpt-4o" -LONG_CONTEXT_LLM="gemini/gemini-2.0-flash" #LiteLLM TTS Provider: https://docs.litellm.ai/docs/text_to_speech#supported-providers -TTS_SERVICE="openai/tts-1" +TTS_SERVICE=openai/tts-1 +#Respective TTS Service API +TTS_SERVICE_API_KEY= +#OPTIONAL: TTS Provider API Base +TTS_SERVICE_API_BASE= #LiteLLM STT Provider: https://docs.litellm.ai/docs/audio_transcription#supported-providers -STT_SERVICE="openai/whisper-1" +STT_SERVICE=openai/whisper-1 +#Respective STT Service API +STT_SERVICE_API_KEY="" +#OPTIONAL: STT Provider API Base +STT_SERVICE_API_BASE= -# Chosen LiteLLM Providers Keys -OPENAI_API_KEY="sk-proj-iA" -GEMINI_API_KEY="AIzaSyB6-1641124124124124124124124124124" -FIRECRAWL_API_KEY="fcr-01J0000000000000000000000" +FIRECRAWL_API_KEY=fcr-01J0000000000000000000000 #File Parser Service -ETL_SERVICE="UNSTRUCTURED" or "LLAMACLOUD" -UNSTRUCTURED_API_KEY="Tpu3P0U8iy" -LLAMA_CLOUD_API_KEY="llx-nnn" +ETL_SERVICE=UNSTRUCTURED or LLAMACLOUD +UNSTRUCTURED_API_KEY=Tpu3P0U8iy +LLAMA_CLOUD_API_KEY=llx-nnn #OPTIONAL: Add these for LangSmith Observability LANGSMITH_TRACING=true -LANGSMITH_ENDPOINT="https://api.smith.langchain.com" -LANGSMITH_API_KEY="lsv2_pt_....." -LANGSMITH_PROJECT="surfsense" +LANGSMITH_ENDPOINT=https://api.smith.langchain.com +LANGSMITH_API_KEY=lsv2_pt_..... +LANGSMITH_PROJECT=surfsense -# OPTIONAL: LiteLLM API Base -FAST_LLM_API_BASE="" -STRATEGIC_LLM_API_BASE="" -LONG_CONTEXT_LLM_API_BASE="" -TTS_SERVICE_API_BASE="" -STT_SERVICE_API_BASE="" + + +# FAST_LLM=openai/gpt-4o-mini +# STRATEGIC_LLM=openai/gpt-4o +# LONG_CONTEXT_LLM=gemini/gemini-2.0-flash + +# FAST_LLM=ollama/gemma3:12b +# STRATEGIC_LLM=ollama/deepseek-r1:8b +# LONG_CONTEXT_LLM=ollama/deepseek-r1:8b diff --git a/surfsense_backend/alembic/versions/11_add_llm_config_table_and_relationships.py b/surfsense_backend/alembic/versions/11_add_llm_config_table_and_relationships.py new file mode 100644 index 0000000..83fdef1 --- /dev/null +++ b/surfsense_backend/alembic/versions/11_add_llm_config_table_and_relationships.py @@ -0,0 +1,86 @@ +"""Add LLMConfig table and user LLM preferences + +Revision ID: 11 +Revises: 10 +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID, JSON + + +# revision identifiers, used by Alembic. +revision: str = "11" +down_revision: Union[str, None] = "10" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema - add LiteLLMProvider enum, LLMConfig table and user LLM preferences.""" + + # Check if enum type exists and create if it doesn't + op.execute(""" + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'litellmprovider') THEN + CREATE TYPE litellmprovider AS ENUM ('OPENAI', 'ANTHROPIC', 'GROQ', 'COHERE', 'HUGGINGFACE', 'AZURE_OPENAI', 'GOOGLE', 'AWS_BEDROCK', 'OLLAMA', 'MISTRAL', 'TOGETHER_AI', 'REPLICATE', 'PALM', 'VERTEX_AI', 'ANYSCALE', 'PERPLEXITY', 'DEEPINFRA', 'AI21', 'NLPCLOUD', 'ALEPH_ALPHA', 'PETALS', 'CUSTOM'); + END IF; + END$$; + """) + + # Create llm_configs table using raw SQL to avoid enum creation conflicts + op.execute(""" + CREATE TABLE llm_configs ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + name VARCHAR(100) NOT NULL, + provider litellmprovider NOT NULL, + custom_provider VARCHAR(100), + model_name VARCHAR(100) NOT NULL, + api_key TEXT NOT NULL, + api_base VARCHAR(500), + litellm_params JSONB, + user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE + ) + """) + + # Create indexes + op.create_index(op.f('ix_llm_configs_id'), 'llm_configs', ['id'], unique=False) + op.create_index(op.f('ix_llm_configs_created_at'), 'llm_configs', ['created_at'], unique=False) + op.create_index(op.f('ix_llm_configs_name'), 'llm_configs', ['name'], unique=False) + + # Add LLM preference columns to user table + op.add_column('user', sa.Column('long_context_llm_id', sa.Integer(), nullable=True)) + op.add_column('user', sa.Column('fast_llm_id', sa.Integer(), nullable=True)) + op.add_column('user', sa.Column('strategic_llm_id', sa.Integer(), nullable=True)) + + # Create foreign key constraints for LLM preferences + op.create_foreign_key(op.f('fk_user_long_context_llm_id_llm_configs'), 'user', 'llm_configs', ['long_context_llm_id'], ['id'], ondelete='SET NULL') + op.create_foreign_key(op.f('fk_user_fast_llm_id_llm_configs'), 'user', 'llm_configs', ['fast_llm_id'], ['id'], ondelete='SET NULL') + op.create_foreign_key(op.f('fk_user_strategic_llm_id_llm_configs'), 'user', 'llm_configs', ['strategic_llm_id'], ['id'], ondelete='SET NULL') + + +def downgrade() -> None: + """Downgrade schema - remove LLMConfig table and user LLM preferences.""" + + # Drop foreign key constraints + op.drop_constraint(op.f('fk_user_strategic_llm_id_llm_configs'), 'user', type_='foreignkey') + op.drop_constraint(op.f('fk_user_fast_llm_id_llm_configs'), 'user', type_='foreignkey') + op.drop_constraint(op.f('fk_user_long_context_llm_id_llm_configs'), 'user', type_='foreignkey') + + # Drop LLM preference columns from user table + op.drop_column('user', 'strategic_llm_id') + op.drop_column('user', 'fast_llm_id') + op.drop_column('user', 'long_context_llm_id') + + # Drop indexes and table + op.drop_index(op.f('ix_llm_configs_name'), table_name='llm_configs') + op.drop_index(op.f('ix_llm_configs_created_at'), table_name='llm_configs') + op.drop_index(op.f('ix_llm_configs_id'), table_name='llm_configs') + op.drop_table('llm_configs') + + # Drop LiteLLMProvider enum + op.execute("DROP TYPE IF EXISTS litellmprovider") \ No newline at end of file diff --git a/surfsense_backend/app/agents/podcaster/configuration.py b/surfsense_backend/app/agents/podcaster/configuration.py index 6bbb4ce..062b1ee 100644 --- a/surfsense_backend/app/agents/podcaster/configuration.py +++ b/surfsense_backend/app/agents/podcaster/configuration.py @@ -16,7 +16,8 @@ class Configuration: # these values can be pre-set when you # create assistants (https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/) # and when you invoke the graph - podcast_title: str + podcast_title: str + user_id: str @classmethod def from_runnable_config( diff --git a/surfsense_backend/app/agents/podcaster/nodes.py b/surfsense_backend/app/agents/podcaster/nodes.py index 9ea590a..d1dea9b 100644 --- a/surfsense_backend/app/agents/podcaster/nodes.py +++ b/surfsense_backend/app/agents/podcaster/nodes.py @@ -14,13 +14,22 @@ from .configuration import Configuration from .state import PodcastTranscriptEntry, State, PodcastTranscripts from .prompts import get_podcast_generation_prompt from app.config import config as app_config +from app.utils.llm_service import get_user_long_context_llm async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dict[str, Any]: """Each node does work.""" - # Initialize LLM - llm = app_config.long_context_llm_instance + # Get configuration from runnable config + configuration = Configuration.from_runnable_config(config) + user_id = configuration.user_id + + # Get user's long context LLM + llm = await get_user_long_context_llm(state.db_session, user_id) + if not llm: + error_message = f"No long context LLM configured for user {user_id}" + print(error_message) + raise RuntimeError(error_message) # Get the prompt prompt = get_podcast_generation_prompt() @@ -139,6 +148,7 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D response = await aspeech( model=app_config.TTS_SERVICE, api_base=app_config.TTS_SERVICE_API_BASE, + api_key=app_config.TTS_SERVICE_API_KEY, voice=voice, input=dialog, max_retries=2, @@ -147,6 +157,7 @@ async def create_merged_podcast_audio(state: State, config: RunnableConfig) -> D else: response = await aspeech( model=app_config.TTS_SERVICE, + api_key=app_config.TTS_SERVICE_API_KEY, voice=voice, input=dialog, max_retries=2, diff --git a/surfsense_backend/app/agents/podcaster/state.py b/surfsense_backend/app/agents/podcaster/state.py index d77270d..79fccef 100644 --- a/surfsense_backend/app/agents/podcaster/state.py +++ b/surfsense_backend/app/agents/podcaster/state.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass from typing import List, Optional from pydantic import BaseModel, Field - +from sqlalchemy.ext.asyncio import AsyncSession class PodcastTranscriptEntry(BaseModel): """ @@ -32,7 +32,8 @@ class State: See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state for more information. """ - + # Runtime context + db_session: AsyncSession source_content: str podcast_transcript: Optional[List[PodcastTranscriptEntry]] = None final_podcast_file_path: Optional[str] = None diff --git a/surfsense_backend/app/agents/researcher/nodes.py b/surfsense_backend/app/agents/researcher/nodes.py index 22c70cc..aa44bba 100644 --- a/surfsense_backend/app/agents/researcher/nodes.py +++ b/surfsense_backend/app/agents/researcher/nodes.py @@ -2,7 +2,6 @@ import asyncio import json from typing import Any, Dict, List -from app.config import config as app_config from app.db import async_session_maker from app.utils.connector_service import ConnectorService from langchain_core.messages import HumanMessage, SystemMessage @@ -274,6 +273,9 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str Returns: Dict containing the answer outline in the "answer_outline" key for state update. """ + from app.utils.llm_service import get_user_strategic_llm + from app.db import get_async_session + streaming_service = state.streaming_service streaming_service.only_update_terminal("🔍 Generating answer outline...") @@ -283,12 +285,18 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str reformulated_query = state.reformulated_query user_query = configuration.user_query num_sections = configuration.num_sections + user_id = configuration.user_id streaming_service.only_update_terminal(f"🤔 Planning research approach for: \"{user_query[:100]}...\"") writer({"yeild_value": streaming_service._format_annotations()}) - # Initialize LLM - llm = app_config.strategic_llm_instance + # Get user's strategic LLM + llm = await get_user_strategic_llm(state.db_session, user_id) + if not llm: + error_message = f"No strategic LLM configured for user {user_id}" + streaming_service.only_update_terminal(f"❌ {error_message}", "error") + writer({"yeild_value": streaming_service._format_annotations()}) + raise RuntimeError(error_message) # Create the human message content human_message_content = f""" @@ -828,48 +836,47 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW user_selected_documents = [] user_selected_sources = [] - async with async_session_maker() as db_session: - try: - # First, fetch user-selected documents if any - if configuration.document_ids_to_add_in_context: - streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...") - writer({"yeild_value": streaming_service._format_annotations()}) - - user_selected_sources, user_selected_documents = await fetch_documents_by_ids( - document_ids=configuration.document_ids_to_add_in_context, - user_id=configuration.user_id, - db_session=db_session - ) - - if user_selected_documents: - streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context") - writer({"yeild_value": streaming_service._format_annotations()}) - - # Create connector service inside the db_session scope - connector_service = ConnectorService(db_session, user_id=configuration.user_id) - await connector_service.initialize_counter() - - relevant_documents = await fetch_relevant_documents( - research_questions=all_questions, - user_id=configuration.user_id, - search_space_id=configuration.search_space_id, - db_session=db_session, - connectors_to_search=configuration.connectors_to_search, - writer=writer, - state=state, - top_k=TOP_K, - connector_service=connector_service, - search_mode=configuration.search_mode, - user_selected_sources=user_selected_sources - ) - except Exception as e: - error_message = f"Error fetching relevant documents: {str(e)}" - print(error_message) - streaming_service.only_update_terminal(f"❌ {error_message}", "error") + try: + # First, fetch user-selected documents if any + if configuration.document_ids_to_add_in_context: + streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...") writer({"yeild_value": streaming_service._format_annotations()}) - # Log the error and continue with an empty list of documents - # This allows the process to continue, but the report might lack information - relevant_documents = [] + + user_selected_sources, user_selected_documents = await fetch_documents_by_ids( + document_ids=configuration.document_ids_to_add_in_context, + user_id=configuration.user_id, + db_session=state.db_session + ) + + if user_selected_documents: + streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context") + writer({"yeild_value": streaming_service._format_annotations()}) + + # Create connector service using state db_session + connector_service = ConnectorService(state.db_session, user_id=configuration.user_id) + await connector_service.initialize_counter() + + relevant_documents = await fetch_relevant_documents( + research_questions=all_questions, + user_id=configuration.user_id, + search_space_id=configuration.search_space_id, + db_session=state.db_session, + connectors_to_search=configuration.connectors_to_search, + writer=writer, + state=state, + top_k=TOP_K, + connector_service=connector_service, + search_mode=configuration.search_mode, + user_selected_sources=user_selected_sources + ) + except Exception as e: + error_message = f"Error fetching relevant documents: {str(e)}" + print(error_message) + streaming_service.only_update_terminal(f"❌ {error_message}", "error") + writer({"yeild_value": streaming_service._format_annotations()}) + # Log the error and continue with an empty list of documents + # This allows the process to continue, but the report might lack information + relevant_documents = [] # Combine user-selected documents with connector-fetched documents all_documents = user_selected_documents + relevant_documents @@ -1014,82 +1021,80 @@ async def process_section_with_documents( for question in section_questions ] - # Create a new database session for this section - async with async_session_maker() as db_session: - # Call the sub_section_writer graph with the appropriate config - config = { - "configurable": { - "sub_section_title": section_title, - "sub_section_questions": section_questions, - "sub_section_type": sub_section_type, - "user_query": user_query, - "relevant_documents": documents_to_use, - "user_id": user_id, - "search_space_id": search_space_id - } + # Call the sub_section_writer graph with the appropriate config + config = { + "configurable": { + "sub_section_title": section_title, + "sub_section_questions": section_questions, + "sub_section_type": sub_section_type, + "user_query": user_query, + "relevant_documents": documents_to_use, + "user_id": user_id, + "search_space_id": search_space_id } - - # Create the initial state with db_session and chat_history - sub_state = { - "db_session": db_session, - "chat_history": state.chat_history - } - - # Invoke the sub-section writer graph with streaming - print(f"Invoking sub_section_writer for: {section_title}") - if state and state.streaming_service and writer: - state.streaming_service.only_update_terminal(f"🧠 Analyzing information and drafting content for section: \"{section_title}\"") - writer({"yeild_value": state.streaming_service._format_annotations()}) - - # Variables to track streaming state - complete_content = "" # Tracks the complete content received so far - - async for chunk_type, chunk in sub_section_writer_graph.astream(sub_state, config, stream_mode=["values"]): - if "final_answer" in chunk: - new_content = chunk["final_answer"] - if new_content and new_content != complete_content: - # Extract only the new content (delta) - delta = new_content[len(complete_content):] + } + + # Create the initial state with db_session and chat_history + sub_state = { + "db_session": state.db_session, + "chat_history": state.chat_history + } + + # Invoke the sub-section writer graph with streaming + print(f"Invoking sub_section_writer for: {section_title}") + if state and state.streaming_service and writer: + state.streaming_service.only_update_terminal(f"🧠 Analyzing information and drafting content for section: \"{section_title}\"") + writer({"yeild_value": state.streaming_service._format_annotations()}) + + # Variables to track streaming state + complete_content = "" # Tracks the complete content received so far + + async for chunk_type, chunk in sub_section_writer_graph.astream(sub_state, config, stream_mode=["values"]): + if "final_answer" in chunk: + new_content = chunk["final_answer"] + if new_content and new_content != complete_content: + # Extract only the new content (delta) + delta = new_content[len(complete_content):] + + # Update what we've processed so far + complete_content = new_content + + # Only stream if there's actual new content + if delta and state and state.streaming_service and writer: + # Update terminal with real-time progress indicator + state.streaming_service.only_update_terminal(f"✍️ Writing section {section_id+1}... ({len(complete_content.split())} words)") - # Update what we've processed so far - complete_content = new_content + # Update section_contents with just the new delta + section_contents[section_id]["content"] += delta - # Only stream if there's actual new content - if delta and state and state.streaming_service and writer: - # Update terminal with real-time progress indicator - state.streaming_service.only_update_terminal(f"✍️ Writing section {section_id+1}... ({len(complete_content.split())} words)") - - # Update section_contents with just the new delta - section_contents[section_id]["content"] += delta - - # Build UI-friendly content for all sections - complete_answer = [] - for i in range(len(section_contents)): - if i in section_contents and section_contents[i]["content"]: - # Add section header - complete_answer.append(f"# {section_contents[i]['title']}") - complete_answer.append("") # Empty line after title - - # Add section content - content_lines = section_contents[i]["content"].split("\n") - complete_answer.extend(content_lines) - complete_answer.append("") # Empty line after content - - # Update answer in UI in real-time - state.streaming_service.only_update_answer(complete_answer) - writer({"yeild_value": state.streaming_service._format_annotations()}) - - # Set default if no content was received - if not complete_content: - complete_content = "No content was generated for this section." - section_contents[section_id]["content"] = complete_content - - # Final terminal update - if state and state.streaming_service and writer: - state.streaming_service.only_update_terminal(f"✅ Completed section: \"{section_title}\"") - writer({"yeild_value": state.streaming_service._format_annotations()}) - - return complete_content + # Build UI-friendly content for all sections + complete_answer = [] + for i in range(len(section_contents)): + if i in section_contents and section_contents[i]["content"]: + # Add section header + complete_answer.append(f"# {section_contents[i]['title']}") + complete_answer.append("") # Empty line after title + + # Add section content + content_lines = section_contents[i]["content"].split("\n") + complete_answer.extend(content_lines) + complete_answer.append("") # Empty line after content + + # Update answer in UI in real-time + state.streaming_service.only_update_answer(complete_answer) + writer({"yeild_value": state.streaming_service._format_annotations()}) + + # Set default if no content was received + if not complete_content: + complete_content = "No content was generated for this section." + section_contents[section_id]["content"] = complete_content + + # Final terminal update + if state and state.streaming_service and writer: + state.streaming_service.only_update_terminal(f"✅ Completed section: \"{section_title}\"") + writer({"yeild_value": state.streaming_service._format_annotations()}) + + return complete_content except Exception as e: print(f"Error processing section '{section_title}': {str(e)}") @@ -1113,7 +1118,7 @@ async def reformulate_user_query(state: State, config: RunnableConfig, writer: S if len(state.chat_history) == 0: reformulated_query = user_query else: - reformulated_query = await QueryService.reformulate_query_with_chat_history(user_query, chat_history_str) + reformulated_query = await QueryService.reformulate_query_with_chat_history(user_query=user_query, session=state.db_session, user_id=configuration.user_id, chat_history_str=chat_history_str) return { "reformulated_query": reformulated_query @@ -1152,50 +1157,49 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre user_selected_documents = [] user_selected_sources = [] - async with async_session_maker() as db_session: - try: - # First, fetch user-selected documents if any - if configuration.document_ids_to_add_in_context: - streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...") - writer({"yeild_value": streaming_service._format_annotations()}) - - user_selected_sources, user_selected_documents = await fetch_documents_by_ids( - document_ids=configuration.document_ids_to_add_in_context, - user_id=configuration.user_id, - db_session=db_session - ) - - if user_selected_documents: - streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context") - writer({"yeild_value": streaming_service._format_annotations()}) - - # Create connector service inside the db_session scope - connector_service = ConnectorService(db_session, user_id=configuration.user_id) - await connector_service.initialize_counter() - - # Use the reformulated query as a single research question - research_questions = [reformulated_query, user_query] - - relevant_documents = await fetch_relevant_documents( - research_questions=research_questions, - user_id=configuration.user_id, - search_space_id=configuration.search_space_id, - db_session=db_session, - connectors_to_search=configuration.connectors_to_search, - writer=writer, - state=state, - top_k=TOP_K, - connector_service=connector_service, - search_mode=configuration.search_mode, - user_selected_sources=user_selected_sources - ) - except Exception as e: - error_message = f"Error fetching relevant documents for QNA: {str(e)}" - print(error_message) - streaming_service.only_update_terminal(f"❌ {error_message}", "error") + try: + # First, fetch user-selected documents if any + if configuration.document_ids_to_add_in_context: + streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...") writer({"yeild_value": streaming_service._format_annotations()}) - # Continue with empty documents - the QNA agent will handle this gracefully - relevant_documents = [] + + user_selected_sources, user_selected_documents = await fetch_documents_by_ids( + document_ids=configuration.document_ids_to_add_in_context, + user_id=configuration.user_id, + db_session=state.db_session + ) + + if user_selected_documents: + streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context") + writer({"yeild_value": streaming_service._format_annotations()}) + + # Create connector service using state db_session + connector_service = ConnectorService(state.db_session, user_id=configuration.user_id) + await connector_service.initialize_counter() + + # Use the reformulated query as a single research question + research_questions = [reformulated_query, user_query] + + relevant_documents = await fetch_relevant_documents( + research_questions=research_questions, + user_id=configuration.user_id, + search_space_id=configuration.search_space_id, + db_session=state.db_session, + connectors_to_search=configuration.connectors_to_search, + writer=writer, + state=state, + top_k=TOP_K, + connector_service=connector_service, + search_mode=configuration.search_mode, + user_selected_sources=user_selected_sources + ) + except Exception as e: + error_message = f"Error fetching relevant documents for QNA: {str(e)}" + print(error_message) + streaming_service.only_update_terminal(f"❌ {error_message}", "error") + writer({"yeild_value": streaming_service._format_annotations()}) + # Continue with empty documents - the QNA agent will handle this gracefully + relevant_documents = [] # Combine user-selected documents with connector-fetched documents all_documents = user_selected_documents + relevant_documents diff --git a/surfsense_backend/app/agents/researcher/qna_agent/nodes.py b/surfsense_backend/app/agents/researcher/qna_agent/nodes.py index 4684290..565d806 100644 --- a/surfsense_backend/app/agents/researcher/qna_agent/nodes.py +++ b/surfsense_backend/app/agents/researcher/qna_agent/nodes.py @@ -85,14 +85,20 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any Returns: Dict containing the final answer in the "final_answer" key. """ + from app.utils.llm_service import get_user_fast_llm # Get configuration and relevant documents from configuration configuration = Configuration.from_runnable_config(config) documents = state.reranked_documents user_query = configuration.user_query + user_id = configuration.user_id - # Initialize LLM - llm = app_config.fast_llm_instance + # Get user's fast LLM + llm = await get_user_fast_llm(state.db_session, user_id) + if not llm: + error_message = f"No fast LLM configured for user {user_id}" + print(error_message) + raise RuntimeError(error_message) # Determine if we have documents and optimize for token limits has_documents_initially = documents and len(documents) > 0 @@ -118,7 +124,7 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any # Optimize documents to fit within token limits optimized_documents, has_optimized_documents = optimize_documents_for_token_limit( - documents, base_messages, app_config.FAST_LLM + documents, base_messages, llm.model ) # Update state based on optimization result @@ -161,7 +167,7 @@ async def answer_question(state: State, config: RunnableConfig) -> Dict[str, Any ] # Log final token count - total_tokens = calculate_token_count(messages_with_chat_history, app_config.FAST_LLM) + total_tokens = calculate_token_count(messages_with_chat_history, llm.model) print(f"Final token count: {total_tokens}") diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py b/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py index 2475d24..2febe54 100644 --- a/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py +++ b/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py @@ -91,13 +91,19 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A Returns: Dict containing the final answer in the "final_answer" key. """ + from app.utils.llm_service import get_user_fast_llm # Get configuration and relevant documents from configuration configuration = Configuration.from_runnable_config(config) documents = state.reranked_documents + user_id = configuration.user_id - # Initialize LLM - llm = app_config.fast_llm_instance + # Get user's fast LLM + llm = await get_user_fast_llm(state.db_session, user_id) + if not llm: + error_message = f"No fast LLM configured for user {user_id}" + print(error_message) + raise RuntimeError(error_message) # Extract configuration data section_title = configuration.sub_section_title @@ -153,7 +159,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A # Optimize documents to fit within token limits optimized_documents, has_optimized_documents = optimize_documents_for_token_limit( - documents, base_messages, app_config.FAST_LLM + documents, base_messages, llm.model ) # Update state based on optimization result @@ -206,7 +212,7 @@ async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, A ] # Log final token count - total_tokens = calculate_token_count(messages_with_chat_history, app_config.FAST_LLM) + total_tokens = calculate_token_count(messages_with_chat_history, llm.model) print(f"Final token count: {total_tokens}") # Call the LLM and get the response diff --git a/surfsense_backend/app/agents/researcher/utils.py b/surfsense_backend/app/agents/researcher/utils.py index e40ad1b..0cf8137 100644 --- a/surfsense_backend/app/agents/researcher/utils.py +++ b/surfsense_backend/app/agents/researcher/utils.py @@ -1,7 +1,6 @@ from typing import List, Dict, Any, Tuple, NamedTuple from langchain_core.messages import BaseMessage from litellm import token_counter, get_model_info -from app.config import config as app_config class DocumentTokenInfo(NamedTuple): @@ -127,7 +126,7 @@ def get_model_context_window(model_name: str) -> int: def optimize_documents_for_token_limit( documents: List[Dict[str, Any]], base_messages: List[BaseMessage], - model_name: str = None + model_name: str ) -> Tuple[List[Dict[str, Any]], bool]: """ Optimize documents to fit within token limits using binary search. @@ -135,7 +134,7 @@ def optimize_documents_for_token_limit( Args: documents: List of documents with content and metadata base_messages: Base messages without documents (chat history + system + human message template) - model_name: Model name for token counting (defaults to app_config.FAST_LLM) + model_name: Model name for token counting (required) output_token_buffer: Number of tokens to reserve for model output Returns: @@ -144,7 +143,7 @@ def optimize_documents_for_token_limit( if not documents: return [], False - model = model_name or app_config.FAST_LLM + model = model_name context_window = get_model_context_window(model) # Calculate base token cost @@ -178,8 +177,8 @@ def optimize_documents_for_token_limit( return optimized_documents, has_documents_remaining -def calculate_token_count(messages: List[BaseMessage], model_name: str = None) -> int: +def calculate_token_count(messages: List[BaseMessage], model_name: str) -> int: """Calculate token count for a list of LangChain messages.""" - model = model_name or app_config.FAST_LLM + model = model_name messages_dict = convert_langchain_messages_to_dict(messages) return token_counter(messages=messages_dict, model=model) diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 9135c32..90011ab 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -4,7 +4,7 @@ import shutil from chonkie import AutoEmbeddings, CodeChunker, RecursiveChunker from dotenv import load_dotenv -from langchain_community.chat_models import ChatLiteLLM + from rerankers import Reranker @@ -49,31 +49,8 @@ class Config: GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET") - # LONG-CONTEXT LLMS - LONG_CONTEXT_LLM = os.getenv("LONG_CONTEXT_LLM") - LONG_CONTEXT_LLM_API_BASE = os.getenv("LONG_CONTEXT_LLM_API_BASE") - if LONG_CONTEXT_LLM_API_BASE: - long_context_llm_instance = ChatLiteLLM(model=LONG_CONTEXT_LLM, api_base=LONG_CONTEXT_LLM_API_BASE) - else: - long_context_llm_instance = ChatLiteLLM(model=LONG_CONTEXT_LLM) - - # FAST LLM - FAST_LLM = os.getenv("FAST_LLM") - FAST_LLM_API_BASE = os.getenv("FAST_LLM_API_BASE") - if FAST_LLM_API_BASE: - fast_llm_instance = ChatLiteLLM(model=FAST_LLM, api_base=FAST_LLM_API_BASE) - else: - fast_llm_instance = ChatLiteLLM(model=FAST_LLM) - - - - # STRATEGIC LLM - STRATEGIC_LLM = os.getenv("STRATEGIC_LLM") - STRATEGIC_LLM_API_BASE = os.getenv("STRATEGIC_LLM_API_BASE") - if STRATEGIC_LLM_API_BASE: - strategic_llm_instance = ChatLiteLLM(model=STRATEGIC_LLM, api_base=STRATEGIC_LLM_API_BASE) - else: - strategic_llm_instance = ChatLiteLLM(model=STRATEGIC_LLM) + # LLM instances are now managed per-user through the LLMConfig system + # Legacy environment variables removed in favor of user-specific configurations # Chonkie Configuration | Edit this to your needs EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL") @@ -114,10 +91,12 @@ class Config: # Litellm TTS Configuration TTS_SERVICE = os.getenv("TTS_SERVICE") TTS_SERVICE_API_BASE = os.getenv("TTS_SERVICE_API_BASE") + TTS_SERVICE_API_KEY = os.getenv("TTS_SERVICE_API_KEY") # Litellm STT Configuration STT_SERVICE = os.getenv("STT_SERVICE") STT_SERVICE_API_BASE = os.getenv("STT_SERVICE_API_BASE") + STT_SERVICE_API_KEY = os.getenv("STT_SERVICE_API_KEY") # Validation Checks diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 8c7f311..0c6311a 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -67,6 +67,30 @@ class ChatType(str, Enum): REPORT_GENERAL = "REPORT_GENERAL" REPORT_DEEP = "REPORT_DEEP" REPORT_DEEPER = "REPORT_DEEPER" + +class LiteLLMProvider(str, Enum): + OPENAI = "OPENAI" + ANTHROPIC = "ANTHROPIC" + GROQ = "GROQ" + COHERE = "COHERE" + HUGGINGFACE = "HUGGINGFACE" + AZURE_OPENAI = "AZURE_OPENAI" + GOOGLE = "GOOGLE" + AWS_BEDROCK = "AWS_BEDROCK" + OLLAMA = "OLLAMA" + MISTRAL = "MISTRAL" + TOGETHER_AI = "TOGETHER_AI" + REPLICATE = "REPLICATE" + PALM = "PALM" + VERTEX_AI = "VERTEX_AI" + ANYSCALE = "ANYSCALE" + PERPLEXITY = "PERPLEXITY" + DEEPINFRA = "DEEPINFRA" + AI21 = "AI21" + NLPCLOUD = "NLPCLOUD" + ALEPH_ALPHA = "ALEPH_ALPHA" + PETALS = "PETALS" + CUSTOM = "CUSTOM" class Base(DeclarativeBase): pass @@ -152,6 +176,26 @@ class SearchSourceConnector(BaseModel, TimestampMixin): user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False) user = relationship("User", back_populates="search_source_connectors") +class LLMConfig(BaseModel, TimestampMixin): + __tablename__ = "llm_configs" + + name = Column(String(100), nullable=False, index=True) + # Provider from the enum + provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False) + # Custom provider name when provider is CUSTOM + custom_provider = Column(String(100), nullable=True) + # Just the model name without provider prefix + model_name = Column(String(100), nullable=False) + # API Key should be encrypted before storing + api_key = Column(String, nullable=False) + api_base = Column(String(500), nullable=True) + + # For any other parameters that litellm supports + litellm_params = Column(JSON, nullable=True, default={}) + + user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False) + user = relationship("User", back_populates="llm_configs", foreign_keys=[user_id]) + if config.AUTH_TYPE == "GOOGLE": class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): pass @@ -163,11 +207,29 @@ if config.AUTH_TYPE == "GOOGLE": ) search_spaces = relationship("SearchSpace", back_populates="user") search_source_connectors = relationship("SearchSourceConnector", back_populates="user") + llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan") + + long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True) + fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True) + strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True) + + long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True) + fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True) + strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True) else: class User(SQLAlchemyBaseUserTableUUID, Base): search_spaces = relationship("SearchSpace", back_populates="user") search_source_connectors = relationship("SearchSourceConnector", back_populates="user") + llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan") + + long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True) + fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True) + strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True) + + long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True) + fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True) + strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True) engine = create_async_engine(DATABASE_URL) diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index c2266f9..3420f35 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -4,6 +4,7 @@ from .documents_routes import router as documents_router from .podcasts_routes import router as podcasts_router from .chats_routes import router as chats_router from .search_source_connectors_routes import router as search_source_connectors_router +from .llm_config_routes import router as llm_config_router router = APIRouter() @@ -12,3 +13,4 @@ router.include_router(documents_router) router.include_router(podcasts_router) router.include_router(chats_router) router.include_router(search_source_connectors_router) +router.include_router(llm_config_router) diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index acd246e..5b82dd7 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -38,21 +38,24 @@ async def create_documents( fastapi_background_tasks.add_task( process_extension_document_with_new_session, individual_document, - request.search_space_id + request.search_space_id, + str(user.id) ) elif request.document_type == DocumentType.CRAWLED_URL: for url in request.content: fastapi_background_tasks.add_task( process_crawled_url_with_new_session, url, - request.search_space_id + request.search_space_id, + str(user.id) ) elif request.document_type == DocumentType.YOUTUBE_VIDEO: for url in request.content: fastapi_background_tasks.add_task( process_youtube_video_with_new_session, url, - request.search_space_id + request.search_space_id, + str(user.id) ) else: raise HTTPException( @@ -106,7 +109,8 @@ async def create_documents( process_file_in_background_with_new_session, temp_path, file.filename, - search_space_id + search_space_id, + str(user.id) ) except Exception as e: raise HTTPException( @@ -130,6 +134,7 @@ async def process_file_in_background( file_path: str, filename: str, search_space_id: int, + user_id: str, session: AsyncSession ): try: @@ -151,7 +156,8 @@ async def process_file_in_background( session, filename, markdown_content, - search_space_id + search_space_id, + user_id ) # Check if the file is an audio file elif filename.lower().endswith(('.mp3', '.mp4', '.mpeg', '.mpga', '.m4a', '.wav', '.webm')): @@ -162,11 +168,13 @@ async def process_file_in_background( transcription_response = await atranscription( model=app_config.STT_SERVICE, file=audio_file, - api_base=app_config.STT_SERVICE_API_BASE + api_base=app_config.STT_SERVICE_API_BASE, + api_key=app_config.STT_SERVICE_API_KEY ) else: transcription_response = await atranscription( model=app_config.STT_SERVICE, + api_key=app_config.STT_SERVICE_API_KEY, file=audio_file ) @@ -187,7 +195,8 @@ async def process_file_in_background( session, filename, transcribed_text, - search_space_id + search_space_id, + user_id ) else: if app_config.ETL_SERVICE == "UNSTRUCTURED": @@ -218,7 +227,8 @@ async def process_file_in_background( session, filename, docs, - search_space_id + search_space_id, + user_id ) elif app_config.ETL_SERVICE == "LLAMACLOUD": from llama_cloud_services import LlamaParse @@ -256,7 +266,8 @@ async def process_file_in_background( session, filename, llamacloud_markdown_document=markdown_content, - search_space_id=search_space_id + search_space_id=search_space_id, + user_id=user_id ) except Exception as e: import logging @@ -426,14 +437,15 @@ async def delete_document( async def process_extension_document_with_new_session( individual_document, - search_space_id: int + search_space_id: int, + user_id: str ): """Create a new session and process extension document.""" from app.db import async_session_maker async with async_session_maker() as session: try: - await add_extension_received_document(session, individual_document, search_space_id) + await add_extension_received_document(session, individual_document, search_space_id, user_id) except Exception as e: import logging logging.error(f"Error processing extension document: {str(e)}") @@ -441,14 +453,15 @@ async def process_extension_document_with_new_session( async def process_crawled_url_with_new_session( url: str, - search_space_id: int + search_space_id: int, + user_id: str ): """Create a new session and process crawled URL.""" from app.db import async_session_maker async with async_session_maker() as session: try: - await add_crawled_url_document(session, url, search_space_id) + await add_crawled_url_document(session, url, search_space_id, user_id) except Exception as e: import logging logging.error(f"Error processing crawled URL: {str(e)}") @@ -457,25 +470,27 @@ async def process_crawled_url_with_new_session( async def process_file_in_background_with_new_session( file_path: str, filename: str, - search_space_id: int + search_space_id: int, + user_id: str ): """Create a new session and process file.""" from app.db import async_session_maker async with async_session_maker() as session: - await process_file_in_background(file_path, filename, search_space_id, session) + await process_file_in_background(file_path, filename, search_space_id, user_id, session) async def process_youtube_video_with_new_session( url: str, - search_space_id: int + search_space_id: int, + user_id: str ): """Create a new session and process YouTube video.""" from app.db import async_session_maker async with async_session_maker() as session: try: - await add_youtube_video_document(session, url, search_space_id) + await add_youtube_video_document(session, url, search_space_id, user_id) except Exception as e: import logging logging.error(f"Error processing YouTube video: {str(e)}") diff --git a/surfsense_backend/app/routes/llm_config_routes.py b/surfsense_backend/app/routes/llm_config_routes.py new file mode 100644 index 0000000..644503f --- /dev/null +++ b/surfsense_backend/app/routes/llm_config_routes.py @@ -0,0 +1,243 @@ +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from typing import List, Optional +from pydantic import BaseModel +from app.db import get_async_session, User, LLMConfig +from app.schemas import LLMConfigCreate, LLMConfigUpdate, LLMConfigRead +from app.users import current_active_user +from app.utils.check_ownership import check_ownership + +router = APIRouter() + +class LLMPreferencesUpdate(BaseModel): + """Schema for updating user LLM preferences""" + long_context_llm_id: Optional[int] = None + fast_llm_id: Optional[int] = None + strategic_llm_id: Optional[int] = None + +class LLMPreferencesRead(BaseModel): + """Schema for reading user LLM preferences""" + long_context_llm_id: Optional[int] = None + fast_llm_id: Optional[int] = None + strategic_llm_id: Optional[int] = None + long_context_llm: Optional[LLMConfigRead] = None + fast_llm: Optional[LLMConfigRead] = None + strategic_llm: Optional[LLMConfigRead] = None + +@router.post("/llm-configs/", response_model=LLMConfigRead) +async def create_llm_config( + llm_config: LLMConfigCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user) +): + """Create a new LLM configuration for the authenticated user""" + try: + db_llm_config = LLMConfig(**llm_config.model_dump(), user_id=user.id) + session.add(db_llm_config) + await session.commit() + await session.refresh(db_llm_config) + return db_llm_config + except HTTPException: + raise + except Exception as e: + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"Failed to create LLM configuration: {str(e)}" + ) + +@router.get("/llm-configs/", response_model=List[LLMConfigRead]) +async def read_llm_configs( + skip: int = 0, + limit: int = 200, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user) +): + """Get all LLM configurations for the authenticated user""" + try: + result = await session.execute( + select(LLMConfig) + .filter(LLMConfig.user_id == user.id) + .offset(skip) + .limit(limit) + ) + return result.scalars().all() + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to fetch LLM configurations: {str(e)}" + ) + +@router.get("/llm-configs/{llm_config_id}", response_model=LLMConfigRead) +async def read_llm_config( + llm_config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user) +): + """Get a specific LLM configuration by ID""" + try: + llm_config = await check_ownership(session, LLMConfig, llm_config_id, user) + return llm_config + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to fetch LLM configuration: {str(e)}" + ) + +@router.put("/llm-configs/{llm_config_id}", response_model=LLMConfigRead) +async def update_llm_config( + llm_config_id: int, + llm_config_update: LLMConfigUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user) +): + """Update an existing LLM configuration""" + try: + db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user) + update_data = llm_config_update.model_dump(exclude_unset=True) + + for key, value in update_data.items(): + setattr(db_llm_config, key, value) + + await session.commit() + await session.refresh(db_llm_config) + return db_llm_config + except HTTPException: + raise + except Exception as e: + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"Failed to update LLM configuration: {str(e)}" + ) + +@router.delete("/llm-configs/{llm_config_id}", response_model=dict) +async def delete_llm_config( + llm_config_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user) +): + """Delete an LLM configuration""" + try: + db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user) + await session.delete(db_llm_config) + await session.commit() + return {"message": "LLM configuration deleted successfully"} + except HTTPException: + raise + except Exception as e: + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"Failed to delete LLM configuration: {str(e)}" + ) + +# User LLM Preferences endpoints + +@router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead) +async def get_user_llm_preferences( + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user) +): + """Get the current user's LLM preferences""" + try: + # Refresh user to get latest relationships + await session.refresh(user) + + result = { + "long_context_llm_id": user.long_context_llm_id, + "fast_llm_id": user.fast_llm_id, + "strategic_llm_id": user.strategic_llm_id, + "long_context_llm": None, + "fast_llm": None, + "strategic_llm": None, + } + + # Fetch the actual LLM configs if they exist + if user.long_context_llm_id: + long_context_llm = await session.execute( + select(LLMConfig).filter( + LLMConfig.id == user.long_context_llm_id, + LLMConfig.user_id == user.id + ) + ) + llm_config = long_context_llm.scalars().first() + if llm_config: + result["long_context_llm"] = llm_config + + if user.fast_llm_id: + fast_llm = await session.execute( + select(LLMConfig).filter( + LLMConfig.id == user.fast_llm_id, + LLMConfig.user_id == user.id + ) + ) + llm_config = fast_llm.scalars().first() + if llm_config: + result["fast_llm"] = llm_config + + if user.strategic_llm_id: + strategic_llm = await session.execute( + select(LLMConfig).filter( + LLMConfig.id == user.strategic_llm_id, + LLMConfig.user_id == user.id + ) + ) + llm_config = strategic_llm.scalars().first() + if llm_config: + result["strategic_llm"] = llm_config + + return result + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Failed to fetch LLM preferences: {str(e)}" + ) + +@router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead) +async def update_user_llm_preferences( + preferences: LLMPreferencesUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user) +): + """Update the current user's LLM preferences""" + try: + # Validate that all provided LLM config IDs belong to the user + update_data = preferences.model_dump(exclude_unset=True) + + for key, llm_config_id in update_data.items(): + if llm_config_id is not None: + # Verify ownership of the LLM config + result = await session.execute( + select(LLMConfig).filter( + LLMConfig.id == llm_config_id, + LLMConfig.user_id == user.id + ) + ) + llm_config = result.scalars().first() + if not llm_config: + raise HTTPException( + status_code=404, + detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it" + ) + + # Update user preferences + for key, value in update_data.items(): + setattr(user, key, value) + + await session.commit() + await session.refresh(user) + + # Return updated preferences + return await get_user_llm_preferences(session, user) + except HTTPException: + raise + except Exception as e: + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"Failed to update LLM preferences: {str(e)}" + ) \ No newline at end of file diff --git a/surfsense_backend/app/routes/podcasts_routes.py b/surfsense_backend/app/routes/podcasts_routes.py index bc82e21..507c15e 100644 --- a/surfsense_backend/app/routes/podcasts_routes.py +++ b/surfsense_backend/app/routes/podcasts_routes.py @@ -128,14 +128,15 @@ async def delete_podcast( async def generate_chat_podcast_with_new_session( chat_id: int, search_space_id: int, - podcast_title: str = "SurfSense Podcast" + podcast_title: str, + user_id: int ): """Create a new session and process chat podcast generation.""" from app.db import async_session_maker async with async_session_maker() as session: try: - await generate_chat_podcast(session, chat_id, search_space_id, podcast_title) + await generate_chat_podcast(session, chat_id, search_space_id, podcast_title, user_id) except Exception as e: import logging logging.error(f"Error generating podcast from chat: {str(e)}") @@ -175,7 +176,8 @@ async def generate_podcast( generate_chat_podcast_with_new_session, chat_id, request.search_space_id, - request.podcast_title + request.podcast_title, + user.id ) return { diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index ab69639..54f97d6 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -328,25 +328,25 @@ async def index_connector_content( if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR: # Run indexing in background logger.info(f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}") - background_tasks.add_task(run_slack_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to) + background_tasks.add_task(run_slack_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to) response_message = "Slack indexing started in the background." elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR: # Run indexing in background logger.info(f"Triggering Notion indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}") - background_tasks.add_task(run_notion_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to) + background_tasks.add_task(run_notion_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to) response_message = "Notion indexing started in the background." elif connector.connector_type == SearchSourceConnectorType.GITHUB_CONNECTOR: # Run indexing in background logger.info(f"Triggering GitHub indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}") - background_tasks.add_task(run_github_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to) + background_tasks.add_task(run_github_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to) response_message = "GitHub indexing started in the background." elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR: # Run indexing in background logger.info(f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}") - background_tasks.add_task(run_linear_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to) + background_tasks.add_task(run_linear_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to) response_message = "Linear indexing started in the background." elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR: @@ -355,7 +355,7 @@ async def index_connector_content( f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" ) background_tasks.add_task( - run_discord_indexing_with_new_session, connector_id, search_space_id, indexing_from, indexing_to + run_discord_indexing_with_new_session, connector_id, search_space_id, str(user.id), indexing_from, indexing_to ) response_message = "Discord indexing started in the background." @@ -410,6 +410,7 @@ async def update_connector_last_indexed( async def run_slack_indexing_with_new_session( connector_id: int, search_space_id: int, + user_id: str, start_date: str, end_date: str ): @@ -418,12 +419,13 @@ async def run_slack_indexing_with_new_session( This prevents session leaks by creating a dedicated session for the background task. """ async with async_session_maker() as session: - await run_slack_indexing(session, connector_id, search_space_id, start_date, end_date) + await run_slack_indexing(session, connector_id, search_space_id, user_id, start_date, end_date) async def run_slack_indexing( session: AsyncSession, connector_id: int, search_space_id: int, + user_id: str, start_date: str, end_date: str ): @@ -434,6 +436,7 @@ async def run_slack_indexing( session: Database session connector_id: ID of the Slack connector search_space_id: ID of the search space + user_id: ID of the user start_date: Start date for indexing end_date: End date for indexing """ @@ -443,6 +446,7 @@ async def run_slack_indexing( session=session, connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, start_date=start_date, end_date=end_date, update_last_indexed=False # Don't update timestamp in the indexing function @@ -460,6 +464,7 @@ async def run_slack_indexing( async def run_notion_indexing_with_new_session( connector_id: int, search_space_id: int, + user_id: str, start_date: str, end_date: str ): @@ -468,12 +473,13 @@ async def run_notion_indexing_with_new_session( This prevents session leaks by creating a dedicated session for the background task. """ async with async_session_maker() as session: - await run_notion_indexing(session, connector_id, search_space_id, start_date, end_date) + await run_notion_indexing(session, connector_id, search_space_id, user_id, start_date, end_date) async def run_notion_indexing( session: AsyncSession, connector_id: int, search_space_id: int, + user_id: str, start_date: str, end_date: str ): @@ -484,6 +490,7 @@ async def run_notion_indexing( session: Database session connector_id: ID of the Notion connector search_space_id: ID of the search space + user_id: ID of the user start_date: Start date for indexing end_date: End date for indexing """ @@ -493,6 +500,7 @@ async def run_notion_indexing( session=session, connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, start_date=start_date, end_date=end_date, update_last_indexed=False # Don't update timestamp in the indexing function @@ -511,26 +519,28 @@ async def run_notion_indexing( async def run_github_indexing_with_new_session( connector_id: int, search_space_id: int, + user_id: str, start_date: str, end_date: str ): """Wrapper to run GitHub indexing with its own database session.""" logger.info(f"Background task started: Indexing GitHub connector {connector_id} into space {search_space_id} from {start_date} to {end_date}") async with async_session_maker() as session: - await run_github_indexing(session, connector_id, search_space_id, start_date, end_date) + await run_github_indexing(session, connector_id, search_space_id, user_id, start_date, end_date) logger.info(f"Background task finished: Indexing GitHub connector {connector_id}") async def run_github_indexing( session: AsyncSession, connector_id: int, search_space_id: int, + user_id: str, start_date: str, end_date: str ): """Runs the GitHub indexing task and updates the timestamp.""" try: indexed_count, error_message = await index_github_repos( - session, connector_id, search_space_id, start_date, end_date, update_last_indexed=False + session, connector_id, search_space_id, user_id, start_date, end_date, update_last_indexed=False ) if error_message: logger.error(f"GitHub indexing failed for connector {connector_id}: {error_message}") @@ -549,26 +559,28 @@ async def run_github_indexing( async def run_linear_indexing_with_new_session( connector_id: int, search_space_id: int, + user_id: str, start_date: str, end_date: str ): """Wrapper to run Linear indexing with its own database session.""" logger.info(f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}") async with async_session_maker() as session: - await run_linear_indexing(session, connector_id, search_space_id, start_date, end_date) + await run_linear_indexing(session, connector_id, search_space_id, user_id, start_date, end_date) logger.info(f"Background task finished: Indexing Linear connector {connector_id}") async def run_linear_indexing( session: AsyncSession, connector_id: int, search_space_id: int, + user_id: str, start_date: str, end_date: str ): """Runs the Linear indexing task and updates the timestamp.""" try: indexed_count, error_message = await index_linear_issues( - session, connector_id, search_space_id, start_date, end_date, update_last_indexed=False + session, connector_id, search_space_id, user_id, start_date, end_date, update_last_indexed=False ) if error_message: logger.error(f"Linear indexing failed for connector {connector_id}: {error_message}") @@ -587,6 +599,7 @@ async def run_linear_indexing( async def run_discord_indexing_with_new_session( connector_id: int, search_space_id: int, + user_id: str, start_date: str, end_date: str ): @@ -595,12 +608,13 @@ async def run_discord_indexing_with_new_session( This prevents session leaks by creating a dedicated session for the background task. """ async with async_session_maker() as session: - await run_discord_indexing(session, connector_id, search_space_id, start_date, end_date) + await run_discord_indexing(session, connector_id, search_space_id, user_id, start_date, end_date) async def run_discord_indexing( session: AsyncSession, connector_id: int, search_space_id: int, + user_id: str, start_date: str, end_date: str ): @@ -610,6 +624,7 @@ async def run_discord_indexing( session: Database session connector_id: ID of the Discord connector search_space_id: ID of the search space + user_id: ID of the user start_date: Start date for indexing end_date: End date for indexing """ @@ -619,6 +634,7 @@ async def run_discord_indexing( session=session, connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, start_date=start_date, end_date=end_date, update_last_indexed=False # Don't update timestamp in the indexing function diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 21688df..f62172a 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -13,6 +13,7 @@ from .chunks import ChunkBase, ChunkCreate, ChunkUpdate, ChunkRead from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest from .chats import ChatBase, ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest from .search_source_connector import SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead +from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigUpdate, LLMConfigRead __all__ = [ "AISDKChatRequest", @@ -48,4 +49,8 @@ __all__ = [ "SearchSourceConnectorCreate", "SearchSourceConnectorUpdate", "SearchSourceConnectorRead", + "LLMConfigBase", + "LLMConfigCreate", + "LLMConfigUpdate", + "LLMConfigRead", ] \ No newline at end of file diff --git a/surfsense_backend/app/schemas/llm_config.py b/surfsense_backend/app/schemas/llm_config.py new file mode 100644 index 0000000..f4032cb --- /dev/null +++ b/surfsense_backend/app/schemas/llm_config.py @@ -0,0 +1,34 @@ +from datetime import datetime +import uuid +from typing import Optional, Dict, Any +from pydantic import BaseModel, ConfigDict, Field +from .base import IDModel, TimestampModel +from app.db import LiteLLMProvider + +class LLMConfigBase(BaseModel): + name: str = Field(..., max_length=100, description="User-friendly name for the LLM configuration") + provider: LiteLLMProvider = Field(..., description="LiteLLM provider type") + custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM") + model_name: str = Field(..., max_length=100, description="Model name without provider prefix") + api_key: str = Field(..., description="API key for the provider") + api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL") + litellm_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional LiteLLM parameters") + +class LLMConfigCreate(LLMConfigBase): + pass + +class LLMConfigUpdate(BaseModel): + name: Optional[str] = Field(None, max_length=100, description="User-friendly name for the LLM configuration") + provider: Optional[LiteLLMProvider] = Field(None, description="LiteLLM provider type") + custom_provider: Optional[str] = Field(None, max_length=100, description="Custom provider name when provider is CUSTOM") + model_name: Optional[str] = Field(None, max_length=100, description="Model name without provider prefix") + api_key: Optional[str] = Field(None, description="API key for the provider") + api_base: Optional[str] = Field(None, max_length=500, description="Optional API base URL") + litellm_params: Optional[Dict[str, Any]] = Field(None, description="Additional LiteLLM parameters") + +class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel): + id: int + created_at: datetime + user_id: uuid.UUID + + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/surfsense_backend/app/tasks/background_tasks.py b/surfsense_backend/app/tasks/background_tasks.py index 2641895..77e06e7 100644 --- a/surfsense_backend/app/tasks/background_tasks.py +++ b/surfsense_backend/app/tasks/background_tasks.py @@ -7,6 +7,7 @@ from app.schemas import ExtensionDocumentContent from app.config import config from app.prompts import SUMMARY_PROMPT_TEMPLATE from app.utils.document_converters import convert_document_to_markdown, generate_content_hash +from app.utils.llm_service import get_user_long_context_llm from langchain_core.documents import Document as LangChainDocument from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader from langchain_community.document_transformers import MarkdownifyTransformer @@ -18,9 +19,8 @@ import logging md = MarkdownifyTransformer() - async def add_crawled_url_document( - session: AsyncSession, url: str, search_space_id: int + session: AsyncSession, url: str, search_space_id: int, user_id: str ) -> Optional[Document]: try: if not validators.url(url): @@ -84,8 +84,13 @@ async def add_crawled_url_document( logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") return existing_document + # Get user's long context LLM + user_llm = await get_user_long_context_llm(session, user_id) + if not user_llm: + raise RuntimeError(f"No long context LLM configured for user {user_id}") + # Generate summary - summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance + summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm summary_result = await summary_chain.ainvoke( {"document": combined_document_string} ) @@ -130,7 +135,7 @@ async def add_crawled_url_document( async def add_extension_received_document( - session: AsyncSession, content: ExtensionDocumentContent, search_space_id: int + session: AsyncSession, content: ExtensionDocumentContent, search_space_id: int, user_id: str ) -> Optional[Document]: """ Process and store document content received from the SurfSense Extension. @@ -186,8 +191,13 @@ async def add_extension_received_document( logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") return existing_document + # Get user's long context LLM + user_llm = await get_user_long_context_llm(session, user_id) + if not user_llm: + raise RuntimeError(f"No long context LLM configured for user {user_id}") + # Generate summary - summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance + summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm summary_result = await summary_chain.ainvoke( {"document": combined_document_string} ) @@ -230,7 +240,7 @@ async def add_extension_received_document( async def add_received_markdown_file_document( - session: AsyncSession, file_name: str, file_in_markdown: str, search_space_id: int + session: AsyncSession, file_name: str, file_in_markdown: str, search_space_id: int, user_id: str ) -> Optional[Document]: try: content_hash = generate_content_hash(file_in_markdown) @@ -245,8 +255,13 @@ async def add_received_markdown_file_document( logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") return existing_document + # Get user's long context LLM + user_llm = await get_user_long_context_llm(session, user_id) + if not user_llm: + raise RuntimeError(f"No long context LLM configured for user {user_id}") + # Generate summary - summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance + summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm summary_result = await summary_chain.ainvoke({"document": file_in_markdown}) summary_content = summary_result.content summary_embedding = config.embedding_model_instance.embed(summary_content) @@ -292,6 +307,7 @@ async def add_received_file_document_using_unstructured( file_name: str, unstructured_processed_elements: List[LangChainDocument], search_space_id: int, + user_id: str, ) -> Optional[Document]: try: file_in_markdown = await convert_document_to_markdown( @@ -312,8 +328,13 @@ async def add_received_file_document_using_unstructured( # TODO: Check if file_markdown exceeds token limit of embedding model + # Get user's long context LLM + user_llm = await get_user_long_context_llm(session, user_id) + if not user_llm: + raise RuntimeError(f"No long context LLM configured for user {user_id}") + # Generate summary - summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance + summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm summary_result = await summary_chain.ainvoke({"document": file_in_markdown}) summary_content = summary_result.content summary_embedding = config.embedding_model_instance.embed(summary_content) @@ -360,6 +381,7 @@ async def add_received_file_document_using_llamacloud( file_name: str, llamacloud_markdown_document: str, search_space_id: int, + user_id: str, ) -> Optional[Document]: """ Process and store document content parsed by LlamaCloud. @@ -389,8 +411,13 @@ async def add_received_file_document_using_llamacloud( logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") return existing_document + # Get user's long context LLM + user_llm = await get_user_long_context_llm(session, user_id) + if not user_llm: + raise RuntimeError(f"No long context LLM configured for user {user_id}") + # Generate summary - summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance + summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm summary_result = await summary_chain.ainvoke({"document": file_in_markdown}) summary_content = summary_result.content summary_embedding = config.embedding_model_instance.embed(summary_content) @@ -433,7 +460,7 @@ async def add_received_file_document_using_llamacloud( async def add_youtube_video_document( - session: AsyncSession, url: str, search_space_id: int + session: AsyncSession, url: str, search_space_id: int, user_id: str ): """ Process a YouTube video URL, extract transcripts, and store as a document. @@ -541,8 +568,13 @@ async def add_youtube_video_document( logging.info(f"Document with content hash {content_hash} already exists. Skipping processing.") return existing_document + # Get user's long context LLM + user_llm = await get_user_long_context_llm(session, user_id) + if not user_llm: + raise RuntimeError(f"No long context LLM configured for user {user_id}") + # Generate summary - summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance + summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm summary_result = await summary_chain.ainvoke( {"document": combined_document_string} ) diff --git a/surfsense_backend/app/tasks/connectors_indexing_tasks.py b/surfsense_backend/app/tasks/connectors_indexing_tasks.py index 2572c7b..21243c3 100644 --- a/surfsense_backend/app/tasks/connectors_indexing_tasks.py +++ b/surfsense_backend/app/tasks/connectors_indexing_tasks.py @@ -3,9 +3,10 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.future import select from datetime import datetime, timedelta, timezone -from app.db import Document, DocumentType, Chunk, SearchSourceConnector, SearchSourceConnectorType +from app.db import Document, DocumentType, Chunk, SearchSourceConnector, SearchSourceConnectorType, SearchSpace from app.config import config from app.prompts import SUMMARY_PROMPT_TEMPLATE +from app.utils.llm_service import get_user_long_context_llm from app.connectors.slack_history import SlackHistory from app.connectors.notion_history import NotionHistoryConnector from app.connectors.github_connector import GitHubConnector @@ -24,6 +25,7 @@ async def index_slack_messages( session: AsyncSession, connector_id: int, search_space_id: int, + user_id: str, start_date: str = None, end_date: str = None, update_last_indexed: bool = True @@ -211,8 +213,16 @@ async def index_slack_messages( documents_skipped += 1 continue + # Get user's long context LLM + user_llm = await get_user_long_context_llm(session, user_id) + if not user_llm: + logger.error(f"No long context LLM configured for user {user_id}") + skipped_channels.append(f"{channel_name} (no LLM configured)") + documents_skipped += 1 + continue + # Generate summary - summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance + summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm summary_result = await summary_chain.ainvoke({"document": combined_document_string}) summary_content = summary_result.content summary_embedding = config.embedding_model_instance.embed(summary_content) @@ -289,6 +299,7 @@ async def index_notion_pages( session: AsyncSession, connector_id: int, search_space_id: int, + user_id: str, start_date: str = None, end_date: str = None, update_last_indexed: bool = True @@ -476,9 +487,17 @@ async def index_notion_pages( documents_skipped += 1 continue + # Get user's long context LLM + user_llm = await get_user_long_context_llm(session, user_id) + if not user_llm: + logger.error(f"No long context LLM configured for user {user_id}") + skipped_pages.append(f"{page_title} (no LLM configured)") + documents_skipped += 1 + continue + # Generate summary logger.debug(f"Generating summary for page {page_title}") - summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance + summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm summary_result = await summary_chain.ainvoke({"document": combined_document_string}) summary_content = summary_result.content summary_embedding = config.embedding_model_instance.embed(summary_content) @@ -549,6 +568,7 @@ async def index_github_repos( session: AsyncSession, connector_id: int, search_space_id: int, + user_id: str, start_date: str = None, end_date: str = None, update_last_indexed: bool = True @@ -717,6 +737,7 @@ async def index_linear_issues( session: AsyncSession, connector_id: int, search_space_id: int, + user_id: str, start_date: str = None, end_date: str = None, update_last_indexed: bool = True @@ -955,6 +976,7 @@ async def index_discord_messages( session: AsyncSession, connector_id: int, search_space_id: int, + user_id: str, start_date: str = None, end_date: str = None, update_last_indexed: bool = True @@ -1142,8 +1164,16 @@ async def index_discord_messages( documents_skipped += 1 continue + # Get user's long context LLM + user_llm = await get_user_long_context_llm(session, user_id) + if not user_llm: + logger.error(f"No long context LLM configured for user {user_id}") + skipped_channels.append(f"{guild_name}#{channel_name} (no LLM configured)") + documents_skipped += 1 + continue + # Generate summary using summary_chain - summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance + summary_chain = SUMMARY_PROMPT_TEMPLATE | user_llm summary_result = await summary_chain.ainvoke({"document": combined_document_string}) summary_content = summary_result.content summary_embedding = await asyncio.to_thread( diff --git a/surfsense_backend/app/tasks/podcast_tasks.py b/surfsense_backend/app/tasks/podcast_tasks.py index 12364e7..a6be546 100644 --- a/surfsense_backend/app/tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/podcast_tasks.py @@ -21,7 +21,8 @@ async def generate_chat_podcast( session: AsyncSession, chat_id: int, search_space_id: int, - podcast_title: str + podcast_title: str, + user_id: int ): # Fetch the chat with the specified ID query = select(Chat).filter( @@ -57,12 +58,14 @@ async def generate_chat_podcast( # Pass it to the SurfSense Podcaster config = { "configurable": { - "podcast_title" : "Surfsense", + "podcast_title": "SurfSense", + "user_id": str(user_id), } } # Initialize state with database session and streaming service initial_state = State( source_content=chat_history_str, + db_session=session ) # Run the graph directly diff --git a/surfsense_backend/app/utils/llm_service.py b/surfsense_backend/app/utils/llm_service.py new file mode 100644 index 0000000..7867d09 --- /dev/null +++ b/surfsense_backend/app/utils/llm_service.py @@ -0,0 +1,120 @@ +from typing import Optional +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from langchain_community.chat_models import ChatLiteLLM +import logging + +from app.db import User, LLMConfig + +logger = logging.getLogger(__name__) + +class LLMRole: + LONG_CONTEXT = "long_context" + FAST = "fast" + STRATEGIC = "strategic" + +async def get_user_llm_instance( + session: AsyncSession, + user_id: str, + role: str +) -> Optional[ChatLiteLLM]: + """ + Get a ChatLiteLLM instance for a specific user and role. + + Args: + session: Database session + user_id: User ID + role: LLM role ('long_context', 'fast', or 'strategic') + + Returns: + ChatLiteLLM instance or None if not found + """ + try: + # Get user with their LLM preferences + result = await session.execute( + select(User).where(User.id == user_id) + ) + user = result.scalars().first() + + if not user: + logger.error(f"User {user_id} not found") + return None + + # Get the appropriate LLM config ID based on role + llm_config_id = None + if role == LLMRole.LONG_CONTEXT: + llm_config_id = user.long_context_llm_id + elif role == LLMRole.FAST: + llm_config_id = user.fast_llm_id + elif role == LLMRole.STRATEGIC: + llm_config_id = user.strategic_llm_id + else: + logger.error(f"Invalid LLM role: {role}") + return None + + if not llm_config_id: + logger.error(f"No {role} LLM configured for user {user_id}") + return None + + # Get the LLM configuration + result = await session.execute( + select(LLMConfig).where( + LLMConfig.id == llm_config_id, + LLMConfig.user_id == user_id + ) + ) + llm_config = result.scalars().first() + + if not llm_config: + logger.error(f"LLM config {llm_config_id} not found for user {user_id}") + return None + + # Build the model string for litellm + if llm_config.custom_provider: + model_string = f"{llm_config.custom_provider}/{llm_config.model_name}" + else: + # Map provider enum to litellm format + provider_map = { + "OPENAI": "openai", + "ANTHROPIC": "anthropic", + "GROQ": "groq", + "COHERE": "cohere", + "GOOGLE": "gemini", + "OLLAMA": "ollama", + "MISTRAL": "mistral", + # Add more mappings as needed + } + provider_prefix = provider_map.get(llm_config.provider.value, llm_config.provider.value.lower()) + model_string = f"{provider_prefix}/{llm_config.model_name}" + + # Create ChatLiteLLM instance + litellm_kwargs = { + "model": model_string, + "api_key": llm_config.api_key, + } + + # Add optional parameters + if llm_config.api_base: + litellm_kwargs["api_base"] = llm_config.api_base + + # Add any additional litellm parameters + if llm_config.litellm_params: + litellm_kwargs.update(llm_config.litellm_params) + + return ChatLiteLLM(**litellm_kwargs) + + except Exception as e: + logger.error(f"Error getting LLM instance for user {user_id}, role {role}: {str(e)}") + return None + +async def get_user_long_context_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]: + """Get user's long context LLM instance.""" + return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT) + +async def get_user_fast_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]: + """Get user's fast LLM instance.""" + return await get_user_llm_instance(session, user_id, LLMRole.FAST) + +async def get_user_strategic_llm(session: AsyncSession, user_id: str) -> Optional[ChatLiteLLM]: + """Get user's strategic LLM instance.""" + return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC) \ No newline at end of file diff --git a/surfsense_backend/app/utils/query_service.py b/surfsense_backend/app/utils/query_service.py index 4442c8f..61ba5e4 100644 --- a/surfsense_backend/app/utils/query_service.py +++ b/surfsense_backend/app/utils/query_service.py @@ -1,6 +1,8 @@ import datetime from langchain.schema import HumanMessage, SystemMessage, AIMessage from app.config import config +from app.utils.llm_service import get_user_strategic_llm +from sqlalchemy.ext.asyncio import AsyncSession from typing import Any, List, Optional @@ -10,14 +12,21 @@ class QueryService: """ @staticmethod - async def reformulate_query_with_chat_history(user_query: str, chat_history_str: Optional[str] = None) -> str: + async def reformulate_query_with_chat_history( + user_query: str, + session: AsyncSession, + user_id: str, + chat_history_str: Optional[str] = None + ) -> str: """ - Reformulate the user query using the STRATEGIC_LLM to make it more + Reformulate the user query using the user's strategic LLM to make it more effective for information retrieval and research purposes. Args: user_query: The original user query - chat_history: Optional list of previous chat messages + session: Database session for accessing user LLM configs + user_id: User ID to get their specific LLM configuration + chat_history_str: Optional chat history string Returns: str: The reformulated query @@ -26,8 +35,11 @@ class QueryService: return user_query try: - # Get the strategic LLM instance from config - llm = config.strategic_llm_instance + # Get the user's strategic LLM instance + llm = await get_user_strategic_llm(session, user_id) + if not llm: + print(f"Warning: No strategic LLM configured for user {user_id}. Using original query.") + return user_query # Create system message with instructions system_message = SystemMessage( diff --git a/surfsense_web/app/dashboard/[search_space_id]/researcher/[chat_id]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/researcher/[chat_id]/page.tsx index a9ebd9c..dcdcc21 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/researcher/[chat_id]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/researcher/[chat_id]/page.tsx @@ -19,7 +19,9 @@ import { FolderOpen, Upload, ChevronDown, - Filter + Filter, + Brain, + Zap } from 'lucide-react'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; import { Button } from '@/components/ui/button'; @@ -42,6 +44,13 @@ import { DropdownMenuSeparator, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; import { Badge } from "@/components/ui/badge"; import { Skeleton } from "@/components/ui/skeleton"; import { @@ -62,6 +71,7 @@ import { MarkdownViewer } from '@/components/markdown-viewer'; import { Logo } from '@/components/Logo'; import { useSearchSourceConnectors } from '@/hooks'; import { useDocuments } from '@/hooks/use-documents'; +import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs'; interface SourceItem { id: number; @@ -374,6 +384,8 @@ const ChatPage = () => { const [currentDate, setCurrentDate] = useState(''); const terminalMessagesRef = useRef(null); const { connectorSourceItems, isLoading: isLoadingConnectors } = useSearchSourceConnectors(); + const { llmConfigs } = useLLMConfigs(); + const { preferences, updatePreferences } = useLLMPreferences(); const INITIAL_SOURCES_DISPLAY = 3; @@ -457,6 +469,8 @@ const ChatPage = () => { setCurrentTime(new Date().toTimeString().split(' ')[0]); }, []); + + // Add this CSS to remove input shadow and improve the UI useEffect(() => { if (typeof document !== 'undefined') { @@ -710,6 +724,7 @@ const ChatPage = () => { if (!input.trim() || status !== 'ready') return; // Validation: require at least one connector OR at least one document + // Note: Fast LLM selection updates user preferences automatically // if (selectedConnectors.length === 0 && selectedDocuments.length === 0) { // alert("Please select at least one connector or document"); // return; @@ -1569,6 +1584,75 @@ const ChatPage = () => { onChange={setResearchMode} /> + + {/* Fast LLM Selector */} +
+ +
diff --git a/surfsense_web/app/dashboard/layout.tsx b/surfsense_web/app/dashboard/layout.tsx new file mode 100644 index 0000000..0a43474 --- /dev/null +++ b/surfsense_web/app/dashboard/layout.tsx @@ -0,0 +1,90 @@ +"use client"; + +import { useEffect, useState } from 'react'; +import { useRouter } from 'next/navigation'; +import { useLLMPreferences } from '@/hooks/use-llm-configs'; +import { Loader2 } from 'lucide-react'; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; + +interface DashboardLayoutProps { + children: React.ReactNode; +} + +export default function DashboardLayout({ children }: DashboardLayoutProps) { + const router = useRouter(); + const { loading, error, isOnboardingComplete } = useLLMPreferences(); + const [isCheckingAuth, setIsCheckingAuth] = useState(true); + + useEffect(() => { + // Check if user is authenticated + const token = localStorage.getItem('surfsense_bearer_token'); + if (!token) { + router.push('/login'); + return; + } + setIsCheckingAuth(false); + }, [router]); + + useEffect(() => { + // Wait for preferences to load, then check if onboarding is complete + if (!loading && !error && !isCheckingAuth) { + if (!isOnboardingComplete()) { + router.push('/onboard'); + } + } + }, [loading, error, isCheckingAuth, isOnboardingComplete, router]); + + // Show loading screen while checking authentication or loading preferences + if (isCheckingAuth || loading) { + return ( +
+ + + Loading Dashboard + Checking your configuration... + + + + + +
+ ); + } + + // Show error screen if there's an error loading preferences + if (error) { + return ( +
+ + + Configuration Error + Failed to load your LLM configuration + + +

{error}

+
+
+
+ ); + } + + // Only render children if onboarding is complete + if (isOnboardingComplete()) { + return <>{children}; + } + + // This should not be reached due to redirect, but just in case + return ( +
+ + + Redirecting... + Taking you to complete your setup + + + + + +
+ ); +} \ No newline at end of file diff --git a/surfsense_web/app/onboard/page.tsx b/surfsense_web/app/onboard/page.tsx new file mode 100644 index 0000000..22cccd4 --- /dev/null +++ b/surfsense_web/app/onboard/page.tsx @@ -0,0 +1,227 @@ +"use client"; + +import React, { useState, useEffect } from 'react'; +import { useRouter } from 'next/navigation'; +import { motion, AnimatePresence } from 'framer-motion'; +import { Button } from '@/components/ui/button'; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; +import { Progress } from '@/components/ui/progress'; +import { CheckCircle, ArrowRight, ArrowLeft, Bot, Sparkles, Zap, Brain } from 'lucide-react'; +import { Logo } from '@/components/Logo'; +import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs'; +import { AddProviderStep } from '@/components/onboard/add-provider-step'; +import { AssignRolesStep } from '@/components/onboard/assign-roles-step'; +import { CompletionStep } from '@/components/onboard/completion-step'; + +const TOTAL_STEPS = 3; + +const OnboardPage = () => { + const router = useRouter(); + const { llmConfigs, loading: configsLoading } = useLLMConfigs(); + const { preferences, loading: preferencesLoading, isOnboardingComplete, refreshPreferences } = useLLMPreferences(); + const [currentStep, setCurrentStep] = useState(1); + const [hasUserProgressed, setHasUserProgressed] = useState(false); + + // Check if user is authenticated + useEffect(() => { + const token = localStorage.getItem('surfsense_bearer_token'); + if (!token) { + router.push('/login'); + return; + } + }, [router]); + + // Track if user has progressed beyond step 1 + useEffect(() => { + if (currentStep > 1) { + setHasUserProgressed(true); + } + }, [currentStep]); + + // Redirect to dashboard if onboarding is already complete and user hasn't progressed (fresh page load) + useEffect(() => { + if (!preferencesLoading && isOnboardingComplete() && !hasUserProgressed) { + router.push('/dashboard'); + } + }, [preferencesLoading, isOnboardingComplete, hasUserProgressed, router]); + + + + const progress = (currentStep / TOTAL_STEPS) * 100; + + const stepTitles = [ + "Add LLM Provider", + "Assign LLM Roles", + "Setup Complete" + ]; + + const stepDescriptions = [ + "Configure your first model provider", + "Assign specific roles to your LLM configurations", + "You're all set to start using SurfSense!" + ]; + + const canProceedToStep2 = !configsLoading && llmConfigs.length > 0; + const canProceedToStep3 = !preferencesLoading && preferences.long_context_llm_id && preferences.fast_llm_id && preferences.strategic_llm_id; + + + + const handleNext = () => { + if (currentStep < TOTAL_STEPS) { + setCurrentStep(currentStep + 1); + } + }; + + const handlePrevious = () => { + if (currentStep > 1) { + setCurrentStep(currentStep - 1); + } + }; + + const handleComplete = () => { + router.push('/dashboard'); + }; + + if (configsLoading || preferencesLoading) { + return ( +
+ + + +

Loading your configuration...

+
+
+
+ ); + } + + return ( +
+ + {/* Header */} +
+
+ +

Welcome to SurfSense

+
+

Let's configure your SurfSense to get started

+
+ + {/* Progress */} + + +
+
Step {currentStep} of {TOTAL_STEPS}
+
{Math.round(progress)}% Complete
+
+ +
+ {Array.from({ length: TOTAL_STEPS }, (_, i) => { + const stepNum = i + 1; + const isCompleted = stepNum < currentStep; + const isCurrent = stepNum === currentStep; + + return ( +
+
+ {isCompleted ? : stepNum} +
+
+

+ {stepTitles[i]} +

+
+
+ ); + })} +
+
+
+ + {/* Step Content */} + + + + {currentStep === 1 && } + {currentStep === 2 && } + {currentStep === 3 && } + {stepTitles[currentStep - 1]} + + + {stepDescriptions[currentStep - 1]} + + + + + + {currentStep === 1 && } + {currentStep === 2 && } + {currentStep === 3 && } + + + + + + {/* Navigation */} +
+ + +
+ {currentStep < TOTAL_STEPS && ( + + )} + + {currentStep === TOTAL_STEPS && ( + + )} +
+
+
+
+ ); +}; + +export default OnboardPage; \ No newline at end of file diff --git a/surfsense_web/app/settings/page.tsx b/surfsense_web/app/settings/page.tsx new file mode 100644 index 0000000..ff3bad2 --- /dev/null +++ b/surfsense_web/app/settings/page.tsx @@ -0,0 +1,60 @@ +"use client"; + +import React from 'react'; +import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; +import { Separator } from '@/components/ui/separator'; +import { Bot, Settings, Brain } from 'lucide-react'; +import { ModelConfigManager } from '@/components/settings/model-config-manager'; +import { LLMRoleManager } from '@/components/settings/llm-role-manager'; + +export default function SettingsPage() { + return ( +
+
+
+ {/* Header Section */} +
+
+
+ +
+
+

Settings

+

+ Manage your LLM configurations and role assignments. +

+
+
+ +
+ + {/* Settings Content */} + +
+ + + + Model Configs + Models + + + + LLM Roles + Roles + + +
+ + + + + + + + +
+
+
+
+ ); +} \ No newline at end of file diff --git a/surfsense_web/components/onboard/add-provider-step.tsx b/surfsense_web/components/onboard/add-provider-step.tsx new file mode 100644 index 0000000..39cc8b8 --- /dev/null +++ b/surfsense_web/components/onboard/add-provider-step.tsx @@ -0,0 +1,255 @@ +"use client"; + +import React, { useState } from 'react'; +import { motion } from 'framer-motion'; +import { Button } from '@/components/ui/button'; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; +import { Input } from '@/components/ui/input'; +import { Label } from '@/components/ui/label'; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; + +import { Badge } from '@/components/ui/badge'; +import { Plus, Trash2, Bot, AlertCircle } from 'lucide-react'; +import { useLLMConfigs, CreateLLMConfig } from '@/hooks/use-llm-configs'; +import { toast } from 'sonner'; +import { Alert, AlertDescription } from '@/components/ui/alert'; + +const LLM_PROVIDERS = [ + { value: 'OPENAI', label: 'OpenAI', example: 'gpt-4o, gpt-4, gpt-3.5-turbo' }, + { value: 'ANTHROPIC', label: 'Anthropic', example: 'claude-3-5-sonnet-20241022, claude-3-opus-20240229' }, + { value: 'GROQ', label: 'Groq', example: 'llama3-70b-8192, mixtral-8x7b-32768' }, + { value: 'COHERE', label: 'Cohere', example: 'command-r-plus, command-r' }, + { value: 'HUGGINGFACE', label: 'HuggingFace', example: 'microsoft/DialoGPT-medium' }, + { value: 'AZURE_OPENAI', label: 'Azure OpenAI', example: 'gpt-4, gpt-35-turbo' }, + { value: 'GOOGLE', label: 'Google', example: 'gemini-pro, gemini-pro-vision' }, + { value: 'AWS_BEDROCK', label: 'AWS Bedrock', example: 'anthropic.claude-v2' }, + { value: 'OLLAMA', label: 'Ollama', example: 'llama2, codellama' }, + { value: 'MISTRAL', label: 'Mistral', example: 'mistral-large-latest, mistral-medium' }, + { value: 'TOGETHER_AI', label: 'Together AI', example: 'togethercomputer/llama-2-70b-chat' }, + { value: 'REPLICATE', label: 'Replicate', example: 'meta/llama-2-70b-chat' }, + { value: 'CUSTOM', label: 'Custom Provider', example: 'your-custom-model' }, +]; + +export function AddProviderStep() { + const { llmConfigs, createLLMConfig, deleteLLMConfig } = useLLMConfigs(); + const [isAddingNew, setIsAddingNew] = useState(false); + const [formData, setFormData] = useState({ + name: '', + provider: '', + custom_provider: '', + model_name: '', + api_key: '', + api_base: '', + litellm_params: {} + }); + const [isSubmitting, setIsSubmitting] = useState(false); + + const handleInputChange = (field: keyof CreateLLMConfig, value: string) => { + setFormData(prev => ({ ...prev, [field]: value })); + }; + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + if (!formData.name || !formData.provider || !formData.model_name || !formData.api_key) { + toast.error('Please fill in all required fields'); + return; + } + + setIsSubmitting(true); + const result = await createLLMConfig(formData); + setIsSubmitting(false); + + if (result) { + setFormData({ + name: '', + provider: '', + custom_provider: '', + model_name: '', + api_key: '', + api_base: '', + litellm_params: {} + }); + setIsAddingNew(false); + } + }; + + const selectedProvider = LLM_PROVIDERS.find(p => p.value === formData.provider); + + return ( +
+ {/* Info Alert */} + + + + Add at least one LLM provider to continue. You can configure multiple providers and choose specific roles for each one in the next step. + + + + {/* Existing Configurations */} + {llmConfigs.length > 0 && ( +
+

Your LLM Configurations

+
+ {llmConfigs.map((config) => ( + + + +
+
+
+ +

{config.name}

+ {config.provider} +
+

+ Model: {config.model_name} + {config.api_base && ` • Base: ${config.api_base}`} +

+
+ +
+
+
+
+ ))} +
+
+ )} + + {/* Add New Provider */} + {!isAddingNew ? ( + + + +

Add LLM Provider

+

+ Configure your first model provider to get started +

+ +
+
+ ) : ( + + + Add New LLM Provider + + Configure a new language model provider for your AI assistant + + + +
+
+
+ + handleInputChange('name', e.target.value)} + required + /> +
+ +
+ + +
+
+ + {formData.provider === 'CUSTOM' && ( +
+ + handleInputChange('custom_provider', e.target.value)} + required + /> +
+ )} + +
+ + handleInputChange('model_name', e.target.value)} + required + /> + {selectedProvider && ( +

+ Examples: {selectedProvider.example} +

+ )} +
+ +
+ + handleInputChange('api_key', e.target.value)} + required + /> +
+ +
+ + handleInputChange('api_base', e.target.value)} + /> +
+ +
+ + +
+
+
+
+ )} +
+ ); +} \ No newline at end of file diff --git a/surfsense_web/components/onboard/assign-roles-step.tsx b/surfsense_web/components/onboard/assign-roles-step.tsx new file mode 100644 index 0000000..255fdee --- /dev/null +++ b/surfsense_web/components/onboard/assign-roles-step.tsx @@ -0,0 +1,232 @@ +"use client"; + +import React, { useState, useEffect } from 'react'; +import { motion } from 'framer-motion'; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; +import { Badge } from '@/components/ui/badge'; +import { Brain, Zap, Bot, AlertCircle, CheckCircle } from 'lucide-react'; +import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs'; +import { Alert, AlertDescription } from '@/components/ui/alert'; + +const ROLE_DESCRIPTIONS = { + long_context: { + icon: Brain, + title: 'Long Context LLM', + description: 'Handles complex tasks requiring extensive context understanding and reasoning', + color: 'bg-blue-100 text-blue-800 border-blue-200', + examples: 'Document analysis, research synthesis, complex Q&A' + }, + fast: { + icon: Zap, + title: 'Fast LLM', + description: 'Optimized for quick responses and real-time interactions', + color: 'bg-green-100 text-green-800 border-green-200', + examples: 'Quick searches, simple questions, instant responses' + }, + strategic: { + icon: Bot, + title: 'Strategic LLM', + description: 'Advanced reasoning for planning and strategic decision making', + color: 'bg-purple-100 text-purple-800 border-purple-200', + examples: 'Planning workflows, strategic analysis, complex problem solving' + } +}; + +interface AssignRolesStepProps { + onPreferencesUpdated?: () => Promise; +} + +export function AssignRolesStep({ onPreferencesUpdated }: AssignRolesStepProps) { + const { llmConfigs } = useLLMConfigs(); + const { preferences, updatePreferences } = useLLMPreferences(); + + + const [assignments, setAssignments] = useState({ + long_context_llm_id: preferences.long_context_llm_id || '', + fast_llm_id: preferences.fast_llm_id || '', + strategic_llm_id: preferences.strategic_llm_id || '' + }); + + + useEffect(() => { + setAssignments({ + long_context_llm_id: preferences.long_context_llm_id || '', + fast_llm_id: preferences.fast_llm_id || '', + strategic_llm_id: preferences.strategic_llm_id || '' + }); + }, [preferences]); + + const handleRoleAssignment = async (role: string, configId: string) => { + const newAssignments = { + ...assignments, + [role]: configId === '' ? '' : parseInt(configId) + }; + + setAssignments(newAssignments); + + // Auto-save if this assignment completes all roles + const hasAllAssignments = newAssignments.long_context_llm_id && newAssignments.fast_llm_id && newAssignments.strategic_llm_id; + + if (hasAllAssignments) { + const numericAssignments = { + long_context_llm_id: typeof newAssignments.long_context_llm_id === 'string' ? parseInt(newAssignments.long_context_llm_id) : newAssignments.long_context_llm_id, + fast_llm_id: typeof newAssignments.fast_llm_id === 'string' ? parseInt(newAssignments.fast_llm_id) : newAssignments.fast_llm_id, + strategic_llm_id: typeof newAssignments.strategic_llm_id === 'string' ? parseInt(newAssignments.strategic_llm_id) : newAssignments.strategic_llm_id, + }; + + const success = await updatePreferences(numericAssignments); + + // Refresh parent preferences state + if (success && onPreferencesUpdated) { + await onPreferencesUpdated(); + } + } + }; + + + + const isAssignmentComplete = assignments.long_context_llm_id && assignments.fast_llm_id && assignments.strategic_llm_id; + + if (llmConfigs.length === 0) { + return ( +
+ +

No LLM Configurations Found

+

+ Please add at least one LLM provider in the previous step before assigning roles. +

+
+ ); + } + + return ( +
+ {/* Info Alert */} + + + + Assign your LLM configurations to specific roles. Each role serves different purposes in your workflow. + + + + {/* Role Assignment Cards */} +
+ {Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => { + const IconComponent = role.icon; + const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments]; + const assignedConfig = llmConfigs.find(config => config.id === currentAssignment); + + return ( + + + +
+
+
+ +
+
+ {role.title} + {role.description} +
+
+ {currentAssignment && ( + + )} +
+
+ +
+ Use cases: {role.examples} +
+ +
+ + +
+ + {assignedConfig && ( +
+
+ + Assigned: + {assignedConfig.provider} + {assignedConfig.name} +
+
+ Model: {assignedConfig.model_name} +
+
+ )} +
+
+
+ ); + })} +
+ + + + {/* Status Indicator */} + {isAssignmentComplete && ( +
+
+ + All roles assigned and saved! +
+
+ )} + + {/* Progress Indicator */} +
+
+ Progress: +
+ {Object.keys(ROLE_DESCRIPTIONS).map((key, index) => ( +
+ ))} +
+ + {Object.values(assignments).filter(Boolean).length} of {Object.keys(ROLE_DESCRIPTIONS).length} roles assigned + +
+
+
+ ); +} \ No newline at end of file diff --git a/surfsense_web/components/onboard/completion-step.tsx b/surfsense_web/components/onboard/completion-step.tsx new file mode 100644 index 0000000..1a14753 --- /dev/null +++ b/surfsense_web/components/onboard/completion-step.tsx @@ -0,0 +1,125 @@ +"use client"; + +import React from 'react'; +import { motion } from 'framer-motion'; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; +import { Badge } from '@/components/ui/badge'; +import { CheckCircle, Bot, Brain, Zap, Sparkles, ArrowRight } from 'lucide-react'; +import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs'; + + +const ROLE_ICONS = { + long_context: Brain, + fast: Zap, + strategic: Bot +}; + +export function CompletionStep() { + const { llmConfigs } = useLLMConfigs(); + const { preferences } = useLLMPreferences(); + + const assignedConfigs = { + long_context: llmConfigs.find(c => c.id === preferences.long_context_llm_id), + fast: llmConfigs.find(c => c.id === preferences.fast_llm_id), + strategic: llmConfigs.find(c => c.id === preferences.strategic_llm_id) + }; + + return ( +
+ {/* Success Message */} + +
+ +
+

Setup Complete!

+
+ + {/* Configuration Summary */} + + + + + + Your LLM Configuration + + + Here's a summary of your setup + + + + {Object.entries(assignedConfigs).map(([role, config]) => { + if (!config) return null; + + const IconComponent = ROLE_ICONS[role as keyof typeof ROLE_ICONS]; + const roleDisplayNames = { + long_context: 'Long Context LLM', + fast: 'Fast LLM', + strategic: 'Strategic LLM' + }; + + return ( + +
+
+ +
+
+

{roleDisplayNames[role as keyof typeof roleDisplayNames]}

+

{config.name}

+
+
+
+ {config.provider} + {config.model_name} +
+
+ ); + })} +
+
+
+ + + {/* Next Steps */} + + + +
+
+ +
+

Ready to Get Started?

+
+

+ Click "Complete Setup" to enter your dashboard and start exploring! +

+
+ ✓ {llmConfigs.length} LLM provider{llmConfigs.length > 1 ? 's' : ''} configured + ✓ All roles assigned + ✓ Ready to use +
+
+
+
+
+ ); +} \ No newline at end of file diff --git a/surfsense_web/components/settings/llm-role-manager.tsx b/surfsense_web/components/settings/llm-role-manager.tsx new file mode 100644 index 0000000..581732d --- /dev/null +++ b/surfsense_web/components/settings/llm-role-manager.tsx @@ -0,0 +1,465 @@ +"use client"; + +import React, { useState, useEffect } from 'react'; +import { motion } from 'framer-motion'; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; +import { Badge } from '@/components/ui/badge'; +import { Button } from '@/components/ui/button'; +import { + Brain, + Zap, + Bot, + AlertCircle, + CheckCircle, + Settings2, + RefreshCw, + Save, + RotateCcw, + Loader2 +} from 'lucide-react'; +import { useLLMConfigs, useLLMPreferences } from '@/hooks/use-llm-configs'; +import { Alert, AlertDescription } from '@/components/ui/alert'; +import { toast } from 'sonner'; + +const ROLE_DESCRIPTIONS = { + long_context: { + icon: Brain, + title: 'Long Context LLM', + description: 'Handles complex tasks requiring extensive context understanding and reasoning', + color: 'bg-blue-100 text-blue-800 border-blue-200', + examples: 'Document analysis, research synthesis, complex Q&A', + characteristics: ['Large context window', 'Deep reasoning', 'Complex analysis'] + }, + fast: { + icon: Zap, + title: 'Fast LLM', + description: 'Optimized for quick responses and real-time interactions', + color: 'bg-green-100 text-green-800 border-green-200', + examples: 'Quick searches, simple questions, instant responses', + characteristics: ['Low latency', 'Quick responses', 'Real-time chat'] + }, + strategic: { + icon: Bot, + title: 'Strategic LLM', + description: 'Advanced reasoning for planning and strategic decision making', + color: 'bg-purple-100 text-purple-800 border-purple-200', + examples: 'Planning workflows, strategic analysis, complex problem solving', + characteristics: ['Strategic thinking', 'Long-term planning', 'Complex reasoning'] + } +}; + +export function LLMRoleManager() { + const { llmConfigs, loading: configsLoading, error: configsError, refreshConfigs } = useLLMConfigs(); + const { preferences, loading: preferencesLoading, error: preferencesError, updatePreferences, refreshPreferences } = useLLMPreferences(); + + const [assignments, setAssignments] = useState({ + long_context_llm_id: preferences.long_context_llm_id || '', + fast_llm_id: preferences.fast_llm_id || '', + strategic_llm_id: preferences.strategic_llm_id || '' + }); + + const [hasChanges, setHasChanges] = useState(false); + const [isSaving, setIsSaving] = useState(false); + + useEffect(() => { + const newAssignments = { + long_context_llm_id: preferences.long_context_llm_id || '', + fast_llm_id: preferences.fast_llm_id || '', + strategic_llm_id: preferences.strategic_llm_id || '' + }; + setAssignments(newAssignments); + setHasChanges(false); + }, [preferences]); + + const handleRoleAssignment = (role: string, configId: string) => { + const newAssignments = { + ...assignments, + [role]: configId === 'unassigned' ? '' : parseInt(configId) + }; + + setAssignments(newAssignments); + + // Check if there are changes compared to current preferences + const currentPrefs = { + long_context_llm_id: preferences.long_context_llm_id || '', + fast_llm_id: preferences.fast_llm_id || '', + strategic_llm_id: preferences.strategic_llm_id || '' + }; + + const hasChangesNow = Object.keys(newAssignments).some( + key => newAssignments[key as keyof typeof newAssignments] !== currentPrefs[key as keyof typeof currentPrefs] + ); + + setHasChanges(hasChangesNow); + }; + + const handleSave = async () => { + setIsSaving(true); + + const numericAssignments = { + long_context_llm_id: typeof assignments.long_context_llm_id === 'string' + ? (assignments.long_context_llm_id ? parseInt(assignments.long_context_llm_id) : undefined) + : assignments.long_context_llm_id, + fast_llm_id: typeof assignments.fast_llm_id === 'string' + ? (assignments.fast_llm_id ? parseInt(assignments.fast_llm_id) : undefined) + : assignments.fast_llm_id, + strategic_llm_id: typeof assignments.strategic_llm_id === 'string' + ? (assignments.strategic_llm_id ? parseInt(assignments.strategic_llm_id) : undefined) + : assignments.strategic_llm_id, + }; + + const success = await updatePreferences(numericAssignments); + + if (success) { + setHasChanges(false); + toast.success('LLM role assignments saved successfully!'); + } + + setIsSaving(false); + }; + + const handleReset = () => { + setAssignments({ + long_context_llm_id: preferences.long_context_llm_id || '', + fast_llm_id: preferences.fast_llm_id || '', + strategic_llm_id: preferences.strategic_llm_id || '' + }); + setHasChanges(false); + }; + + const isAssignmentComplete = assignments.long_context_llm_id && assignments.fast_llm_id && assignments.strategic_llm_id; + const assignedConfigIds = Object.values(assignments).filter(id => id !== ''); + const availableConfigs = llmConfigs.filter(config => config.id && config.id.toString().trim() !== ''); + + const isLoading = configsLoading || preferencesLoading; + const hasError = configsError || preferencesError; + + return ( +
+ {/* Header */} +
+
+
+
+ +
+
+

LLM Role Management

+

+ Assign your LLM configurations to specific roles for different purposes. +

+
+
+
+
+ + +
+
+ + {/* Error Alert */} + {hasError && ( + + + + {configsError || preferencesError} + + + )} + + {/* Loading State */} + {isLoading && ( + + +
+ + + {configsLoading && preferencesLoading ? 'Loading configurations and preferences...' : + configsLoading ? 'Loading configurations...' : + 'Loading preferences...'} + +
+
+
+ )} + + {/* Stats Overview */} + {!isLoading && !hasError && ( +
+ + +
+
+

{availableConfigs.length}

+

Available Models

+
+
+ +
+
+
+
+ + + +
+
+

{assignedConfigIds.length}

+

Assigned Roles

+
+
+ +
+
+
+
+ + + +
+
+

+ {Math.round((assignedConfigIds.length / 3) * 100)}% +

+

Completion

+
+
+ {isAssignmentComplete ? ( + + ) : ( + + )} +
+
+
+
+ + + +
+
+

+ {isAssignmentComplete ? 'Ready' : 'Setup'} +

+

Status

+
+
+ {isAssignmentComplete ? ( + + ) : ( + + )} +
+
+
+
+
+ )} + + {/* Info Alert */} + {!isLoading && !hasError && ( +
+ {availableConfigs.length === 0 ? ( + + + + No LLM configurations found. Please add at least one LLM provider in the Model Configs tab before assigning roles. + + + ) : !isAssignmentComplete ? ( + + + + Complete all role assignments to enable full functionality. Each role serves different purposes in your workflow. + + + ) : ( + + + + All roles are assigned and ready to use! Your LLM configuration is complete. + + + )} + + {/* Role Assignment Cards */} + {availableConfigs.length > 0 && ( +
+ {Object.entries(ROLE_DESCRIPTIONS).map(([key, role]) => { + const IconComponent = role.icon; + const currentAssignment = assignments[`${key}_llm_id` as keyof typeof assignments]; + const assignedConfig = availableConfigs.find(config => config.id === currentAssignment); + + return ( + + + +
+
+
+ +
+
+ {role.title} + {role.description} +
+
+ {currentAssignment && ( + + )} +
+
+ +
+
+ Use cases: {role.examples} +
+
+ {role.characteristics.map((char, idx) => ( + + {char} + + ))} +
+
+ +
+ + +
+ + {assignedConfig && ( +
+
+ + Assigned: + {assignedConfig.provider} + {assignedConfig.name} +
+
+ Model: {assignedConfig.model_name} +
+ {assignedConfig.api_base && ( +
+ Base: {assignedConfig.api_base} +
+ )} +
+ )} +
+
+
+ ); + })} +
+ )} + + {/* Action Buttons */} + {hasChanges && ( +
+ + +
+ )} + + {/* Status Indicator */} + {isAssignmentComplete && !hasChanges && ( +
+
+ + All roles assigned and saved! +
+
+ )} + + {/* Progress Indicator */} +
+
+ Progress: +
+ {Object.keys(ROLE_DESCRIPTIONS).map((key, index) => ( +
+ ))} +
+ + {assignedConfigIds.length} of {Object.keys(ROLE_DESCRIPTIONS).length} roles assigned + +
+
+
+ )} +
+ ); +} \ No newline at end of file diff --git a/surfsense_web/components/settings/model-config-manager.tsx b/surfsense_web/components/settings/model-config-manager.tsx new file mode 100644 index 0000000..13b8e5a --- /dev/null +++ b/surfsense_web/components/settings/model-config-manager.tsx @@ -0,0 +1,631 @@ +"use client"; + +import React, { useState, useEffect } from 'react'; +import { motion, AnimatePresence } from 'framer-motion'; +import { Button } from '@/components/ui/button'; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card'; +import { Input } from '@/components/ui/input'; +import { Label } from '@/components/ui/label'; +import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; +import { Badge } from '@/components/ui/badge'; +import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle } from '@/components/ui/dialog'; +import { + Plus, + Trash2, + Bot, + AlertCircle, + Edit3, + Settings2, + Eye, + EyeOff, + CheckCircle, + Clock, + AlertTriangle, + RefreshCw, + Loader2 +} from 'lucide-react'; +import { useLLMConfigs, CreateLLMConfig, UpdateLLMConfig, LLMConfig } from '@/hooks/use-llm-configs'; +import { toast } from 'sonner'; +import { Alert, AlertDescription } from '@/components/ui/alert'; + +const LLM_PROVIDERS = [ + { + value: 'OPENAI', + label: 'OpenAI', + example: 'gpt-4o, gpt-4, gpt-3.5-turbo', + description: 'Most popular and versatile AI models' + }, + { + value: 'ANTHROPIC', + label: 'Anthropic', + example: 'claude-3-5-sonnet-20241022, claude-3-opus-20240229', + description: 'Constitutional AI with strong reasoning' + }, + { + value: 'GROQ', + label: 'Groq', + example: 'llama3-70b-8192, mixtral-8x7b-32768', + description: 'Ultra-fast inference speeds' + }, + { + value: 'COHERE', + label: 'Cohere', + example: 'command-r-plus, command-r', + description: 'Enterprise-focused language models' + }, + { + value: 'HUGGINGFACE', + label: 'HuggingFace', + example: 'microsoft/DialoGPT-medium', + description: 'Open source model hub' + }, + { + value: 'AZURE_OPENAI', + label: 'Azure OpenAI', + example: 'gpt-4, gpt-35-turbo', + description: 'Enterprise OpenAI through Azure' + }, + { + value: 'GOOGLE', + label: 'Google', + example: 'gemini-pro, gemini-pro-vision', + description: 'Google\'s Gemini AI models' + }, + { + value: 'AWS_BEDROCK', + label: 'AWS Bedrock', + example: 'anthropic.claude-v2', + description: 'AWS managed AI service' + }, + { + value: 'OLLAMA', + label: 'Ollama', + example: 'llama2, codellama', + description: 'Run models locally' + }, + { + value: 'MISTRAL', + label: 'Mistral', + example: 'mistral-large-latest, mistral-medium', + description: 'European AI excellence' + }, + { + value: 'TOGETHER_AI', + label: 'Together AI', + example: 'togethercomputer/llama-2-70b-chat', + description: 'Decentralized AI platform' + }, + { + value: 'REPLICATE', + label: 'Replicate', + example: 'meta/llama-2-70b-chat', + description: 'Run models via API' + }, + { + value: 'CUSTOM', + label: 'Custom Provider', + example: 'your-custom-model', + description: 'Your own model endpoint' + }, +]; + +export function ModelConfigManager() { + const { llmConfigs, loading, error, createLLMConfig, updateLLMConfig, deleteLLMConfig, refreshConfigs } = useLLMConfigs(); + const [isAddingNew, setIsAddingNew] = useState(false); + const [editingConfig, setEditingConfig] = useState(null); + const [showApiKey, setShowApiKey] = useState>({}); + const [formData, setFormData] = useState({ + name: '', + provider: '', + custom_provider: '', + model_name: '', + api_key: '', + api_base: '', + litellm_params: {} + }); + const [isSubmitting, setIsSubmitting] = useState(false); + + // Populate form when editing + useEffect(() => { + if (editingConfig) { + setFormData({ + name: editingConfig.name, + provider: editingConfig.provider, + custom_provider: editingConfig.custom_provider || '', + model_name: editingConfig.model_name, + api_key: editingConfig.api_key, + api_base: editingConfig.api_base || '', + litellm_params: editingConfig.litellm_params || {} + }); + } + }, [editingConfig]); + + const handleInputChange = (field: keyof CreateLLMConfig, value: string) => { + setFormData(prev => ({ ...prev, [field]: value })); + }; + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + if (!formData.name || !formData.provider || !formData.model_name || !formData.api_key) { + toast.error('Please fill in all required fields'); + return; + } + + setIsSubmitting(true); + + let result; + if (editingConfig) { + // Update existing config + result = await updateLLMConfig(editingConfig.id, formData); + } else { + // Create new config + result = await createLLMConfig(formData); + } + + setIsSubmitting(false); + + if (result) { + setFormData({ + name: '', + provider: '', + custom_provider: '', + model_name: '', + api_key: '', + api_base: '', + litellm_params: {} + }); + setIsAddingNew(false); + setEditingConfig(null); + } + }; + + const handleDelete = async (id: number) => { + if (confirm('Are you sure you want to delete this configuration? This action cannot be undone.')) { + await deleteLLMConfig(id); + } + }; + + const toggleApiKeyVisibility = (configId: number) => { + setShowApiKey(prev => ({ + ...prev, + [configId]: !prev[configId] + })); + }; + + const selectedProvider = LLM_PROVIDERS.find(p => p.value === formData.provider); + + const getProviderInfo = (providerValue: string) => { + return LLM_PROVIDERS.find(p => p.value === providerValue); + }; + + const maskApiKey = (apiKey: string) => { + if (apiKey.length <= 8) return '*'.repeat(apiKey.length); + return apiKey.substring(0, 4) + '*'.repeat(apiKey.length - 8) + apiKey.substring(apiKey.length - 4); + }; + + return ( +
+ {/* Header */} +
+
+
+
+ +
+
+

Model Configurations

+

+ Manage your LLM provider configurations and API settings. +

+
+
+
+
+ +
+
+ + {/* Error Alert */} + {error && ( + + + + {error} + + + )} + + {/* Loading State */} + {loading && ( + + +
+ + Loading configurations... +
+
+
+ )} + + {/* Stats Overview */} + {!loading && !error && ( +
+ + +
+
+

{llmConfigs.length}

+

Total Configurations

+
+
+ +
+
+
+
+ + + +
+
+

+ {new Set(llmConfigs.map(c => c.provider)).size} +

+

Unique Providers

+
+
+ +
+
+
+
+ + + +
+
+

Active

+

System Status

+
+
+ +
+
+
+
+
+ )} + + {/* Configuration Management */} + {!loading && !error && ( +
+
+
+

Your Configurations

+

+ Manage and configure your LLM providers +

+
+ +
+ + {llmConfigs.length === 0 ? ( + + +
+ +
+
+

No Configurations Yet

+

+ Get started by adding your first LLM provider configuration to begin using the system. +

+
+ +
+
+ ) : ( +
+ + {llmConfigs.map((config) => { + const providerInfo = getProviderInfo(config.provider); + return ( + + + +
+
+ {/* Header */} +
+
+ +
+
+
+

{config.name}

+ + {config.provider} + +
+

+ {config.model_name} +

+
+
+ + {/* Provider Description */} + {providerInfo && ( +

+ {providerInfo.description} +

+ )} + + {/* Configuration Details */} +
+
+ +
+ + {showApiKey[config.id] + ? config.api_key + : maskApiKey(config.api_key) + } + + +
+
+ + {config.api_base && ( +
+ + + {config.api_base} + +
+ )} +
+ + {/* Metadata */} +
+
+ + Created {new Date(config.created_at).toLocaleDateString()} +
+
+
+ Active +
+
+
+ + {/* Actions */} +
+ + +
+
+
+
+
+ ); + })} +
+
+ )} +
+ )} + + {/* Add/Edit Configuration Dialog */} + { + if (!open) { + setIsAddingNew(false); + setEditingConfig(null); + setFormData({ + name: '', + provider: '', + custom_provider: '', + model_name: '', + api_key: '', + api_base: '', + litellm_params: {} + }); + } + }}> + + + + {editingConfig ? : } + {editingConfig ? 'Edit LLM Configuration' : 'Add New LLM Configuration'} + + + {editingConfig + ? 'Update your language model provider configuration' + : 'Configure a new language model provider for your AI assistant' + } + + + +
+
+
+ + handleInputChange('name', e.target.value)} + required + /> +
+ +
+ + +
+
+ + {formData.provider === 'CUSTOM' && ( +
+ + handleInputChange('custom_provider', e.target.value)} + required + /> +
+ )} + +
+ + handleInputChange('model_name', e.target.value)} + required + /> + {selectedProvider && ( +

+ Examples: {selectedProvider.example} +

+ )} +
+ +
+ + handleInputChange('api_key', e.target.value)} + required + /> +
+ +
+ + handleInputChange('api_base', e.target.value)} + /> +
+ +
+ + +
+
+
+
+
+ ); +} \ No newline at end of file diff --git a/surfsense_web/components/sidebar/nav-user.tsx b/surfsense_web/components/sidebar/nav-user.tsx index 934a582..fe16530 100644 --- a/surfsense_web/components/sidebar/nav-user.tsx +++ b/surfsense_web/components/sidebar/nav-user.tsx @@ -4,6 +4,7 @@ import { BadgeCheck, ChevronsUpDown, LogOut, + Settings, } from "lucide-react" import { @@ -93,6 +94,10 @@ export function NavUser({ + router.push(`/settings`)}> + + Settings + Log out diff --git a/surfsense_web/components/ui/progress.tsx b/surfsense_web/components/ui/progress.tsx new file mode 100644 index 0000000..d255dea --- /dev/null +++ b/surfsense_web/components/ui/progress.tsx @@ -0,0 +1,29 @@ +"use client" + +import * as React from "react" +import { cn } from "@/lib/utils" + +interface ProgressProps extends React.HTMLAttributes { + value?: number +} + +const Progress = React.forwardRef( + ({ className, value = 0, ...props }, ref) => ( +
+
+
+ ) +) +Progress.displayName = "Progress" + +export { Progress } \ No newline at end of file diff --git a/surfsense_web/hooks/use-llm-configs.ts b/surfsense_web/hooks/use-llm-configs.ts new file mode 100644 index 0000000..ecccb50 --- /dev/null +++ b/surfsense_web/hooks/use-llm-configs.ts @@ -0,0 +1,246 @@ +"use client" +import { useState, useEffect } from 'react'; +import { toast } from 'sonner'; + +export interface LLMConfig { + id: number; + name: string; + provider: string; + custom_provider?: string; + model_name: string; + api_key: string; + api_base?: string; + litellm_params?: Record; + created_at: string; + user_id: string; +} + +export interface LLMPreferences { + long_context_llm_id?: number; + fast_llm_id?: number; + strategic_llm_id?: number; + long_context_llm?: LLMConfig; + fast_llm?: LLMConfig; + strategic_llm?: LLMConfig; +} + +export interface CreateLLMConfig { + name: string; + provider: string; + custom_provider?: string; + model_name: string; + api_key: string; + api_base?: string; + litellm_params?: Record; +} + +export interface UpdateLLMConfig { + name?: string; + provider?: string; + custom_provider?: string; + model_name?: string; + api_key?: string; + api_base?: string; + litellm_params?: Record; +} + +export function useLLMConfigs() { + const [llmConfigs, setLlmConfigs] = useState([]); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + + const fetchLLMConfigs = async () => { + try { + setLoading(true); + const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/`, { + headers: { + Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`, + }, + method: "GET", + }); + + if (!response.ok) { + throw new Error("Failed to fetch LLM configurations"); + } + + const data = await response.json(); + setLlmConfigs(data); + setError(null); + } catch (err: any) { + setError(err.message || 'Failed to fetch LLM configurations'); + console.error('Error fetching LLM configurations:', err); + } finally { + setLoading(false); + } + }; + + useEffect(() => { + fetchLLMConfigs(); + }, []); + + const createLLMConfig = async (config: CreateLLMConfig): Promise => { + try { + const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`, + }, + body: JSON.stringify(config), + }); + + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.detail || 'Failed to create LLM configuration'); + } + + const newConfig = await response.json(); + setLlmConfigs(prev => [...prev, newConfig]); + toast.success('LLM configuration created successfully'); + return newConfig; + } catch (err: any) { + toast.error(err.message || 'Failed to create LLM configuration'); + console.error('Error creating LLM configuration:', err); + return null; + } + }; + + const deleteLLMConfig = async (id: number): Promise => { + try { + const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/${id}`, { + method: 'DELETE', + headers: { + Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`, + }, + }); + + if (!response.ok) { + throw new Error('Failed to delete LLM configuration'); + } + + setLlmConfigs(prev => prev.filter(config => config.id !== id)); + toast.success('LLM configuration deleted successfully'); + return true; + } catch (err: any) { + toast.error(err.message || 'Failed to delete LLM configuration'); + console.error('Error deleting LLM configuration:', err); + return false; + } + }; + + const updateLLMConfig = async (id: number, config: UpdateLLMConfig): Promise => { + try { + const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/${id}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`, + }, + body: JSON.stringify(config), + }); + + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.detail || 'Failed to update LLM configuration'); + } + + const updatedConfig = await response.json(); + setLlmConfigs(prev => prev.map(c => c.id === id ? updatedConfig : c)); + toast.success('LLM configuration updated successfully'); + return updatedConfig; + } catch (err: any) { + toast.error(err.message || 'Failed to update LLM configuration'); + console.error('Error updating LLM configuration:', err); + return null; + } + }; + + return { + llmConfigs, + loading, + error, + createLLMConfig, + updateLLMConfig, + deleteLLMConfig, + refreshConfigs: fetchLLMConfigs + }; +} + +export function useLLMPreferences() { + const [preferences, setPreferences] = useState({}); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + + const fetchPreferences = async () => { + try { + setLoading(true); + const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`, { + headers: { + Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`, + }, + method: "GET", + }); + + if (!response.ok) { + throw new Error("Failed to fetch LLM preferences"); + } + + const data = await response.json(); + setPreferences(data); + setError(null); + } catch (err: any) { + setError(err.message || 'Failed to fetch LLM preferences'); + console.error('Error fetching LLM preferences:', err); + } finally { + setLoading(false); + } + }; + + useEffect(() => { + fetchPreferences(); + }, []); + + const updatePreferences = async (newPreferences: Partial): Promise => { + try { + const response = await fetch(`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${localStorage.getItem('surfsense_bearer_token')}`, + }, + body: JSON.stringify(newPreferences), + }); + + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.detail || 'Failed to update LLM preferences'); + } + + const updatedPreferences = await response.json(); + setPreferences(updatedPreferences); + toast.success('LLM preferences updated successfully'); + return true; + } catch (err: any) { + toast.error(err.message || 'Failed to update LLM preferences'); + console.error('Error updating LLM preferences:', err); + return false; + } + }; + + const isOnboardingComplete = (): boolean => { + return !!( + preferences.long_context_llm_id && + preferences.fast_llm_id && + preferences.strategic_llm_id + ); + }; + + return { + preferences, + loading, + error, + updatePreferences, + refreshPreferences: fetchPreferences, + isOnboardingComplete + }; +} \ No newline at end of file