diff --git a/src/bot_backend.py b/src/bot_backend.py index 1f8b9a1..1dae292 100644 --- a/src/bot_backend.py +++ b/src/bot_backend.py @@ -5,7 +5,9 @@ import shutil from jupyter_backend import * from typing import * -from notebook_serializer import add_markdown_to_notebook, add_code_cell_to_notebook +from notebook_serializer import add_markdown_to_notebook, add_code_cell_to_notebook, nb +from notebook_serializer import notebook_path +from bs4 import BeautifulSoup functions = [ { @@ -115,7 +117,10 @@ class BotBackend(GPTResponseLog): def __init__(self): super().__init__() self.unique_id = hash(id(self)) - self.jupyter_work_dir = f'cache/work_dir_{self.unique_id}' + if notebook_path: + self.jupyter_work_dir = os.path.dirname(notebook_path) + else: + self.jupyter_work_dir = f'cache/work_dir_{self.unique_id}' self.jupyter_kernel = JupyterKernel(work_dir=self.jupyter_work_dir) self.gpt_model_choice = "GPT-3.5" self.revocable_files = [] @@ -123,6 +128,63 @@ def __init__(self): self._init_api_config() self._init_kwargs_for_chat_completion() + for cell in nb['cells']: + if cell['cell_type'] == 'code': + _, _ = self.jupyter_kernel.execute_code(cell['source']) + self.conversation.append( + {'role': "function", 'name': "python", 'content': cell['source']} + ) + + for output in cell['outputs']: + if output['output_type'] == 'display_data': + for mime_type, output_data in output['data'].items(): + if 'text' in mime_type: + if mime_type == 'text/html': + soup = BeautifulSoup(output_data, 'html.parser') + text_output = soup.get_text().strip() + else: + text_output = output_data + self.conversation.append( + { + "role": "function", + 'name': "python", + "content": text_output, + } + ) + if 'image' in mime_type: + self.conversation.append( + { + "role": "function", + 'name': "python", + "content": "[image]", + } + ) + if output['output_type'] == 'error': + for tracebak in output['traceback']: + self.conversation.append( + { + "role": "function", + "name": "tracebak", + "content": tracebak, + } + ) + + if cell['cell_type'] == 'markdown': + source = cell['source'] + if source.startswith("##### User:\n"): + stripped_source = source[len("#####User:\n")+1:] + self.conversation.append( + {'role': "user", 'content': stripped_source} + ) + if source.startswith("##### Assistant:\n"): + stripped_source = source[len("#####Assistant:\n")+1:] + self.conversation.append( + {'role': 'assistant', 'content': stripped_source} + ) + + print("conversation:", json.dumps(self.conversation, indent=1)) + + def _init_conversation(self): first_system_msg = {'role': 'system', 'content': system_msg} self.context_window_tokens = 0 # num of tokens actually sent to GPT @@ -175,6 +237,7 @@ def add_text_message(self, user_text): self.revocable_files.clear() self.update_finish_reason(finish_reason='new_input') add_markdown_to_notebook(user_text, title="User") + def add_file_message(self, path, bot_msg): filename = os.path.basename(path) diff --git a/src/functional.py b/src/functional.py index 1caf1ba..8401361 100644 --- a/src/functional.py +++ b/src/functional.py @@ -43,6 +43,7 @@ def chat_completion(bot_backend: BotBackend): assert model_name in config['model_context_window'], \ f"{model_name} lacks context window information. Please check the config.json file." + print(json.dumps(kwargs_for_chat_completion, indent=1)) response = openai.ChatCompletion.create(**kwargs_for_chat_completion) return response diff --git a/src/jupyter_backend.py b/src/jupyter_backend.py index c080d8a..4075d72 100644 --- a/src/jupyter_backend.py +++ b/src/jupyter_backend.py @@ -1,6 +1,7 @@ import jupyter_client import re - +import os +from notebook_serializer import notebook_path def delete_color_control_char(string): ansi_escape = re.compile(r'(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]') @@ -86,12 +87,18 @@ def execute_code(self, code): return '\n'.join(text_to_gpt), content_to_display def _create_work_dir(self): - # set work dir in jupyter environment - init_code = f"import os\n" \ - f"if not os.path.exists('{self.work_dir}'):\n" \ - f" os.mkdir('{self.work_dir}')\n" \ - f"os.chdir('{self.work_dir}')\n" \ - f"del os" + if notebook_path: + init_code = f""" + import os + os.chdir('{self.work_dir}') + """ + else: + # set work dir in jupyter environment + init_code = f"import os\n" \ + f"if not os.path.exists('{self.work_dir}'):\n" \ + f" os.mkdir('{self.work_dir}')\n" \ + f"os.chdir('{self.work_dir}')\n" \ + f"del os" self.execute_code_(init_code) def restart_jupyter_kernel(self): diff --git a/src/notebook_serializer.py b/src/notebook_serializer.py index 2d8da56..d115630 100644 --- a/src/notebook_serializer.py +++ b/src/notebook_serializer.py @@ -2,20 +2,54 @@ from nbformat import v4 as nbf import ansi2html import os +import json import argparse # main code parser = argparse.ArgumentParser() parser.add_argument("-n", "--notebook", help="Path to the output notebook", default=None, type=str) args = parser.parse_args() +nb = nbf.new_notebook() +notebook_path = "" + if args.notebook: notebook_path = os.path.join(os.getcwd(), args.notebook) base, ext = os.path.splitext(notebook_path) if ext.lower() != '.ipynb': notebook_path += '.ipynb' + + if os.path.isfile(notebook_path): + with open(notebook_path, 'r') as notebook_file: + nb = nbformat.read(notebook_file, as_version=4) -# Global variable for code cells -nb = nbf.new_notebook() +def desirialize_notebook_into_conv_history(): + history = [] + for cell in nb['cells']: + # Handle markdown + if cell['cell_type'] == 'markdown': + append_to_history(history, cell['source'], cell) + # Handle code + if cell['cell_type'] == 'code': + append_to_history(history, "```python\n" + cell['source'] + "\n```", cell) + # Handle outputs + for output in cell['outputs']: + # Handle display data + if output['output_type'] == 'display_data': + for mime_type, output_data in output['data'].items(): + if 'text' in mime_type: + append_to_history(history, output_data, cell) + # Handle error + if output['output_type'] == 'error': + for tracebak in output['traceback']: + append_to_history(history, ansi_to_html(tracebak), cell) + return history + +def append_to_history(history, obj, cell): + is_from_user = 'author' in cell['metadata'] and cell['metadata']['author'] == 'user' + if is_from_user: + history.append((obj, None)) + else: + history.append((None, obj)) def ansi_to_html(ansi_text): converter = ansi2html.Ansi2HTMLConverter() diff --git a/src/response_parser.py b/src/response_parser.py index 0e13773..3d63741 100644 --- a/src/response_parser.py +++ b/src/response_parser.py @@ -129,7 +129,6 @@ def execute(self, bot_backend: BotBackend, history: List, whether_exit: bool): bot_backend.function_name ](code_str) - # add function call to conversion bot_backend.add_function_call_response_message(function_response=text_to_gpt, save_tokens=True) add_function_response_to_bot_history( diff --git a/src/web_ui.py b/src/web_ui.py index 1e98fa2..01a1479 100644 --- a/src/web_ui.py +++ b/src/web_ui.py @@ -1,4 +1,5 @@ import gradio as gr +from notebook_serializer import desirialize_notebook_into_conv_history from response_parser import * @@ -137,7 +138,8 @@ def bot(state_dict: Dict, history: List) -> List: # UI components state = gr.State(value={"bot_backend": None}) with gr.Tab("Chat"): - chatbot = gr.Chatbot([], elem_id="chatbot", label="Local Code Interpreter", height=750) + history = desirialize_notebook_into_conv_history() + chatbot = gr.Chatbot(history, elem_id="chatbot", label="Local Code Interpreter", height=750) with gr.Row(): with gr.Column(scale=0.85): text_box = gr.Textbox(