@@ -176,17 +176,91 @@ def system_prompt(self, value: str | None):
176176 if value is not None :
177177 self ._turns .insert (0 , Turn ("system" , value ))
178178
179- def tokens (self ) -> list [tuple [ int , int ] | None ]:
179+ def tokens (self ) -> list [int ]:
180180 """
181181 Get the tokens for each turn in the chat.
182182
183183 Returns
184184 -------
185- list[tuple[int, int] | None]
186- A list of tuples, where each tuple contains the start and end token
187- indices for a turn.
185+ list[int]
186+ A list of token counts for each (non-system )turn in the chat. The
187+ 1st turn includes the tokens count for the system prompt (if any).
188+
189+ Raises
190+ ------
191+ ValueError
192+ If the chat's turns (i.e., `.get_turns()`) are not in an expected
193+ format. This may happen if the chat history is manually set (i.e.,
194+ `.set_turns()`). In this case, you can inspect the "raw" token
195+ values via the `.get_turns()` method (each turn has a `.tokens`
196+ attribute).
188197 """
189- return [turn .tokens for turn in self ._turns ]
198+
199+ turns = self .get_turns (include_system_prompt = False )
200+
201+ if len (turns ) == 0 :
202+ return []
203+
204+ err_info = (
205+ "This can happen if the chat history is manually set (i.e., `.set_turns()`). "
206+ "Consider getting the 'raw' token values via the `.get_turns()` method "
207+ "(each turn has a `.tokens` attribute)."
208+ )
209+
210+ # Sanity checks for the assumptions made to figure out user token counts
211+ if len (turns ) == 1 :
212+ raise ValueError (
213+ "Expected at least two turns in the chat history. " + err_info
214+ )
215+
216+ if len (turns ) % 2 != 0 :
217+ raise ValueError (
218+ "Expected an even number of turns in the chat history. " + err_info
219+ )
220+
221+ if turns [0 ].role != "user" :
222+ raise ValueError (
223+ "Expected the 1st non-system turn to have role='user'. " + err_info
224+ )
225+
226+ if turns [1 ].role != "assistant" :
227+ raise ValueError (
228+ "Expected the 2nd turn non-system to have role='assistant'. " + err_info
229+ )
230+
231+ if turns [1 ].tokens is None :
232+ raise ValueError (
233+ "Expected the 1st assistant turn to contain token counts. " + err_info
234+ )
235+
236+ res : list [int ] = [
237+ # Implied token count for the 1st user input
238+ turns [1 ].tokens [0 ],
239+ # The token count for the 1st assistant response
240+ turns [1 ].tokens [1 ],
241+ ]
242+ for i in range (1 , len (turns ) - 1 , 2 ):
243+ ti = turns [i ]
244+ tj = turns [i + 2 ]
245+ if ti .role != "assistant" or tj .role != "assistant" :
246+ raise ValueError (
247+ "Expected even turns to have role='assistant'." + err_info
248+ )
249+ if ti .tokens is None or tj .tokens is None :
250+ raise ValueError (
251+ "Expected role='assistant' turns to contain token counts."
252+ + err_info
253+ )
254+ res .extend (
255+ [
256+ # Implied token count for the user input
257+ tj .tokens [0 ] - sum (ti .tokens ),
258+ # The token count for the assistant response
259+ tj .tokens [1 ],
260+ ]
261+ )
262+
263+ return res
190264
191265 def app (
192266 self ,
0 commit comments