@@ -176,17 +176,87 @@ def system_prompt(self, value: str | None):
176176 if value is not None :
177177 self ._turns .insert (0 , Turn ("system" , value ))
178178
179- def tokens (self ) -> list [tuple [ int , int ] | None ]:
179+ def tokens (self ) -> list [int ]:
180180 """
181181 Get the tokens for each turn in the chat.
182182
183183 Returns
184184 -------
185- list[tuple[int, int] | None]
186- A list of tuples, where each tuple contains the start and end token
187- indices for a turn.
185+ list[int]
186+ A list of token counts for each turn in the chat. Note that the
187+ 1st turn includes the tokens count for the system prompt (if any).
188+
189+ Raises
190+ ------
191+ ValueError
192+ If the chat's turns (i.e., `.get_turns()`) are not in an expected
193+ format. This may happen if the chat history is manually set (i.e.,
194+ `.set_turns()`). In this case, you can inspect the "raw" token
195+ values via the `.get_turns()` method (each turn has a `.tokens`
196+ attribute).
188197 """
189- return [turn .tokens for turn in self ._turns ]
198+
199+ turns = self .get_turns (include_system_prompt = False )
200+
201+ if len (turns ) == 0 :
202+ return []
203+
204+ err_info = (
205+ "This can happen if the chat history is manually set (i.e., `.set_turns()`). "
206+ "Consider getting the 'raw' token values via the `.get_turns()` method "
207+ "(each turn has a `.tokens` attribute)."
208+ )
209+
210+ # Sanity checks for the assumptions made to figure out user token counts
211+ if len (turns ) == 1 :
212+ raise ValueError (
213+ "Expected at least two turns in the chat history. " + err_info
214+ )
215+
216+ if len (turns ) % 2 != 0 :
217+ raise ValueError (
218+ "Expected an even number of turns in the chat history. " + err_info
219+ )
220+
221+ if turns [0 ].role != "user" :
222+ raise ValueError ("Expected the first turn to have role='user'. " + err_info )
223+
224+ if turns [1 ].role != "assistant" :
225+ raise ValueError (
226+ "Expected the 2nd turn to have role='assistant'. " + err_info
227+ )
228+
229+ if turns [1 ].tokens is None :
230+ raise ValueError (
231+ "Expected the 1st assistant turn to contain token counts. " + err_info
232+ )
233+
234+ tokens : list [int ] = [
235+ turns [1 ].tokens [0 ],
236+ sum (turns [1 ].tokens ),
237+ ]
238+ for i in range (1 , len (turns ) - 1 , 2 ):
239+ ti = turns [i ]
240+ tj = turns [i + 2 ]
241+ if ti .role != "assistant" or tj .role != "assistant" :
242+ raise ValueError (
243+ "Expected even turns to have role='assistant'." + err_info
244+ )
245+ if ti .tokens is None or tj .tokens is None :
246+ raise ValueError (
247+ "Expected role='assistant' turns to contain token counts."
248+ + err_info
249+ )
250+ tokens .extend (
251+ [
252+ # Implied token count for the user input
253+ tj .tokens [0 ] - sum (ti .tokens ),
254+ # The token count for the assistant response
255+ tj .tokens [1 ],
256+ ]
257+ )
258+
259+ return tokens
190260
191261 def app (
192262 self ,
0 commit comments