diff --git a/CHANGELOG.md b/CHANGELOG.md
index 81ec2d0c..ec6d0cc1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,14 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
-->
+## [Development]
+
+### New features
+
+### Bug fixes
+
+* Update formatting of exported markdown to use `repr()` instead of `str()` when exporting tool call results. (#30)
+
## [0.3.0] - 2024-12-20
### New features
@@ -23,4 +31,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.2.0] - 2024-12-11
-First stable release of `chatlas`, see the website to learn more
+First stable release of `chatlas`, see the website to learn more
diff --git a/chatlas/_chat.py b/chatlas/_chat.py
index 8dfd109a..9989404e 100644
--- a/chatlas/_chat.py
+++ b/chatlas/_chat.py
@@ -948,11 +948,11 @@ def export(
is_html = filename.suffix == ".html"
# Get contents from each turn
- contents = ""
+ content_arr: list[str] = []
for turn in turns:
turn_content = "\n\n".join(
[
- str(content)
+ str(content).strip()
for content in turn.contents
if include == "all" or isinstance(content, ContentText)
]
@@ -963,7 +963,8 @@ def export(
turn_content = f""
else:
turn_content = f"## {turn.role.capitalize()}\n\n{turn_content}"
- contents += f"{turn_content}\n\n"
+ content_arr.append(turn_content)
+ contents = "\n\n".join(content_arr)
# Shiny chat message components requires container elements
if is_html:
diff --git a/chatlas/_content.py b/chatlas/_content.py
index 650d852b..76ea2bde 100644
--- a/chatlas/_content.py
+++ b/chatlas/_content.py
@@ -2,6 +2,7 @@
import json
from dataclasses import dataclass
+from pprint import pformat
from typing import Any, Literal, Optional
ImageContentTypes = Literal[
@@ -154,7 +155,7 @@ def __str__(self):
args_str = self._arguments_str()
func_call = f"{self.name}({args_str})"
comment = f"# tool request ({self.id})"
- return f"\n```python\n{comment}\n{func_call}\n```\n"
+ return f"```python\n{comment}\n{func_call}\n```\n"
def _repr_markdown_(self):
return self.__str__()
@@ -195,10 +196,20 @@ class ContentToolResult(Content):
value: Any = None
error: Optional[str] = None
+ def _get_value_and_language(self) -> tuple[str, str]:
+ if self.error:
+ return f"Tool calling failed with error: '{self.error}'", ""
+ try:
+ json_val = json.loads(self.value)
+ return pformat(json_val, indent=2, sort_dicts=False), "python"
+ except: # noqa: E722
+ return str(self.value), ""
+
def __str__(self):
comment = f"# tool result ({self.id})"
- val = self.get_final_value()
- return f"""\n```python\n{comment}\n"{val}"\n```\n"""
+ value, language = self._get_value_and_language()
+
+ return f"""```{language}\n{comment}\n{value}\n```"""
def _repr_markdown_(self):
return self.__str__()
@@ -211,9 +222,8 @@ def __repr__(self, indent: int = 0):
return res + ">"
def get_final_value(self) -> str:
- if self.error:
- return f"Tool calling failed with error: '{self.error}'"
- return str(self.value)
+ value, _language = self._get_value_and_language()
+ return value
@dataclass
@@ -236,7 +246,7 @@ def __str__(self):
return json.dumps(self.value, indent=2)
def _repr_markdown_(self):
- return f"""\n```json\n{self.__str__()}\n```\n"""
+ return f"""```json\n{self.__str__()}\n```"""
def __repr__(self, indent: int = 0):
return " " * indent + f""
diff --git a/tests/__snapshots__/test_chat.ambr b/tests/__snapshots__/test_chat.ambr
index d966b8f8..88b3a275 100644
--- a/tests/__snapshots__/test_chat.ambr
+++ b/tests/__snapshots__/test_chat.ambr
@@ -17,8 +17,6 @@
-
-
System prompt
diff --git a/tests/test_provider_openai.py b/tests/test_provider_openai.py
index 17468fbd..f067c6ee 100644
--- a/tests/test_provider_openai.py
+++ b/tests/test_provider_openai.py
@@ -1,4 +1,5 @@
import pytest
+
from chatlas import ChatOpenAI
from .conftest import (
@@ -21,7 +22,10 @@ def test_openai_simple_request():
chat.chat("What is 1 + 1?")
turn = chat.get_last_turn()
assert turn is not None
- assert turn.tokens == (27, 2)
+ assert turn.tokens is not None
+ assert len(turn.tokens) == 2
+ assert turn.tokens[0] == 27
+ # Not testing turn.tokens[1] because it's not deterministic. Typically 1 or 2.
assert turn.finish_reason == "stop"