mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 18:19:08 +00:00
add jira to document type enum and search source connector type enum
This commit is contained in:
parent
c4eab5eaba
commit
90bfec6e7d
1 changed files with 207 additions and 76 deletions
|
@ -2,30 +2,30 @@ from collections.abc import AsyncGenerator
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||||
|
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
|
||||||
from pgvector.sqlalchemy import Vector
|
from pgvector.sqlalchemy import Vector
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
ARRAY,
|
ARRAY,
|
||||||
|
JSON,
|
||||||
|
TIMESTAMP,
|
||||||
Boolean,
|
Boolean,
|
||||||
Column,
|
Column,
|
||||||
Enum as SQLAlchemyEnum,
|
)
|
||||||
|
from sqlalchemy import Enum as SQLAlchemyEnum
|
||||||
|
from sqlalchemy import (
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
JSON,
|
|
||||||
String,
|
String,
|
||||||
Text,
|
Text,
|
||||||
text,
|
text,
|
||||||
TIMESTAMP
|
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, relationship
|
||||||
|
|
||||||
from app.config import config
|
|
||||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
|
||||||
from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever
|
|
||||||
|
|
||||||
if config.AUTH_TYPE == "GOOGLE":
|
if config.AUTH_TYPE == "GOOGLE":
|
||||||
from fastapi_users.db import (
|
from fastapi_users.db import (
|
||||||
SQLAlchemyBaseOAuthAccountTableUUID,
|
SQLAlchemyBaseOAuthAccountTableUUID,
|
||||||
|
@ -51,9 +51,11 @@ class DocumentType(str, Enum):
|
||||||
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
|
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
|
||||||
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
|
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
|
||||||
DISCORD_CONNECTOR = "DISCORD_CONNECTOR"
|
DISCORD_CONNECTOR = "DISCORD_CONNECTOR"
|
||||||
|
JIRA_CONNECTOR = "JIRA_CONNECTOR"
|
||||||
|
|
||||||
|
|
||||||
class SearchSourceConnectorType(str, Enum):
|
class SearchSourceConnectorType(str, Enum):
|
||||||
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
|
SERPER_API = "SERPER_API" # NOT IMPLEMENTED YET : DON'T REMEMBER WHY : MOST PROBABLY BECAUSE WE NEED TO CRAWL THE RESULTS RETURNED BY IT
|
||||||
TAVILY_API = "TAVILY_API"
|
TAVILY_API = "TAVILY_API"
|
||||||
LINKUP_API = "LINKUP_API"
|
LINKUP_API = "LINKUP_API"
|
||||||
SLACK_CONNECTOR = "SLACK_CONNECTOR"
|
SLACK_CONNECTOR = "SLACK_CONNECTOR"
|
||||||
|
@ -61,13 +63,16 @@ class SearchSourceConnectorType(str, Enum):
|
||||||
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
|
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
|
||||||
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
|
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
|
||||||
DISCORD_CONNECTOR = "DISCORD_CONNECTOR"
|
DISCORD_CONNECTOR = "DISCORD_CONNECTOR"
|
||||||
|
JIRA_CONNECTOR = "JIRA_CONNECTOR"
|
||||||
|
|
||||||
|
|
||||||
class ChatType(str, Enum):
|
class ChatType(str, Enum):
|
||||||
QNA = "QNA"
|
QNA = "QNA"
|
||||||
REPORT_GENERAL = "REPORT_GENERAL"
|
REPORT_GENERAL = "REPORT_GENERAL"
|
||||||
REPORT_DEEP = "REPORT_DEEP"
|
REPORT_DEEP = "REPORT_DEEP"
|
||||||
REPORT_DEEPER = "REPORT_DEEPER"
|
REPORT_DEEPER = "REPORT_DEEPER"
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProvider(str, Enum):
|
class LiteLLMProvider(str, Enum):
|
||||||
OPENAI = "OPENAI"
|
OPENAI = "OPENAI"
|
||||||
ANTHROPIC = "ANTHROPIC"
|
ANTHROPIC = "ANTHROPIC"
|
||||||
|
@ -92,6 +97,7 @@ class LiteLLMProvider(str, Enum):
|
||||||
PETALS = "PETALS"
|
PETALS = "PETALS"
|
||||||
CUSTOM = "CUSTOM"
|
CUSTOM = "CUSTOM"
|
||||||
|
|
||||||
|
|
||||||
class LogLevel(str, Enum):
|
class LogLevel(str, Enum):
|
||||||
DEBUG = "DEBUG"
|
DEBUG = "DEBUG"
|
||||||
INFO = "INFO"
|
INFO = "INFO"
|
||||||
|
@ -99,18 +105,27 @@ class LogLevel(str, Enum):
|
||||||
ERROR = "ERROR"
|
ERROR = "ERROR"
|
||||||
CRITICAL = "CRITICAL"
|
CRITICAL = "CRITICAL"
|
||||||
|
|
||||||
|
|
||||||
class LogStatus(str, Enum):
|
class LogStatus(str, Enum):
|
||||||
IN_PROGRESS = "IN_PROGRESS"
|
IN_PROGRESS = "IN_PROGRESS"
|
||||||
SUCCESS = "SUCCESS"
|
SUCCESS = "SUCCESS"
|
||||||
FAILED = "FAILED"
|
FAILED = "FAILED"
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TimestampMixin:
|
class TimestampMixin:
|
||||||
@declared_attr
|
@declared_attr
|
||||||
def created_at(cls):
|
def created_at(cls):
|
||||||
return Column(TIMESTAMP(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), index=True)
|
return Column(
|
||||||
|
TIMESTAMP(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
default=lambda: datetime.now(timezone.utc),
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(Base):
|
class BaseModel(Base):
|
||||||
__abstract__ = True
|
__abstract__ = True
|
||||||
|
@ -118,6 +133,7 @@ class BaseModel(Base):
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
|
||||||
|
|
||||||
class Chat(BaseModel, TimestampMixin):
|
class Chat(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "chats"
|
__tablename__ = "chats"
|
||||||
|
|
||||||
|
@ -125,73 +141,115 @@ class Chat(BaseModel, TimestampMixin):
|
||||||
title = Column(String, nullable=False, index=True)
|
title = Column(String, nullable=False, index=True)
|
||||||
initial_connectors = Column(ARRAY(String), nullable=True)
|
initial_connectors = Column(ARRAY(String), nullable=True)
|
||||||
messages = Column(JSON, nullable=False)
|
messages = Column(JSON, nullable=False)
|
||||||
|
|
||||||
search_space_id = Column(Integer, ForeignKey('searchspaces.id', ondelete='CASCADE'), nullable=False)
|
search_space_id = Column(
|
||||||
search_space = relationship('SearchSpace', back_populates='chats')
|
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
search_space = relationship("SearchSpace", back_populates="chats")
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel, TimestampMixin):
|
class Document(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "documents"
|
__tablename__ = "documents"
|
||||||
|
|
||||||
title = Column(String, nullable=False, index=True)
|
title = Column(String, nullable=False, index=True)
|
||||||
document_type = Column(SQLAlchemyEnum(DocumentType), nullable=False)
|
document_type = Column(SQLAlchemyEnum(DocumentType), nullable=False)
|
||||||
document_metadata = Column(JSON, nullable=True)
|
document_metadata = Column(JSON, nullable=True)
|
||||||
|
|
||||||
content = Column(Text, nullable=False)
|
content = Column(Text, nullable=False)
|
||||||
content_hash = Column(String, nullable=False, index=True, unique=True)
|
content_hash = Column(String, nullable=False, index=True, unique=True)
|
||||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||||
|
|
||||||
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
|
search_space_id = Column(
|
||||||
|
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
search_space = relationship("SearchSpace", back_populates="documents")
|
search_space = relationship("SearchSpace", back_populates="documents")
|
||||||
chunks = relationship("Chunk", back_populates="document", cascade="all, delete-orphan")
|
chunks = relationship(
|
||||||
|
"Chunk", back_populates="document", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Chunk(BaseModel, TimestampMixin):
|
class Chunk(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "chunks"
|
__tablename__ = "chunks"
|
||||||
|
|
||||||
content = Column(Text, nullable=False)
|
content = Column(Text, nullable=False)
|
||||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||||
|
|
||||||
document_id = Column(Integer, ForeignKey("documents.id", ondelete='CASCADE'), nullable=False)
|
document_id = Column(
|
||||||
|
Integer, ForeignKey("documents.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
document = relationship("Document", back_populates="chunks")
|
document = relationship("Document", back_populates="chunks")
|
||||||
|
|
||||||
|
|
||||||
class Podcast(BaseModel, TimestampMixin):
|
class Podcast(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "podcasts"
|
__tablename__ = "podcasts"
|
||||||
|
|
||||||
title = Column(String, nullable=False, index=True)
|
title = Column(String, nullable=False, index=True)
|
||||||
podcast_transcript = Column(JSON, nullable=False, default={})
|
podcast_transcript = Column(JSON, nullable=False, default={})
|
||||||
file_location = Column(String(500), nullable=False, default="")
|
file_location = Column(String(500), nullable=False, default="")
|
||||||
|
|
||||||
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
|
search_space_id = Column(
|
||||||
|
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
search_space = relationship("SearchSpace", back_populates="podcasts")
|
search_space = relationship("SearchSpace", back_populates="podcasts")
|
||||||
|
|
||||||
|
|
||||||
class SearchSpace(BaseModel, TimestampMixin):
|
class SearchSpace(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "searchspaces"
|
__tablename__ = "searchspaces"
|
||||||
|
|
||||||
name = Column(String(100), nullable=False, index=True)
|
name = Column(String(100), nullable=False, index=True)
|
||||||
description = Column(String(500), nullable=True)
|
description = Column(String(500), nullable=True)
|
||||||
|
|
||||||
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
|
user_id = Column(
|
||||||
|
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
user = relationship("User", back_populates="search_spaces")
|
user = relationship("User", back_populates="search_spaces")
|
||||||
|
|
||||||
documents = relationship("Document", back_populates="search_space", order_by="Document.id", cascade="all, delete-orphan")
|
documents = relationship(
|
||||||
podcasts = relationship("Podcast", back_populates="search_space", order_by="Podcast.id", cascade="all, delete-orphan")
|
"Document",
|
||||||
chats = relationship('Chat', back_populates='search_space', order_by='Chat.id', cascade="all, delete-orphan")
|
back_populates="search_space",
|
||||||
logs = relationship("Log", back_populates="search_space", order_by="Log.id", cascade="all, delete-orphan")
|
order_by="Document.id",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
podcasts = relationship(
|
||||||
|
"Podcast",
|
||||||
|
back_populates="search_space",
|
||||||
|
order_by="Podcast.id",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
chats = relationship(
|
||||||
|
"Chat",
|
||||||
|
back_populates="search_space",
|
||||||
|
order_by="Chat.id",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
logs = relationship(
|
||||||
|
"Log",
|
||||||
|
back_populates="search_space",
|
||||||
|
order_by="Log.id",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SearchSourceConnector(BaseModel, TimestampMixin):
|
class SearchSourceConnector(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "search_source_connectors"
|
__tablename__ = "search_source_connectors"
|
||||||
|
|
||||||
name = Column(String(100), nullable=False, index=True)
|
name = Column(String(100), nullable=False, index=True)
|
||||||
connector_type = Column(SQLAlchemyEnum(SearchSourceConnectorType), nullable=False, unique=True)
|
connector_type = Column(
|
||||||
|
SQLAlchemyEnum(SearchSourceConnectorType), nullable=False, unique=True
|
||||||
|
)
|
||||||
is_indexable = Column(Boolean, nullable=False, default=False)
|
is_indexable = Column(Boolean, nullable=False, default=False)
|
||||||
last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||||
config = Column(JSON, nullable=False)
|
config = Column(JSON, nullable=False)
|
||||||
|
|
||||||
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
|
user_id = Column(
|
||||||
|
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
user = relationship("User", back_populates="search_source_connectors")
|
user = relationship("User", back_populates="search_source_connectors")
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig(BaseModel, TimestampMixin):
|
class LLMConfig(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "llm_configs"
|
__tablename__ = "llm_configs"
|
||||||
|
|
||||||
name = Column(String(100), nullable=False, index=True)
|
name = Column(String(100), nullable=False, index=True)
|
||||||
# Provider from the enum
|
# Provider from the enum
|
||||||
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
|
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
|
||||||
|
@ -202,78 +260,142 @@ class LLMConfig(BaseModel, TimestampMixin):
|
||||||
# API Key should be encrypted before storing
|
# API Key should be encrypted before storing
|
||||||
api_key = Column(String, nullable=False)
|
api_key = Column(String, nullable=False)
|
||||||
api_base = Column(String(500), nullable=True)
|
api_base = Column(String(500), nullable=True)
|
||||||
|
|
||||||
# For any other parameters that litellm supports
|
# For any other parameters that litellm supports
|
||||||
litellm_params = Column(JSON, nullable=True, default={})
|
litellm_params = Column(JSON, nullable=True, default={})
|
||||||
|
|
||||||
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
|
user_id = Column(
|
||||||
|
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
user = relationship("User", back_populates="llm_configs", foreign_keys=[user_id])
|
user = relationship("User", back_populates="llm_configs", foreign_keys=[user_id])
|
||||||
|
|
||||||
|
|
||||||
class Log(BaseModel, TimestampMixin):
|
class Log(BaseModel, TimestampMixin):
|
||||||
__tablename__ = "logs"
|
__tablename__ = "logs"
|
||||||
|
|
||||||
level = Column(SQLAlchemyEnum(LogLevel), nullable=False, index=True)
|
level = Column(SQLAlchemyEnum(LogLevel), nullable=False, index=True)
|
||||||
status = Column(SQLAlchemyEnum(LogStatus), nullable=False, index=True)
|
status = Column(SQLAlchemyEnum(LogStatus), nullable=False, index=True)
|
||||||
message = Column(Text, nullable=False)
|
message = Column(Text, nullable=False)
|
||||||
source = Column(String(200), nullable=True, index=True) # Service/component that generated the log
|
source = Column(
|
||||||
|
String(200), nullable=True, index=True
|
||||||
|
) # Service/component that generated the log
|
||||||
log_metadata = Column(JSON, nullable=True, default={}) # Additional context data
|
log_metadata = Column(JSON, nullable=True, default={}) # Additional context data
|
||||||
|
|
||||||
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
|
search_space_id = Column(
|
||||||
|
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
search_space = relationship("SearchSpace", back_populates="logs")
|
search_space = relationship("SearchSpace", back_populates="logs")
|
||||||
|
|
||||||
|
|
||||||
if config.AUTH_TYPE == "GOOGLE":
|
if config.AUTH_TYPE == "GOOGLE":
|
||||||
|
|
||||||
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||||
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||||
"OAuthAccount", lazy="joined"
|
"OAuthAccount", lazy="joined"
|
||||||
)
|
)
|
||||||
search_spaces = relationship("SearchSpace", back_populates="user")
|
search_spaces = relationship("SearchSpace", back_populates="user")
|
||||||
search_source_connectors = relationship("SearchSourceConnector", back_populates="user")
|
search_source_connectors = relationship(
|
||||||
llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan")
|
"SearchSourceConnector", back_populates="user"
|
||||||
|
)
|
||||||
|
llm_configs = relationship(
|
||||||
|
"LLMConfig",
|
||||||
|
back_populates="user",
|
||||||
|
foreign_keys="LLMConfig.user_id",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
long_context_llm_id = Column(
|
||||||
fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||||
strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
)
|
||||||
|
fast_llm_id = Column(
|
||||||
|
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
strategic_llm_id = Column(
|
||||||
|
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
long_context_llm = relationship(
|
||||||
|
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
|
||||||
|
)
|
||||||
|
fast_llm = relationship(
|
||||||
|
"LLMConfig", foreign_keys=[fast_llm_id], post_update=True
|
||||||
|
)
|
||||||
|
strategic_llm = relationship(
|
||||||
|
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
|
||||||
|
)
|
||||||
|
|
||||||
long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True)
|
|
||||||
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
|
|
||||||
strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True)
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||||
|
|
||||||
search_spaces = relationship("SearchSpace", back_populates="user")
|
search_spaces = relationship("SearchSpace", back_populates="user")
|
||||||
search_source_connectors = relationship("SearchSourceConnector", back_populates="user")
|
search_source_connectors = relationship(
|
||||||
llm_configs = relationship("LLMConfig", back_populates="user", foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan")
|
"SearchSourceConnector", back_populates="user"
|
||||||
|
)
|
||||||
|
llm_configs = relationship(
|
||||||
|
"LLMConfig",
|
||||||
|
back_populates="user",
|
||||||
|
foreign_keys="LLMConfig.user_id",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
long_context_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
long_context_llm_id = Column(
|
||||||
fast_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||||
strategic_llm_id = Column(Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True)
|
)
|
||||||
|
fast_llm_id = Column(
|
||||||
|
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
strategic_llm_id = Column(
|
||||||
|
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True)
|
long_context_llm = relationship(
|
||||||
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
|
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
|
||||||
strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True)
|
)
|
||||||
|
fast_llm = relationship(
|
||||||
|
"LLMConfig", foreign_keys=[fast_llm_id], post_update=True
|
||||||
|
)
|
||||||
|
strategic_llm = relationship(
|
||||||
|
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
engine = create_async_engine(DATABASE_URL)
|
engine = create_async_engine(DATABASE_URL)
|
||||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
async def setup_indexes():
|
async def setup_indexes():
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
# Create indexes
|
# Create indexes
|
||||||
# Document Summary Indexes
|
# Document Summary Indexes
|
||||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)'))
|
await conn.execute(
|
||||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector(\'english\', content))'))
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))"
|
||||||
|
)
|
||||||
|
)
|
||||||
# Document Chuck Indexes
|
# Document Chuck Indexes
|
||||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)'))
|
await conn.execute(
|
||||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector(\'english\', content))'))
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def create_db_and_tables():
|
async def create_db_and_tables():
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
|
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
await setup_indexes()
|
await setup_indexes()
|
||||||
|
|
||||||
|
@ -284,14 +406,23 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
|
||||||
|
|
||||||
if config.AUTH_TYPE == "GOOGLE":
|
if config.AUTH_TYPE == "GOOGLE":
|
||||||
|
|
||||||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||||||
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
|
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||||||
yield SQLAlchemyUserDatabase(session, User)
|
yield SQLAlchemyUserDatabase(session, User)
|
||||||
|
|
||||||
async def get_chucks_hybrid_search_retriever(session: AsyncSession = Depends(get_async_session)):
|
|
||||||
|
async def get_chucks_hybrid_search_retriever(
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
):
|
||||||
return ChucksHybridSearchRetriever(session)
|
return ChucksHybridSearchRetriever(session)
|
||||||
|
|
||||||
async def get_documents_hybrid_search_retriever(session: AsyncSession = Depends(get_async_session)):
|
|
||||||
|
async def get_documents_hybrid_search_retriever(
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
):
|
||||||
return DocumentHybridSearchRetriever(session)
|
return DocumentHybridSearchRetriever(session)
|
||||||
|
|
Loading…
Add table
Reference in a new issue