Skip to content

Commit 08e104d

Browse files
committed
Added feature to utilize Python syntax for field comments in schema instead of verbose Annotated[Doc()]-based comment style. Currently implemented for examples/math
1 parent 1fed2c7 commit 08e104d

File tree

6 files changed

+105
-3
lines changed

6 files changed

+105
-3
lines changed

python/examples/math/demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sys
55
from typing import cast
66
from dotenv import dotenv_values
7-
import schema as math
7+
import schema_with_comments as math
88
from typechat import Failure, create_language_model, process_requests
99
from program import TypeChatProgramTranslator, TypeChatProgramValidator, evaluate_json_program
1010

python/examples/math/program.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ class TypeChatProgramTranslator(TypeChatJsonTranslator[JsonProgram]):
121121
_api_declaration_str: str
122122

123123
def __init__(self, model: TypeChatLanguageModel, validator: TypeChatProgramValidator, api_type: type):
124+
api_type = self._convert_pythonic_comments_to_annotated_docs(api_type)
124125
super().__init__(model=model, validator=validator, target_type=api_type, _raise_on_schema_errors = False)
125126
# TODO: the conversion result here has errors!
126127
conversion_result = python_type_to_typescript_schema(api_type)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import re
2+
import inspect
3+
from schema_with_comments import MathAPI
4+
5+
6+
def _convert_pythonic_comments_to_annotated_docs(schema_class, debug=True):
7+
8+
schema_path = inspect.getfile(schema_class)
9+
10+
with open(schema_path, 'r') as file:
11+
schema_class_source = file.read()
12+
13+
if debug:
14+
print("File contents before modification:")
15+
print("--"*50)
16+
print(schema_class_source)
17+
print("--"*50)
18+
19+
pattern = r"(\w+\s*:\s*.*?)(?=\s*#\s*(.+?)(?:\n|\Z))"
20+
commented_fields = re.findall(pattern, schema_class_source)
21+
annotated_fields = []
22+
23+
for field, comment in commented_fields:
24+
field_separator = field.split(":")
25+
field_name = field_separator[0].strip()
26+
field_type = field_separator[1].strip()
27+
28+
annotated_fields.append(
29+
f"{field_name}: Annotated[{field_type}, Doc(\"{comment}\")]")
30+
31+
for field, annotation in zip(commented_fields, annotated_fields):
32+
schema_class_source = schema_class_source.replace(field[0], annotation)
33+
34+
if debug:
35+
print("File contents after modification:")
36+
print("--"*50)
37+
print(schema_class_source)
38+
print("--"*50)
39+
40+
namespace = {}
41+
exec(schema_class_source, namespace)
42+
return namespace[schema_class.__name__]
43+
44+
45+
if __name__ == "__main__":
46+
print(_convert_pythonic_comments_to_annotated_docs(MathAPI))

python/examples/math/schema.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing_extensions import TypedDict, Annotated, Callable, Doc
22

3-
43
class MathAPI(TypedDict):
54
"""
65
This is API for a simple calculator
@@ -12,4 +11,4 @@ class MathAPI(TypedDict):
1211
div: Annotated[Callable[[float, float], float], Doc("Divide two numbers")]
1312
neg: Annotated[Callable[[float], float], Doc("Negate a number")]
1413
id: Annotated[Callable[[float], float], Doc("Identity function")]
15-
unknown: Annotated[Callable[[str], float], Doc("Unknown request")]
14+
unknown: Annotated[Callable[[str], float], Doc("Unknown request")]
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing_extensions import TypedDict, Annotated, Callable, Doc
2+
3+
class MathAPI(TypedDict):
4+
"""
5+
This is API for a simple calculator
6+
"""
7+
8+
# this is a comment
9+
10+
add: Callable[[float, float], float] # Add two numbers
11+
sub: Callable[[float, float], float] # Subtract two numbers
12+
mul: Callable[[float, float], float] # Multiply two numbers
13+
div: Callable[[float, float], float] # Divide two numbers
14+
neg: Callable[[float], float] # Negate a number
15+
id: Callable[[float], float] # Identity function
16+
unknown: Callable[[str], float] # Unknown request

python/src/typechat/_internal/translator.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing_extensions import Generic, TypeVar
22

33
import pydantic_core
4+
import re
5+
import inspect
46

57
from typechat._internal.model import PromptSection, TypeChatLanguageModel
68
from typechat._internal.result import Failure, Result, Success
@@ -123,3 +125,41 @@ def _create_repair_prompt(self, validation_error: str) -> str:
123125
The following is a revised JSON object:
124126
"""
125127
return prompt
128+
129+
def _convert_pythonic_comments_to_annotated_docs(schema_class, debug=False):
130+
131+
schema_path = inspect.getfile(schema_class)
132+
133+
with open(schema_path, 'r') as file:
134+
schema_class_source = file.read()
135+
136+
if debug:
137+
print("File contents before modification:")
138+
print("--"*50)
139+
print(schema_class_source)
140+
print("--"*50)
141+
142+
pattern = r"(\w+\s*:\s*.*?)(?=\s*#\s*(.+?)(?:\n|\Z))"
143+
commented_fields = re.findall(pattern, schema_class_source)
144+
annotated_fields = []
145+
146+
for field, comment in commented_fields:
147+
field_separator = field.split(":")
148+
field_name = field_separator[0].strip()
149+
field_type = field_separator[1].strip()
150+
151+
annotated_fields.append(
152+
f"{field_name}: Annotated[{field_type}, Doc(\"{comment}\")]")
153+
154+
for field, annotation in zip(commented_fields, annotated_fields):
155+
schema_class_source = schema_class_source.replace(field[0], annotation)
156+
157+
if debug:
158+
print("File contents after modification:")
159+
print("--"*50)
160+
print(schema_class_source)
161+
print("--"*50)
162+
163+
namespace = {}
164+
exec(schema_class_source, namespace)
165+
return namespace[schema_class.__name__]

0 commit comments

Comments
 (0)