mirror of
https://github.com/MODSetter/SurfSense.git
synced 2025-09-01 18:19:08 +00:00
SurfSense v3 - Highlight: Local LLM Support
This commit is contained in:
parent
04df919cf9
commit
7f38091d3d
13 changed files with 692 additions and 1345 deletions
14
.gitignore
vendored
14
.gitignore
vendored
|
@ -6,7 +6,19 @@ venv/
|
|||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
data/
|
||||
.data
|
||||
|
||||
__pycache__
|
||||
__pycache__/
|
||||
.__pycache__
|
||||
.__pycache__
|
||||
|
||||
backend/examples
|
||||
|
||||
backend/old
|
||||
backend/RAGAgent
|
||||
backend/testfiles
|
||||
|
||||
vectorstores/*
|
||||
vectorstores/
|
||||
.vectorstores
|
|
@ -1 +1 @@
|
|||
Subproject commit 044dc1b24788a38f8b2735113091ba1510462a5a
|
||||
Subproject commit aeb87beee51bbf040307157d6c54b3a3a82ef620
|
|
@ -1,7 +1,10 @@
|
|||
#true if you wana run local setup with Ollama llama3.1
|
||||
IS_LOCAL_SETUP = 'false'
|
||||
|
||||
#POSTGRES DB TO TRACK USERS
|
||||
POSTGRES_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/surfsense"
|
||||
|
||||
# API KEY TO VERIFY
|
||||
# API KEY TO PREVENT USER REGISTRATION SPAM
|
||||
API_SECRET_KEY = "surfsense"
|
||||
|
||||
# Your JWT secret and algorithm
|
303
backend/HIndices.py
Normal file
303
backend/HIndices.py
Normal file
|
@ -0,0 +1,303 @@
|
|||
from langchain_chroma import Chroma
|
||||
from langchain_ollama import OllamaLLM, OllamaEmbeddings
|
||||
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
from langchain.docstore.document import Document
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import FlashrankRerank
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import Depends
|
||||
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from database import SessionLocal
|
||||
from models import Documents, User
|
||||
from prompts import CONTEXT_ANSWER_PROMPT
|
||||
load_dotenv()
|
||||
|
||||
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
|
||||
|
||||
# Dependency
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
class HIndices:
|
||||
def __init__(self, username, api_key='local'):
|
||||
"""
|
||||
"""
|
||||
self.username = username
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
self.llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
self.embeddings = OllamaEmbeddings(model="mistral-nemo")
|
||||
else:
|
||||
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=api_key)
|
||||
self.embeddings = OpenAIEmbeddings(api_key=api_key)
|
||||
|
||||
self.summary_store = Chroma(
|
||||
collection_name="summary_store",
|
||||
embedding_function=self.embeddings,
|
||||
persist_directory="./vectorstores/" + username + "/summary_store_db", # Where to save data locally
|
||||
)
|
||||
self.detailed_store = Chroma(
|
||||
collection_name="detailed_store",
|
||||
embedding_function=self.embeddings,
|
||||
persist_directory="./vectorstores/" + username + "/detailed_store_db", # Where to save data locally
|
||||
)
|
||||
|
||||
# self.summary_store_size = len(self.summary_store.get()['documents'])
|
||||
# self.detailed_store_size = len(self.detailed_store.get()['documents'])
|
||||
|
||||
def encode_docs_hierarchical(self, documents, files_type, search_space='GENERAL', db: Session = Depends(get_db)):
|
||||
"""
|
||||
Creates and Saves/Updates docs in hierarchical indices and postgres table
|
||||
"""
|
||||
|
||||
def summarize_doc(page_no,doc):
|
||||
"""
|
||||
Summarizes a single document.
|
||||
|
||||
Args:
|
||||
page_no: Page no in Summary Vector store
|
||||
doc: The document to be summarized.
|
||||
|
||||
Returns:
|
||||
A summarized Document object.
|
||||
"""
|
||||
|
||||
|
||||
report_template = """You are a forensic investigator expert in making detailed report of the document. You are given the document make a report of it.
|
||||
|
||||
DOCUMENT: {document}
|
||||
|
||||
Detailed Report:"""
|
||||
|
||||
|
||||
report_prompt = PromptTemplate(
|
||||
input_variables=["document"],
|
||||
template=report_template
|
||||
)
|
||||
|
||||
# Create an LLMChain for sub-query decomposition
|
||||
report_chain = report_prompt | self.llm
|
||||
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
# Local LLMS suck at summaries so need this slow and painful procedure
|
||||
text_splitter = SemanticChunker(embeddings=self.embeddings)
|
||||
chunks = text_splitter.split_documents([doc])
|
||||
combined_summary = ""
|
||||
for i, chunk in enumerate(chunks):
|
||||
print("GENERATING SUMMARY FOR CHUNK "+ str(i))
|
||||
chunk_summary = report_chain.invoke({"document": chunk})
|
||||
combined_summary += "\n\n" + chunk_summary + "\n\n"
|
||||
|
||||
response = combined_summary
|
||||
|
||||
return Document(
|
||||
id=str(page_no),
|
||||
page_content=response,
|
||||
metadata={
|
||||
"filetype": files_type,
|
||||
"page": page_no,
|
||||
"summary": True,
|
||||
"search_space": search_space,
|
||||
"BrowsingSessionId": doc.metadata['BrowsingSessionId'],
|
||||
"VisitedWebPageURL": doc.metadata['VisitedWebPageURL'],
|
||||
"VisitedWebPageTitle": doc.metadata['VisitedWebPageTitle'],
|
||||
"VisitedWebPageDateWithTimeInISOString": doc.metadata['VisitedWebPageDateWithTimeInISOString'],
|
||||
"VisitedWebPageReffererURL": doc.metadata['VisitedWebPageReffererURL'],
|
||||
"VisitedWebPageVisitDurationInMilliseconds": doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = report_chain.invoke({"document": doc})
|
||||
|
||||
return Document(
|
||||
id=str(page_no),
|
||||
page_content=response.content,
|
||||
metadata={
|
||||
"filetype": files_type,
|
||||
"page": page_no,
|
||||
"summary": True,
|
||||
"search_space": search_space,
|
||||
"BrowsingSessionId": doc.metadata['BrowsingSessionId'],
|
||||
"VisitedWebPageURL": doc.metadata['VisitedWebPageURL'],
|
||||
"VisitedWebPageTitle": doc.metadata['VisitedWebPageTitle'],
|
||||
"VisitedWebPageDateWithTimeInISOString": doc.metadata['VisitedWebPageDateWithTimeInISOString'],
|
||||
"VisitedWebPageReffererURL": doc.metadata['VisitedWebPageReffererURL'],
|
||||
"VisitedWebPageVisitDurationInMilliseconds": doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# DocumentPgEntry = []
|
||||
# searchspace = db.query(SearchSpace).filter(SearchSpace.search_space == search_space).first()
|
||||
|
||||
# for doc in documents:
|
||||
# pgdocmeta = stringify(doc.metadata)
|
||||
|
||||
# if(searchspace):
|
||||
# DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=search_space, document_metadata=pgdocmeta, page_content=doc.page_content))
|
||||
# else:
|
||||
# DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=SearchSpace(search_space=search_space), document_metadata=pgdocmeta, page_content=doc.page_content))
|
||||
|
||||
|
||||
prev_doc_idx = len(documents) + 1
|
||||
# #Save docs in PG
|
||||
user = db.query(User).filter(User.username == self.username).first()
|
||||
|
||||
if(len(user.documents) < prev_doc_idx):
|
||||
summary_last_id = 0
|
||||
detail_id_counter = 0
|
||||
else:
|
||||
summary_last_id = int(user.documents[-prev_doc_idx].id)
|
||||
detail_id_counter = int(user.documents[-prev_doc_idx].desc_vector_end)
|
||||
|
||||
|
||||
# Process documents
|
||||
summaries = []
|
||||
batch_summaries = [summarize_doc(i + summary_last_id, doc) for i, doc in enumerate(documents)]
|
||||
summaries.extend(batch_summaries)
|
||||
|
||||
detailed_chunks = []
|
||||
|
||||
for i, summary in enumerate(summaries):
|
||||
|
||||
# Semantic chucking for better contexual comprression
|
||||
text_splitter = SemanticChunker(embeddings=self.embeddings)
|
||||
chunks = text_splitter.split_documents([documents[i]])
|
||||
|
||||
user.documents[-(len(summaries) - i)].desc_vector_start = detail_id_counter
|
||||
user.documents[-(len(summaries) - i)].desc_vector_end = detail_id_counter + len(chunks)
|
||||
# summary_entry = db.query(Documents).filter(Documents.id == int(user.documents[-1].id)).first()
|
||||
# summary_entry.desc_vector_start = detail_id_counter
|
||||
# summary_entry.desc_vector_end = detail_id_counter + len(chunks)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Update metadata for detailed chunks
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk.id = str(detail_id_counter)
|
||||
chunk.metadata.update({
|
||||
"chunk_id": detail_id_counter,
|
||||
"summary": False,
|
||||
"page": summary.metadata['page'],
|
||||
})
|
||||
|
||||
detail_id_counter += 1
|
||||
|
||||
detailed_chunks.extend(chunks)
|
||||
|
||||
#update vector stores
|
||||
self.summary_store.add_documents(summaries)
|
||||
self.detailed_store.add_documents(detailed_chunks)
|
||||
|
||||
return self.summary_store, self.detailed_store
|
||||
|
||||
def delete_vector_stores(self, summary_ids_to_delete: list[str], db: Session = Depends(get_db)):
|
||||
self.summary_store.delete(ids=summary_ids_to_delete)
|
||||
for id in summary_ids_to_delete:
|
||||
summary_entry = db.query(Documents).filter(Documents.id == int(id) + 1).first()
|
||||
|
||||
desc_ids_to_del = [str(id) for id in range(summary_entry.desc_vector_start, summary_entry.desc_vector_end)]
|
||||
|
||||
self.detailed_store.delete(ids=desc_ids_to_del)
|
||||
db.delete(summary_entry)
|
||||
db.commit()
|
||||
|
||||
return "success"
|
||||
|
||||
|
||||
|
||||
def is_query_answerable(self, query, context):
|
||||
prompt = PromptTemplate(
|
||||
template="""You are a grader assessing relevance of a retrieved document to a user question. \n
|
||||
Here is the retrieved document: \n\n {context} \n\n
|
||||
Here is the user question: {question} \n
|
||||
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
|
||||
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
||||
Only return 'yes' or 'no'""",
|
||||
input_variables=["context", "question"],
|
||||
)
|
||||
|
||||
ans_chain = prompt | self.llm
|
||||
|
||||
finalans = ans_chain.invoke({"question": query, "context": context})
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
return finalans
|
||||
else:
|
||||
return finalans.content
|
||||
|
||||
def local_search(self, query, search_space='GENERAL'):
|
||||
top_summaries_compressor = FlashrankRerank(top_n=5)
|
||||
details_compressor = FlashrankRerank(top_n=30)
|
||||
top_summaries_retreiver = ContextualCompressionRetriever(
|
||||
base_compressor=top_summaries_compressor, base_retriever=self.summary_store.as_retriever(search_kwargs={'filter': {'search_space': search_space}})
|
||||
)
|
||||
|
||||
top_summaries_compressed_docs = top_summaries_retreiver.invoke(query)
|
||||
|
||||
for summary in top_summaries_compressed_docs:
|
||||
# For each summary, retrieve relevant detailed chunks
|
||||
page_number = summary.metadata["page"]
|
||||
|
||||
detailed_compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=details_compressor, base_retriever=self.detailed_store.as_retriever(search_kwargs={'filter': {'page': page_number}})
|
||||
)
|
||||
|
||||
detailed_compressed_docs = detailed_compression_retriever.invoke(
|
||||
query
|
||||
)
|
||||
|
||||
contextdocs = top_summaries_compressed_docs + detailed_compressed_docs
|
||||
|
||||
context_to_answer = ""
|
||||
for i, doc in enumerate(contextdocs):
|
||||
context_to_answer += "DOCUMENT " + str(i) + " PAGECONTENT CHUCK: \n\n ===================================== \n\n" + doc.page_content + '\n\n ==============================================='
|
||||
|
||||
if(self.is_query_answerable(query=query, context=context_to_answer).lower() == 'yes'):
|
||||
ans_chain = CONTEXT_ANSWER_PROMPT | self.llm
|
||||
|
||||
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
return finalans
|
||||
else:
|
||||
return finalans.content
|
||||
else:
|
||||
continue
|
||||
|
||||
return "I couldn't find any answer"
|
||||
|
||||
def global_search(self,query, search_space='GENERAL'):
|
||||
top_summaries_compressor = FlashrankRerank(top_n=20)
|
||||
|
||||
top_summaries_retreiver = ContextualCompressionRetriever(
|
||||
base_compressor=top_summaries_compressor, base_retriever=self.summary_store.as_retriever(search_kwargs={'filter': {'search_space': search_space}})
|
||||
)
|
||||
|
||||
top_summaries_compressed_docs = top_summaries_retreiver.invoke(query)
|
||||
|
||||
context_to_answer = ""
|
||||
for i, doc in enumerate(top_summaries_compressed_docs):
|
||||
context_to_answer += "DOCUMENT " + str(i) + " PAGECONTENT: \n\n ===================================== \n\n" + doc.page_content + '\n\n ==============================================='
|
||||
|
||||
ans_chain = CONTEXT_ANSWER_PROMPT | self.llm
|
||||
|
||||
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
return finalans, top_summaries_compressed_docs
|
||||
else:
|
||||
return finalans.content, top_summaries_compressed_docs
|
||||
|
|
@ -1,832 +0,0 @@
|
|||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.prompts import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
PromptTemplate,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
examples = [
|
||||
{
|
||||
"text": (
|
||||
"Adam is a software engineer in Microsoft since 2009, "
|
||||
"and last year he got an award as the Best Talent"
|
||||
),
|
||||
"head": "Adam",
|
||||
"head_type": "Person",
|
||||
"relation": "WORKS_FOR",
|
||||
"tail": "Microsoft",
|
||||
"tail_type": "Company",
|
||||
},
|
||||
{
|
||||
"text": (
|
||||
"Adam is a software engineer in Microsoft since 2009, "
|
||||
"and last year he got an award as the Best Talent"
|
||||
),
|
||||
"head": "Adam",
|
||||
"head_type": "Person",
|
||||
"relation": "HAS_AWARD",
|
||||
"tail": "Best Talent",
|
||||
"tail_type": "Award",
|
||||
},
|
||||
{
|
||||
"text": (
|
||||
"Microsoft is a tech company that provide "
|
||||
"several products such as Microsoft Word"
|
||||
),
|
||||
"head": "Microsoft Word",
|
||||
"head_type": "Product",
|
||||
"relation": "PRODUCED_BY",
|
||||
"tail": "Microsoft",
|
||||
"tail_type": "Company",
|
||||
},
|
||||
{
|
||||
"text": "Microsoft Word is a lightweight app that accessible offline",
|
||||
"head": "Microsoft Word",
|
||||
"head_type": "Product",
|
||||
"relation": "HAS_CHARACTERISTIC",
|
||||
"tail": "lightweight app",
|
||||
"tail_type": "Characteristic",
|
||||
},
|
||||
{
|
||||
"text": "Microsoft Word is a lightweight app that accessible offline",
|
||||
"head": "Microsoft Word",
|
||||
"head_type": "Product",
|
||||
"relation": "HAS_CHARACTERISTIC",
|
||||
"tail": "accessible offline",
|
||||
"tail_type": "Characteristic",
|
||||
},
|
||||
]
|
||||
|
||||
system_prompt = (
|
||||
"# Knowledge Graph Instructions for GPT-4\n"
|
||||
"## 1. Overview\n"
|
||||
"You are a top-tier algorithm designed for extracting information in structured "
|
||||
"formats to build a knowledge graph.\n"
|
||||
"You are given a text containing a user's browsing history with a unique session ID, including the following details for each webpage in this session: "
|
||||
"Webpage URL visited, its referring URL, duration of time in milliseconds spent on this Webpage URL, date and time of the visit on this Webpage URL , title of the webpage, and webpage content in Markdown format"
|
||||
"Try to capture as much information from the text as possible without "
|
||||
"sacrificing accuracy. Do not add any information that is not explicitly "
|
||||
"mentioned in the text.\n"
|
||||
"- **Nodes** represent entities and concepts.\n"
|
||||
"- The aim is to achieve simplicity and clarity in the knowledge graph, making it\n"
|
||||
"accessible for a vast audience.\n"
|
||||
"## 2. Labeling Nodes\n"
|
||||
"- **Consistency**: Ensure you use available types for node labels.\n"
|
||||
"Ensure you use basic or elementary types for node labels.\n"
|
||||
"- For example, when you identify an entity representing a person, "
|
||||
"always label it as **'person'**. Avoid using more specific terms "
|
||||
"like 'mathematician' or 'scientist'."
|
||||
"- **Node IDs**: Never utilize integers as node IDs. Node IDs should be "
|
||||
"names or human-readable identifiers found in the text.\n"
|
||||
"- **Relationships** represent connections between entities or concepts.\n"
|
||||
"Ensure consistency and generality in relationship types when constructing "
|
||||
"knowledge graphs. Instead of using specific and momentary types "
|
||||
"such as 'BECAME_PROFESSOR', use more general and timeless relationship types "
|
||||
"like 'PROFESSOR'. Make sure to use general and timeless relationship types!\n"
|
||||
"## 3. Coreference Resolution\n"
|
||||
"- **Maintain Entity Consistency**: When extracting entities, it's vital to "
|
||||
"ensure consistency.\n"
|
||||
'If an entity, such as "John Doe", is mentioned multiple times in the text '
|
||||
'but is referred to by different names or pronouns (e.g., "Joe", "he"),'
|
||||
"always use the most complete identifier for that entity throughout the "
|
||||
'knowledge graph. In this example, use "John Doe" as the entity ID.\n'
|
||||
"Remember, the knowledge graph should be coherent and easily understandable, "
|
||||
"so maintaining consistency in entity references is crucial.\n"
|
||||
"## 4. Strict Compliance\n"
|
||||
"Adhere to the rules strictly. Non-compliance will result in termination."
|
||||
)
|
||||
|
||||
default_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
system_prompt,
|
||||
),
|
||||
(
|
||||
"human",
|
||||
(
|
||||
"Note: The information given to you is information about a User's Web Browsing History."
|
||||
"Tip: Make sure to answer in the correct format and do "
|
||||
"not include any explanations. "
|
||||
"Use the given format to extract information from the "
|
||||
"following input: {input}"
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_additional_info(input_type: str) -> str:
|
||||
# Check if the input_type is one of the allowed values
|
||||
if input_type not in ["node", "relationship", "property"]:
|
||||
raise ValueError("input_type must be 'node', 'relationship', or 'property'")
|
||||
|
||||
# Perform actions based on the input_type
|
||||
if input_type == "node":
|
||||
return (
|
||||
"Ensure you use basic or elementary types for node labels.\n"
|
||||
"For example, when you identify an entity representing a person, "
|
||||
"always label it as **'Person'**. Avoid using more specific terms "
|
||||
"like 'Mathematician' or 'Scientist'"
|
||||
)
|
||||
elif input_type == "relationship":
|
||||
return (
|
||||
"Instead of using specific and momentary types such as "
|
||||
"'BECAME_PROFESSOR', use more general and timeless relationship types "
|
||||
"like 'PROFESSOR'. However, do not sacrifice any accuracy for generality"
|
||||
)
|
||||
elif input_type == "property":
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
def optional_enum_field(
|
||||
enum_values: Optional[List[str]] = None,
|
||||
description: str = "",
|
||||
input_type: str = "node",
|
||||
llm_type: Optional[str] = None,
|
||||
**field_kwargs: Any,
|
||||
) -> Any:
|
||||
"""Utility function to conditionally create a field with an enum constraint."""
|
||||
# Only openai supports enum param
|
||||
if enum_values and llm_type == "openai-chat":
|
||||
return Field(
|
||||
...,
|
||||
enum=enum_values,
|
||||
description=f"{description}. Available options are {enum_values}",
|
||||
**field_kwargs,
|
||||
)
|
||||
elif enum_values:
|
||||
return Field(
|
||||
...,
|
||||
description=f"{description}. Available options are {enum_values}",
|
||||
**field_kwargs,
|
||||
)
|
||||
else:
|
||||
additional_info = _get_additional_info(input_type)
|
||||
return Field(..., description=description + additional_info, **field_kwargs)
|
||||
|
||||
|
||||
class _Graph(BaseModel):
|
||||
nodes: Optional[List]
|
||||
relationships: Optional[List]
|
||||
|
||||
|
||||
class UnstructuredRelation(BaseModel):
|
||||
head: str = Field(
|
||||
description=(
|
||||
"extracted head entity like Microsoft, Apple, John. "
|
||||
"Must use human-readable unique identifier."
|
||||
)
|
||||
)
|
||||
head_type: str = Field(
|
||||
description="type of the extracted head entity like Person, Company, etc"
|
||||
)
|
||||
relation: str = Field(description="relation between the head and the tail entities")
|
||||
tail: str = Field(
|
||||
description=(
|
||||
"extracted tail entity like Microsoft, Apple, John. "
|
||||
"Must use human-readable unique identifier."
|
||||
)
|
||||
)
|
||||
tail_type: str = Field(
|
||||
description="type of the extracted tail entity like Person, Company, etc"
|
||||
)
|
||||
|
||||
|
||||
def create_unstructured_prompt(
|
||||
node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None
|
||||
) -> ChatPromptTemplate:
|
||||
node_labels_str = str(node_labels) if node_labels else ""
|
||||
rel_types_str = str(rel_types) if rel_types else ""
|
||||
base_string_parts = [
|
||||
"You are a top-tier algorithm designed for extracting information in "
|
||||
"structured formats to build a knowledge graph. Your task is to identify "
|
||||
"the entities and relations requested with the user prompt from a given "
|
||||
"text. You must generate the output in a JSON format containing a list "
|
||||
'with JSON objects. Each object should have the keys: "head", '
|
||||
'"head_type", "relation", "tail", and "tail_type". The "head" '
|
||||
"key must contain the text of the extracted entity with one of the types "
|
||||
"from the provided list in the user prompt.",
|
||||
f'The "head_type" key must contain the type of the extracted head entity, '
|
||||
f"which must be one of the types from {node_labels_str}."
|
||||
if node_labels
|
||||
else "",
|
||||
f'The "relation" key must contain the type of relation between the "head" '
|
||||
f'and the "tail", which must be one of the relations from {rel_types_str}.'
|
||||
if rel_types
|
||||
else "",
|
||||
f'The "tail" key must represent the text of an extracted entity which is '
|
||||
f'the tail of the relation, and the "tail_type" key must contain the type '
|
||||
f"of the tail entity from {node_labels_str}."
|
||||
if node_labels
|
||||
else "",
|
||||
"Attempt to extract as many entities and relations as you can. Maintain "
|
||||
"Entity Consistency: When extracting entities, it's vital to ensure "
|
||||
'consistency. If an entity, such as "John Doe", is mentioned multiple '
|
||||
"times in the text but is referred to by different names or pronouns "
|
||||
'(e.g., "Joe", "he"), always use the most complete identifier for '
|
||||
"that entity. The knowledge graph should be coherent and easily "
|
||||
"understandable, so maintaining consistency in entity references is "
|
||||
"crucial.",
|
||||
"IMPORTANT NOTES:\n- Don't add any explanation and text.",
|
||||
]
|
||||
system_prompt = "\n".join(filter(None, base_string_parts))
|
||||
|
||||
system_message = SystemMessage(content=system_prompt)
|
||||
parser = JsonOutputParser(pydantic_object=UnstructuredRelation)
|
||||
|
||||
human_string_parts = [
|
||||
"Based on the following example, extract entities and "
|
||||
"relations from the provided text.\n\n",
|
||||
"Use the following entity types, don't use other entity "
|
||||
"that is not defined below:"
|
||||
"# ENTITY TYPES:"
|
||||
"{node_labels}"
|
||||
if node_labels
|
||||
else "",
|
||||
"Use the following relation types, don't use other relation "
|
||||
"that is not defined below:"
|
||||
"# RELATION TYPES:"
|
||||
"{rel_types}"
|
||||
if rel_types
|
||||
else "",
|
||||
"Below are a number of examples of text and their extracted "
|
||||
"entities and relationships."
|
||||
"{examples}\n"
|
||||
"For the following text, extract entities and relations as "
|
||||
"in the provided example."
|
||||
"{format_instructions}\nText: {input}",
|
||||
]
|
||||
human_prompt_string = "\n".join(filter(None, human_string_parts))
|
||||
human_prompt = PromptTemplate(
|
||||
template=human_prompt_string,
|
||||
input_variables=["input"],
|
||||
partial_variables={
|
||||
"format_instructions": parser.get_format_instructions(),
|
||||
"node_labels": node_labels,
|
||||
"rel_types": rel_types,
|
||||
"examples": examples,
|
||||
},
|
||||
)
|
||||
|
||||
human_message_prompt = HumanMessagePromptTemplate(prompt=human_prompt)
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[system_message, human_message_prompt]
|
||||
)
|
||||
return chat_prompt
|
||||
|
||||
|
||||
def create_simple_model(
|
||||
node_labels: Optional[List[str]] = None,
|
||||
rel_types: Optional[List[str]] = None,
|
||||
node_properties: Union[bool, List[str]] = False,
|
||||
llm_type: Optional[str] = None,
|
||||
relationship_properties: Union[bool, List[str]] = False,
|
||||
) -> Type[_Graph]:
|
||||
"""
|
||||
Create a simple graph model with optional constraints on node
|
||||
and relationship types.
|
||||
|
||||
Args:
|
||||
node_labels (Optional[List[str]]): Specifies the allowed node types.
|
||||
Defaults to None, allowing all node types.
|
||||
rel_types (Optional[List[str]]): Specifies the allowed relationship types.
|
||||
Defaults to None, allowing all relationship types.
|
||||
node_properties (Union[bool, List[str]]): Specifies if node properties should
|
||||
be included. If a list is provided, only properties with keys in the list
|
||||
will be included. If True, all properties are included. Defaults to False.
|
||||
relationship_properties (Union[bool, List[str]]): Specifies if relationship
|
||||
properties should be included. If a list is provided, only properties with
|
||||
keys in the list will be included. If True, all properties are included.
|
||||
Defaults to False.
|
||||
llm_type (Optional[str]): The type of the language model. Defaults to None.
|
||||
Only openai supports enum param: openai-chat.
|
||||
|
||||
Returns:
|
||||
Type[_Graph]: A graph model with the specified constraints.
|
||||
|
||||
Raises:
|
||||
ValueError: If 'id' is included in the node or relationship properties list.
|
||||
"""
|
||||
|
||||
node_fields: Dict[str, Tuple[Any, Any]] = {
|
||||
"id": (
|
||||
str,
|
||||
Field(..., description="Name or human-readable unique identifier."),
|
||||
),
|
||||
"type": (
|
||||
str,
|
||||
optional_enum_field(
|
||||
node_labels,
|
||||
description="The type or label of the node.",
|
||||
input_type="node",
|
||||
llm_type=llm_type,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
if node_properties:
|
||||
if isinstance(node_properties, list) and "id" in node_properties:
|
||||
raise ValueError("The node property 'id' is reserved and cannot be used.")
|
||||
# Map True to empty array
|
||||
node_properties_mapped: List[str] = (
|
||||
[] if node_properties is True else node_properties
|
||||
)
|
||||
|
||||
class Property(BaseModel):
|
||||
"""A single property consisting of key and value"""
|
||||
|
||||
key: str = optional_enum_field(
|
||||
node_properties_mapped,
|
||||
description="Property key.",
|
||||
input_type="property",
|
||||
llm_type=llm_type,
|
||||
)
|
||||
value: str = Field(..., description="value")
|
||||
|
||||
node_fields["properties"] = (
|
||||
Optional[List[Property]],
|
||||
Field(None, description="List of node properties"),
|
||||
)
|
||||
SimpleNode = create_model("SimpleNode", **node_fields) # type: ignore
|
||||
|
||||
relationship_fields: Dict[str, Tuple[Any, Any]] = {
|
||||
"source_node_id": (
|
||||
str,
|
||||
Field(
|
||||
...,
|
||||
description="Name or human-readable unique identifier of source node",
|
||||
),
|
||||
),
|
||||
"source_node_type": (
|
||||
str,
|
||||
optional_enum_field(
|
||||
node_labels,
|
||||
description="The type or label of the source node.",
|
||||
input_type="node",
|
||||
llm_type=llm_type,
|
||||
),
|
||||
),
|
||||
"target_node_id": (
|
||||
str,
|
||||
Field(
|
||||
...,
|
||||
description="Name or human-readable unique identifier of target node",
|
||||
),
|
||||
),
|
||||
"target_node_type": (
|
||||
str,
|
||||
optional_enum_field(
|
||||
node_labels,
|
||||
description="The type or label of the target node.",
|
||||
input_type="node",
|
||||
llm_type=llm_type,
|
||||
),
|
||||
),
|
||||
"type": (
|
||||
str,
|
||||
optional_enum_field(
|
||||
rel_types,
|
||||
description="The type of the relationship.",
|
||||
input_type="relationship",
|
||||
llm_type=llm_type,
|
||||
),
|
||||
),
|
||||
}
|
||||
if relationship_properties:
|
||||
if (
|
||||
isinstance(relationship_properties, list)
|
||||
and "id" in relationship_properties
|
||||
):
|
||||
raise ValueError(
|
||||
"The relationship property 'id' is reserved and cannot be used."
|
||||
)
|
||||
# Map True to empty array
|
||||
relationship_properties_mapped: List[str] = (
|
||||
[] if relationship_properties is True else relationship_properties
|
||||
)
|
||||
|
||||
class RelationshipProperty(BaseModel):
|
||||
"""A single property consisting of key and value"""
|
||||
|
||||
key: str = optional_enum_field(
|
||||
relationship_properties_mapped,
|
||||
description="Property key.",
|
||||
input_type="property",
|
||||
llm_type=llm_type,
|
||||
)
|
||||
value: str = Field(..., description="value")
|
||||
|
||||
relationship_fields["properties"] = (
|
||||
Optional[List[RelationshipProperty]],
|
||||
Field(None, description="List of relationship properties"),
|
||||
)
|
||||
SimpleRelationship = create_model("SimpleRelationship", **relationship_fields) # type: ignore
|
||||
|
||||
class DynamicGraph(_Graph):
|
||||
"""Represents a graph document consisting of nodes and relationships."""
|
||||
|
||||
nodes: Optional[List[SimpleNode]] = Field(description="List of nodes") # type: ignore
|
||||
relationships: Optional[List[SimpleRelationship]] = Field( # type: ignore
|
||||
description="List of relationships"
|
||||
)
|
||||
|
||||
return DynamicGraph
|
||||
|
||||
|
||||
def map_to_base_node(node: Any) -> Node:
|
||||
"""Map the SimpleNode to the base Node."""
|
||||
properties = {}
|
||||
if hasattr(node, "properties") and node.properties:
|
||||
for p in node.properties:
|
||||
properties[format_property_key(p.key)] = p.value
|
||||
return Node(id=node.id, type=node.type, properties=properties)
|
||||
|
||||
|
||||
def map_to_base_relationship(rel: Any) -> Relationship:
|
||||
"""Map the SimpleRelationship to the base Relationship."""
|
||||
source = Node(id=rel.source_node_id, type=rel.source_node_type)
|
||||
target = Node(id=rel.target_node_id, type=rel.target_node_type)
|
||||
properties = {}
|
||||
if hasattr(rel, "properties") and rel.properties:
|
||||
for p in rel.properties:
|
||||
properties[format_property_key(p.key)] = p.value
|
||||
return Relationship(
|
||||
source=source, target=target, type=rel.type, properties=properties
|
||||
)
|
||||
|
||||
|
||||
def _parse_and_clean_json(
|
||||
argument_json: Dict[str, Any],
|
||||
) -> Tuple[List[Node], List[Relationship]]:
|
||||
nodes = []
|
||||
for node in argument_json["nodes"]:
|
||||
if not node.get("id"): # Id is mandatory, skip this node
|
||||
continue
|
||||
node_properties = {}
|
||||
if "properties" in node and node["properties"]:
|
||||
for p in node["properties"]:
|
||||
node_properties[format_property_key(p["key"])] = p["value"]
|
||||
nodes.append(
|
||||
Node(
|
||||
id=node["id"],
|
||||
type=node.get("type"),
|
||||
properties=node_properties,
|
||||
)
|
||||
)
|
||||
relationships = []
|
||||
for rel in argument_json["relationships"]:
|
||||
# Mandatory props
|
||||
if (
|
||||
not rel.get("source_node_id")
|
||||
or not rel.get("target_node_id")
|
||||
or not rel.get("type")
|
||||
):
|
||||
continue
|
||||
|
||||
# Node type copying if needed from node list
|
||||
if not rel.get("source_node_type"):
|
||||
try:
|
||||
rel["source_node_type"] = [
|
||||
el.get("type")
|
||||
for el in argument_json["nodes"]
|
||||
if el["id"] == rel["source_node_id"]
|
||||
][0]
|
||||
except IndexError:
|
||||
rel["source_node_type"] = None
|
||||
if not rel.get("target_node_type"):
|
||||
try:
|
||||
rel["target_node_type"] = [
|
||||
el.get("type")
|
||||
for el in argument_json["nodes"]
|
||||
if el["id"] == rel["target_node_id"]
|
||||
][0]
|
||||
except IndexError:
|
||||
rel["target_node_type"] = None
|
||||
|
||||
rel_properties = {}
|
||||
if "properties" in rel and rel["properties"]:
|
||||
for p in rel["properties"]:
|
||||
rel_properties[format_property_key(p["key"])] = p["value"]
|
||||
|
||||
source_node = Node(
|
||||
id=rel["source_node_id"],
|
||||
type=rel["source_node_type"],
|
||||
)
|
||||
target_node = Node(
|
||||
id=rel["target_node_id"],
|
||||
type=rel["target_node_type"],
|
||||
)
|
||||
relationships.append(
|
||||
Relationship(
|
||||
source=source_node,
|
||||
target=target_node,
|
||||
type=rel["type"],
|
||||
properties=rel_properties,
|
||||
)
|
||||
)
|
||||
return nodes, relationships
|
||||
|
||||
|
||||
def _format_nodes(nodes: List[Node]) -> List[Node]:
|
||||
return [
|
||||
Node(
|
||||
id=el.id.title() if isinstance(el.id, str) else el.id,
|
||||
type=el.type.capitalize() # type: ignore[arg-type]
|
||||
if el.type
|
||||
else None, # handle empty strings # type: ignore[arg-type]
|
||||
properties=el.properties,
|
||||
)
|
||||
for el in nodes
|
||||
]
|
||||
|
||||
|
||||
def _format_relationships(rels: List[Relationship]) -> List[Relationship]:
|
||||
return [
|
||||
Relationship(
|
||||
source=_format_nodes([el.source])[0],
|
||||
target=_format_nodes([el.target])[0],
|
||||
type=el.type.replace(" ", "_").upper(),
|
||||
properties=el.properties,
|
||||
)
|
||||
for el in rels
|
||||
]
|
||||
|
||||
|
||||
def format_property_key(s: str) -> str:
|
||||
words = s.split()
|
||||
if not words:
|
||||
return s
|
||||
first_word = words[0].lower()
|
||||
capitalized_words = [word.capitalize() for word in words[1:]]
|
||||
return "".join([first_word] + capitalized_words)
|
||||
|
||||
|
||||
def _convert_to_graph_document(
|
||||
raw_schema: Dict[Any, Any],
|
||||
) -> Tuple[List[Node], List[Relationship]]:
|
||||
# If there are validation errors
|
||||
if not raw_schema["parsed"]:
|
||||
try:
|
||||
try: # OpenAI type response
|
||||
argument_json = json.loads(
|
||||
raw_schema["raw"].additional_kwargs["tool_calls"][0]["function"][
|
||||
"arguments"
|
||||
]
|
||||
)
|
||||
except Exception: # Google type response
|
||||
argument_json = json.loads(
|
||||
raw_schema["raw"].additional_kwargs["function_call"]["arguments"]
|
||||
)
|
||||
|
||||
nodes, relationships = _parse_and_clean_json(argument_json)
|
||||
except Exception: # If we can't parse JSON
|
||||
return ([], [])
|
||||
else: # If there are no validation errors use parsed pydantic object
|
||||
parsed_schema: _Graph = raw_schema["parsed"]
|
||||
nodes = (
|
||||
[map_to_base_node(node) for node in parsed_schema.nodes if node.id]
|
||||
if parsed_schema.nodes
|
||||
else []
|
||||
)
|
||||
|
||||
relationships = (
|
||||
[
|
||||
map_to_base_relationship(rel)
|
||||
for rel in parsed_schema.relationships
|
||||
if rel.type and rel.source_node_id and rel.target_node_id
|
||||
]
|
||||
if parsed_schema.relationships
|
||||
else []
|
||||
)
|
||||
# Title / Capitalize
|
||||
return _format_nodes(nodes), _format_relationships(relationships)
|
||||
|
||||
|
||||
class LLMGraphTransformer:
|
||||
"""Transform documents into graph-based documents using a LLM.
|
||||
|
||||
It allows specifying constraints on the types of nodes and relationships to include
|
||||
in the output graph. The class supports extracting properties for both nodes and
|
||||
relationships.
|
||||
|
||||
Args:
|
||||
llm (BaseLanguageModel): An instance of a language model supporting structured
|
||||
output.
|
||||
allowed_nodes (List[str], optional): Specifies which node types are
|
||||
allowed in the graph. Defaults to an empty list, allowing all node types.
|
||||
allowed_relationships (List[str], optional): Specifies which relationship types
|
||||
are allowed in the graph. Defaults to an empty list, allowing all relationship
|
||||
types.
|
||||
prompt (Optional[ChatPromptTemplate], optional): The prompt to pass to
|
||||
the LLM with additional instructions.
|
||||
strict_mode (bool, optional): Determines whether the transformer should apply
|
||||
filtering to strictly adhere to `allowed_nodes` and `allowed_relationships`.
|
||||
Defaults to True.
|
||||
node_properties (Union[bool, List[str]]): If True, the LLM can extract any
|
||||
node properties from text. Alternatively, a list of valid properties can
|
||||
be provided for the LLM to extract, restricting extraction to those specified.
|
||||
relationship_properties (Union[bool, List[str]]): If True, the LLM can extract
|
||||
any relationship properties from text. Alternatively, a list of valid
|
||||
properties can be provided for the LLM to extract, restricting extraction to
|
||||
those specified.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from langchain_experimental.graph_transformers import LLMGraphTransformer
|
||||
from langchain_core.documents import Document
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
llm=ChatOpenAI(temperature=0)
|
||||
transformer = LLMGraphTransformer(
|
||||
llm=llm,
|
||||
allowed_nodes=["Person", "Organization"])
|
||||
|
||||
doc = Document(page_content="Elon Musk is suing OpenAI")
|
||||
graph_documents = transformer.convert_to_graph_documents([doc])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLanguageModel,
|
||||
allowed_nodes: List[str] = [],
|
||||
allowed_relationships: List[str] = [],
|
||||
prompt: Optional[ChatPromptTemplate] = None,
|
||||
strict_mode: bool = True,
|
||||
node_properties: Union[bool, List[str]] = False,
|
||||
relationship_properties: Union[bool, List[str]] = False,
|
||||
) -> None:
|
||||
self.allowed_nodes = allowed_nodes
|
||||
self.allowed_relationships = allowed_relationships
|
||||
self.strict_mode = strict_mode
|
||||
self._function_call = True
|
||||
# Check if the LLM really supports structured output
|
||||
try:
|
||||
llm.with_structured_output(_Graph)
|
||||
except NotImplementedError:
|
||||
self._function_call = False
|
||||
if not self._function_call:
|
||||
if node_properties or relationship_properties:
|
||||
raise ValueError(
|
||||
"The 'node_properties' and 'relationship_properties' parameters "
|
||||
"cannot be used in combination with a LLM that doesn't support "
|
||||
"native function calling."
|
||||
)
|
||||
try:
|
||||
import json_repair # type: ignore
|
||||
|
||||
self.json_repair = json_repair
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import json_repair python package. "
|
||||
"Please install it with `pip install json-repair`."
|
||||
)
|
||||
prompt = prompt or create_unstructured_prompt(
|
||||
allowed_nodes, allowed_relationships
|
||||
)
|
||||
self.chain = prompt | llm
|
||||
else:
|
||||
# Define chain
|
||||
try:
|
||||
llm_type = llm._llm_type # type: ignore
|
||||
except AttributeError:
|
||||
llm_type = None
|
||||
schema = create_simple_model(
|
||||
allowed_nodes,
|
||||
allowed_relationships,
|
||||
node_properties,
|
||||
llm_type,
|
||||
relationship_properties,
|
||||
)
|
||||
structured_llm = llm.with_structured_output(schema, include_raw=True)
|
||||
prompt = prompt or default_prompt
|
||||
self.chain = prompt | structured_llm
|
||||
|
||||
def process_response(
|
||||
self, document: Document, config: Optional[RunnableConfig] = None
|
||||
) -> GraphDocument:
|
||||
"""
|
||||
Processes a single document, transforming it into a graph document using
|
||||
an LLM based on the model's schema and constraints.
|
||||
"""
|
||||
text = document.page_content
|
||||
raw_schema = self.chain.invoke({"input": text}, config=config)
|
||||
if self._function_call:
|
||||
raw_schema = cast(Dict[Any, Any], raw_schema)
|
||||
nodes, relationships = _convert_to_graph_document(raw_schema)
|
||||
else:
|
||||
nodes_set = set()
|
||||
relationships = []
|
||||
if not isinstance(raw_schema, str):
|
||||
raw_schema = raw_schema.content
|
||||
parsed_json = self.json_repair.loads(raw_schema)
|
||||
for rel in parsed_json:
|
||||
# Nodes need to be deduplicated using a set
|
||||
nodes_set.add((rel["head"], rel["head_type"]))
|
||||
nodes_set.add((rel["tail"], rel["tail_type"]))
|
||||
|
||||
source_node = Node(id=rel["head"], type=rel["head_type"])
|
||||
target_node = Node(id=rel["tail"], type=rel["tail_type"])
|
||||
relationships.append(
|
||||
Relationship(
|
||||
source=source_node, target=target_node, type=rel["relation"]
|
||||
)
|
||||
)
|
||||
# Create nodes list
|
||||
nodes = [Node(id=el[0], type=el[1]) for el in list(nodes_set)]
|
||||
|
||||
# Strict mode filtering
|
||||
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
|
||||
if self.allowed_nodes:
|
||||
lower_allowed_nodes = [el.lower() for el in self.allowed_nodes]
|
||||
nodes = [
|
||||
node for node in nodes if node.type.lower() in lower_allowed_nodes
|
||||
]
|
||||
relationships = [
|
||||
rel
|
||||
for rel in relationships
|
||||
if rel.source.type.lower() in lower_allowed_nodes
|
||||
and rel.target.type.lower() in lower_allowed_nodes
|
||||
]
|
||||
if self.allowed_relationships:
|
||||
relationships = [
|
||||
rel
|
||||
for rel in relationships
|
||||
if rel.type.lower()
|
||||
in [el.lower() for el in self.allowed_relationships]
|
||||
]
|
||||
|
||||
return GraphDocument(nodes=nodes, relationships=relationships, source=document)
|
||||
|
||||
def convert_to_graph_documents(
|
||||
self, documents: Sequence[Document], config: Optional[RunnableConfig] = None
|
||||
) -> List[GraphDocument]:
|
||||
"""Convert a sequence of documents into graph documents.
|
||||
|
||||
Args:
|
||||
documents (Sequence[Document]): The original documents.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Sequence[GraphDocument]: The transformed documents as graphs.
|
||||
"""
|
||||
return [self.process_response(document, config) for document in documents]
|
||||
|
||||
async def aprocess_response(
|
||||
self, document: Document, config: Optional[RunnableConfig] = None
|
||||
) -> GraphDocument:
|
||||
"""
|
||||
Asynchronously processes a single document, transforming it into a
|
||||
graph document.
|
||||
"""
|
||||
text = document.page_content
|
||||
raw_schema = await self.chain.ainvoke({"input": text}, config=config)
|
||||
raw_schema = cast(Dict[Any, Any], raw_schema)
|
||||
nodes, relationships = _convert_to_graph_document(raw_schema)
|
||||
|
||||
if self.strict_mode and (self.allowed_nodes or self.allowed_relationships):
|
||||
if self.allowed_nodes:
|
||||
lower_allowed_nodes = [el.lower() for el in self.allowed_nodes]
|
||||
nodes = [
|
||||
node for node in nodes if node.type.lower() in lower_allowed_nodes
|
||||
]
|
||||
relationships = [
|
||||
rel
|
||||
for rel in relationships
|
||||
if rel.source.type.lower() in lower_allowed_nodes
|
||||
and rel.target.type.lower() in lower_allowed_nodes
|
||||
]
|
||||
if self.allowed_relationships:
|
||||
relationships = [
|
||||
rel
|
||||
for rel in relationships
|
||||
if rel.type.lower()
|
||||
in [el.lower() for el in self.allowed_relationships]
|
||||
]
|
||||
|
||||
return GraphDocument(nodes=nodes, relationships=relationships, source=document)
|
||||
|
||||
async def aconvert_to_graph_documents(
|
||||
self, documents: Sequence[Document], config: Optional[RunnableConfig] = None
|
||||
) -> List[GraphDocument]:
|
||||
"""
|
||||
Asynchronously convert a sequence of documents into graph documents.
|
||||
"""
|
||||
tasks = [
|
||||
asyncio.create_task(self.aprocess_response(document, config))
|
||||
for document in documents
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
27
backend/Utils/stringify.py
Normal file
27
backend/Utils/stringify.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
try:
|
||||
from collections.abc import Mapping, Sequence
|
||||
except ImportError:
|
||||
from collections import Mapping, Sequence
|
||||
|
||||
import json
|
||||
|
||||
COMPACT_SEPARATORS = (',', ':')
|
||||
|
||||
def order_by_key(kv):
|
||||
key, val = kv
|
||||
return key
|
||||
|
||||
def recursive_order(node):
|
||||
if isinstance(node, Mapping):
|
||||
ordered_mapping = OrderedDict(sorted(node.items(), key=order_by_key))
|
||||
for key, value in ordered_mapping.items():
|
||||
ordered_mapping[key] = recursive_order(value)
|
||||
return ordered_mapping
|
||||
elif isinstance(node, Sequence) and not isinstance(node, (str, bytes)):
|
||||
return [recursive_order(item) for item in node]
|
||||
return node
|
||||
|
||||
def stringify(node):
|
||||
return json.dumps(recursive_order(node), separators=COMPACT_SEPARATORS)
|
|
@ -1,9 +1,11 @@
|
|||
from sqlalchemy import create_engine, Column, Integer, String
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from envs import POSTGRES_DATABASE_URL
|
||||
|
||||
POSTGRES_DATABASE_URL = os.environ.get("POSTGRES_DATABASE_URL")
|
||||
|
||||
engine = create_engine(
|
||||
POSTGRES_DATABASE_URL
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from datetime import datetime
|
||||
from typing import List
|
||||
from database import Base, engine
|
||||
from sqlalchemy import Column, ForeignKey, Integer, String, create_engine
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, create_engine
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
class BaseModel(Base):
|
||||
|
@ -10,12 +11,12 @@ class BaseModel(Base):
|
|||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
__tablename__ = "notifications"
|
||||
# class Notification(BaseModel):
|
||||
# __tablename__ = "notifications"
|
||||
|
||||
text = Column(String)
|
||||
user_id = Column(ForeignKey('users.id'))
|
||||
user = relationship('User')
|
||||
# text = Column(String)
|
||||
# user_id = Column(ForeignKey('users.id'))
|
||||
# user = relationship('User')
|
||||
|
||||
|
||||
class Chat(BaseModel):
|
||||
|
@ -24,18 +25,44 @@ class Chat(BaseModel):
|
|||
type = Column(String)
|
||||
title = Column(String)
|
||||
chats_list = Column(String)
|
||||
|
||||
user_id = Column(ForeignKey('users.id'))
|
||||
user = relationship('User')
|
||||
|
||||
|
||||
class Documents(BaseModel):
|
||||
__tablename__ = "documents"
|
||||
|
||||
title = Column(String)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
file_type = Column(String)
|
||||
document_metadata = Column(String)
|
||||
page_content = Column(String)
|
||||
desc_vector_start = Column(Integer, default=0)
|
||||
desc_vector_end = Column(Integer, default=0)
|
||||
|
||||
search_space_id = Column(ForeignKey('searchspaces.id'))
|
||||
search_space = relationship('SearchSpace')
|
||||
|
||||
user_id = Column(ForeignKey('users.id'))
|
||||
user = relationship('User')
|
||||
|
||||
class SearchSpace(BaseModel):
|
||||
__tablename__ = "searchspaces"
|
||||
|
||||
search_space = Column(String, unique=True)
|
||||
|
||||
documents = relationship(Documents)
|
||||
|
||||
class User(BaseModel):
|
||||
__tablename__ = "users"
|
||||
|
||||
username = Column(String, unique=True, index=True)
|
||||
hashed_password = Column(String)
|
||||
graph_config = Column(String)
|
||||
llm_config = Column(String)
|
||||
chats = relationship(Chat)
|
||||
notifications = relationship(Notification)
|
||||
chats = relationship(Chat, order_by="Chat.id")
|
||||
documents = relationship(Documents, order_by="Documents.id")
|
||||
|
||||
|
||||
# Create the database tables if they don't exist
|
||||
User.metadata.create_all(bind=engine)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
|
@ -7,132 +6,51 @@ from datetime import datetime, timezone
|
|||
|
||||
DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n'
|
||||
|
||||
GRAPH_QUERY_GEN_TEMPLATE = DATE_TODAY + """You are a top tier Prompt Engineering Expert.
|
||||
A User's Data is stored in a Knowledge Graph.
|
||||
Your main task is to read the User Question below and give a optimized Question prompt in Natural Language.
|
||||
Question prompt will be used by a LLM to easlily get data from Knowledge Graph's.
|
||||
# Create a prompt template for sub-query decomposition
|
||||
SUBQUERY_DECOMPOSITION_TEMPLATE = DATE_TODAY + """You are an AI assistant tasked with breaking down complex queries into simpler sub-queries for a vector store.
|
||||
Given the original query, decompose it into 2-4 simpler sub-queries for vector search that helps in expanding context.
|
||||
|
||||
Make sure to only return the promt text thats it. Never change the meaning of users question.
|
||||
Original query: {original_query}
|
||||
|
||||
Here are the examples of the User's Data Documents that is stored in Knowledge Graph:
|
||||
{context}
|
||||
IMPORTANT INSTRUCTION: Make sure to only return sub-queries and no explanation.
|
||||
|
||||
Note: Do not include any explanations or apologies in your responses.
|
||||
Do not include any text except the generated promt text.
|
||||
EXAMPLE:
|
||||
|
||||
Question: {question}
|
||||
Prompt For Cypher Query Construction:"""
|
||||
User Query: What are the impacts of climate change on the environment?
|
||||
|
||||
GRAPH_QUERY_GEN_PROMPT = PromptTemplate(
|
||||
input_variables=["context", "question"], template=GRAPH_QUERY_GEN_TEMPLATE
|
||||
)
|
||||
|
||||
CYPHER_QA_TEMPLATE = DATE_TODAY + """You are an assistant that helps to form nice and human understandable answers.
|
||||
The information part contains the provided information that you must use to construct an answer.
|
||||
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
|
||||
Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
|
||||
Only give the answer if it satisfies the user requirements in Question. Else return exactly 'don't know' as answer.
|
||||
|
||||
Here are the examples:
|
||||
|
||||
Question: What type of general topics I explore the most?
|
||||
Context:[['Topic': 'Langchain', 'topicCount': 5], ['Topic': 'Graphrag', 'topicCount': 2], ['Topic': 'Ai', 'topicCount': 2], ['Topic': 'Fastapi', 'topicCount': 2], ['Topic': 'Nextjs', 'topicCount': 1]]
|
||||
Helpful Answer: You mostly explore about Langchain, Graphrag, Ai, Fastapi and Nextjs.
|
||||
|
||||
Follow this example when generating answers.
|
||||
If the provided information is empty or incomplete, return exactly 'don't know' as answer.
|
||||
|
||||
Information:
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
Helpful Answer:"""
|
||||
|
||||
CYPHER_QA_PROMPT = PromptTemplate(
|
||||
input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE
|
||||
)
|
||||
|
||||
SIMILARITY_SEARCH_RAG = DATE_TODAY + """You are an assistant for question-answering tasks.
|
||||
Use the following pieces of retrieved context to answer the question.
|
||||
If you don't know the answer, return exactly 'don't know' as answer.
|
||||
Question: {question}
|
||||
Context: {context}
|
||||
Answer:"""
|
||||
|
||||
|
||||
SIMILARITY_SEARCH_PROMPT = PromptTemplate(
|
||||
input_variables=["context", "question"], template=SIMILARITY_SEARCH_RAG
|
||||
)
|
||||
|
||||
# doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
|
||||
|
||||
|
||||
CYPHER_GENERATION_TEMPLATE = DATE_TODAY + """Task:Generate Cypher statement to query a graph database.
|
||||
Instructions:
|
||||
Use only the provided relationship types and properties in the schema.
|
||||
Do not use any other relationship types or properties that are not provided.
|
||||
|
||||
Schema:
|
||||
{schema}
|
||||
Note: Do not include any explanations or apologies in your responses.
|
||||
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
|
||||
Do not include any text except the generated Cypher statement.
|
||||
|
||||
The question is:
|
||||
{question}"""
|
||||
CYPHER_GENERATION_PROMPT = PromptTemplate(
|
||||
input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
|
||||
)
|
||||
|
||||
|
||||
DOC_DESCRIPTION_TEMPLATE = DATE_TODAY + """Task:Give Detailed Description of the page content of the given document.
|
||||
Instructions:
|
||||
Provide as much details about metadata & page content as if you need to give human readable report of this Browsing session event.
|
||||
|
||||
Document:
|
||||
{document}
|
||||
AI Answer:
|
||||
What are the impacts of climate change on biodiversity?
|
||||
How does climate change affect the oceans?
|
||||
What are the effects of climate change on agriculture?
|
||||
What are the impacts of climate change on human health?
|
||||
"""
|
||||
DOC_DESCRIPTION_PROMPT = PromptTemplate(
|
||||
input_variables=["document"], template=DOC_DESCRIPTION_TEMPLATE
|
||||
|
||||
# SUBQUERY_DECOMPOSITION_TEMPLATE_TWO = DATE_TODAY + """You are an AI language model assistant. Your task is to generate five
|
||||
# different versions of the given user question to retrieve relevant documents from a vector
|
||||
# database. By generating multiple perspectives on the user question, your goal is to help
|
||||
# the user overcome some of the limitations of the distance-based similarity search.
|
||||
# Provide these alternative questions separated by newlines.
|
||||
# Original question: {original_query}"""
|
||||
|
||||
|
||||
SUBQUERY_DECOMPOSITION_PROMT = PromptTemplate(
|
||||
input_variables=["original_query"],
|
||||
template=SUBQUERY_DECOMPOSITION_TEMPLATE
|
||||
)
|
||||
|
||||
CONTEXT_ANSWER_TEMPLATE = DATE_TODAY + """You are a phd in english litrature. You are given the task to give detailed report and explanation to the user query based on the given context.
|
||||
|
||||
DOCUMENT_METADATA_EXTRACTION_SYSTEM_MESSAGE = DATE_TODAY + """You are a helpful assistant. You are given a Cypher statement result after quering the Neo4j graph database.
|
||||
Generate a very good Query that can be used to perform similarity search on the vector store of the Neo4j graph database"""
|
||||
IMPORTANT INSTRUCTION: Only return answer if you can find it in given context otherwise just say you don't know.
|
||||
|
||||
DOCUMENT_METADATA_EXTRACTION_PROMT = ChatPromptTemplate.from_messages([("system", DOCUMENT_METADATA_EXTRACTION_SYSTEM_MESSAGE), ("human", "{input}")])
|
||||
|
||||
|
||||
|
||||
VECTOR_QUERY_GENERATION_TEMPLATE = DATE_TODAY + """You are a helpful assistant. You are given a user query and the examples of document on which user is asking query about.
|
||||
Give instruction to machine how to search for the data based on user query.
|
||||
|
||||
Document Examples:
|
||||
{examples}
|
||||
|
||||
Note: Only return the Query and nothing else. No explanation.
|
||||
Context: {context}
|
||||
|
||||
User Query: {query}
|
||||
Helpful Answer:"""
|
||||
|
||||
VECTOR_QUERY_GENERATION_PROMT = PromptTemplate(
|
||||
input_variables=["examples", "query"], template=VECTOR_QUERY_GENERATION_TEMPLATE
|
||||
)
|
||||
Detailed Report:"""
|
||||
|
||||
|
||||
NOTIFICATION_GENERATION_TEMPLATE = """You are a highly attentive assistant. You are provided with a collection of User Browsing History Events containing page content. Your task is to thoroughly analyze these events and generate a concise list of critical notifications that the User must be aware of.
|
||||
|
||||
User Browsing History Events Documents:
|
||||
{documents}
|
||||
|
||||
Instructions:
|
||||
Return only the notification text, and nothing else.
|
||||
Exclude any notifications that are not essential.
|
||||
|
||||
Response:"""
|
||||
|
||||
NOTIFICATION_GENERATION_PROMT = PromptTemplate(
|
||||
input_variables=["documents"], template=NOTIFICATION_GENERATION_TEMPLATE
|
||||
CONTEXT_ANSWER_PROMPT = PromptTemplate(
|
||||
input_variables=["context","query"],
|
||||
template=CONTEXT_ANSWER_TEMPLATE
|
||||
)
|
||||
|
||||
|
||||
|
@ -142,3 +60,5 @@ NOTIFICATION_GENERATION_PROMT = PromptTemplate(
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
apisecretkey: str
|
||||
|
||||
class DocMeta(BaseModel):
|
||||
BrowsingSessionId: Optional[str] = Field(default=None, description="BrowsingSessionId of Document")
|
||||
VisitedWebPageURL: Optional[str] = Field(default=None, description="VisitedWebPageURL of Document")
|
||||
|
@ -8,77 +13,53 @@ class DocMeta(BaseModel):
|
|||
VisitedWebPageDateWithTimeInISOString: Optional[str] = Field(default=None, description="VisitedWebPageDateWithTimeInISOString of Document")
|
||||
VisitedWebPageReffererURL: Optional[str] = Field(default=None, description="VisitedWebPageReffererURL of Document")
|
||||
VisitedWebPageVisitDurationInMilliseconds: Optional[int] = Field(default=None, description="VisitedWebPageVisitDurationInMilliseconds of Document"),
|
||||
|
||||
class DocWithContent(BaseModel):
|
||||
BrowsingSessionId: Optional[str] = Field(default=None, description="BrowsingSessionId of Document")
|
||||
VisitedWebPageURL: Optional[str] = Field(default=None, description="VisitedWebPageURL of Document")
|
||||
VisitedWebPageTitle: Optional[str] = Field(default=None, description="VisitedWebPageTitle of Document")
|
||||
VisitedWebPageDateWithTimeInISOString: Optional[str] = Field(default=None, description="VisitedWebPageDateWithTimeInISOString of Document")
|
||||
VisitedWebPageReffererURL: Optional[str] = Field(default=None, description="VisitedWebPageReffererURL of Document")
|
||||
VisitedWebPageVisitDurationInMilliseconds: Optional[int] = Field(default=None, description="VisitedWebPageVisitDurationInMilliseconds of Document"),
|
||||
VisitedWebPageContent: Optional[str] = Field(default=None, description="Visited WebPage Content in markdown of Document")
|
||||
|
||||
class PrecisionQuery(BaseModel):
|
||||
sessionid: Optional[str] = Field(default=None)
|
||||
webpageurl: Optional[str] = Field(default=None)
|
||||
daterange: Optional[List[str]]
|
||||
timerange: Optional[List[int]]
|
||||
neourl: str
|
||||
neouser: str
|
||||
neopass: str
|
||||
|
||||
class DocumentsToDelete(BaseModel):
|
||||
ids_to_delete: List[str]
|
||||
openaikey: str
|
||||
apisecretkey: str
|
||||
|
||||
class PrecisionResponse(BaseModel):
|
||||
documents: List[DocMeta]
|
||||
|
||||
|
||||
token: str
|
||||
|
||||
class UserQuery(BaseModel):
|
||||
query: str
|
||||
neourl: str
|
||||
neouser: str
|
||||
neopass: str
|
||||
search_space: str
|
||||
openaikey: str
|
||||
apisecretkey: str
|
||||
token: str
|
||||
|
||||
class ChatHistory(BaseModel):
|
||||
type: str
|
||||
content: str | List[DocMeta]
|
||||
content: str | List[DocMeta] | List[str]
|
||||
|
||||
class UserQueryWithChatHistory(BaseModel):
|
||||
chat: List[ChatHistory]
|
||||
query: str
|
||||
neourl: str
|
||||
neouser: str
|
||||
neopass: str
|
||||
openaikey: str
|
||||
apisecretkey: str
|
||||
|
||||
token: str
|
||||
|
||||
class DescriptionResponse(BaseModel):
|
||||
response: str
|
||||
|
||||
|
||||
class RetrivedDocListItem(BaseModel):
|
||||
metadata: DocMeta
|
||||
pageContent: str
|
||||
|
||||
class RetrivedDocList(BaseModel):
|
||||
documents: List[RetrivedDocListItem]
|
||||
neourl: str
|
||||
neouser: str
|
||||
neopass: str
|
||||
search_space: str | None
|
||||
openaikey: str
|
||||
token: str
|
||||
|
||||
|
||||
class UserQueryResponse(BaseModel):
|
||||
response: str
|
||||
relateddocs: List[DocMeta]
|
||||
|
||||
|
||||
class VectorSearchQuery(BaseModel):
|
||||
searchquery: str
|
||||
|
||||
|
||||
class NewUserData(BaseModel):
|
||||
token: str
|
||||
userid: str
|
||||
chats: str
|
||||
notifications: str
|
||||
relateddocs: List[DocWithContent]
|
||||
|
||||
class NewUserChat(BaseModel):
|
||||
token: str
|
||||
|
@ -86,19 +67,7 @@ class NewUserChat(BaseModel):
|
|||
title: str
|
||||
chats_list: str
|
||||
|
||||
|
||||
class ChatToUpdate(BaseModel):
|
||||
chatid: str
|
||||
token: str
|
||||
# type: str
|
||||
# title: str
|
||||
chats_list: str
|
||||
|
||||
class GraphDocs(BaseModel):
|
||||
documents: List[RetrivedDocListItem]
|
||||
token: str
|
||||
|
||||
|
||||
class Notifications(BaseModel):
|
||||
notifications: List[str]
|
||||
|
||||
chats_list: str
|
|
@ -1,3 +1,5 @@
|
|||
python-dotenv
|
||||
pydantic
|
||||
bcrypt
|
||||
cryptography
|
||||
fastapi
|
||||
|
@ -11,5 +13,7 @@ langchain-core
|
|||
langchain-community
|
||||
langchain-experimental
|
||||
langchain_openai
|
||||
psycopg2
|
||||
neo4j
|
||||
langchain_ollama
|
||||
langchain_chroma
|
||||
flashrank
|
||||
psycopg2
|
|
@ -1,42 +1,40 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from langchain.chains import GraphCypherQAChain
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.documents import Document
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_community.vectorstores import Neo4jVector
|
||||
from envs import ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, API_SECRET_KEY, SECRET_KEY
|
||||
from prompts import CYPHER_QA_PROMPT, DATE_TODAY, DOC_DESCRIPTION_PROMPT, GRAPH_QUERY_GEN_PROMPT, NOTIFICATION_GENERATION_PROMT, SIMILARITY_SEARCH_PROMPT , CYPHER_GENERATION_PROMPT, DOCUMENT_METADATA_EXTRACTION_PROMT
|
||||
from pydmodels import ChatToUpdate, DescriptionResponse, GraphDocs, NewUserChat, NewUserData, Notifications, PrecisionQuery, PrecisionResponse, UserQuery, DocMeta, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory, VectorSearchQuery
|
||||
from langchain_experimental.text_splitter import SemanticChunker
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_openai import ChatOpenAI
|
||||
from prompts import CONTEXT_ANSWER_PROMPT, DATE_TODAY, SUBQUERY_DECOMPOSITION_PROMT
|
||||
from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
|
||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
||||
|
||||
#Our Imps
|
||||
from LLMGraphTransformer import LLMGraphTransformer
|
||||
from langchain_openai import ChatOpenAI
|
||||
from DataExample import examples
|
||||
# import nest_asyncio
|
||||
# from langchain_community.chains.graph_qa.gremlin import GremlinQAChain
|
||||
# from langchain_community.graphs import GremlinGraph
|
||||
# from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship
|
||||
# from langchain_core.documents import Document
|
||||
# from langchain_openai import AzureChatOpenAI
|
||||
#Heirerical Indices class
|
||||
from HIndices import HIndices
|
||||
|
||||
from Utils.stringify import stringify
|
||||
|
||||
# Auth Libs
|
||||
from fastapi import FastAPI, Depends, HTTPException, Request, status
|
||||
from fastapi import FastAPI, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import JWTError, jwt
|
||||
from datetime import datetime, timedelta
|
||||
from passlib.context import CryptContext
|
||||
from models import Chat, Notification, User
|
||||
from database import SessionLocal, engine
|
||||
from pydantic import BaseModel
|
||||
from models import Chat, Documents, SearchSpace, User
|
||||
from database import SessionLocal
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
|
||||
ALGORITHM = os.environ.get("ALGORITHM")
|
||||
API_SECRET_KEY = os.environ.get("API_SECRET_KEY")
|
||||
SECRET_KEY = os.environ.get("SECRET_KEY")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Dependency
|
||||
|
@ -47,251 +45,117 @@ def get_db():
|
|||
finally:
|
||||
db.close()
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
apisecretkey: str
|
||||
|
||||
|
||||
|
||||
# General GraphCypherQAChain
|
||||
@app.post("/")
|
||||
@app.post("/chat/")
|
||||
def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
|
||||
|
||||
if(data.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=data.openaikey
|
||||
)
|
||||
|
||||
# Query Expansion
|
||||
# searchchain = GRAPH_QUERY_GEN_PROMPT | llm
|
||||
|
||||
# qry = searchchain.invoke({"question": data.query, "context": examples})
|
||||
|
||||
query = data.query #qry.content
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
api_key=data.openaikey,
|
||||
)
|
||||
|
||||
|
||||
chain = GraphCypherQAChain.from_llm(
|
||||
graph=graph,
|
||||
cypher_prompt=CYPHER_GENERATION_PROMPT,
|
||||
cypher_llm=llm,
|
||||
verbose=True,
|
||||
validate_cypher=True,
|
||||
qa_prompt=CYPHER_QA_PROMPT ,
|
||||
qa_llm=llm,
|
||||
return_intermediate_steps=True,
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
vector_index = Neo4jVector.from_existing_graph(
|
||||
embeddings,
|
||||
graph=graph,
|
||||
search_type="hybrid",
|
||||
node_label="Document",
|
||||
text_node_properties=["text"],
|
||||
embedding_node_property="embedding",
|
||||
)
|
||||
|
||||
graphdocs = vector_index.similarity_search(data.query,k=15)
|
||||
docsDict = {}
|
||||
|
||||
for d in graphdocs:
|
||||
if d.metadata['BrowsingSessionId'] not in docsDict:
|
||||
newVal = d.metadata.copy()
|
||||
newVal['VisitedWebPageContent'] = d.page_content
|
||||
docsDict[d.metadata['BrowsingSessionId']] = newVal
|
||||
else:
|
||||
docsDict[d.metadata['BrowsingSessionId']]['VisitedWebPageContent'] += d.page_content
|
||||
|
||||
docstoreturn = []
|
||||
|
||||
for x in docsDict.values():
|
||||
docstoreturn.append(DocMeta(
|
||||
BrowsingSessionId=x['BrowsingSessionId'],
|
||||
VisitedWebPageURL=x['VisitedWebPageURL'],
|
||||
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
VisitedWebPageTitle=x['VisitedWebPageTitle'],
|
||||
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
|
||||
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
|
||||
VisitedWebPageContent=x['VisitedWebPageContent']
|
||||
))
|
||||
|
||||
|
||||
try:
|
||||
responsegrp = chain.invoke({"query": query})
|
||||
|
||||
if "don't know" in responsegrp["result"]:
|
||||
raise Exception("No response from graph")
|
||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
structured_llm = llm.with_structured_output(VectorSearchQuery)
|
||||
doc_extract_chain = DOCUMENT_METADATA_EXTRACTION_PROMT | structured_llm
|
||||
query = data.query
|
||||
search_space = data.search_space
|
||||
|
||||
newquery = doc_extract_chain.invoke(responsegrp["intermediate_steps"][1]["context"])
|
||||
|
||||
graphdocs = vector_index.similarity_search(newquery.searchquery,k=15)
|
||||
|
||||
docsDict = {}
|
||||
|
||||
for d in graphdocs:
|
||||
if d.metadata['BrowsingSessionId'] not in docsDict:
|
||||
newVal = d.metadata.copy()
|
||||
newVal['VisitedWebPageContent'] = d.page_content
|
||||
docsDict[d.metadata['BrowsingSessionId']] = newVal
|
||||
else:
|
||||
docsDict[d.metadata['BrowsingSessionId']]['VisitedWebPageContent'] += d.page_content
|
||||
|
||||
docstoreturn = []
|
||||
|
||||
for x in docsDict.values():
|
||||
docstoreturn.append(DocMeta(
|
||||
BrowsingSessionId=x['BrowsingSessionId'],
|
||||
VisitedWebPageURL=x['VisitedWebPageURL'],
|
||||
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
VisitedWebPageTitle=x['VisitedWebPageTitle'],
|
||||
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
|
||||
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
|
||||
VisitedWebPageContent=x['VisitedWebPageContent']
|
||||
))
|
||||
|
||||
return UserQueryResponse(relateddocs=docstoreturn,response=responsegrp["result"])
|
||||
except:
|
||||
# Fallback to Similarity Search RAG
|
||||
searchchain = SIMILARITY_SEARCH_PROMPT | llm
|
||||
|
||||
response = searchchain.invoke({"question": data.query, "context": docstoreturn})
|
||||
|
||||
return UserQueryResponse(relateddocs=docstoreturn,response=response.content)
|
||||
|
||||
|
||||
# Precision Search
|
||||
@app.post("/precision")
|
||||
def get_precision_search_response(data: PrecisionQuery, response_model=PrecisionResponse):
|
||||
if(data.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
graph = Neo4jGraph(url=data.neourl, username=data.neouser, password=data.neopass)
|
||||
|
||||
GRAPH_QUERY = "MATCH (d:Document) WHERE d.VisitedWebPageDateWithTimeInISOString >= " + "'" + data.daterange[0] + "'" + " AND d.VisitedWebPageDateWithTimeInISOString <= " + "'" + data.daterange[1] + "'"
|
||||
|
||||
if(data.timerange[0] >= data.timerange[1]):
|
||||
GRAPH_QUERY += " AND d.VisitedWebPageVisitDurationInMilliseconds >= 0"
|
||||
else:
|
||||
GRAPH_QUERY += " AND d.VisitedWebPageVisitDurationInMilliseconds >= "+ str(data.timerange[0]) + " AND d.VisitedWebPageVisitDurationInMilliseconds <= " + str(data.timerange[1])
|
||||
|
||||
if(data.webpageurl):
|
||||
GRAPH_QUERY += " AND d.VisitedWebPageURL CONTAINS " + "'" + data.webpageurl.lower() + "'"
|
||||
|
||||
if(data.sessionid):
|
||||
GRAPH_QUERY += " AND d.BrowsingSessionId = " + "'" + data.sessionid + "'"
|
||||
|
||||
GRAPH_QUERY += " RETURN d;"
|
||||
|
||||
graphdocs = graph.query(GRAPH_QUERY)
|
||||
|
||||
docsDict = {}
|
||||
|
||||
for d in graphdocs:
|
||||
if d['d']['VisitedWebPageVisitDurationInMilliseconds'] not in docsDict:
|
||||
docsDict[d['d']['VisitedWebPageVisitDurationInMilliseconds']] = d['d']
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
sub_query_llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
qa_llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
else:
|
||||
docsDict[d['d']['VisitedWebPageVisitDurationInMilliseconds']]['text'] += d['d']['text']
|
||||
|
||||
docs = []
|
||||
|
||||
for x in docsDict.values():
|
||||
docs.append(DocMeta(
|
||||
BrowsingSessionId=x['BrowsingSessionId'],
|
||||
VisitedWebPageURL=x['VisitedWebPageURL'],
|
||||
VisitedWebPageVisitDurationInMilliseconds=x['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
VisitedWebPageTitle=x['VisitedWebPageTitle'],
|
||||
VisitedWebPageReffererURL=x['VisitedWebPageReffererURL'],
|
||||
VisitedWebPageDateWithTimeInISOString=x['VisitedWebPageDateWithTimeInISOString'],
|
||||
VisitedWebPageContent=x['text']
|
||||
))
|
||||
|
||||
return PrecisionResponse(documents=docs)
|
||||
sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||
qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||
|
||||
|
||||
# Multi DOC Chat
|
||||
@app.post("/chat/docs")
|
||||
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
|
||||
if(data.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=data.openaikey
|
||||
)
|
||||
|
||||
chatHistory = []
|
||||
|
||||
for chat in data.chat:
|
||||
if(chat.type == 'system'):
|
||||
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
|
||||
Use the following pieces of retrieved context to answer the question.
|
||||
If you don't know the answer, just say that you don't know.
|
||||
Context:""" + str(chat.content)))
|
||||
|
||||
if(chat.type == 'ai'):
|
||||
chatHistory.append(AIMessage(content=chat.content))
|
||||
|
||||
if(chat.type == 'human'):
|
||||
chatHistory.append(HumanMessage(content=chat.content))
|
||||
|
||||
chatHistory.append(("human", "{input}"));
|
||||
|
||||
|
||||
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
|
||||
# Create an LLMChain for sub-query decomposition
|
||||
subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm
|
||||
|
||||
descriptionchain = qa_prompt | llm
|
||||
def decompose_query(original_query: str):
|
||||
"""
|
||||
Decompose the original query into simpler sub-queries.
|
||||
|
||||
Args:
|
||||
original_query (str): The original complex query
|
||||
|
||||
Returns:
|
||||
List[str]: A list of simpler sub-queries
|
||||
"""
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
response = subquery_decomposer_chain.invoke(original_query)
|
||||
else:
|
||||
response = subquery_decomposer_chain.invoke(original_query).content
|
||||
|
||||
sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')]
|
||||
return sub_queries
|
||||
|
||||
response = descriptionchain.invoke({"input": data.query})
|
||||
|
||||
return DescriptionResponse(response=response.content)
|
||||
|
||||
# Create Heirarical Indecices
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username,api_key=data.openaikey)
|
||||
|
||||
|
||||
|
||||
# For Those Who Want HyDe Questions
|
||||
# sub_queries = decompose_query(query)
|
||||
|
||||
sub_queries = []
|
||||
sub_queries.append(query)
|
||||
|
||||
|
||||
# DOC DESCRIPTION
|
||||
@app.post("/kb/doc")
|
||||
def get_doc_description(data: UserQuery, response_model=DescriptionResponse):
|
||||
if(data.apisecretkey != API_SECRET_KEY):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
document = data.query
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=data.openaikey
|
||||
)
|
||||
|
||||
descriptionchain = DOC_DESCRIPTION_PROMPT | llm
|
||||
|
||||
response = descriptionchain.invoke({"document": document})
|
||||
|
||||
return DescriptionResponse(response=response.content)
|
||||
|
||||
duplicate_related_summary_docs = []
|
||||
context_to_answer = ""
|
||||
for sub_query in sub_queries:
|
||||
localreturn = index.local_search(query=sub_query, search_space=search_space)
|
||||
globalreturn, related_summary_docs = index.global_search(query=sub_query, search_space=search_space)
|
||||
|
||||
context_to_answer += localreturn + "\n\n" + globalreturn
|
||||
|
||||
# SAVE DOCS TO GRAPH DB
|
||||
@app.post("/kb/")
|
||||
def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
||||
duplicate_related_summary_docs.extend(related_summary_docs)
|
||||
|
||||
|
||||
combined_docs_seen_metadata = set()
|
||||
combined_docs_unique_documents = []
|
||||
|
||||
for doc in duplicate_related_summary_docs:
|
||||
# Convert metadata to a tuple of its items (this allows it to be added to a set)
|
||||
doc.metadata['relevance_score'] = 0.0
|
||||
metadata_tuple = tuple(sorted(doc.metadata.items()))
|
||||
if metadata_tuple not in combined_docs_seen_metadata:
|
||||
combined_docs_seen_metadata.add(metadata_tuple)
|
||||
combined_docs_unique_documents.append(doc)
|
||||
|
||||
returnDocs = []
|
||||
for doc in combined_docs_unique_documents:
|
||||
entry = DocWithContent(
|
||||
BrowsingSessionId=doc.metadata['BrowsingSessionId'],
|
||||
VisitedWebPageURL=doc.metadata['VisitedWebPageURL'],
|
||||
VisitedWebPageContent=doc.page_content,
|
||||
VisitedWebPageTitle=doc.metadata['VisitedWebPageTitle'],
|
||||
VisitedWebPageDateWithTimeInISOString=doc.metadata['VisitedWebPageDateWithTimeInISOString'],
|
||||
VisitedWebPageReffererURL=doc.metadata['VisitedWebPageReffererURL'],
|
||||
VisitedWebPageVisitDurationInMilliseconds=doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
|
||||
)
|
||||
|
||||
returnDocs.append(entry)
|
||||
|
||||
|
||||
ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm
|
||||
|
||||
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
|
||||
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
|
||||
else:
|
||||
return UserQueryResponse(response=finalans.content, relateddocs=returnDocs)
|
||||
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
# SAVE DOCS
|
||||
@app.post("/save/")
|
||||
def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
||||
|
||||
try:
|
||||
payload = jwt.decode(apires.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
@ -299,58 +163,49 @@ def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
|||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
print("STARTED")
|
||||
# print(apires)
|
||||
graph = Neo4jGraph(url=apires.neourl, username=apires.neouser, password=apires.neopass)
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
api_key=apires.openaikey
|
||||
)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
api_key=apires.openaikey,
|
||||
)
|
||||
|
||||
llm_transformer = LLMGraphTransformer(llm=llm)
|
||||
|
||||
DocumentPgEntry = []
|
||||
raw_documents = []
|
||||
|
||||
searchspace = db.query(SearchSpace).filter(SearchSpace.search_space == apires.search_space).first()
|
||||
|
||||
for doc in apires.documents:
|
||||
raw_documents.append(Document(page_content=doc.pageContent, metadata=doc.metadata))
|
||||
|
||||
text_splitter = SemanticChunker(embeddings=embeddings)
|
||||
|
||||
documents = text_splitter.split_documents(raw_documents)
|
||||
graph_documents = llm_transformer.convert_to_graph_documents(documents)
|
||||
|
||||
|
||||
graph.add_graph_documents(
|
||||
graph_documents,
|
||||
baseEntityLabel=True,
|
||||
include_source=True
|
||||
)
|
||||
|
||||
|
||||
structured_llm = llm.with_structured_output(Notifications)
|
||||
notifs_extraction_chain = NOTIFICATION_GENERATION_PROMT | structured_llm
|
||||
|
||||
notifications = notifs_extraction_chain.invoke({"documents": raw_documents})
|
||||
|
||||
notifsdb = []
|
||||
|
||||
for text in notifications.notifications:
|
||||
notifsdb.append(Notification(text=text))
|
||||
content = f"USER BROWSING SESSION EVENT: \n"
|
||||
content += f"=======================================METADATA==================================== \n"
|
||||
content += f"User Browsing Session ID : {doc.metadata.BrowsingSessionId} \n"
|
||||
content += f"User Visited website with url : {doc.metadata.VisitedWebPageURL} \n"
|
||||
content += f"This visited website url had title : {doc.metadata.VisitedWebPageTitle} \n"
|
||||
content += f"User Visited this website from reffering url : {doc.metadata.VisitedWebPageReffererURL} \n"
|
||||
content += f"User Visited this website url at this Date and Time : {doc.metadata.VisitedWebPageDateWithTimeInISOString} \n"
|
||||
content += f"User Visited this website for : {str(doc.metadata.VisitedWebPageVisitDurationInMilliseconds)} milliseconds. \n"
|
||||
content += f"===================================================================================== \n"
|
||||
content += f"Webpage Content of the visited webpage url in markdown format : \n\n {doc.pageContent} \n\n"
|
||||
content += f"===================================================================================== \n"
|
||||
raw_documents.append(Document(page_content=content,metadata=doc.metadata.__dict__))
|
||||
|
||||
pgdocmeta = stringify(doc.metadata.__dict__)
|
||||
|
||||
if(searchspace):
|
||||
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=searchspace, document_metadata=pgdocmeta, page_content=content))
|
||||
else:
|
||||
DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=SearchSpace(search_space=apires.search_space.upper()), document_metadata=pgdocmeta, page_content=content))
|
||||
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
user.notifications.extend(notifsdb)
|
||||
|
||||
db.commit()
|
||||
#Save docs in PG
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
user.documents.extend(DocumentPgEntry)
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
# Create Heirarical Indecices
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username,api_key=apires.openaikey)
|
||||
|
||||
#Save Indices in vector Stores
|
||||
index.encode_docs_hierarchical(documents=raw_documents, files_type='WEBPAGE',search_space=apires.search_space.upper(), db=db)
|
||||
|
||||
print("FINISHED")
|
||||
|
||||
|
@ -360,21 +215,83 @@ def populate_graph(apires: RetrivedDocList, db: Session = Depends(get_db)):
|
|||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
#Fuction to populate db ( Comment out when running on server )
|
||||
@app.post("/add/")
|
||||
def add_user(user: UserCreate, db: Session = Depends(get_db)):
|
||||
db_user = User(username=user.username, hashed_password=user.password, graph_config="", llm_config="")
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
return "Success"
|
||||
|
||||
# Multi DOC Chat
|
||||
@app.post("/chat/docs")
|
||||
def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=DescriptionResponse):
|
||||
try:
|
||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
llm = OllamaLLM(model="mistral-nemo",temperature=0)
|
||||
else:
|
||||
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
|
||||
|
||||
chatHistory = []
|
||||
|
||||
for chat in data.chat:
|
||||
if(chat.type == 'system'):
|
||||
chatHistory.append(SystemMessage(content=DATE_TODAY + """You are an helpful assistant for question-answering tasks.
|
||||
Use the following pieces of retrieved context to answer the question.
|
||||
If you don't know the answer, just say that you don't know.
|
||||
Context:""" + str(chat.content)))
|
||||
|
||||
if(chat.type == 'ai'):
|
||||
chatHistory.append(AIMessage(content=chat.content))
|
||||
|
||||
if(chat.type == 'human'):
|
||||
chatHistory.append(HumanMessage(content=chat.content))
|
||||
|
||||
chatHistory.append(("human", "{input}"));
|
||||
|
||||
|
||||
qa_prompt = ChatPromptTemplate.from_messages(chatHistory)
|
||||
|
||||
descriptionchain = qa_prompt | llm
|
||||
|
||||
response = descriptionchain.invoke({"input": data.query})
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
return DescriptionResponse(response=response)
|
||||
else:
|
||||
return DescriptionResponse(response=response.content)
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
# Multi DOC Chat
|
||||
|
||||
@app.post("/delete/docs")
|
||||
def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(data.token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
if(IS_LOCAL_SETUP == 'true'):
|
||||
index = HIndices(username=username)
|
||||
else:
|
||||
index = HIndices(username=username,api_key=data.openaikey)
|
||||
|
||||
message = index.delete_vector_stores(summary_ids_to_delete=data.ids_to_delete,db=db )
|
||||
|
||||
return {
|
||||
"message": message
|
||||
}
|
||||
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
#AUTH CODE
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
# Recommended for Local Setups
|
||||
# Manual Origins
|
||||
# origins = [
|
||||
# "http://localhost:3000", # Adjust the port if your frontend runs on a different one
|
||||
# "https://yourfrontenddomain.com",
|
||||
|
@ -395,7 +312,7 @@ def get_user_by_username(db: Session, username: str):
|
|||
|
||||
def create_user(db: Session, user: UserCreate):
|
||||
hashed_password = pwd_context.hash(user.password)
|
||||
db_user = User(username=user.username, hashed_password=hashed_password, graph_config="", llm_config="")
|
||||
db_user = User(username=user.username, hashed_password=hashed_password)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
return "complete"
|
||||
|
@ -462,9 +379,6 @@ async def verify_user_token(token: str):
|
|||
verify_token(token=token)
|
||||
return {"message": "Token is valid"}
|
||||
|
||||
|
||||
|
||||
|
||||
@app.post("/user/chat/save")
|
||||
def populate_user_chat(chat: NewUserChat, db: Session = Depends(get_db)):
|
||||
try:
|
||||
|
@ -533,34 +447,32 @@ async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
|||
"userid": user.id,
|
||||
"username": user.username,
|
||||
"chats": user.chats,
|
||||
"notifications": user.notifications
|
||||
"documents": user.documents
|
||||
}
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
@app.get("/user/notification/delete/{token}/{notificationid}")
|
||||
async def delete_chat_of_user(token: str, notificationid: str, db: Session = Depends(get_db)):
|
||||
@app.get("/searchspaces/{token}")
|
||||
async def get_user_with_token(token: str, db: Session = Depends(get_db)):
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
notificationindb = db.query(Notification).filter(Notification.id == notificationid).first()
|
||||
db.delete(notificationindb)
|
||||
db.commit()
|
||||
search_spaces = db.query(SearchSpace).all()
|
||||
return {
|
||||
"message": "Notification Deleted"
|
||||
"search_spaces": search_spaces
|
||||
}
|
||||
except JWTError:
|
||||
raise HTTPException(status_code=403, detail="Token is invalid or expired")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||
uvicorn.run(app, host="127.0.0.1", port=8000)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 83205acb7b9712020940461b74e0aa66f9316a77
|
||||
Subproject commit cd5d99a4bc794dad2fdf80abcbfddef3c9fea7b2
|
Loading…
Add table
Reference in a new issue