Skip to content

Commit 95191f5

Browse files
authored
feat: add web search (#580) bump:patch
* feat: add web search * feat: update requirements
1 parent 4fe0807 commit 95191f5

File tree

10 files changed

+218
-27
lines changed

10 files changed

+218
-27
lines changed

flowsettings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@
8181
KH_ENABLE_ALEMBIC = False
8282
KH_DATABASE = f"sqlite:///{KH_USER_DATA_DIR / 'sql.db'}"
8383
KH_FILESTORAGE_PATH = str(KH_USER_DATA_DIR / "files")
84+
KH_WEB_SEARCH_BACKEND = (
85+
"kotaemon.indices.retrievers.tavily_web_search.WebSearch"
86+
# "kotaemon.indices.retrievers.jina_web_search.WebSearch"
87+
)
8488

8589
KH_DOCSTORE = {
8690
# "__type__": "kotaemon.storages.ElasticsearchDocumentStore",

libs/kotaemon/kotaemon/indices/retrievers/__init__.py

Whitespace-only changes.
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import requests
2+
from decouple import config
3+
4+
from kotaemon.base import BaseComponent, RetrievedDocument
5+
6+
JINA_API_KEY = config("JINA_API_KEY", default="")
7+
JINA_URL = config("JINA_URL", default="https://r.jina.ai/")
8+
9+
10+
class WebSearch(BaseComponent):
11+
"""WebSearch component for fetching data from the web
12+
using Jina API
13+
"""
14+
15+
def run(
16+
self,
17+
text: str,
18+
*args,
19+
**kwargs,
20+
) -> list[RetrievedDocument]:
21+
if JINA_API_KEY == "":
22+
raise ValueError(
23+
"This feature requires JINA_API_KEY "
24+
"(get free one from https://jina.ai/reader)"
25+
)
26+
27+
# setup the request
28+
api_url = f"https://s.jina.ai/{text}"
29+
headers = {"X-With-Generated-Alt": "true", "Accept": "application/json"}
30+
if JINA_API_KEY:
31+
headers["Authorization"] = f"Bearer {JINA_API_KEY}"
32+
33+
response = requests.get(api_url, headers=headers)
34+
response.raise_for_status()
35+
response_dict = response.json()
36+
37+
return [
38+
RetrievedDocument(
39+
text=(
40+
"###URL: [{url}]({url})\n\n"
41+
"####{title}\n\n"
42+
"{description}\n"
43+
"{content}"
44+
).format(
45+
url=item["url"],
46+
title=item["title"],
47+
description=item["description"],
48+
content=item["content"],
49+
),
50+
metadata={
51+
"file_name": "Web search",
52+
"type": "table",
53+
"llm_trulens_score": 1.0,
54+
},
55+
)
56+
for item in response_dict["data"]
57+
]
58+
59+
def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
60+
return documents
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from decouple import config
2+
3+
from kotaemon.base import BaseComponent, RetrievedDocument
4+
5+
TAVILY_API_KEY = config("TAVILY_API_KEY", default="")
6+
7+
8+
class WebSearch(BaseComponent):
9+
"""WebSearch component for fetching data from the web
10+
using Jina API
11+
"""
12+
13+
def run(
14+
self,
15+
text: str,
16+
*args,
17+
**kwargs,
18+
) -> list[RetrievedDocument]:
19+
if TAVILY_API_KEY == "":
20+
raise ValueError(
21+
"This feature requires TAVILY_API_KEY "
22+
"(get free one from https://app.tavily.com/)"
23+
)
24+
25+
try:
26+
from tavily import TavilyClient
27+
except ImportError:
28+
raise ImportError(
29+
"Please install `pip install tavily-python` to use this feature"
30+
)
31+
32+
tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
33+
results = tavily_client.search(
34+
query=text,
35+
search_depth="advanced",
36+
)["results"]
37+
context = "\n\n".join(
38+
"###URL: [{url}]({url})\n\n{content}".format(
39+
url=result["url"],
40+
content=result["content"],
41+
)
42+
for result in results
43+
)
44+
45+
return [
46+
RetrievedDocument(
47+
text=context,
48+
metadata={
49+
"file_name": "Web search",
50+
"type": "table",
51+
"llm_trulens_score": 1.0,
52+
},
53+
)
54+
]
55+
56+
def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
57+
return documents

libs/kotaemon/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ dependencies = [
5555
"theflow>=0.8.6,<0.9.0",
5656
"trogon>=0.5.0,<0.6",
5757
"umap-learn==0.5.5",
58+
"tavily-python>=0.4.0",
5859
]
5960
readme = "README.md"
6061
authors = [

libs/ktem/ktem/index/file/ui.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from sqlalchemy.orm import Session
2020
from theflow.settings import settings as flowsettings
2121

22+
from ...utils.commands import WEB_SEARCH_COMMAND
23+
2224
DOWNLOAD_MESSAGE = "Press again to download"
2325
MAX_FILENAME_LENGTH = 20
2426

@@ -38,6 +40,13 @@
3840
value: '"' + file_list[i][0] + '"',
3941
});
4042
}
43+
44+
// manually push web search tag
45+
values.push({
46+
key: "web_search",
47+
value: '"web_search"',
48+
});
49+
4150
var tribute = new Tribute({
4251
values: values,
4352
noMatchTemplate: "",
@@ -46,7 +55,9 @@
4655
input_box = document.querySelector('#chat-input textarea');
4756
tribute.attach(input_box);
4857
}
49-
"""
58+
""".replace(
59+
"web_search", WEB_SEARCH_COMMAND
60+
)
5061

5162

5263
class File(gr.File):

libs/ktem/ktem/pages/chat/__init__.py

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import importlib
23
import json
34
import re
45
from copy import deepcopy
@@ -23,11 +24,22 @@
2324
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
2425

2526
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls
27+
from ...utils.commands import WEB_SEARCH_COMMAND
2628
from .chat_panel import ChatPanel
2729
from .common import STATE
2830
from .control import ConversationControl
2931
from .report import ReportIssue
3032

33+
KH_WEB_SEARCH_BACKEND = getattr(flowsettings, "KH_WEB_SEARCH_BACKEND", None)
34+
WebSearch = None
35+
if KH_WEB_SEARCH_BACKEND:
36+
try:
37+
module_name, class_name = KH_WEB_SEARCH_BACKEND.rsplit(".", 1)
38+
module = importlib.import_module(module_name)
39+
WebSearch = getattr(module, class_name)
40+
except (ImportError, AttributeError) as e:
41+
print(f"Error importing {class_name} from {module_name}: {e}")
42+
3143
DEFAULT_SETTING = "(default)"
3244
INFO_PANEL_SCALES = {True: 8, False: 4}
3345

@@ -113,6 +125,7 @@ def __init__(self, app):
113125
value=getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False)
114126
)
115127
self._info_panel_expanded = gr.State(value=True)
128+
self._command_state = gr.State(value=None)
116129

117130
def on_building_ui(self):
118131
with gr.Row():
@@ -299,6 +312,7 @@ def on_register_events(self):
299312
# file selector from the first index
300313
self._indices_input[0],
301314
self._indices_input[1],
315+
self._command_state,
302316
],
303317
concurrency_limit=20,
304318
show_progress="hidden",
@@ -315,6 +329,7 @@ def on_register_events(self):
315329
self.citation,
316330
self.language,
317331
self.state_chat,
332+
self._command_state,
318333
self._app.user_id,
319334
]
320335
+ self._indices_input,
@@ -647,13 +662,19 @@ def submit_msg(
647662

648663
chat_input_text = chat_input.get("text", "")
649664
file_ids = []
665+
used_command = None
650666

651667
first_selector_choices_map = {
652668
item[0]: item[1] for item in first_selector_choices
653669
}
654670

655671
# get all file names with pattern @"filename" in input_str
656672
file_names, chat_input_text = get_file_names_regex(chat_input_text)
673+
674+
# check if web search command is in file_names
675+
if WEB_SEARCH_COMMAND in file_names:
676+
used_command = WEB_SEARCH_COMMAND
677+
657678
# get all urls in input_str
658679
urls, chat_input_text = get_urls(chat_input_text)
659680

@@ -707,13 +728,17 @@ def submit_msg(
707728
conv_update = gr.update()
708729
new_conv_name = conv_name
709730

710-
return [
711-
{},
712-
chat_history,
713-
new_conv_id,
714-
conv_update,
715-
new_conv_name,
716-
] + selector_output
731+
return (
732+
[
733+
{},
734+
chat_history,
735+
new_conv_id,
736+
conv_update,
737+
new_conv_name,
738+
]
739+
+ selector_output
740+
+ [used_command]
741+
)
717742

718743
def toggle_delete(self, conv_id):
719744
if conv_id:
@@ -877,6 +902,7 @@ def create_pipeline(
877902
session_use_citation: str,
878903
session_language: str,
879904
state: dict,
905+
command_state: str | None,
880906
user_id: int,
881907
*selecteds,
882908
):
@@ -934,17 +960,26 @@ def create_pipeline(
934960

935961
# get retrievers
936962
retrievers = []
937-
for index in self._app.index_manager.indices:
938-
index_selected = []
939-
if isinstance(index.selector, int):
940-
index_selected = selecteds[index.selector]
941-
if isinstance(index.selector, tuple):
942-
for i in index.selector:
943-
index_selected.append(selecteds[i])
944-
iretrievers = index.get_retriever_pipelines(
945-
settings, user_id, index_selected
946-
)
947-
retrievers += iretrievers
963+
964+
if command_state == WEB_SEARCH_COMMAND:
965+
# set retriever for web search
966+
if not WebSearch:
967+
raise ValueError("Web search back-end is not available.")
968+
969+
web_search = WebSearch()
970+
retrievers.append(web_search)
971+
else:
972+
for index in self._app.index_manager.indices:
973+
index_selected = []
974+
if isinstance(index.selector, int):
975+
index_selected = selecteds[index.selector]
976+
if isinstance(index.selector, tuple):
977+
for i in index.selector:
978+
index_selected.append(selecteds[i])
979+
iretrievers = index.get_retriever_pipelines(
980+
settings, user_id, index_selected
981+
)
982+
retrievers += iretrievers
948983

949984
# prepare states
950985
reasoning_state = {
@@ -966,7 +1001,8 @@ def chat_fn(
9661001
use_mind_map,
9671002
use_citation,
9681003
language,
969-
state,
1004+
chat_state,
1005+
command_state,
9701006
user_id,
9711007
*selecteds,
9721008
):
@@ -976,7 +1012,7 @@ def chat_fn(
9761012

9771013
# if chat_input is empty, assume regen mode
9781014
if chat_output:
979-
state["app"]["regen"] = True
1015+
chat_state["app"]["regen"] = True
9801016

9811017
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
9821018

@@ -988,7 +1024,8 @@ def chat_fn(
9881024
use_mind_map,
9891025
use_citation,
9901026
language,
991-
state,
1027+
chat_state,
1028+
command_state,
9921029
user_id,
9931030
*selecteds,
9941031
)
@@ -1005,7 +1042,7 @@ def chat_fn(
10051042
refs,
10061043
plot_gr,
10071044
plot,
1008-
state,
1045+
chat_state,
10091046
)
10101047

10111048
for response in pipeline.stream(chat_input, conversation_id, chat_history):
@@ -1032,14 +1069,14 @@ def chat_fn(
10321069
plot = response.content
10331070
plot_gr = self._json_to_plot(plot)
10341071

1035-
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
1072+
chat_state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
10361073

10371074
yield (
10381075
chat_history + [(chat_input, text or msg_placeholder)],
10391076
refs,
10401077
plot_gr,
10411078
plot,
1042-
state,
1079+
chat_state,
10431080
)
10441081

10451082
if not text:
@@ -1052,7 +1089,7 @@ def chat_fn(
10521089
refs,
10531090
plot_gr,
10541091
plot,
1055-
state,
1092+
chat_state,
10561093
)
10571094

10581095
def check_and_suggest_name_conv(self, chat_history):

libs/ktem/ktem/pages/chat/chat_panel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def on_building_ui(self):
2525
interactive=True,
2626
scale=20,
2727
file_count="multiple",
28-
placeholder="Type a message (or tag a file with @filename)",
28+
placeholder=(
29+
"Type a message, or search the @web, " "tag a file with @filename"
30+
),
2931
container=False,
3032
show_label=False,
3133
elem_id="chat-input",

libs/ktem/ktem/utils/commands.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
WEB_SEARCH_COMMAND = "web"

0 commit comments

Comments
 (0)