mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
improved websearch endpoint
This commit is contained in:
parent
5451a8e8a9
commit
709dab6289
1 changed files with 52 additions and 17 deletions
69
koboldcpp.py
69
koboldcpp.py
|
@ -1278,18 +1278,37 @@ def websearch(query):
|
||||||
import urllib.request
|
import urllib.request
|
||||||
import difflib
|
import difflib
|
||||||
from html.parser import HTMLParser
|
from html.parser import HTMLParser
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
num_results = 3
|
num_results = 3
|
||||||
searchresults = []
|
searchresults = []
|
||||||
|
|
||||||
def fetch_searched_webpage(url):
|
def fetch_searched_webpage(url):
|
||||||
|
if args.debugmode:
|
||||||
|
utfprint(f"WebSearch URL: {url}")
|
||||||
try:
|
try:
|
||||||
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
|
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)'})
|
||||||
with urllib.request.urlopen(req) as response:
|
with urllib.request.urlopen(req, timeout=15) as response:
|
||||||
html_content = response.read().decode('utf-8', errors='ignore')
|
html_content = response.read().decode('utf-8', errors='ignore')
|
||||||
return html_content
|
return html_content
|
||||||
|
except urllib.error.HTTPError: #we got blocked? try 1 more time with a different user agent
|
||||||
|
try:
|
||||||
|
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36'})
|
||||||
|
with urllib.request.urlopen(req, timeout=15) as response:
|
||||||
|
html_content = response.read().decode('utf-8', errors='ignore')
|
||||||
|
return html_content
|
||||||
|
except Exception as e:
|
||||||
|
if args.debugmode != -1 and not args.quiet:
|
||||||
|
print(f"Error fetching text from URL {url}: {e}")
|
||||||
|
return ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error fetching text from URL {url}: {e}")
|
if args.debugmode != -1 and not args.quiet:
|
||||||
|
print(f"Error fetching text from URL {url}: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
def fetch_webpages_parallel(urls):
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
# Submit tasks and gather results
|
||||||
|
results = list(executor.map(fetch_searched_webpage, urls))
|
||||||
|
return results
|
||||||
|
|
||||||
class VisibleTextParser(HTMLParser):
|
class VisibleTextParser(HTMLParser):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -1361,6 +1380,7 @@ def websearch(query):
|
||||||
titles = parser.titles[:num_results]
|
titles = parser.titles[:num_results]
|
||||||
searchurls = parser.urls[:num_results]
|
searchurls = parser.urls[:num_results]
|
||||||
descs = parser.descs[:num_results]
|
descs = parser.descs[:num_results]
|
||||||
|
fetchedcontent = fetch_webpages_parallel(searchurls)
|
||||||
for i in range(len(descs)):
|
for i in range(len(descs)):
|
||||||
# dive into the results to try and get even more details
|
# dive into the results to try and get even more details
|
||||||
title = titles[i]
|
title = titles[i]
|
||||||
|
@ -1369,13 +1389,13 @@ def websearch(query):
|
||||||
pagedesc = ""
|
pagedesc = ""
|
||||||
try:
|
try:
|
||||||
desclen = len(desc)
|
desclen = len(desc)
|
||||||
html_content = fetch_searched_webpage(url)
|
html_content = fetchedcontent[i]
|
||||||
parser2 = VisibleTextParser()
|
parser2 = VisibleTextParser()
|
||||||
parser2.feed(html_content)
|
parser2.feed(html_content)
|
||||||
scraped = parser2.get_text().strip()
|
scraped = parser2.get_text().strip()
|
||||||
s = difflib.SequenceMatcher(None, scraped.lower(), desc.lower())
|
s = difflib.SequenceMatcher(None, scraped.lower(), desc.lower(), autojunk=False)
|
||||||
matches = s.find_longest_match(0, len(scraped), 0, desclen)
|
matches = s.find_longest_match(0, len(scraped), 0, desclen)
|
||||||
if matches.size > 100 and desclen-matches.size < 50: #good enough match
|
if matches.size > 100 and desclen-matches.size < 100: #good enough match
|
||||||
# expand description by some chars both sides
|
# expand description by some chars both sides
|
||||||
expandamtbefore = 250
|
expandamtbefore = 250
|
||||||
expandamtafter = 750
|
expandamtafter = 750
|
||||||
|
@ -1388,7 +1408,8 @@ def websearch(query):
|
||||||
searchresults.append({"title":title,"url":url,"desc":desc,"content":pagedesc})
|
searchresults.append({"title":title,"url":url,"desc":desc,"content":pagedesc})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error fetching URL {search_url}: {e}")
|
if args.debugmode != -1 and not args.quiet:
|
||||||
|
print(f"Error fetching URL {search_url}: {e}")
|
||||||
return ""
|
return ""
|
||||||
return searchresults
|
return searchresults
|
||||||
|
|
||||||
|
@ -2146,13 +2167,27 @@ Enter Prompt:<br>
|
||||||
|
|
||||||
elif self.path.startswith(("/websearch")):
|
elif self.path.startswith(("/websearch")):
|
||||||
if args.websearch:
|
if args.websearch:
|
||||||
parsed_url = urlparse.urlparse(self.path)
|
# ensure authorized
|
||||||
parsed_dict = urlparse.parse_qs(parsed_url.query)
|
auth_ok = True
|
||||||
searchstr = (parsed_dict['q'][0]) if 'q' in parsed_dict else ""
|
if password and password !="":
|
||||||
if args.debugmode:
|
auth_header = None
|
||||||
print(f"Searching web for: {searchstr}")
|
auth_ok = False
|
||||||
searchres = websearch(searchstr)
|
if 'Authorization' in self.headers:
|
||||||
response_body = (json.dumps(searchres).encode())
|
auth_header = self.headers['Authorization']
|
||||||
|
elif 'authorization' in self.headers:
|
||||||
|
auth_header = self.headers['authorization']
|
||||||
|
if auth_header is not None and auth_header.startswith('Bearer '):
|
||||||
|
token = auth_header[len('Bearer '):].strip()
|
||||||
|
if token==password:
|
||||||
|
auth_ok = True
|
||||||
|
if auth_ok:
|
||||||
|
parsed_url = urlparse.urlparse(self.path)
|
||||||
|
parsed_dict = urlparse.parse_qs(parsed_url.query)
|
||||||
|
searchstr = (parsed_dict['q'][0]) if 'q' in parsed_dict else ""
|
||||||
|
searchres = websearch(searchstr)
|
||||||
|
response_body = (json.dumps(searchres).encode())
|
||||||
|
else:
|
||||||
|
response_body = (json.dumps([]).encode())
|
||||||
else:
|
else:
|
||||||
response_body = (json.dumps([]).encode())
|
response_body = (json.dumps([]).encode())
|
||||||
|
|
||||||
|
@ -4721,6 +4756,9 @@ def main(launch_args,start_server=True):
|
||||||
print("==========")
|
print("==========")
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
|
if args.password and args.password!="":
|
||||||
|
password = args.password.strip()
|
||||||
|
|
||||||
#handle loading text model
|
#handle loading text model
|
||||||
if args.model_param:
|
if args.model_param:
|
||||||
if not os.path.exists(args.model_param):
|
if not os.path.exists(args.model_param):
|
||||||
|
@ -4766,9 +4804,6 @@ def main(launch_args,start_server=True):
|
||||||
args.mmproj = os.path.abspath(args.mmproj)
|
args.mmproj = os.path.abspath(args.mmproj)
|
||||||
mmprojpath = args.mmproj
|
mmprojpath = args.mmproj
|
||||||
|
|
||||||
if args.password and args.password!="":
|
|
||||||
password = args.password.strip()
|
|
||||||
|
|
||||||
if not args.blasthreads or args.blasthreads <= 0:
|
if not args.blasthreads or args.blasthreads <= 0:
|
||||||
args.blasthreads = args.threads
|
args.blasthreads = args.threads
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue