mirror of
https://github.com/TheBlewish/Automated-AI-Web-Researcher-Ollama.git
synced 2025-01-19 00:47:46 +00:00
309 lines
12 KiB
Python
309 lines
12 KiB
Python
from typing import Dict, Any, Optional
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
import requests
|
|
from datetime import datetime, timedelta
|
|
import json
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
# Add parent directory to path for imports when running as script
|
|
if __name__ == "__main__":
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
|
from search_providers.base_provider import BaseSearchProvider
|
|
else:
|
|
from .base_provider import BaseSearchProvider
|
|
|
|
class BraveSearchProvider(BaseSearchProvider):
|
|
"""
|
|
Brave implementation of the search provider interface.
|
|
Handles both web and news-specific searches using Brave's APIs.
|
|
"""
|
|
|
|
WEB_SEARCH_ENDPOINT = "https://api.search.brave.com/res/v1/web/search"
|
|
NEWS_SEARCH_ENDPOINT = "https://api.search.brave.com/res/v1/news/search"
|
|
SUMMARIZER_ENDPOINT = "https://api.search.brave.com/res/v1/summarizer/search"
|
|
|
|
def __init__(self, api_key: Optional[str] = None):
|
|
"""
|
|
Initialize the Brave search provider.
|
|
|
|
Args:
|
|
api_key: Optional Brave API key. If not provided, will try to get from environment.
|
|
"""
|
|
self.api_key = api_key or os.getenv("BRAVE_API_KEY")
|
|
self.pro_api_key = os.getenv("BRAVE_AI_PRO_API_KEY") #Optional, used for AI summary requests
|
|
self.headers = {
|
|
'X-Subscription-Token': self.api_key,
|
|
'Accept': 'application/json'
|
|
} if self.api_key else None
|
|
self.proheaders = {
|
|
'X-Subscription-Token': self.pro_api_key,
|
|
'Accept': 'application/json'
|
|
} if self.pro_api_key else None
|
|
def is_configured(self) -> bool:
|
|
"""Check if Brave API is properly configured."""
|
|
return self.headers is not None
|
|
|
|
def get_brave_summary(self, query):
|
|
# Query parameters
|
|
params = {
|
|
"q": query,
|
|
"summary": 1
|
|
}
|
|
|
|
# Make the initial web search request to get summarizer key
|
|
search_response = requests.get(self.WEB_SEARCH_ENDPOINT, headers=self.proheaders, params=params)
|
|
|
|
if search_response.status_code == 200:
|
|
data = search_response.json()
|
|
|
|
if "summarizer" in data and "key" in data["summarizer"]:
|
|
summarizer_key = data["summarizer"]["key"]
|
|
|
|
# Make request to summarizer endpoint
|
|
summarizer_params = {
|
|
"key": summarizer_key,
|
|
"entity_info": 1
|
|
}
|
|
|
|
summary_response = requests.get(
|
|
self.SUMMARIZER_ENDPOINT,
|
|
headers=self.proheaders,
|
|
params=summarizer_params
|
|
)
|
|
|
|
if summary_response.status_code == 200:
|
|
summary_data = summary_response.json()
|
|
try:
|
|
return summary_data['summary'][0]['data']
|
|
except (KeyError, IndexError):
|
|
return None
|
|
|
|
return None
|
|
|
|
def search(self, query: str, **kwargs) -> Dict[str, Any]:
|
|
"""
|
|
Perform a search using Brave API.
|
|
|
|
Args:
|
|
query: The search query string
|
|
**kwargs: Additional search parameters:
|
|
- topic: Optional search topic (e.g., "news")
|
|
- max_results: Maximum number of results (default: 10)
|
|
- market: Market code (default: "en-US")
|
|
- days: Number of days to look back (for news searches)
|
|
|
|
Returns:
|
|
Dict containing search results or error information
|
|
"""
|
|
if not self.is_configured():
|
|
return {'error': 'Brave API key not configured'}
|
|
|
|
try:
|
|
# Set default search parameters
|
|
search_params = {
|
|
'count': str(kwargs.get('max_results', 10)),
|
|
'country': kwargs.get('market', 'us'), # Brave uses country code
|
|
'q': query
|
|
}
|
|
|
|
# Determine if this is a news search
|
|
if kwargs.get('topic') == 'news':
|
|
# Add freshness parameter for news if days specified
|
|
if 'days' in kwargs:
|
|
days = kwargs['days']
|
|
if days <= 1:
|
|
search_params['freshness'] = 'pd' # past day
|
|
elif days <= 7:
|
|
search_params['freshness'] = 'pw' # past week
|
|
else:
|
|
search_params['freshness'] = 'pm' # past month
|
|
|
|
response = requests.get(
|
|
self.NEWS_SEARCH_ENDPOINT,
|
|
headers=self.headers,
|
|
params=search_params
|
|
)
|
|
|
|
response_data = response.json()
|
|
result = self._process_news_results(response_data, days=kwargs.get('days', 3), topic=query)
|
|
else:
|
|
response = requests.get(
|
|
self.WEB_SEARCH_ENDPOINT,
|
|
headers=self.headers,
|
|
params=search_params
|
|
)
|
|
response_data = response.json()
|
|
result = self._process_general_results(response_data)
|
|
|
|
# Include summarizer response if it exists
|
|
summary_response = self.get_brave_summary(query)
|
|
if summary_response:
|
|
result['summarizer'] = summary_response
|
|
|
|
return result
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
return {'error': f'API request failed: {str(e)}'}
|
|
except Exception as e:
|
|
return {'error': f'An unexpected error occurred: {str(e)}'}
|
|
|
|
def _process_general_results(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Process results for general web searches."""
|
|
web_results = response.get('web', {}).get('results', [])
|
|
with ThreadPoolExecutor() as executor:
|
|
# Use index as key instead of the result dictionary
|
|
futures = {i: executor.submit(self.get_brave_summary, result.get('title', ''))
|
|
for i, result in enumerate(web_results[:2])}
|
|
|
|
results = []
|
|
for i, result in enumerate(web_results):
|
|
summary = None
|
|
if i < 2:
|
|
try:
|
|
summary = futures[i].result()
|
|
except Exception as e:
|
|
print(f"Error getting summary: {e}")
|
|
|
|
processed_result = {
|
|
'title': result.get('title', ''),
|
|
'url': result.get('url', ''),
|
|
'content': result.get('description', ''),
|
|
'score': result.get('score', 1.0),
|
|
'extra_snippets': None,
|
|
'summary': None
|
|
}
|
|
if summary:
|
|
processed_result['summary'] = summary
|
|
else:
|
|
processed_result['extra_snippets'] = result.get('extra_snippets', [])
|
|
results.append(processed_result)
|
|
return {'results': results}
|
|
|
|
def _process_news_results(self, response: Dict[str, Any], days: int, topic: str) -> Dict[str, Any]:
|
|
"""Process results for news-specific searches."""
|
|
news_results = response.get('results', [])
|
|
def convert_age_to_minutes(age_str: str) -> int:
|
|
"""
|
|
Convert age string to minutes.
|
|
|
|
Args:
|
|
age_str: Age string in the format of "X minutes", "X hours", "X days"
|
|
|
|
Returns:
|
|
Age in minutes
|
|
"""
|
|
age_value = int(age_str.split()[0])
|
|
age_unit = age_str.split()[1]
|
|
if age_unit == 'minutes':
|
|
return age_value
|
|
elif age_unit == 'hours':
|
|
return age_value * 60
|
|
elif age_unit == 'days':
|
|
return age_value * 1440 # 24 hours * 60 minutes
|
|
else:
|
|
return 0 # Default to 0 if unknown unit
|
|
|
|
# Sort news results based on the age field
|
|
news_results.sort(key=lambda x: convert_age_to_minutes(x.get('age', '0 minutes')))
|
|
|
|
with ThreadPoolExecutor() as executor:
|
|
# Use enumerate to create futures with index as key
|
|
futures = {i: executor.submit(self.get_brave_summary, article_data.get('title', ''))
|
|
for i, article_data in enumerate(news_results)}
|
|
|
|
articles = []
|
|
for i, article_data in enumerate(news_results):
|
|
try:
|
|
summary = futures[i].result()
|
|
except Exception as e:
|
|
print(f"Error getting summary: {e}")
|
|
summary = None
|
|
|
|
article = {
|
|
'title': article_data.get('title', ''),
|
|
'url': article_data.get('url', ''),
|
|
'published_date': article_data.get('age', ''),
|
|
'breaking' : article_data.get('breaking', False),
|
|
'content': article_data.get('description', ''),
|
|
'extra_snippets': None,
|
|
'summary': None,
|
|
'score': article_data.get('score', 1.0)
|
|
}
|
|
if summary:
|
|
article['summary'] = summary
|
|
else:
|
|
article['extra_snippets'] = article_data.get('extra_snippets', [])
|
|
articles.append(article)
|
|
|
|
return {
|
|
'articles': articles,
|
|
'time_period': f"Past {days} days",
|
|
'topic': topic
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
# Test code using actual API
|
|
provider = BraveSearchProvider()
|
|
if not provider.is_configured():
|
|
print("Error: Brave API key not configured")
|
|
exit(1)
|
|
|
|
# Test general search
|
|
print("\n=== Testing General Search ===")
|
|
general_result = provider.search(
|
|
"What is artificial intelligence?",
|
|
max_results=1 # Increased max_results to test summary limiting
|
|
)
|
|
|
|
if 'error' in general_result:
|
|
print(f"Error in general search: {general_result['error']}")
|
|
else:
|
|
print("\nTop Results:")
|
|
for idx, result in enumerate(general_result['results'], 1):
|
|
print(f"\n{idx}. {result['title']}")
|
|
print(f" URL: {result['url']}")
|
|
print(f" Preview: {result['content']}...")
|
|
print(f" Score: {result['score']}")
|
|
if result['extra_snippets']:
|
|
print(" Extra Snippets:")
|
|
for snippet in result['extra_snippets']:
|
|
print(f" - {snippet}")
|
|
if result['summary']: # Check if summary exists before printing
|
|
print(f" Summary: {result.get('summary', '')}...")
|
|
import time
|
|
time.sleep(1)
|
|
|
|
# Test news search
|
|
print("\n\n=== Testing News Search ===")
|
|
import time
|
|
start_time = time.time()
|
|
news_result = provider.search(
|
|
"mike tyson fight",
|
|
topic="news",
|
|
days=3,
|
|
max_results=1
|
|
)
|
|
end_time = time.time()
|
|
|
|
|
|
if 'error' in news_result:
|
|
print(f"Error in news search: {news_result['error']}")
|
|
else:
|
|
print("\nRecent Articles:")
|
|
for idx, article in enumerate(news_result['articles'], 1):
|
|
print(f"\n{idx}. {article['title']}")
|
|
print(f" Published: {article['published_date']}")
|
|
print(f" Breaking: {article['breaking']}")
|
|
print(f" URL: {article['url']}")
|
|
print(f" Preview: {article['content'][:400]}...")
|
|
if article['extra_snippets']:
|
|
print(" Extra Snippets:")
|
|
for snippet in article['extra_snippets']:
|
|
print(f" - {snippet}")
|
|
if article['summary']:
|
|
print(f" Summary: {article.get('summary', '')}...")
|
|
|
|
print(f"Execution time: {round(end_time - start_time, 1)} seconds")
|