Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### New features

* `Chat`'s `.tokens()` method gains a `values` argument. Set it to `"discrete"` to get a result that can be summed to determine the token cost of submitting the current turns. The default (`"cumulative"`), remains the same (the result can be summed to determine the overall token cost of the conversation).
* `Chat` gains a `.token_count()` method to help estimate token cost of new input. (#23)

### Bug fixes

* `ChatOllama` no longer fails when a `OPENAI_API_KEY` environment variable is not set.
* `ChatOpenAI` now correctly includes the relevant `detail` on `ContentImageRemote()` input.
* `ChatGoogle` now correctly logs its `token_usage()`. (#23)


## [0.2.0] - 2024-12-11
Expand Down
55 changes: 54 additions & 1 deletion chatlas/_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ._provider import Provider
from ._tokens import tokens_log
from ._tools import Tool, basemodel_to_param_schema
from ._turn import Turn, normalize_turns
from ._turn import Turn, normalize_turns, user_turn

if TYPE_CHECKING:
from anthropic.types import (
Expand Down Expand Up @@ -380,6 +380,59 @@ async def stream_turn_async(self, completion, has_data_model, stream) -> Turn:
def value_turn(self, completion, has_data_model) -> Turn:
return self._as_turn(completion, has_data_model)

def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int:
kwargs = self._token_count_args(
*args,
tools=tools,
data_model=data_model,
)
res = self._client.messages.count_tokens(**kwargs)
return res.input_tokens

async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int:
kwargs = self._token_count_args(
*args,
tools=tools,
data_model=data_model,
)
res = await self._async_client.messages.count_tokens(**kwargs)
return res.input_tokens

def _token_count_args(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> dict[str, Any]:
turn = user_turn(*args)

kwargs = self._chat_perform_args(
stream=False,
turns=[turn],
tools=tools,
data_model=data_model,
)

args_to_keep = [
"messages",
"model",
"system",
"tools",
"tool_choice",
]

return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs}

def _as_message_params(self, turns: list[Turn]) -> list["MessageParam"]:
messages: list["MessageParam"] = []
for turn in turns:
Expand Down
87 changes: 87 additions & 0 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,93 @@ def tokens(

return res

def token_count(
self,
*args: Content | str,
data_model: Optional[type[BaseModel]] = None,
) -> int:
"""
Get an estimated token count for the given input.

Estimate the token size of input content. This can help determine whether input(s)
and/or conversation history (i.e., `.get_turns()`) should be reduced in size before
sending it to the model.

Parameters
----------
args
The input to get a token count for.
data_model
If the input is meant for data extraction (i.e., `.extract_data()`), then
this should be the Pydantic model that describes the structure of the data to
extract.

Returns
-------
int
The token count for the input.

Note
----
Remember that the token count is an estimate. Also, models based on
`ChatOpenAI()` currently does not take tools into account when
estimating token counts.

Examples
--------
```python
from chatlas import ChatAnthropic

chat = ChatAnthropic()
# Estimate the token count before sending the input
print(chat.token_count("What is 2 + 2?"))

# Once input is sent, you can get the actual input and output
# token counts from the chat object
chat.chat("What is 2 + 2?", echo="none")
print(chat.token_usage())
```
"""

return self.provider.token_count(
*args,
tools=self._tools,
data_model=data_model,
)

async def token_count_async(
self,
*args: Content | str,
data_model: Optional[type[BaseModel]] = None,
) -> int:
"""
Get an estimated token count for the given input asynchronously.

Estimate the token size of input content. This can help determine whether input(s)
and/or conversation history (i.e., `.get_turns()`) should be reduced in size before
sending it to the model.

Parameters
----------
args
The input to get a token count for.
data_model
If this input is meant for data extraction (i.e., `.extract_data_async()`),
then this should be the Pydantic model that describes the structure of the data
to extract.

Returns
-------
int
The token count for the input.
"""

return await self.provider.token_count_async(
*args,
tools=self._tools,
data_model=data_model,
)

def app(
self,
*,
Expand Down
54 changes: 53 additions & 1 deletion chatlas/_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
)
from ._logging import log_model_default
from ._provider import Provider
from ._tokens import tokens_log
from ._tools import Tool, basemodel_to_param_schema
from ._turn import Turn, normalize_turns
from ._turn import Turn, normalize_turns, user_turn

if TYPE_CHECKING:
from google.generativeai.types.content_types import (
Expand Down Expand Up @@ -332,6 +333,55 @@ async def stream_turn_async(
def value_turn(self, completion, has_data_model) -> Turn:
return self._as_turn(completion, has_data_model)

def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
):
kwargs = self._token_count_args(
*args,
tools=tools,
data_model=data_model,
)

res = self._client.count_tokens(**kwargs)
return res.total_tokens

async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
):
kwargs = self._token_count_args(
*args,
tools=tools,
data_model=data_model,
)

res = await self._client.count_tokens_async(**kwargs)
return res.total_tokens

def _token_count_args(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> dict[str, Any]:
turn = user_turn(*args)

kwargs = self._chat_perform_args(
stream=False,
turns=[turn],
tools=tools,
data_model=data_model,
)

args_to_keep = ["contents", "tools"]

return {arg: kwargs[arg] for arg in args_to_keep if arg in kwargs}

def _google_contents(self, turns: list[Turn]) -> list["ContentDict"]:
contents: list["ContentDict"] = []
for turn in turns:
Expand Down Expand Up @@ -421,6 +471,8 @@ def _as_turn(
usage.candidates_token_count,
)

tokens_log(self, tokens)

finish = message.candidates[0].finish_reason

return Turn(
Expand Down
54 changes: 53 additions & 1 deletion chatlas/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ._chat import Chat
from ._content import (
Content,
ContentImage,
ContentImageInline,
ContentImageRemote,
ContentJson,
Expand All @@ -20,7 +21,7 @@
from ._provider import Provider
from ._tokens import tokens_log
from ._tools import Tool, basemodel_to_param_schema
from ._turn import Turn, normalize_turns
from ._turn import Turn, normalize_turns, user_turn
from ._utils import MISSING, MISSING_TYPE, is_testing

if TYPE_CHECKING:
Expand Down Expand Up @@ -351,6 +352,57 @@ async def stream_turn_async(self, completion, has_data_model, stream):
def value_turn(self, completion, has_data_model) -> Turn:
return self._as_turn(completion, has_data_model)

def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int:
try:
import tiktoken
except ImportError:
raise ImportError(
"The tiktoken package is required for token counting. "
"Please install it with `pip install tiktoken`."
)

encoding = tiktoken.encoding_for_model(self._model)

turn = user_turn(*args)

# Count the tokens in image contents
image_tokens = sum(
self._image_token_count(x)
for x in turn.contents
if isinstance(x, ContentImage)
)

# For other contents, get the token count from the actual message param
other_contents = [x for x in turn.contents if not isinstance(x, ContentImage)]
other_full = self._as_message_param([Turn("user", other_contents)])
other_tokens = len(encoding.encode(str(other_full)))

return other_tokens + image_tokens

async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int:
return self.token_count(*args, tools=tools, data_model=data_model)

@staticmethod
def _image_token_count(image: ContentImage) -> int:
if isinstance(image, ContentImageRemote) and image.detail == "low":
return 85
else:
# This is just the max token count for an image The highest possible
# resolution is 768 x 2048, and 8 tiles of size 512px can fit inside
# TODO: this is obviously a very conservative estimate and could be improved
# https://platform.openai.com/docs/guides/vision/calculating-costs
return 170 * 8 + 85

@staticmethod
def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:
from openai.types.chat import (
Expand Down
17 changes: 17 additions & 0 deletions chatlas/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from pydantic import BaseModel

from ._content import Content
from ._tools import Tool
from ._turn import Turn

Expand Down Expand Up @@ -141,3 +142,19 @@ def value_turn(
completion: ChatCompletionT,
has_data_model: bool,
) -> Turn: ...

@abstractmethod
def token_count(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int: ...

@abstractmethod
async def token_count_async(
self,
*args: Content | str,
tools: dict[str, Tool],
data_model: Optional[type[BaseModel]],
) -> int: ...
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dev = [
"anthropic[bedrock]",
"google-generativeai>=0.8.3",
"numpy>1.24.4",
"tiktoken",
]
docs = [
"griffe>=1",
Expand Down
14 changes: 12 additions & 2 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from chatlas import ChatOpenAI, Turn
from chatlas import ChatAnthropic, ChatGoogle, ChatOpenAI, Turn
from chatlas._openai import OpenAIAzureProvider, OpenAIProvider
from chatlas._tokens import token_usage, tokens_log, tokens_reset

Expand Down Expand Up @@ -26,10 +26,20 @@ def test_tokens_method():
)

assert chat.tokens(values="discrete") == [2, 10, 2, 10]

assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)]


def test_token_count_method():
chat = ChatOpenAI(model="gpt-4o-mini")
assert chat.token_count("What is 1 + 1?") == 31

chat = ChatAnthropic(model="claude-3-5-sonnet-20241022")
assert chat.token_count("What is 1 + 1?") == 16

chat = ChatGoogle(model="gemini-1.5-flash")
assert chat.token_count("What is 1 + 1?") == 9


def test_usage_is_none():
tokens_reset()
assert token_usage() is None
Expand Down
Loading