Skip to content

Commit b2f35b6

Browse files
authored
♻️ Refactor history sanitizing (#12)
1 parent 529bcf5 commit b2f35b6

File tree

4 files changed

+50
-42
lines changed

4 files changed

+50
-42
lines changed

app/server/chat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,14 @@ async def create_chat_completion(
108108
logger.exception(f"Error generating content from Gemini API: {e}")
109109
raise
110110

111-
# Format and clean the output
111+
# Format the response from API
112112
model_output = GeminiClientWrapper.extract_output(response, include_thoughts=True)
113113
stored_output = GeminiClientWrapper.extract_output(response, include_thoughts=False)
114114

115-
# After cleaning, persist the conversation
115+
# After formatting, persist the conversation to LMDB
116116
try:
117117
last_message = Message(role="assistant", content=stored_output)
118-
cleaned_history = db.clean_assistant_messages(request.messages)
118+
cleaned_history = db.sanitize_assistant_messages(request.messages)
119119
conv = ConversationInStore(
120120
model=model.model_name,
121121
client_id=client.id,

app/services/client.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def __init__(self, client_id: str, **kwargs):
1616
self.id = client_id
1717

1818
async def init(self, **kwargs):
19-
# Inject default configuration values
19+
"""
20+
Inject default configuration values.
21+
"""
2022
kwargs.setdefault("timeout", g_config.gemini.timeout)
2123
kwargs.setdefault("auto_refresh", g_config.gemini.auto_refresh)
2224
kwargs.setdefault("verbose", g_config.gemini.verbose)
@@ -67,7 +69,9 @@ async def process_message(
6769
return model_input, files
6870

6971
@staticmethod
70-
async def process_conversation(messages: list[Message], tempdir: Path | None = None):
72+
async def process_conversation(
73+
messages: list[Message], tempdir: Path | None = None
74+
) -> tuple[str, list[Path | str]]:
7175
"""
7276
Process the entire conversation and return a formatted string and list of
7377
files. The last message is assumed to be the assistant's response.
@@ -86,7 +90,7 @@ async def process_conversation(messages: list[Message], tempdir: Path | None = N
8690
return "\n".join(conversation), files
8791

8892
@staticmethod
89-
def extract_output(response: ModelOutput, include_thoughts: bool = True):
93+
def extract_output(response: ModelOutput, include_thoughts: bool = True) -> str:
9094
"""
9195
Extract and format the output text from the Gemini response.
9296
"""

app/services/lmdb.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import hashlib
2+
import re
23
from contextlib import contextmanager
34
from datetime import datetime
45
from pathlib import Path
@@ -11,31 +12,24 @@
1112
from ..models import ConversationInStore, Message
1213
from ..utils import g_config
1314
from ..utils.singleton import Singleton
14-
import re
1515

16-
def _normalize_content(content: str) -> str:
17-
"""Remove <think>...</think> tags and strip whitespace from content."""
18-
# Remove think tags
19-
cleaned_content = re.sub(r"<think>.*?</think>\n?", "", content, flags=re.DOTALL)
20-
# Strip leading/trailing whitespace
21-
return cleaned_content.strip()
2216

23-
def hash_message(message: Message) -> str:
17+
def _hash_message(message: Message) -> str:
2418
"""Generate a hash for a single message."""
2519
# Convert message to dict and sort keys for consistent hashing
2620
message_dict = message.model_dump(mode="json")
2721
message_bytes = orjson.dumps(message_dict, option=orjson.OPT_SORT_KEYS)
2822
return hashlib.sha256(message_bytes).hexdigest()
2923

3024

31-
def hash_conversation(client_id: str, model: str, messages: List[Message]) -> str:
25+
def _hash_conversation(client_id: str, model: str, messages: List[Message]) -> str:
3226
"""Generate a hash for a list of messages and client id."""
3327
# Create a combined hash from all individual message hashes
3428
combined_hash = hashlib.sha256()
3529
combined_hash.update(client_id.encode("utf-8"))
3630
combined_hash.update(model.encode("utf-8"))
3731
for message in messages:
38-
message_hash = hash_message(message)
32+
message_hash = _hash_message(message)
3933
combined_hash.update(message_hash.encode("utf-8"))
4034
return combined_hash.hexdigest()
4135

@@ -123,7 +117,7 @@ def store(
123117
raise ValueError("Messages list cannot be empty")
124118

125119
# Generate hash for the message list
126-
message_hash = hash_conversation(conv.client_id, conv.model, conv.messages)
120+
message_hash = _hash_conversation(conv.client_id, conv.model, conv.messages)
127121
storage_key = custom_key or message_hash
128122

129123
# Prepare data for storage
@@ -178,23 +172,6 @@ def get(self, key: str) -> Optional[ConversationInStore]:
178172
logger.error(f"Failed to retrieve messages for key {key}: {e}")
179173
return None
180174

181-
def clean_assistant_messages(self, messages: List[Message]) -> List[Message]:
182-
"""Create a new list of messages with assistant content cleaned."""
183-
cleaned_messages = []
184-
for msg in messages:
185-
if msg.role == "assistant" and isinstance(msg.content, str):
186-
# Create a new Message object with cleaned content
187-
normalized_content = _normalize_content(msg.content)
188-
# Only create a new object if content actually changed
189-
if normalized_content != msg.content:
190-
cleaned_msg = Message(role=msg.role, content=normalized_content, name=msg.name)
191-
cleaned_messages.append(cleaned_msg)
192-
else:
193-
cleaned_messages.append(msg)
194-
else:
195-
cleaned_messages.append(msg)
196-
return cleaned_messages
197-
198175
def find(self, model: str, messages: List[Message]) -> Optional[ConversationInStore]:
199176
"""
200177
Search conversation data by message list.
@@ -215,7 +192,7 @@ def find(self, model: str, messages: List[Message]) -> Optional[ConversationInSt
215192
return conv
216193

217194
# --- Find with cleaned messages ---
218-
cleaned_messages = self.clean_assistant_messages(messages)
195+
cleaned_messages = self.sanitize_assistant_messages(messages)
219196
if conv := self._find_by_message_list(model, cleaned_messages):
220197
logger.debug("Found conversation with cleaned message history.")
221198
return conv
@@ -228,14 +205,12 @@ def _find_by_message_list(
228205
) -> Optional[ConversationInStore]:
229206
"""Internal find implementation based on a message list."""
230207
for c in g_config.gemini.clients:
231-
message_hash = hash_conversation(c.id, model, messages)
208+
message_hash = _hash_conversation(c.id, model, messages)
232209

233210
key = f"{self.HASH_LOOKUP_PREFIX}{message_hash}"
234211
try:
235212
with self._get_transaction(write=False) as txn:
236-
mapped = txn.get(key.encode("utf-8"))
237-
if mapped:
238-
logger.debug(f"Found mapped key '{mapped.decode('utf-8')}' for hash '{message_hash}'.")
213+
if mapped := txn.get(key.encode("utf-8")): # type: ignore
239214
return self.get(mapped.decode("utf-8")) # type: ignore
240215
except Exception as e:
241216
logger.error(
@@ -283,7 +258,7 @@ def delete(self, key: str) -> Optional[ConversationInStore]:
283258

284259
storage_data = orjson.loads(data) # type: ignore
285260
conv = ConversationInStore.model_validate(storage_data)
286-
message_hash = hash_conversation(conv.client_id, conv.model, conv.messages)
261+
message_hash = _hash_conversation(conv.client_id, conv.model, conv.messages)
287262

288263
# Delete main data
289264
txn.delete(key.encode("utf-8"))
@@ -362,3 +337,32 @@ def close(self) -> None:
362337
def __del__(self):
363338
"""Cleanup on destruction."""
364339
self.close()
340+
341+
@staticmethod
342+
def remove_think_tags(text: str) -> str:
343+
"""
344+
Remove <think>...</think> tags at the start of text and strip whitespace.
345+
"""
346+
cleaned_content = re.sub(r"^(\s*<think>.*?</think>\n?)", "", text, flags=re.DOTALL)
347+
return cleaned_content.strip()
348+
349+
@staticmethod
350+
def sanitize_assistant_messages(messages: list[Message]) -> list[Message]:
351+
"""
352+
Create a new list of messages with assistant content cleaned of <think> tags.
353+
This is useful for store the chat history.
354+
"""
355+
cleaned_messages = []
356+
for msg in messages:
357+
if msg.role == "assistant" and isinstance(msg.content, str):
358+
normalized_content = LMDBConversationStore.remove_think_tags(msg.content)
359+
# Only create a new object if content actually changed
360+
if normalized_content != msg.content:
361+
cleaned_msg = Message(role=msg.role, content=normalized_content, name=msg.name)
362+
cleaned_messages.append(cleaned_msg)
363+
else:
364+
cleaned_messages.append(msg)
365+
else:
366+
cleaned_messages.append(msg)
367+
368+
return cleaned_messages

app/utils/helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def add_tag(role: str, content: str, unclose: bool = False) -> str:
1616

1717

1818
def estimate_tokens(text: str) -> int:
19-
# TODO: Refactor this function to use a proper tokenizer
20-
return len(text.split())
19+
"""Estimate the number of tokens heuristically based on character count"""
20+
return int(len(text) / 3)
2121

2222

2323
async def save_file_to_tempfile(

0 commit comments

Comments
 (0)