From 729ec1cd0520e0d1e9e75deab7311aaeacf7300d Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 17 Dec 2024 17:34:13 -0600 Subject: [PATCH 01/14] Improvements to token usage reporting --- chatlas/_anthropic.py | 37 ++++++++++++++++-- chatlas/_chat.py | 61 +++++++++++++++++++++++++++--- chatlas/_google.py | 27 +++++++++++++- chatlas/_openai.py | 48 ++++++++++++++++++++++-- chatlas/_provider.py | 9 +++++ chatlas/_tokens.py | 87 ------------------------------------------- pyproject.toml | 1 + tests/test_tokens.py | 34 ----------------- 8 files changed, 169 insertions(+), 135 deletions(-) delete mode 100644 chatlas/_tokens.py delete mode 100644 tests/test_tokens.py diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index bdb6cbb6..a439badb 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -18,9 +18,8 @@ ) from ._logging import log_model_default from ._provider import Provider -from ._tokens import tokens_log from ._tools import Tool, basemodel_to_param_schema -from ._turn import Turn, normalize_turns +from ._turn import Turn, normalize_turns, user_turn if TYPE_CHECKING: from anthropic.types import ( @@ -380,6 +379,38 @@ async def stream_turn_async(self, completion, has_data_model, stream) -> Turn: def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) + def token_count( + self, + *args: Content | str, + tools: dict[str, Tool], + has_data_model: bool, + ) -> int: + turn = user_turn(*args) + + kwargs = self._chat_perform_args( + stream=False, + turns=[turn], + tools=tools, + data_model=None if not has_data_model else BaseModel, + ) + + args_to_keep = [ + "messages", + "model", + "system", + "tools", + "tool_choice", + ] + + kwargs_final = {} + for arg in args_to_keep: + if arg in kwargs: + kwargs_final[arg] = kwargs[arg] + + res = self._client.messages.count_tokens(**kwargs_final) + + return res.input_tokens + def _as_message_params(self, turns: list[Turn]) -> list["MessageParam"]: messages: list["MessageParam"] = [] for turn in turns: @@ -476,8 +507,6 @@ def _as_turn(self, completion: Message, has_data_model=False) -> Turn: tokens = completion.usage.input_tokens, completion.usage.output_tokens - tokens_log(self, tokens) - return Turn( "assistant", contents, diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 6efffa0b..c3595e21 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -176,17 +176,66 @@ def system_prompt(self, value: str | None): if value is not None: self._turns.insert(0, Turn("system", value)) - def tokens(self) -> list[tuple[int, int] | None]: + @property + def token_usage(self) -> tuple[int, int]: """ - Get the tokens for each turn in the chat. + Get the current token usage for the chat. + + Returns + ------- + tuple[int, int] + The input and output token usage for the chat. + """ + turn = self.get_last_turn(role="assistant") + if turn is None: + return 0, 0 + return turn.tokens or (0, 0) + + def token_count( + self, + *args: Content | str, + extract_data: bool = False, + ) -> int: + """ + Get the token count for the given input. + + Get the token count for the given input. This can be useful for + understanding how many tokens your input will cost before sending it to + the model. + + Parameters + ---------- + args + The input to get a token count for. + extract_data + Whether or not the input is for data extraction (i.e., `.extract_data()`). Returns ------- - list[tuple[int, int] | None] - A list of tuples, where each tuple contains the start and end token - indices for a turn. + int + The token count for the input. + + Examples + -------- + ```python + from chatlas import ChatOpenAI + + chat = ChatOpenAI() + # Estimate the token count before sending the input + print(chat.token_count("What is 2 + 2?")) + + # Once input is sent, you can get the actual input and output token counts + chat.chat("What is 2 + 2?", echo="none") + print(chat.last_turn().tokens) + ``` + """ - return [turn.tokens for turn in self._turns] + + return self.provider.token_count( + *args, + tools=self._tools, + has_data_model=extract_data, + ) def app( self, diff --git a/chatlas/_google.py b/chatlas/_google.py index f4e96d86..8491443c 100644 --- a/chatlas/_google.py +++ b/chatlas/_google.py @@ -18,7 +18,7 @@ from ._logging import log_model_default from ._provider import Provider from ._tools import Tool, basemodel_to_param_schema -from ._turn import Turn, normalize_turns +from ._turn import Turn, normalize_turns, user_turn if TYPE_CHECKING: from google.generativeai.types.content_types import ( @@ -332,6 +332,31 @@ async def stream_turn_async( def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) + def token_count( + self, + *args: Content | str, + tools: dict[str, Tool], + has_data_model: bool, + ): + turn = user_turn(*args) + + kwargs = self._chat_perform_args( + stream=False, + turns=[turn], + tools=tools, + data_model=None if not has_data_model else BaseModel, + ) + + args_to_keep = ["contents", "tools"] + + kwargs_final = {} + for arg in args_to_keep: + if arg in kwargs: + kwargs_final[arg] = kwargs[arg] + + res = self._client.count_tokens(**kwargs_final) + return res.total_tokens + def _google_contents(self, turns: list[Turn]) -> list["ContentDict"]: contents: list["ContentDict"] = [] for turn in turns: diff --git a/chatlas/_openai.py b/chatlas/_openai.py index 370ffb74..869e2a01 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -8,6 +8,7 @@ from ._chat import Chat from ._content import ( Content, + ContentImage, ContentImageInline, ContentImageRemote, ContentJson, @@ -18,7 +19,6 @@ from ._logging import log_model_default from ._merge import merge_dicts from ._provider import Provider -from ._tokens import tokens_log from ._tools import Tool, basemodel_to_param_schema from ._turn import Turn, normalize_turns from ._utils import MISSING, MISSING_TYPE, is_testing @@ -349,6 +349,50 @@ async def stream_turn_async(self, completion, has_data_model, stream): def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) + def token_count( + self, + *args: Content | str, + tools: dict[str, Tool], + has_data_model: bool, + ) -> int: + try: + import tiktoken + except ImportError: + raise ImportError( + "The tiktoken package is required for token counting. " + "Please install it with `pip install tiktoken`." + ) + + encoding = tiktoken.encoding_for_model(self._model) + + res: int = 0 + for arg in args: + if isinstance(arg, str): + res += len(encoding.encode(arg)) + elif isinstance(arg, ContentText): + res += len(encoding.encode(arg.text)) + elif isinstance(arg, ContentImage): + res += self._image_token_count(arg) + elif isinstance(arg, ContentToolResult): + res += len(encoding.encode(arg.get_final_value())) + else: + raise NotImplementedError( + f"Token counting for {type(arg)} not yet implemented." + ) + + return res + + @staticmethod + def _image_token_count(image: ContentImage) -> int: + if isinstance(image, ContentImageRemote) and image.detail == "low": + return 85 + else: + # This is just the max token count for an image The highest possible + # resolution is 768 x 2048, and 8 tiles of size 512px can fit inside + # TODO: this is obviously a very conservative estimate and could be improved + # https://platform.openai.com/docs/guides/vision/calculating-costs + return 170 * 8 + 85 + @staticmethod def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]: from openai.types.chat import ( @@ -506,8 +550,6 @@ def _as_turn( usage = completion.x_groq["usage"] # type: ignore tokens = usage["prompt_tokens"], usage["completion_tokens"] - tokens_log(self, tokens) - return Turn( "assistant", contents, diff --git a/chatlas/_provider.py b/chatlas/_provider.py index 1cc6c18a..5dae85a5 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -14,6 +14,7 @@ from pydantic import BaseModel +from ._content import Content from ._tools import Tool from ._turn import Turn @@ -141,3 +142,11 @@ def value_turn( completion: ChatCompletionT, has_data_model: bool, ) -> Turn: ... + + @abstractmethod + def token_count( + self, + *args: Content | str, + tools: dict[str, Tool], + has_data_model: bool, + ) -> int: ... diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py deleted file mode 100644 index 35bb0bd6..00000000 --- a/chatlas/_tokens.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -import copy -from threading import Lock -from typing import TYPE_CHECKING - -from ._logging import logger -from ._typing_extensions import TypedDict - -if TYPE_CHECKING: - from ._provider import Provider - - -class TokenUsage(TypedDict): - """ - Token usage for a given provider (name). - """ - - name: str - input: int - output: int - - -class ThreadSafeTokenCounter: - def __init__(self): - self._lock = Lock() - self._tokens: dict[str, TokenUsage] = {} - - def log_tokens(self, name: str, input_tokens: int, output_tokens: int) -> None: - logger.info( - f"Provider '{name}' generated a response of {output_tokens} tokens " - f"from an input of {input_tokens} tokens." - ) - - with self._lock: - if name not in self._tokens: - self._tokens[name] = { - "name": name, - "input": input_tokens, - "output": output_tokens, - } - else: - self._tokens[name]["input"] += input_tokens - self._tokens[name]["output"] += output_tokens - - def get_usage(self) -> list[TokenUsage] | None: - with self._lock: - if not self._tokens: - return None - # Create a deep copy to avoid external modifications - return copy.deepcopy(list(self._tokens.values())) - - -# Global instance -_token_counter = ThreadSafeTokenCounter() - - -def tokens_log(provider: "Provider", tokens: tuple[int, int]) -> None: - """ - Log token usage for a provider in a thread-safe manner. - """ - name = provider.__class__.__name__.replace("Provider", "") - _token_counter.log_tokens(name, tokens[0], tokens[1]) - - -def tokens_reset() -> None: - """ - Reset the token usage counter - """ - global _token_counter # noqa: PLW0603 - _token_counter = ThreadSafeTokenCounter() - - -def token_usage() -> list[TokenUsage] | None: - """ - Report on token usage in the current session - - Call this function to find out the cumulative number of tokens that you - have sent and received in the current session. - - Returns - ------- - list[TokenUsage] | None - A list of dictionaries with the following keys: "name", "input", and "output". - If no tokens have been logged, then None is returned. - """ - return _token_counter.get_usage() diff --git a/pyproject.toml b/pyproject.toml index 4b4e608d..c0c8b4a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dev = [ "anthropic[bedrock]", "google-generativeai>=0.8.3", "numpy>1.24.4", + "tiktoken", ] docs = [ "griffe>=1", diff --git a/tests/test_tokens.py b/tests/test_tokens.py deleted file mode 100644 index 8495c1ca..00000000 --- a/tests/test_tokens.py +++ /dev/null @@ -1,34 +0,0 @@ -from chatlas._openai import OpenAIAzureProvider, OpenAIProvider -from chatlas._tokens import token_usage, tokens_log, tokens_reset - - -def test_usage_is_none(): - tokens_reset() - assert token_usage() is None - - -def test_can_retrieve_and_log_tokens(): - tokens_reset() - - provider = OpenAIProvider(model="foo") - - tokens_log(provider, (10, 50)) - tokens_log(provider, (0, 10)) - usage = token_usage() - assert usage is not None - assert len(usage) == 1 - assert usage[0]["name"] == "OpenAI" - assert usage[0]["input"] == 10 - assert usage[0]["output"] == 60 - - provider2 = OpenAIAzureProvider(endpoint="foo", api_version="bar") - - tokens_log(provider2, (5, 25)) - usage = token_usage() - assert usage is not None - assert len(usage) == 2 - assert usage[1]["name"] == "OpenAIAzure" - assert usage[1]["input"] == 5 - assert usage[1]["output"] == 25 - - tokens_reset() From e59a4ad7a30010102d47be12bf8a5e31d370aee1 Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 17 Dec 2024 17:43:12 -0600 Subject: [PATCH 02/14] Update changelog --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 616af04a..64832e6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] +### Breaking changes + +* The `token_usage()` and `tokens_reset()` functions have been removed. Use the new `.token_usage()` method on the `Chat` instance instead. (#23) +* The `.tokens()` method on the `Chat` instance was removed because you usually only care about `.token_usage()`. If you do indeed want the input/output tokens for each turn, you can `.get_turns()` on the chat instance, and then get the `.tokens` of each turn. (#23) + ### New features +* The `Chat` class gains a `.token_count()` method to help estimate input tokens before sending it to the LLM. (#23) + ### Bug fixes * `ChatOllama` no longer fails when a `OPENAI_API_KEY` environment variable is not set. From 710e6fe6c90d316e94a89f49ea4ef50593c5a98a Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 17 Dec 2024 17:50:17 -0600 Subject: [PATCH 03/14] Clean up docstring --- chatlas/_chat.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index c3595e21..5984b094 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -199,9 +199,8 @@ def token_count( """ Get the token count for the given input. - Get the token count for the given input. This can be useful for - understanding how many tokens your input will cost before sending it to - the model. + This is useful for estimating the number of tokens your input will cost + before sending it to the model. Parameters ---------- @@ -224,11 +223,11 @@ def token_count( # Estimate the token count before sending the input print(chat.token_count("What is 2 + 2?")) - # Once input is sent, you can get the actual input and output token counts + # Once input is sent, you can get the actual input and output + # token counts from the chat object chat.chat("What is 2 + 2?", echo="none") - print(chat.last_turn().tokens) + print(chat.token_usage) ``` - """ return self.provider.token_count( From 970fe36771ccb8b3d9854b29c64b12664b5d7715 Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 17 Dec 2024 17:53:27 -0600 Subject: [PATCH 04/14] Make token_usage() a method not a property Just in case we want parameters --- chatlas/_chat.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 5984b094..6edf5f07 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -176,7 +176,6 @@ def system_prompt(self, value: str | None): if value is not None: self._turns.insert(0, Turn("system", value)) - @property def token_usage(self) -> tuple[int, int]: """ Get the current token usage for the chat. @@ -199,7 +198,7 @@ def token_count( """ Get the token count for the given input. - This is useful for estimating the number of tokens your input will cost + This is useful for estimating the number of tokens your input will cost before sending it to the model. Parameters @@ -226,7 +225,7 @@ def token_count( # Once input is sent, you can get the actual input and output # token counts from the chat object chat.chat("What is 2 + 2?", echo="none") - print(chat.token_usage) + print(chat.token_usage()) ``` """ From 5136c958212f71981876e4e37315409e43a429e3 Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 17 Dec 2024 17:54:23 -0600 Subject: [PATCH 05/14] Fix imports --- chatlas/__init__.py | 2 -- chatlas/types/__init__.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/chatlas/__init__.py b/chatlas/__init__.py index ef921da5..a11f60a2 100644 --- a/chatlas/__init__.py +++ b/chatlas/__init__.py @@ -10,7 +10,6 @@ from ._openai import ChatAzureOpenAI, ChatOpenAI from ._perplexity import ChatPerplexity from ._provider import Provider -from ._tokens import token_usage from ._tools import Tool from ._turn import Turn @@ -31,7 +30,6 @@ "interpolate", "interpolate_file", "Provider", - "token_usage", "Tool", "Turn", "types", diff --git a/chatlas/types/__init__.py b/chatlas/types/__init__.py index 2b66c82e..874e5db0 100644 --- a/chatlas/types/__init__.py +++ b/chatlas/types/__init__.py @@ -10,7 +10,6 @@ ContentToolResult, ImageContentTypes, ) -from .._tokens import TokenUsage from .._utils import MISSING, MISSING_TYPE __all__ = ( @@ -26,7 +25,6 @@ "ChatResponseAsync", "ImageContentTypes", "SubmitInputArgsT", - "TokenUsage", "MISSING_TYPE", "MISSING", ) From be43913af1a6059cf12d965165d200b479071ffb Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 17 Dec 2024 18:23:45 -0600 Subject: [PATCH 06/14] Rollback breaking changes --- CHANGELOG.md | 5 --- chatlas/__init__.py | 2 + chatlas/_anthropic.py | 3 ++ chatlas/_chat.py | 15 +++---- chatlas/_google.py | 3 ++ chatlas/_openai.py | 3 ++ chatlas/_tokens.py | 87 +++++++++++++++++++++++++++++++++++++++ chatlas/types/__init__.py | 2 + tests/test_tokens.py | 34 +++++++++++++++ 9 files changed, 140 insertions(+), 14 deletions(-) create mode 100644 chatlas/_tokens.py create mode 100644 tests/test_tokens.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 64832e6f..1dd9a9ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,11 +9,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] -### Breaking changes - -* The `token_usage()` and `tokens_reset()` functions have been removed. Use the new `.token_usage()` method on the `Chat` instance instead. (#23) -* The `.tokens()` method on the `Chat` instance was removed because you usually only care about `.token_usage()`. If you do indeed want the input/output tokens for each turn, you can `.get_turns()` on the chat instance, and then get the `.tokens` of each turn. (#23) - ### New features * The `Chat` class gains a `.token_count()` method to help estimate input tokens before sending it to the LLM. (#23) diff --git a/chatlas/__init__.py b/chatlas/__init__.py index a11f60a2..ef921da5 100644 --- a/chatlas/__init__.py +++ b/chatlas/__init__.py @@ -10,6 +10,7 @@ from ._openai import ChatAzureOpenAI, ChatOpenAI from ._perplexity import ChatPerplexity from ._provider import Provider +from ._tokens import token_usage from ._tools import Tool from ._turn import Turn @@ -30,6 +31,7 @@ "interpolate", "interpolate_file", "Provider", + "token_usage", "Tool", "Turn", "types", diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index a439badb..f87ac742 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -18,6 +18,7 @@ ) from ._logging import log_model_default from ._provider import Provider +from ._tokens import tokens_log from ._tools import Tool, basemodel_to_param_schema from ._turn import Turn, normalize_turns, user_turn @@ -507,6 +508,8 @@ def _as_turn(self, completion: Message, has_data_model=False) -> Turn: tokens = completion.usage.input_tokens, completion.usage.output_tokens + tokens_log(self, tokens) + return Turn( "assistant", contents, diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 6edf5f07..3b9119a8 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -176,19 +176,16 @@ def system_prompt(self, value: str | None): if value is not None: self._turns.insert(0, Turn("system", value)) - def token_usage(self) -> tuple[int, int]: + def tokens(self) -> list[tuple[int, int] | None]: """ - Get the current token usage for the chat. - + Get the tokens for each turn in the chat. Returns ------- - tuple[int, int] - The input and output token usage for the chat. + list[tuple[int, int] | None] + A list of tuples, where each tuple contains the start and end token + indices for a turn. """ - turn = self.get_last_turn(role="assistant") - if turn is None: - return 0, 0 - return turn.tokens or (0, 0) + return [turn.tokens for turn in self._turns] def token_count( self, diff --git a/chatlas/_google.py b/chatlas/_google.py index 8491443c..c126d891 100644 --- a/chatlas/_google.py +++ b/chatlas/_google.py @@ -17,6 +17,7 @@ ) from ._logging import log_model_default from ._provider import Provider +from ._tokens import tokens_log from ._tools import Tool, basemodel_to_param_schema from ._turn import Turn, normalize_turns, user_turn @@ -446,6 +447,8 @@ def _as_turn( usage.candidates_token_count, ) + tokens_log(self, tokens) + finish = message.candidates[0].finish_reason return Turn( diff --git a/chatlas/_openai.py b/chatlas/_openai.py index 869e2a01..eb983e61 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -19,6 +19,7 @@ from ._logging import log_model_default from ._merge import merge_dicts from ._provider import Provider +from ._tokens import tokens_log from ._tools import Tool, basemodel_to_param_schema from ._turn import Turn, normalize_turns from ._utils import MISSING, MISSING_TYPE, is_testing @@ -550,6 +551,8 @@ def _as_turn( usage = completion.x_groq["usage"] # type: ignore tokens = usage["prompt_tokens"], usage["completion_tokens"] + tokens_log(self, tokens) + return Turn( "assistant", contents, diff --git a/chatlas/_tokens.py b/chatlas/_tokens.py new file mode 100644 index 00000000..35bb0bd6 --- /dev/null +++ b/chatlas/_tokens.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import copy +from threading import Lock +from typing import TYPE_CHECKING + +from ._logging import logger +from ._typing_extensions import TypedDict + +if TYPE_CHECKING: + from ._provider import Provider + + +class TokenUsage(TypedDict): + """ + Token usage for a given provider (name). + """ + + name: str + input: int + output: int + + +class ThreadSafeTokenCounter: + def __init__(self): + self._lock = Lock() + self._tokens: dict[str, TokenUsage] = {} + + def log_tokens(self, name: str, input_tokens: int, output_tokens: int) -> None: + logger.info( + f"Provider '{name}' generated a response of {output_tokens} tokens " + f"from an input of {input_tokens} tokens." + ) + + with self._lock: + if name not in self._tokens: + self._tokens[name] = { + "name": name, + "input": input_tokens, + "output": output_tokens, + } + else: + self._tokens[name]["input"] += input_tokens + self._tokens[name]["output"] += output_tokens + + def get_usage(self) -> list[TokenUsage] | None: + with self._lock: + if not self._tokens: + return None + # Create a deep copy to avoid external modifications + return copy.deepcopy(list(self._tokens.values())) + + +# Global instance +_token_counter = ThreadSafeTokenCounter() + + +def tokens_log(provider: "Provider", tokens: tuple[int, int]) -> None: + """ + Log token usage for a provider in a thread-safe manner. + """ + name = provider.__class__.__name__.replace("Provider", "") + _token_counter.log_tokens(name, tokens[0], tokens[1]) + + +def tokens_reset() -> None: + """ + Reset the token usage counter + """ + global _token_counter # noqa: PLW0603 + _token_counter = ThreadSafeTokenCounter() + + +def token_usage() -> list[TokenUsage] | None: + """ + Report on token usage in the current session + + Call this function to find out the cumulative number of tokens that you + have sent and received in the current session. + + Returns + ------- + list[TokenUsage] | None + A list of dictionaries with the following keys: "name", "input", and "output". + If no tokens have been logged, then None is returned. + """ + return _token_counter.get_usage() diff --git a/chatlas/types/__init__.py b/chatlas/types/__init__.py index 874e5db0..2b66c82e 100644 --- a/chatlas/types/__init__.py +++ b/chatlas/types/__init__.py @@ -10,6 +10,7 @@ ContentToolResult, ImageContentTypes, ) +from .._tokens import TokenUsage from .._utils import MISSING, MISSING_TYPE __all__ = ( @@ -25,6 +26,7 @@ "ChatResponseAsync", "ImageContentTypes", "SubmitInputArgsT", + "TokenUsage", "MISSING_TYPE", "MISSING", ) diff --git a/tests/test_tokens.py b/tests/test_tokens.py new file mode 100644 index 00000000..8495c1ca --- /dev/null +++ b/tests/test_tokens.py @@ -0,0 +1,34 @@ +from chatlas._openai import OpenAIAzureProvider, OpenAIProvider +from chatlas._tokens import token_usage, tokens_log, tokens_reset + + +def test_usage_is_none(): + tokens_reset() + assert token_usage() is None + + +def test_can_retrieve_and_log_tokens(): + tokens_reset() + + provider = OpenAIProvider(model="foo") + + tokens_log(provider, (10, 50)) + tokens_log(provider, (0, 10)) + usage = token_usage() + assert usage is not None + assert len(usage) == 1 + assert usage[0]["name"] == "OpenAI" + assert usage[0]["input"] == 10 + assert usage[0]["output"] == 60 + + provider2 = OpenAIAzureProvider(endpoint="foo", api_version="bar") + + tokens_log(provider2, (5, 25)) + usage = token_usage() + assert usage is not None + assert len(usage) == 2 + assert usage[1]["name"] == "OpenAIAzure" + assert usage[1]["input"] == 5 + assert usage[1]["output"] == 25 + + tokens_reset() From 71d7d83f98783417231c1c8590e427f0bc91f64b Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 17 Dec 2024 18:27:21 -0600 Subject: [PATCH 07/14] Cleanup --- CHANGELOG.md | 1 + chatlas/_chat.py | 1 + 2 files changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1dd9a9ca..70f62b86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `ChatOllama` no longer fails when a `OPENAI_API_KEY` environment variable is not set. * `ChatOpenAI` now correctly includes the relevant `detail` on `ContentImageRemote()` input. +* `ChatGoogle` now correctly logs its `token_usage()`. (#23) ## [0.2.0] - 2024-12-11 diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 3b9119a8..bc6986f0 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -179,6 +179,7 @@ def system_prompt(self, value: str | None): def tokens(self) -> list[tuple[int, int] | None]: """ Get the tokens for each turn in the chat. + Returns ------- list[tuple[int, int] | None] From 69436ffa30ca0fec1d8d37a85a91ee7d237b4cc9 Mon Sep 17 00:00:00 2001 From: Carson Sievert Date: Thu, 19 Dec 2024 09:51:18 -0600 Subject: [PATCH 08/14] Doc improvements --- chatlas/_chat.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index bc6986f0..e54be9ec 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -194,10 +194,11 @@ def token_count( extract_data: bool = False, ) -> int: """ - Get the token count for the given input. + Get an estimated token count for the given input. - This is useful for estimating the number of tokens your input will cost - before sending it to the model. + Estimate the token size of input content. This can help determine whether input(s) + and/or conversation history (i.e., `.get_turns()`) should be reduced in size before + sending it to the model. Parameters ---------- From 1d27b5b4bed72d0f85c5855ad77800a64c3e10ba Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 19 Dec 2024 10:22:18 -0600 Subject: [PATCH 09/14] Add .token_count_async(); require the whole data_model --- chatlas/_anthropic.py | 41 +++++++++++++++++++++++++++++---------- chatlas/_chat.py | 45 ++++++++++++++++++++++++++++++++++++++----- chatlas/_google.py | 42 +++++++++++++++++++++++++++++++--------- chatlas/_openai.py | 10 +++++++++- chatlas/_provider.py | 10 +++++++++- 5 files changed, 122 insertions(+), 26 deletions(-) diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index f87ac742..997aa1bc 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -384,15 +384,43 @@ def token_count( self, *args: Content | str, tools: dict[str, Tool], - has_data_model: bool, + data_model: Optional[type[BaseModel]], ) -> int: + kwargs = self._token_count_args( + *args, + tools=tools, + data_model=data_model, + ) + res = self._client.messages.count_tokens(**kwargs) + return res.input_tokens + + async def token_count_async( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ) -> int: + kwargs = self._token_count_args( + *args, + tools=tools, + data_model=data_model, + ) + res = await self._async_client.messages.count_tokens(**kwargs) + return res.input_tokens + + def _token_count_args( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ) -> dict[str, Any]: turn = user_turn(*args) kwargs = self._chat_perform_args( stream=False, turns=[turn], tools=tools, - data_model=None if not has_data_model else BaseModel, + data_model=data_model, ) args_to_keep = [ @@ -403,14 +431,7 @@ def token_count( "tool_choice", ] - kwargs_final = {} - for arg in args_to_keep: - if arg in kwargs: - kwargs_final[arg] = kwargs[arg] - - res = self._client.messages.count_tokens(**kwargs_final) - - return res.input_tokens + return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs} def _as_message_params(self, turns: list[Turn]) -> list["MessageParam"]: messages: list["MessageParam"] = [] diff --git a/chatlas/_chat.py b/chatlas/_chat.py index e54be9ec..9dbadda8 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -191,12 +191,12 @@ def tokens(self) -> list[tuple[int, int] | None]: def token_count( self, *args: Content | str, - extract_data: bool = False, + data_model: Optional[type[BaseModel]] = None, ) -> int: """ Get an estimated token count for the given input. - Estimate the token size of input content. This can help determine whether input(s) + Estimate the token size of input content. This can help determine whether input(s) and/or conversation history (i.e., `.get_turns()`) should be reduced in size before sending it to the model. @@ -204,8 +204,10 @@ def token_count( ---------- args The input to get a token count for. - extract_data - Whether or not the input is for data extraction (i.e., `.extract_data()`). + data_model + If the input is meant for data extraction (i.e., `.extract_data()`), then + this should be the Pydantic model that describes the structure of the data to + extract. Returns ------- @@ -231,7 +233,40 @@ def token_count( return self.provider.token_count( *args, tools=self._tools, - has_data_model=extract_data, + data_model=data_model, + ) + + async def token_count_async( + self, + *args: Content | str, + data_model: Optional[type[BaseModel]] = None, + ) -> int: + """ + Get an estimated token count for the given input asynchronously. + + Estimate the token size of input content. This can help determine whether input(s) + and/or conversation history (i.e., `.get_turns()`) should be reduced in size before + sending it to the model. + + Parameters + ---------- + args + The input to get a token count for. + data_model + If this input is meant for data extraction (i.e., `.extract_data_async()`), + then this should be the Pydantic model that describes the structure of the data + to extract. + + Returns + ------- + int + The token count for the input. + """ + + return await self.provider.token_count_async( + *args, + tools=self._tools, + data_model=data_model, ) def app( diff --git a/chatlas/_google.py b/chatlas/_google.py index c126d891..313048f7 100644 --- a/chatlas/_google.py +++ b/chatlas/_google.py @@ -337,26 +337,50 @@ def token_count( self, *args: Content | str, tools: dict[str, Tool], - has_data_model: bool, + data_model: Optional[type[BaseModel]], + ): + kwargs = self._token_count_args( + *args, + tools=tools, + data_model=data_model, + ) + + res = self._client.count_tokens(**kwargs) + return res.total_tokens + + async def token_count_async( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], ): + kwargs = self._token_count_args( + *args, + tools=tools, + data_model=data_model, + ) + + res = await self._client.count_tokens_async(**kwargs) + return res.total_tokens + + def _token_count_args( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ) -> dict[str, Any]: turn = user_turn(*args) kwargs = self._chat_perform_args( stream=False, turns=[turn], tools=tools, - data_model=None if not has_data_model else BaseModel, + data_model=data_model, ) args_to_keep = ["contents", "tools"] - kwargs_final = {} - for arg in args_to_keep: - if arg in kwargs: - kwargs_final[arg] = kwargs[arg] - - res = self._client.count_tokens(**kwargs_final) - return res.total_tokens + return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs} def _google_contents(self, turns: list[Turn]) -> list["ContentDict"]: contents: list["ContentDict"] = [] diff --git a/chatlas/_openai.py b/chatlas/_openai.py index eb983e61..1b97e44f 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -354,7 +354,7 @@ def token_count( self, *args: Content | str, tools: dict[str, Tool], - has_data_model: bool, + data_model: Optional[type[BaseModel]], ) -> int: try: import tiktoken @@ -383,6 +383,14 @@ def token_count( return res + async def token_count_async( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ) -> int: + return self.token_count(*args, tools=tools, data_model=data_model) + @staticmethod def _image_token_count(image: ContentImage) -> int: if isinstance(image, ContentImageRemote) and image.detail == "low": diff --git a/chatlas/_provider.py b/chatlas/_provider.py index 5dae85a5..95cf88cb 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -148,5 +148,13 @@ def token_count( self, *args: Content | str, tools: dict[str, Tool], - has_data_model: bool, + data_model: Optional[type[BaseModel]], + ) -> int: ... + + @abstractmethod + async def token_count_async( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], ) -> int: ... From cb225ae927a260fdb5708220ab6b3ef8d7a43a24 Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 19 Dec 2024 11:46:36 -0600 Subject: [PATCH 10/14] Slightly more accurate/conservative token count for OpenAI --- chatlas/_openai.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/chatlas/_openai.py b/chatlas/_openai.py index 1b97e44f..ef1e49f4 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -21,7 +21,7 @@ from ._provider import Provider from ._tokens import tokens_log from ._tools import Tool, basemodel_to_param_schema -from ._turn import Turn, normalize_turns +from ._turn import Turn, normalize_turns, user_turn from ._utils import MISSING, MISSING_TYPE, is_testing if TYPE_CHECKING: @@ -366,22 +366,21 @@ def token_count( encoding = tiktoken.encoding_for_model(self._model) - res: int = 0 - for arg in args: - if isinstance(arg, str): - res += len(encoding.encode(arg)) - elif isinstance(arg, ContentText): - res += len(encoding.encode(arg.text)) - elif isinstance(arg, ContentImage): - res += self._image_token_count(arg) - elif isinstance(arg, ContentToolResult): - res += len(encoding.encode(arg.get_final_value())) - else: - raise NotImplementedError( - f"Token counting for {type(arg)} not yet implemented." - ) + turn = user_turn(*args) - return res + # Count the tokens in image contents + image_tokens = sum( + self._image_token_count(x) + for x in turn.contents + if isinstance(x, ContentImage) + ) + + # For other contents, get the token count from the actual message param + other_contents = [x for x in turn.contents if not isinstance(x, ContentImage)] + other_full = self._as_message_param([Turn("user", other_contents)]) + other_tokens = len(encoding.encode(str(other_full))) + + return other_tokens + image_tokens async def token_count_async( self, From 8aa36042c66ebc8ec83f0baa29d2adf8060039d7 Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 19 Dec 2024 17:21:24 -0600 Subject: [PATCH 11/14] Add tests --- tests/test_tokens.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 2c4295c4..47d260e8 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,4 +1,4 @@ -from chatlas import ChatOpenAI, Turn +from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI, Turn from chatlas._openai import OpenAIAzureProvider, OpenAIProvider from chatlas._tokens import token_usage, tokens_log, tokens_reset @@ -26,10 +26,20 @@ def test_tokens_method(): ) assert chat.tokens(values="discrete") == [2, 10, 2, 10] - assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)] +def test_token_count_method(): + chat = ChatOpenAI(model="gpt-4o-mini") + assert chat.token_count("What is 1 + 1?") == 31 + + chat = ChatAnthropic(model="claude-3-5-sonnet-20241022") + assert chat.token_count("What is 1 + 1?") == 16 + + chat = ChatGoogle(model="gemini-1.5-flash") + assert chat.token_count("What is 1 + 1?") == 9 + + def test_usage_is_none(): tokens_reset() assert token_usage() is None From c6ccc0b1653636bf4111fdd2cf50801c4fe71c31 Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 19 Dec 2024 17:28:48 -0600 Subject: [PATCH 12/14] Add note --- chatlas/_chat.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 6ba08c6a..6d48aa51 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -320,6 +320,12 @@ def token_count( int The token count for the input. + Note + ---- + Remember that the token count is an estimate. Also, models based on + `ChatOpenAI()` currently does not take tools into account when + estimating token counts. + Examples -------- ```python From bc9548970cbe762073aef2c589402840584dd985 Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 19 Dec 2024 17:32:55 -0600 Subject: [PATCH 13/14] Tweak changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a422b01b..33f1d468 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features -* The `Chat` class gains a `.token_count()` method to help estimate token cost of new input before generating a response for it. (#23) * `Chat`'s `.tokens()` method gains a `values` argument. Set it to `"discrete"` to get a result that can be summed to determine the token cost of submitting the current turns. The default (`"cumulative"`), remains the same (the result can be summed to determine the overall token cost of the conversation). +* `Chat` gains a `.token_count()` method to help estimate token cost of new input. (#23) ### Bug fixes From 94fdfc808a68ea6c5aa2f08c9374d1c073b112c3 Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 19 Dec 2024 17:34:37 -0600 Subject: [PATCH 14/14] Tweak docstring --- chatlas/_chat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 6d48aa51..8dfd109a 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -329,9 +329,9 @@ def token_count( Examples -------- ```python - from chatlas import ChatOpenAI + from chatlas import ChatAnthropic - chat = ChatOpenAI() + chat = ChatAnthropic() # Estimate the token count before sending the input print(chat.token_count("What is 2 + 2?"))