Skip to content

Commit 41d4065

Browse files
committed
The .tokens() method now returns a list of ints: where each int represents the number of tokens each turn takes
1 parent f5a300f commit 41d4065

File tree

1 file changed

+79
-5
lines changed

1 file changed

+79
-5
lines changed

chatlas/_chat.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,17 +176,91 @@ def system_prompt(self, value: str | None):
176176
if value is not None:
177177
self._turns.insert(0, Turn("system", value))
178178

179-
def tokens(self) -> list[tuple[int, int] | None]:
179+
def tokens(self) -> list[int]:
180180
"""
181181
Get the tokens for each turn in the chat.
182182
183183
Returns
184184
-------
185-
list[tuple[int, int] | None]
186-
A list of tuples, where each tuple contains the start and end token
187-
indices for a turn.
185+
list[int]
186+
A list of token counts for each (non-system )turn in the chat. The
187+
1st turn includes the tokens count for the system prompt (if any).
188+
189+
Raises
190+
------
191+
ValueError
192+
If the chat's turns (i.e., `.get_turns()`) are not in an expected
193+
format. This may happen if the chat history is manually set (i.e.,
194+
`.set_turns()`). In this case, you can inspect the "raw" token
195+
values via the `.get_turns()` method (each turn has a `.tokens`
196+
attribute).
188197
"""
189-
return [turn.tokens for turn in self._turns]
198+
199+
turns = self.get_turns(include_system_prompt=False)
200+
201+
if len(turns) == 0:
202+
return []
203+
204+
err_info = (
205+
"This can happen if the chat history is manually set (i.e., `.set_turns()`). "
206+
"Consider getting the 'raw' token values via the `.get_turns()` method "
207+
"(each turn has a `.tokens` attribute)."
208+
)
209+
210+
# Sanity checks for the assumptions made to figure out user token counts
211+
if len(turns) == 1:
212+
raise ValueError(
213+
"Expected at least two turns in the chat history. " + err_info
214+
)
215+
216+
if len(turns) % 2 != 0:
217+
raise ValueError(
218+
"Expected an even number of turns in the chat history. " + err_info
219+
)
220+
221+
if turns[0].role != "user":
222+
raise ValueError(
223+
"Expected the 1st non-system turn to have role='user'. " + err_info
224+
)
225+
226+
if turns[1].role != "assistant":
227+
raise ValueError(
228+
"Expected the 2nd turn non-system to have role='assistant'. " + err_info
229+
)
230+
231+
if turns[1].tokens is None:
232+
raise ValueError(
233+
"Expected the 1st assistant turn to contain token counts. " + err_info
234+
)
235+
236+
res: list[int] = [
237+
# Implied token count for the 1st user input
238+
turns[1].tokens[0],
239+
# The token count for the 1st assistant response
240+
turns[1].tokens[1],
241+
]
242+
for i in range(1, len(turns) - 1, 2):
243+
ti = turns[i]
244+
tj = turns[i + 2]
245+
if ti.role != "assistant" or tj.role != "assistant":
246+
raise ValueError(
247+
"Expected even turns to have role='assistant'." + err_info
248+
)
249+
if ti.tokens is None or tj.tokens is None:
250+
raise ValueError(
251+
"Expected role='assistant' turns to contain token counts."
252+
+ err_info
253+
)
254+
res.extend(
255+
[
256+
# Implied token count for the user input
257+
tj.tokens[0] - sum(ti.tokens),
258+
# The token count for the assistant response
259+
tj.tokens[1],
260+
]
261+
)
262+
263+
return res
190264

191265
def app(
192266
self,

0 commit comments

Comments
 (0)