diff --git a/CHANGELOG.md b/CHANGELOG.md index 0991fb3f..33f1d468 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,11 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features * `Chat`'s `.tokens()` method gains a `values` argument. Set it to `"discrete"` to get a result that can be summed to determine the token cost of submitting the current turns. The default (`"cumulative"`), remains the same (the result can be summed to determine the overall token cost of the conversation). +* `Chat` gains a `.token_count()` method to help estimate token cost of new input. (#23) ### Bug fixes * `ChatOllama` no longer fails when a `OPENAI_API_KEY` environment variable is not set. * `ChatOpenAI` now correctly includes the relevant `detail` on `ContentImageRemote()` input. +* `ChatGoogle` now correctly logs its `token_usage()`. (#23) ## [0.2.0] - 2024-12-11 diff --git a/chatlas/_anthropic.py b/chatlas/_anthropic.py index bdb6cbb6..997aa1bc 100644 --- a/chatlas/_anthropic.py +++ b/chatlas/_anthropic.py @@ -20,7 +20,7 @@ from ._provider import Provider from ._tokens import tokens_log from ._tools import Tool, basemodel_to_param_schema -from ._turn import Turn, normalize_turns +from ._turn import Turn, normalize_turns, user_turn if TYPE_CHECKING: from anthropic.types import ( @@ -380,6 +380,59 @@ async def stream_turn_async(self, completion, has_data_model, stream) -> Turn: def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) + def token_count( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ) -> int: + kwargs = self._token_count_args( + *args, + tools=tools, + data_model=data_model, + ) + res = self._client.messages.count_tokens(**kwargs) + return res.input_tokens + + async def token_count_async( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ) -> int: + kwargs = self._token_count_args( + *args, + tools=tools, + data_model=data_model, + ) + res = await self._async_client.messages.count_tokens(**kwargs) + return res.input_tokens + + def _token_count_args( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ) -> dict[str, Any]: + turn = user_turn(*args) + + kwargs = self._chat_perform_args( + stream=False, + turns=[turn], + tools=tools, + data_model=data_model, + ) + + args_to_keep = [ + "messages", + "model", + "system", + "tools", + "tool_choice", + ] + + return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs} + def _as_message_params(self, turns: list[Turn]) -> list["MessageParam"]: messages: list["MessageParam"] = [] for turn in turns: diff --git a/chatlas/_chat.py b/chatlas/_chat.py index be1faaf5..8dfd109a 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -294,6 +294,93 @@ def tokens( return res + def token_count( + self, + *args: Content | str, + data_model: Optional[type[BaseModel]] = None, + ) -> int: + """ + Get an estimated token count for the given input. + + Estimate the token size of input content. This can help determine whether input(s) + and/or conversation history (i.e., `.get_turns()`) should be reduced in size before + sending it to the model. + + Parameters + ---------- + args + The input to get a token count for. + data_model + If the input is meant for data extraction (i.e., `.extract_data()`), then + this should be the Pydantic model that describes the structure of the data to + extract. + + Returns + ------- + int + The token count for the input. + + Note + ---- + Remember that the token count is an estimate. Also, models based on + `ChatOpenAI()` currently does not take tools into account when + estimating token counts. + + Examples + -------- + ```python + from chatlas import ChatAnthropic + + chat = ChatAnthropic() + # Estimate the token count before sending the input + print(chat.token_count("What is 2 + 2?")) + + # Once input is sent, you can get the actual input and output + # token counts from the chat object + chat.chat("What is 2 + 2?", echo="none") + print(chat.token_usage()) + ``` + """ + + return self.provider.token_count( + *args, + tools=self._tools, + data_model=data_model, + ) + + async def token_count_async( + self, + *args: Content | str, + data_model: Optional[type[BaseModel]] = None, + ) -> int: + """ + Get an estimated token count for the given input asynchronously. + + Estimate the token size of input content. This can help determine whether input(s) + and/or conversation history (i.e., `.get_turns()`) should be reduced in size before + sending it to the model. + + Parameters + ---------- + args + The input to get a token count for. + data_model + If this input is meant for data extraction (i.e., `.extract_data_async()`), + then this should be the Pydantic model that describes the structure of the data + to extract. + + Returns + ------- + int + The token count for the input. + """ + + return await self.provider.token_count_async( + *args, + tools=self._tools, + data_model=data_model, + ) + def app( self, *, diff --git a/chatlas/_google.py b/chatlas/_google.py index f4e96d86..313048f7 100644 --- a/chatlas/_google.py +++ b/chatlas/_google.py @@ -17,8 +17,9 @@ ) from ._logging import log_model_default from ._provider import Provider +from ._tokens import tokens_log from ._tools import Tool, basemodel_to_param_schema -from ._turn import Turn, normalize_turns +from ._turn import Turn, normalize_turns, user_turn if TYPE_CHECKING: from google.generativeai.types.content_types import ( @@ -332,6 +333,55 @@ async def stream_turn_async( def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) + def token_count( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ): + kwargs = self._token_count_args( + *args, + tools=tools, + data_model=data_model, + ) + + res = self._client.count_tokens(**kwargs) + return res.total_tokens + + async def token_count_async( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ): + kwargs = self._token_count_args( + *args, + tools=tools, + data_model=data_model, + ) + + res = await self._client.count_tokens_async(**kwargs) + return res.total_tokens + + def _token_count_args( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ) -> dict[str, Any]: + turn = user_turn(*args) + + kwargs = self._chat_perform_args( + stream=False, + turns=[turn], + tools=tools, + data_model=data_model, + ) + + args_to_keep = ["contents", "tools"] + + return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs} + def _google_contents(self, turns: list[Turn]) -> list["ContentDict"]: contents: list["ContentDict"] = [] for turn in turns: @@ -421,6 +471,8 @@ def _as_turn( usage.candidates_token_count, ) + tokens_log(self, tokens) + finish = message.candidates[0].finish_reason return Turn( diff --git a/chatlas/_openai.py b/chatlas/_openai.py index 75030092..8c9a2cc2 100644 --- a/chatlas/_openai.py +++ b/chatlas/_openai.py @@ -8,6 +8,7 @@ from ._chat import Chat from ._content import ( Content, + ContentImage, ContentImageInline, ContentImageRemote, ContentJson, @@ -20,7 +21,7 @@ from ._provider import Provider from ._tokens import tokens_log from ._tools import Tool, basemodel_to_param_schema -from ._turn import Turn, normalize_turns +from ._turn import Turn, normalize_turns, user_turn from ._utils import MISSING, MISSING_TYPE, is_testing if TYPE_CHECKING: @@ -351,6 +352,57 @@ async def stream_turn_async(self, completion, has_data_model, stream): def value_turn(self, completion, has_data_model) -> Turn: return self._as_turn(completion, has_data_model) + def token_count( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ) -> int: + try: + import tiktoken + except ImportError: + raise ImportError( + "The tiktoken package is required for token counting. " + "Please install it with `pip install tiktoken`." + ) + + encoding = tiktoken.encoding_for_model(self._model) + + turn = user_turn(*args) + + # Count the tokens in image contents + image_tokens = sum( + self._image_token_count(x) + for x in turn.contents + if isinstance(x, ContentImage) + ) + + # For other contents, get the token count from the actual message param + other_contents = [x for x in turn.contents if not isinstance(x, ContentImage)] + other_full = self._as_message_param([Turn("user", other_contents)]) + other_tokens = len(encoding.encode(str(other_full))) + + return other_tokens + image_tokens + + async def token_count_async( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ) -> int: + return self.token_count(*args, tools=tools, data_model=data_model) + + @staticmethod + def _image_token_count(image: ContentImage) -> int: + if isinstance(image, ContentImageRemote) and image.detail == "low": + return 85 + else: + # This is just the max token count for an image The highest possible + # resolution is 768 x 2048, and 8 tiles of size 512px can fit inside + # TODO: this is obviously a very conservative estimate and could be improved + # https://platform.openai.com/docs/guides/vision/calculating-costs + return 170 * 8 + 85 + @staticmethod def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]: from openai.types.chat import ( diff --git a/chatlas/_provider.py b/chatlas/_provider.py index 1cc6c18a..95cf88cb 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -14,6 +14,7 @@ from pydantic import BaseModel +from ._content import Content from ._tools import Tool from ._turn import Turn @@ -141,3 +142,19 @@ def value_turn( completion: ChatCompletionT, has_data_model: bool, ) -> Turn: ... + + @abstractmethod + def token_count( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ) -> int: ... + + @abstractmethod + async def token_count_async( + self, + *args: Content | str, + tools: dict[str, Tool], + data_model: Optional[type[BaseModel]], + ) -> int: ... diff --git a/pyproject.toml b/pyproject.toml index 4b4e608d..c0c8b4a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dev = [ "anthropic[bedrock]", "google-generativeai>=0.8.3", "numpy>1.24.4", + "tiktoken", ] docs = [ "griffe>=1", diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 2c4295c4..47d260e8 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,4 +1,4 @@ -from chatlas import ChatOpenAI, Turn +from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI, Turn from chatlas._openai import OpenAIAzureProvider, OpenAIProvider from chatlas._tokens import token_usage, tokens_log, tokens_reset @@ -26,10 +26,20 @@ def test_tokens_method(): ) assert chat.tokens(values="discrete") == [2, 10, 2, 10] - assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)] +def test_token_count_method(): + chat = ChatOpenAI(model="gpt-4o-mini") + assert chat.token_count("What is 1 + 1?") == 31 + + chat = ChatAnthropic(model="claude-3-5-sonnet-20241022") + assert chat.token_count("What is 1 + 1?") == 16 + + chat = ChatGoogle(model="gemini-1.5-flash") + assert chat.token_count("What is 1 + 1?") == 9 + + def test_usage_is_none(): tokens_reset() assert token_usage() is None