11import asyncio
2+ import importlib
23import json
34import re
45from copy import deepcopy
2324from kotaemon .indices .ingests .files import KH_DEFAULT_FILE_EXTRACTORS
2425
2526from ...utils import SUPPORTED_LANGUAGE_MAP , get_file_names_regex , get_urls
27+ from ...utils .commands import WEB_SEARCH_COMMAND
2628from .chat_panel import ChatPanel
2729from .common import STATE
2830from .control import ConversationControl
2931from .report import ReportIssue
3032
33+ KH_WEB_SEARCH_BACKEND = getattr (flowsettings , "KH_WEB_SEARCH_BACKEND" , None )
34+ WebSearch = None
35+ if KH_WEB_SEARCH_BACKEND :
36+ try :
37+ module_name , class_name = KH_WEB_SEARCH_BACKEND .rsplit ("." , 1 )
38+ module = importlib .import_module (module_name )
39+ WebSearch = getattr (module , class_name )
40+ except (ImportError , AttributeError ) as e :
41+ print (f"Error importing { class_name } from { module_name } : { e } " )
42+
3143DEFAULT_SETTING = "(default)"
3244INFO_PANEL_SCALES = {True : 8 , False : 4 }
3345
@@ -113,6 +125,7 @@ def __init__(self, app):
113125 value = getattr (flowsettings , "KH_FEATURE_CHAT_SUGGESTION" , False )
114126 )
115127 self ._info_panel_expanded = gr .State (value = True )
128+ self ._command_state = gr .State (value = None )
116129
117130 def on_building_ui (self ):
118131 with gr .Row ():
@@ -299,6 +312,7 @@ def on_register_events(self):
299312 # file selector from the first index
300313 self ._indices_input [0 ],
301314 self ._indices_input [1 ],
315+ self ._command_state ,
302316 ],
303317 concurrency_limit = 20 ,
304318 show_progress = "hidden" ,
@@ -315,6 +329,7 @@ def on_register_events(self):
315329 self .citation ,
316330 self .language ,
317331 self .state_chat ,
332+ self ._command_state ,
318333 self ._app .user_id ,
319334 ]
320335 + self ._indices_input ,
@@ -647,13 +662,19 @@ def submit_msg(
647662
648663 chat_input_text = chat_input .get ("text" , "" )
649664 file_ids = []
665+ used_command = None
650666
651667 first_selector_choices_map = {
652668 item [0 ]: item [1 ] for item in first_selector_choices
653669 }
654670
655671 # get all file names with pattern @"filename" in input_str
656672 file_names , chat_input_text = get_file_names_regex (chat_input_text )
673+
674+ # check if web search command is in file_names
675+ if WEB_SEARCH_COMMAND in file_names :
676+ used_command = WEB_SEARCH_COMMAND
677+
657678 # get all urls in input_str
658679 urls , chat_input_text = get_urls (chat_input_text )
659680
@@ -707,13 +728,17 @@ def submit_msg(
707728 conv_update = gr .update ()
708729 new_conv_name = conv_name
709730
710- return [
711- {},
712- chat_history ,
713- new_conv_id ,
714- conv_update ,
715- new_conv_name ,
716- ] + selector_output
731+ return (
732+ [
733+ {},
734+ chat_history ,
735+ new_conv_id ,
736+ conv_update ,
737+ new_conv_name ,
738+ ]
739+ + selector_output
740+ + [used_command ]
741+ )
717742
718743 def toggle_delete (self , conv_id ):
719744 if conv_id :
@@ -877,6 +902,7 @@ def create_pipeline(
877902 session_use_citation : str ,
878903 session_language : str ,
879904 state : dict ,
905+ command_state : str | None ,
880906 user_id : int ,
881907 * selecteds ,
882908 ):
@@ -934,17 +960,26 @@ def create_pipeline(
934960
935961 # get retrievers
936962 retrievers = []
937- for index in self ._app .index_manager .indices :
938- index_selected = []
939- if isinstance (index .selector , int ):
940- index_selected = selecteds [index .selector ]
941- if isinstance (index .selector , tuple ):
942- for i in index .selector :
943- index_selected .append (selecteds [i ])
944- iretrievers = index .get_retriever_pipelines (
945- settings , user_id , index_selected
946- )
947- retrievers += iretrievers
963+
964+ if command_state == WEB_SEARCH_COMMAND :
965+ # set retriever for web search
966+ if not WebSearch :
967+ raise ValueError ("Web search back-end is not available." )
968+
969+ web_search = WebSearch ()
970+ retrievers .append (web_search )
971+ else :
972+ for index in self ._app .index_manager .indices :
973+ index_selected = []
974+ if isinstance (index .selector , int ):
975+ index_selected = selecteds [index .selector ]
976+ if isinstance (index .selector , tuple ):
977+ for i in index .selector :
978+ index_selected .append (selecteds [i ])
979+ iretrievers = index .get_retriever_pipelines (
980+ settings , user_id , index_selected
981+ )
982+ retrievers += iretrievers
948983
949984 # prepare states
950985 reasoning_state = {
@@ -966,7 +1001,8 @@ def chat_fn(
9661001 use_mind_map ,
9671002 use_citation ,
9681003 language ,
969- state ,
1004+ chat_state ,
1005+ command_state ,
9701006 user_id ,
9711007 * selecteds ,
9721008 ):
@@ -976,7 +1012,7 @@ def chat_fn(
9761012
9771013 # if chat_input is empty, assume regen mode
9781014 if chat_output :
979- state ["app" ]["regen" ] = True
1015+ chat_state ["app" ]["regen" ] = True
9801016
9811017 queue : asyncio .Queue [Optional [dict ]] = asyncio .Queue ()
9821018
@@ -988,7 +1024,8 @@ def chat_fn(
9881024 use_mind_map ,
9891025 use_citation ,
9901026 language ,
991- state ,
1027+ chat_state ,
1028+ command_state ,
9921029 user_id ,
9931030 * selecteds ,
9941031 )
@@ -1005,7 +1042,7 @@ def chat_fn(
10051042 refs ,
10061043 plot_gr ,
10071044 plot ,
1008- state ,
1045+ chat_state ,
10091046 )
10101047
10111048 for response in pipeline .stream (chat_input , conversation_id , chat_history ):
@@ -1032,14 +1069,14 @@ def chat_fn(
10321069 plot = response .content
10331070 plot_gr = self ._json_to_plot (plot )
10341071
1035- state [pipeline .get_info ()["id" ]] = reasoning_state ["pipeline" ]
1072+ chat_state [pipeline .get_info ()["id" ]] = reasoning_state ["pipeline" ]
10361073
10371074 yield (
10381075 chat_history + [(chat_input , text or msg_placeholder )],
10391076 refs ,
10401077 plot_gr ,
10411078 plot ,
1042- state ,
1079+ chat_state ,
10431080 )
10441081
10451082 if not text :
@@ -1052,7 +1089,7 @@ def chat_fn(
10521089 refs ,
10531090 plot_gr ,
10541091 plot ,
1055- state ,
1092+ chat_state ,
10561093 )
10571094
10581095 def check_and_suggest_name_conv (self , chat_history ):
0 commit comments