feat: gpt-researcher custom response.Now very close to perplexity.

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2024-10-24 22:19:29 -07:00
parent dfb0967dbe
commit 46c9b228df
5 changed files with 215 additions and 194 deletions

View file

@ -1,16 +1,35 @@
#true if you wana run local setup with Ollama
IS_LOCAL_SETUP = 'false'
#Your Unstructured IO API Key. Use any value if running a local instance or file upload support isn't needed.
#Your Unstructed IO API Key. Random value if you are using unstructured locally or dont want to upload files
UNSTRUCTURED_API_KEY = ""
#POSTGRES DB TO TRACK USERS
POSTGRES_DATABASE_URL = "postgresql+psycopg2://postgres:postgres@localhost:5432/surfsense"
# API KEY TO PREVENT USER REGISTRATION SPAM
# API KEY TO VERIFY
API_SECRET_KEY = "surfsense"
# Your JWT secret and algorithm
SECRET_KEY = "your_secret_key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 720
ACCESS_TOKEN_EXPIRE_MINUTES = "720"
# SEARCHE ENGINES TO USE - FUTURE FEATURE - LEAVE EMPTY FOR NOW
TAVILY_API_KEY=""
# UNCOMMENT THE RESPECTIVE BELOW LINES FOR LOCAL/OPENAI SETUP
# For OpenAI LLM SETUP
OPENAI_API_KEY="sk-proj-GHG....."
FAST_LLM="openai:gpt-4o-mini"
SMART_LLM="openai:gpt-4o-mini"
EMBEDDING="openai:text-embedding-3-large"
# For Local Setups
# OPENAI_API_KEY="123"
# OLLAMA_BASE_URL="http://localhost:11434"
# FAST_LLM="ollama:qwen2.5:7b"
# SMART_LLM="ollama:qwen2.5:7b"
# EMBEDDING="ollama:qwen2.5:7b"
# TEMPRATURE="0"

View file

@ -1,3 +1,7 @@
import asyncio
from datetime import datetime
from typing import List
from gpt_researcher import GPTResearcher
from langchain_chroma import Chroma
from langchain_ollama import OllamaLLM, OllamaEmbeddings
from langchain_community.vectorstores.utils import filter_complex_metadata
@ -14,12 +18,23 @@ from langchain_core.prompts import PromptTemplate
import os
from dotenv import load_dotenv
from pydmodels import AIAnswer, Reference
from database import SessionLocal
from models import Documents, User
from prompts import CONTEXT_ANSWER_PROMPT
load_dotenv()
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
FAST_LLM = os.environ.get("FAST_LLM")
EMBEDDING = os.environ.get("EMBEDDING")
IS_LOCAL_SETUP = True if FAST_LLM.startswith("ollama") else False
def extract_model_name(model_string: str) -> tuple[str, str]:
part1, part2 = model_string.split(":", 1) # Split into two parts at the first colon
return part2
MODEL_NAME = extract_model_name(FAST_LLM)
EMBEDDING_MODEL = extract_model_name(EMBEDDING)
# Dependency
def get_db():
@ -35,12 +50,12 @@ class HIndices:
"""
"""
self.username = username
if(IS_LOCAL_SETUP == 'true'):
self.llm = OllamaLLM(model="mistral-nemo",temperature=0)
self.embeddings = OllamaEmbeddings(model="mistral-nemo")
if(IS_LOCAL_SETUP == True):
self.llm = OllamaLLM(model=MODEL_NAME,temperature=0)
self.embeddings = OllamaEmbeddings(model=EMBEDDING_MODEL)
else:
self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=api_key)
self.embeddings = OpenAIEmbeddings(api_key=api_key)
self.llm = ChatOpenAI(temperature=0, model_name=MODEL_NAME, api_key=api_key)
self.embeddings = OpenAIEmbeddings(api_key=api_key,model=EMBEDDING_MODEL)
self.summary_store = Chroma(
collection_name="summary_store",
@ -92,18 +107,9 @@ class HIndices:
report_chain = report_prompt | self.llm
if(IS_LOCAL_SETUP == 'true'):
# Local LLMS suck at summaries so need this slow and painful procedure
text_splitter = SemanticChunker(embeddings=self.embeddings)
chunks = text_splitter.split_documents([doc])
combined_summary = ""
for i, chunk in enumerate(chunks):
print("GENERATING SUMMARY FOR CHUNK "+ str(i))
chunk_summary = report_chain.invoke({"document": chunk})
combined_summary += "\n\n" + chunk_summary + "\n\n"
response = combined_summary
if(IS_LOCAL_SETUP == True):
response = report_chain.invoke({"document": doc})
metadict = {
"page": page_no,
@ -111,10 +117,10 @@ class HIndices:
"search_space": search_space,
}
# metadict['languages'] = metadict['languages'][0]
metadict.update(doc.metadata)
# metadict['languages'] = metadict['languages'][0]
return Document(
id=str(page_no),
page_content=response,
@ -177,17 +183,8 @@ class HIndices:
report_chain = report_prompt | self.llm
if(IS_LOCAL_SETUP == 'true'):
# Local LLMS suck at summaries so need this slow and painful procedure
text_splitter = SemanticChunker(embeddings=self.embeddings)
chunks = text_splitter.split_documents([doc])
combined_summary = ""
for i, chunk in enumerate(chunks):
print("GENERATING SUMMARY FOR CHUNK "+ str(i))
chunk_summary = report_chain.invoke({"document": chunk})
combined_summary += "\n\n" + chunk_summary + "\n\n"
response = combined_summary
if(IS_LOCAL_SETUP == True):
response = report_chain.invoke({"document": doc})
return Document(
id=str(page_no),
@ -205,6 +202,7 @@ class HIndices:
"VisitedWebPageVisitDurationInMilliseconds": doc.metadata['VisitedWebPageVisitDurationInMilliseconds'],
}
)
else:
response = report_chain.invoke({"document": doc})
@ -230,19 +228,6 @@ class HIndices:
Creates and Saves/Updates docs in hierarchical indices and postgres table
"""
# DocumentPgEntry = []
# searchspace = db.query(SearchSpace).filter(SearchSpace.search_space == search_space).first()
# for doc in documents:
# pgdocmeta = stringify(doc.metadata)
# if(searchspace):
# DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=search_space, document_metadata=pgdocmeta, page_content=doc.page_content))
# else:
# DocumentPgEntry.append(Documents(file_type='WEBPAGE',title=doc.metadata.VisitedWebPageTitle,search_space=SearchSpace(search_space=search_space), document_metadata=pgdocmeta, page_content=doc.page_content))
prev_doc_idx = len(documents) + 1
# #Save docs in PG
user = db.query(User).filter(User.username == self.username).first()
@ -262,22 +247,20 @@ class HIndices:
else:
batch_summaries = [self.summarize_file_doc(page_no = i + summary_last_id, doc=doc, search_space=search_space) for i, doc in enumerate(documents)]
# batch_summaries = [summarize_doc(i + summary_last_id, doc) for i, doc in enumerate(documents)]
summaries.extend(batch_summaries)
detailed_chunks = []
for i, summary in enumerate(summaries):
# Semantic chucking for better contexual comprression
# Semantic chucking for better contexual compression
text_splitter = SemanticChunker(embeddings=self.embeddings)
chunks = text_splitter.split_documents([documents[i]])
user.documents[-(len(summaries) - i)].desc_vector_start = detail_id_counter
user.documents[-(len(summaries) - i)].desc_vector_end = detail_id_counter + len(chunks)
# summary_entry = db.query(Documents).filter(Documents.id == int(user.documents[-1].id)).first()
# summary_entry.desc_vector_start = detail_id_counter
# summary_entry.desc_vector_end = detail_id_counter + len(chunks)
db.commit()
@ -290,6 +273,30 @@ class HIndices:
"page": summary.metadata['page'],
})
if(files_type == 'WEBPAGE'):
ieee_content = (
f"=======================================DOCUMENT METADATA==================================== \n"
f"Source: {chunk.metadata['VisitedWebPageURL']} \n"
f"Title: {chunk.metadata['VisitedWebPageTitle']} \n"
f"Visited Date and Time : {chunk.metadata['VisitedWebPageDateWithTimeInISOString']} \n"
f"============================DOCUMENT PAGE CONTENT CHUNK===================================== \n"
f"Page Content Chunk: \n\n{chunk.page_content}\n\n"
f"===================================================================================== \n"
)
else:
ieee_content = (
f"=======================================DOCUMENT METADATA==================================== \n"
f"Source: {chunk.metadata['filename']} \n"
f"Title: {chunk.metadata['filename']} \n"
f"Visited Date and Time : {datetime.now()} \n"
f"============================DOCUMENT PAGE CONTENT CHUNK===================================== \n"
f"Page Content Chunk: \n\n{chunk.page_content}\n\n"
f"===================================================================================== \n"
)
chunk.page_content = ieee_content
detail_id_counter += 1
detailed_chunks.extend(chunks)
@ -313,33 +320,67 @@ class HIndices:
return "success"
def is_query_answerable(self, query, context):
prompt = PromptTemplate(
template="""You are a grader assessing relevance of a retrieved document to a user question. \n
Here is the retrieved document: \n\n {context} \n\n
Here is the user question: {question} \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
Only return 'yes' or 'no'""",
input_variables=["context", "question"],
)
def summary_vector_search(self,query, search_space='GENERAL'):
top_summaries_compressor = FlashrankRerank(top_n=20)
ans_chain = prompt | self.llm
finalans = ans_chain.invoke({"question": query, "context": context})
if(IS_LOCAL_SETUP == 'true'):
return finalans
else:
return finalans.content
def local_search(self, query, search_space='GENERAL'):
top_summaries_compressor = FlashrankRerank(top_n=5)
details_compressor = FlashrankRerank(top_n=30)
top_summaries_retreiver = ContextualCompressionRetriever(
base_compressor=top_summaries_compressor, base_retriever=self.summary_store.as_retriever(search_kwargs={'filter': {'search_space': search_space}})
)
return top_summaries_retreiver.invoke(query)
def deduplicate_references_and_update_answer(self, answer: str, references: List[Reference]) -> tuple[str, List[Reference]]:
"""
Deduplicates references and updates the answer text to maintain correct reference numbering.
Args:
answer: The text containing reference citations
references: List of Reference objects
Returns:
tuple: (updated_answer, deduplicated_references)
"""
# Track unique references and create ID mapping using a dictionary comprehension
unique_refs = {}
id_mapping = {
ref.id: unique_refs.setdefault(
ref.url, Reference(id=str(len(unique_refs) + 1), title=ref.title, url=ref.url)
).id
for ref in references
}
# Apply new mappings to the answer text
updated_answer = answer
for old_id, new_id in sorted(id_mapping.items(), key=lambda x: len(x[0]), reverse=True):
updated_answer = updated_answer.replace(f'[{old_id}]', f'[{new_id}]')
return updated_answer, list(unique_refs.values())
async def get_vectorstore_report(self, query: str, report_type: str, report_source: str, documents: List[Document]) -> str:
researcher = GPTResearcher(query=query, report_type=report_type, report_source=report_source, documents=documents, report_format="IEEE")
await researcher.conduct_research()
report = await researcher.write_report()
return report
async def get_web_report(self, query: str, report_type: str, report_source: str) -> str:
researcher = GPTResearcher(query=query, report_type=report_type, report_source=report_source, report_format="IEEE")
await researcher.conduct_research()
report = await researcher.write_report()
return report
def new_search(self, query, search_space='GENERAL'):
report_type = "custom_report"
report_source = "langchain_documents"
contextdocs = []
top_summaries_compressor = FlashrankRerank(top_n=5)
details_compressor = FlashrankRerank(top_n=50)
top_summaries_retreiver = ContextualCompressionRetriever(
base_compressor=top_summaries_compressor, base_retriever=self.summary_store.as_retriever(search_kwargs={'filter': {'search_space': search_space}})#
)
top_summaries_compressed_docs = top_summaries_retreiver.invoke(query)
for summary in top_summaries_compressed_docs:
@ -354,61 +395,40 @@ class HIndices:
query
)
contextdocs = top_summaries_compressed_docs + detailed_compressed_docs
contextdocs.extend(detailed_compressed_docs)
context_to_answer = ""
for i, doc in enumerate(contextdocs):
content = f":DOCUMENT {str(i)}\n"
content += f"=======================================METADATA==================================== \n"
content += f"{doc.metadata} \n"
content += f"===================================================================================== \n"
content += f"DOCUMENT CONTENT: \n\n {doc.page_content} \n\n"
content += f"===================================================================================== \n"
custom_prompt = """
Please answer the following user query in the format shown below, using in-text citations and IEEE-style references based on the provided documents.
USER QUERY : """+ query +"""
context_to_answer += content
Ensure the answer includes:
- A detailed yet concise explanation with IEEE-style in-text citations (e.g., [1], [2]).
- A list of non-duplicated sources at the end, following IEEE format. Hyperlink each source using: [Website Name](URL).
- Where applicable, provide sources in the text to back up key points.
content = ""
Ensure your response is structured something like this (here user query : Explain the impact of artificial intelligence on modern healthcare.):
---
**Answer:**
Artificial intelligence (AI) has significantly transformed modern healthcare by enhancing diagnostic accuracy, personalizing patient care, and optimizing operational efficiency. AI algorithms can analyze vast datasets to identify patterns that may be missed by human practitioners, leading to improved diagnostic outcomes [1]. For instance, AI systems have been deployed in radiology to detect anomalies in medical imaging with high precision [2]. Moreover, AI-driven tools facilitate personalized treatment plans by considering individual patient data, thereby improving patient outcomes [3].
if(self.is_query_answerable(query=query, context=context_to_answer).lower() == 'yes'):
ans_chain = CONTEXT_ANSWER_PROMPT | self.llm
**References:**
1. (2024, October 23). [Highly Effective Prompt for Summarizing GPT-4 Optimized: r/ChatGPT.](https://www.reddit.com/r/ChatGPT/comments/13na8yp/highly_effective_prompt_for_summarizing_gpt4/)
2. (2024, October 23). [MODSetter/SurfSense: Personal AI Assistant for Internet Surfers and Researchers.](https://github.com/MODSetter/SurfSense)
3. filename.pdf
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
---
if(IS_LOCAL_SETUP == 'true'):
return finalans
else:
return finalans.content
else:
continue
"""
return "I couldn't find any answer"
local_report = asyncio.run(self.get_vectorstore_report(query=custom_prompt, report_type=report_type, report_source=report_source, documents=contextdocs))
def global_search(self,query, search_space='GENERAL'):
top_summaries_compressor = FlashrankRerank(top_n=20)
# web_report = asyncio.run(get_web_report(query=custom_prompt, report_type=report_type, report_source="web"))
top_summaries_retreiver = ContextualCompressionRetriever(
base_compressor=top_summaries_compressor, base_retriever=self.summary_store.as_retriever(search_kwargs={'filter': {'search_space': search_space}})
)
# structured_llm = self.llm.with_structured_output(AIAnswer)
top_summaries_compressed_docs = top_summaries_retreiver.invoke(query)
# out = structured_llm.invoke("Extract exact(i.e without changing) answer string and references information from : \n\n\n" + local_report)
context_to_answer = ""
for i, doc in enumerate(top_summaries_compressed_docs):
content = f":DOCUMENT {str(i)}\n"
content += f"=======================================METADATA==================================== \n"
content += f"{doc.metadata} \n"
content += f"===================================================================================== \n"
content += f"DOCUMENT CONTENT: \n\n {doc.page_content} \n\n"
content += f"===================================================================================== \n"
# mod_out = self.deduplicate_references_and_update_answer(answer=out.answer, references=out.references)
context_to_answer += content
ans_chain = CONTEXT_ANSWER_PROMPT | self.llm
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
if(IS_LOCAL_SETUP == 'true'):
return finalans, top_summaries_compressed_docs
else:
return finalans.content, top_summaries_compressed_docs
return local_report

View file

@ -23,6 +23,19 @@ class DocMeta(BaseModel):
# VisitedWebPageVisitDurationInMilliseconds: Optional[int] = Field(default=None, description="VisitedWebPageVisitDurationInMilliseconds of Document"),
# VisitedWebPageContent: Optional[str] = Field(default=None, description="Visited WebPage Content in markdown of Document")
class Reference(BaseModel):
id: str = Field(..., description="reference no")
title: str = Field(..., description="reference title")
url: str = Field(..., description="reference url")
class AIAnswer(BaseModel):
answer: str = Field(..., description="Given Answer including its intext citation no's like [1], [2] etc.")
references: List[Reference] = Field(..., description="References")
class DocWithContent(BaseModel):
DocMetadata: Optional[str] = Field(default=None, description="Document Metadata")
Content: Optional[str] = Field(default=None, description="Document Page Content")

View file

@ -19,3 +19,9 @@ flashrank
psycopg2
unstructured-client
langchain-unstructured
langgraph
gpt_researcher
langgraph-cli
weasyprint
json5
loguru

View file

@ -5,14 +5,13 @@ from langchain_core.documents import Document
from langchain_ollama import OllamaLLM
from langchain_openai import ChatOpenAI
from sqlalchemy import insert
from prompts import CONTEXT_ANSWER_PROMPT, DATE_TODAY, SUBQUERY_DECOMPOSITION_PROMT
from prompts import DATE_TODAY
from pydmodels import ChatToUpdate, DescriptionResponse, DocWithContent, DocumentsToDelete, NewUserChat, UserCreate, UserQuery, RetrivedDocList, UserQueryResponse, UserQueryWithChatHistory
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_unstructured import UnstructuredLoader
#Heirerical Indices class
from HIndices import HIndices
from Utils.stringify import stringify
# Auth Libs
@ -31,13 +30,20 @@ import os
from dotenv import load_dotenv
load_dotenv()
IS_LOCAL_SETUP = os.environ.get("IS_LOCAL_SETUP")
FAST_LLM = os.environ.get("FAST_LLM")
IS_LOCAL_SETUP = True if FAST_LLM.startswith("ollama") else False
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.environ.get("ACCESS_TOKEN_EXPIRE_MINUTES"))
ALGORITHM = os.environ.get("ALGORITHM")
API_SECRET_KEY = os.environ.get("API_SECRET_KEY")
SECRET_KEY = os.environ.get("SECRET_KEY")
UNSTRUCTURED_API_KEY = os.environ.get("UNSTRUCTURED_API_KEY")
def extract_model_name(model_string: str) -> tuple[str, str]:
part1, part2 = model_string.split(":", 1) # Split into two parts at the first colon
return part2
MODEL_NAME = extract_model_name(FAST_LLM)
app = FastAPI()
# Dependency
@ -71,6 +77,7 @@ async def upload_files(files: list[UploadFile], token: str = Depends(oauth2_sche
chunking_strategy="basic",
max_characters=90000,
include_orig_elements=False,
strategy="fast",
)
filedocs = loader.load()
@ -117,7 +124,7 @@ async def upload_files(files: list[UploadFile], token: str = Depends(oauth2_sche
db.commit()
# Create hierarchical indices
if IS_LOCAL_SETUP == 'true':
if IS_LOCAL_SETUP == True:
index = HIndices(username=username)
else:
index = HIndices(username=username, api_key=api_key)
@ -145,60 +152,21 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
query = data.query
search_space = data.search_space
if(IS_LOCAL_SETUP == 'true'):
sub_query_llm = OllamaLLM(model="mistral-nemo",temperature=0)
qa_llm = OllamaLLM(model="mistral-nemo",temperature=0)
else:
sub_query_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
qa_llm = ChatOpenAI(temperature=0.5, model_name="gpt-4o-mini", api_key=data.openaikey)
# Create an LLMChain for sub-query decomposition
subquery_decomposer_chain = SUBQUERY_DECOMPOSITION_PROMT | sub_query_llm
#Experimental
def decompose_query(original_query: str):
"""
Decompose the original query into simpler sub-queries.
Args:
original_query (str): The original complex query
Returns:
List[str]: A list of simpler sub-queries
"""
if(IS_LOCAL_SETUP == 'true'):
response = subquery_decomposer_chain.invoke(original_query)
else:
response = subquery_decomposer_chain.invoke(original_query).content
sub_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.strip().startswith('Sub-queries:')]
return sub_queries
# Create Heirarical Indecices
if(IS_LOCAL_SETUP == 'true'):
if(IS_LOCAL_SETUP == True):
index = HIndices(username=username)
else:
index = HIndices(username=username,api_key=data.openaikey)
# For Those Who Want HyDe Questions
# sub_queries = decompose_query(query)
#Implement HyDe over it if you crazy
sub_queries = []
sub_queries.append(query)
duplicate_related_summary_docs = []
context_to_answer = ""
for sub_query in sub_queries:
localreturn = index.local_search(query=sub_query, search_space=search_space)
globalreturn, related_summary_docs = index.global_search(query=sub_query, search_space=search_space)
context_to_answer += localreturn + "\n\n" + globalreturn
# I know this is not the best way to do it, but I am too lazy to change it now
related_summary_docs = index.summary_vector_search(query=sub_query, search_space=search_space)
duplicate_related_summary_docs.extend(related_summary_docs)
@ -223,15 +191,10 @@ def get_user_query_response(data: UserQuery, response_model=UserQueryResponse):
returnDocs.append(entry)
ans_chain = CONTEXT_ANSWER_PROMPT | qa_llm
finalans = ans_chain.invoke({"query": query, "context": context_to_answer})
finalans = index.new_search(query=query, search_space=search_space)
if(IS_LOCAL_SETUP == 'true'):
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
else:
return UserQueryResponse(response=finalans.content, relateddocs=returnDocs)
return UserQueryResponse(response=finalans, relateddocs=returnDocs)
except JWTError:
@ -310,7 +273,7 @@ def save_data(apires: RetrivedDocList, db: Session = Depends(get_db)):
db.commit()
# Create hierarchical indices
if IS_LOCAL_SETUP == 'true':
if IS_LOCAL_SETUP == True:
index = HIndices(username=username)
else:
index = HIndices(username=username, api_key=apires.openaikey)
@ -336,10 +299,10 @@ def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=Descrip
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
if(IS_LOCAL_SETUP == 'true'):
llm = OllamaLLM(model="mistral-nemo",temperature=0)
if(IS_LOCAL_SETUP == True):
llm = OllamaLLM(model=MODEL_NAME,temperature=0)
else:
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", api_key=data.openaikey)
llm = ChatOpenAI(temperature=0, model_name=MODEL_NAME, api_key=data.openaikey)
chatHistory = []
@ -365,7 +328,7 @@ def doc_chat_with_history(data: UserQueryWithChatHistory, response_model=Descrip
response = descriptionchain.invoke({"input": data.query})
if(IS_LOCAL_SETUP == 'true'):
if(IS_LOCAL_SETUP == True):
return DescriptionResponse(response=response)
else:
return DescriptionResponse(response=response.content)
@ -384,7 +347,7 @@ def delete_all_related_data(data: DocumentsToDelete, db: Session = Depends(get_d
if username is None:
raise HTTPException(status_code=403, detail="Token is invalid or expired")
if(IS_LOCAL_SETUP == 'true'):
if(IS_LOCAL_SETUP == True):
index = HIndices(username=username)
else:
index = HIndices(username=username,api_key=data.openaikey)