diff --git a/CHANGELOG.md b/CHANGELOG.md index 616af04a..0991fb3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features +* `Chat`'s `.tokens()` method gains a `values` argument. Set it to `"discrete"` to get a result that can be summed to determine the token cost of submitting the current turns. The default (`"cumulative"`), remains the same (the result can be summed to determine the overall token cost of the conversation). + ### Bug fixes * `ChatOllama` no longer fails when a `OPENAI_API_KEY` environment variable is not set. diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 6efffa0b..be1faaf5 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -16,6 +16,7 @@ Optional, Sequence, TypeVar, + overload, ) from pydantic import BaseModel @@ -176,17 +177,122 @@ def system_prompt(self, value: str | None): if value is not None: self._turns.insert(0, Turn("system", value)) - def tokens(self) -> list[tuple[int, int] | None]: + @overload + def tokens(self) -> list[tuple[int, int] | None]: ... + + @overload + def tokens( + self, + values: Literal["cumulative"], + ) -> list[tuple[int, int] | None]: ... + + @overload + def tokens( + self, + values: Literal["discrete"], + ) -> list[int]: ... + + def tokens( + self, + values: Literal["cumulative", "discrete"] = "discrete", + ) -> list[int] | list[tuple[int, int] | None]: """ Get the tokens for each turn in the chat. + Parameters + ---------- + values + If "cumulative" (the default), the result can be summed to get the + chat's overall token usage (helpful for computing overall cost of + the chat). If "discrete", the result can be summed to get the number of + tokens the turns will cost to generate the next response (helpful + for estimating cost of the next response, or for determining if you + are about to exceed the token limit). + Returns ------- - list[tuple[int, int] | None] - A list of tuples, where each tuple contains the start and end token - indices for a turn. + list[int] + A list of token counts for each (non-system) turn in the chat. The + 1st turn includes the tokens count for the system prompt (if any). + + Raises + ------ + ValueError + If the chat's turns (i.e., `.get_turns()`) are not in an expected + format. This may happen if the chat history is manually set (i.e., + `.set_turns()`). In this case, you can inspect the "raw" token + values via the `.get_turns()` method (each turn has a `.tokens` + attribute). """ - return [turn.tokens for turn in self._turns] + + turns = self.get_turns(include_system_prompt=False) + + if values == "cumulative": + return [turn.tokens for turn in turns] + + if len(turns) == 0: + return [] + + err_info = ( + "This can happen if the chat history is manually set (i.e., `.set_turns()`). " + "Consider getting the 'raw' token values via the `.get_turns()` method " + "(each turn has a `.tokens` attribute)." + ) + + # Sanity checks for the assumptions made to figure out user token counts + if len(turns) == 1: + raise ValueError( + "Expected at least two turns in the chat history. " + err_info + ) + + if len(turns) % 2 != 0: + raise ValueError( + "Expected an even number of turns in the chat history. " + err_info + ) + + if turns[0].role != "user": + raise ValueError( + "Expected the 1st non-system turn to have role='user'. " + err_info + ) + + if turns[1].role != "assistant": + raise ValueError( + "Expected the 2nd turn non-system to have role='assistant'. " + err_info + ) + + if turns[1].tokens is None: + raise ValueError( + "Expected the 1st assistant turn to contain token counts. " + err_info + ) + + res: list[int] = [ + # Implied token count for the 1st user input + turns[1].tokens[0], + # The token count for the 1st assistant response + turns[1].tokens[1], + ] + for i in range(1, len(turns) - 1, 2): + ti = turns[i] + tj = turns[i + 2] + if ti.role != "assistant" or tj.role != "assistant": + raise ValueError( + "Expected even turns to have role='assistant'." + err_info + ) + if ti.tokens is None or tj.tokens is None: + raise ValueError( + "Expected role='assistant' turns to contain token counts." + + err_info + ) + res.extend( + [ + # Implied token count for the user input + tj.tokens[0] - sum(ti.tokens), + # The token count for the assistant response + tj.tokens[1], + ] + ) + + return res def app( self, diff --git a/tests/test_provider_openai.py b/tests/test_provider_openai.py index 05c16a94..17468fbd 100644 --- a/tests/test_provider_openai.py +++ b/tests/test_provider_openai.py @@ -21,7 +21,7 @@ def test_openai_simple_request(): chat.chat("What is 1 + 1?") turn = chat.get_last_turn() assert turn is not None - assert turn.tokens == (27, 1) + assert turn.tokens == (27, 2) assert turn.finish_reason == "stop" diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 8495c1ca..2c4295c4 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,7 +1,35 @@ +from chatlas import ChatOpenAI, Turn from chatlas._openai import OpenAIAzureProvider, OpenAIProvider from chatlas._tokens import token_usage, tokens_log, tokens_reset +def test_tokens_method(): + chat = ChatOpenAI() + assert chat.tokens(values="discrete") == [] + + chat = ChatOpenAI( + turns=[ + Turn(role="user", contents="Hi"), + Turn(role="assistant", contents="Hello", tokens=(2, 10)), + ] + ) + + assert chat.tokens(values="discrete") == [2, 10] + + chat = ChatOpenAI( + turns=[ + Turn(role="user", contents="Hi"), + Turn(role="assistant", contents="Hello", tokens=(2, 10)), + Turn(role="user", contents="Hi"), + Turn(role="assistant", contents="Hello", tokens=(14, 10)), + ] + ) + + assert chat.tokens(values="discrete") == [2, 10, 2, 10] + + assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)] + + def test_usage_is_none(): tokens_reset() assert token_usage() is None