diff --git a/CHANGELOG.md b/CHANGELOG.md index 81ec2d0c..ec6d0cc1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). --> +## [Development] + +### New features + +### Bug fixes + +* Update formatting of exported markdown to use `repr()` instead of `str()` when exporting tool call results. (#30) + ## [0.3.0] - 2024-12-20 ### New features @@ -23,4 +31,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.2.0] - 2024-12-11 -First stable release of `chatlas`, see the website to learn more +First stable release of `chatlas`, see the website to learn more diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 8dfd109a..9989404e 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -948,11 +948,11 @@ def export( is_html = filename.suffix == ".html" # Get contents from each turn - contents = "" + content_arr: list[str] = [] for turn in turns: turn_content = "\n\n".join( [ - str(content) + str(content).strip() for content in turn.contents if include == "all" or isinstance(content, ContentText) ] @@ -963,7 +963,8 @@ def export( turn_content = f"" else: turn_content = f"## {turn.role.capitalize()}\n\n{turn_content}" - contents += f"{turn_content}\n\n" + content_arr.append(turn_content) + contents = "\n\n".join(content_arr) # Shiny chat message components requires container elements if is_html: diff --git a/chatlas/_content.py b/chatlas/_content.py index 650d852b..76ea2bde 100644 --- a/chatlas/_content.py +++ b/chatlas/_content.py @@ -2,6 +2,7 @@ import json from dataclasses import dataclass +from pprint import pformat from typing import Any, Literal, Optional ImageContentTypes = Literal[ @@ -154,7 +155,7 @@ def __str__(self): args_str = self._arguments_str() func_call = f"{self.name}({args_str})" comment = f"# tool request ({self.id})" - return f"\n```python\n{comment}\n{func_call}\n```\n" + return f"```python\n{comment}\n{func_call}\n```\n" def _repr_markdown_(self): return self.__str__() @@ -195,10 +196,20 @@ class ContentToolResult(Content): value: Any = None error: Optional[str] = None + def _get_value_and_language(self) -> tuple[str, str]: + if self.error: + return f"Tool calling failed with error: '{self.error}'", "" + try: + json_val = json.loads(self.value) + return pformat(json_val, indent=2, sort_dicts=False), "python" + except: # noqa: E722 + return str(self.value), "" + def __str__(self): comment = f"# tool result ({self.id})" - val = self.get_final_value() - return f"""\n```python\n{comment}\n"{val}"\n```\n""" + value, language = self._get_value_and_language() + + return f"""```{language}\n{comment}\n{value}\n```""" def _repr_markdown_(self): return self.__str__() @@ -211,9 +222,8 @@ def __repr__(self, indent: int = 0): return res + ">" def get_final_value(self) -> str: - if self.error: - return f"Tool calling failed with error: '{self.error}'" - return str(self.value) + value, _language = self._get_value_and_language() + return value @dataclass @@ -236,7 +246,7 @@ def __str__(self): return json.dumps(self.value, indent=2) def _repr_markdown_(self): - return f"""\n```json\n{self.__str__()}\n```\n""" + return f"""```json\n{self.__str__()}\n```""" def __repr__(self, indent: int = 0): return " " * indent + f"" diff --git a/tests/__snapshots__/test_chat.ambr b/tests/__snapshots__/test_chat.ambr index d966b8f8..88b3a275 100644 --- a/tests/__snapshots__/test_chat.ambr +++ b/tests/__snapshots__/test_chat.ambr @@ -17,8 +17,6 @@ - -

System prompt diff --git a/tests/test_provider_openai.py b/tests/test_provider_openai.py index 17468fbd..f067c6ee 100644 --- a/tests/test_provider_openai.py +++ b/tests/test_provider_openai.py @@ -1,4 +1,5 @@ import pytest + from chatlas import ChatOpenAI from .conftest import ( @@ -21,7 +22,10 @@ def test_openai_simple_request(): chat.chat("What is 1 + 1?") turn = chat.get_last_turn() assert turn is not None - assert turn.tokens == (27, 2) + assert turn.tokens is not None + assert len(turn.tokens) == 2 + assert turn.tokens[0] == 27 + # Not testing turn.tokens[1] because it's not deterministic. Typically 1 or 2. assert turn.finish_reason == "stop"