Skip to content

Commit bb87a9d

Browse files
committed
Incorporated ast and inspect-based source code transformation for more robust field comment handling
1 parent 08e104d commit bb87a9d

File tree

3 files changed

+189
-23
lines changed

3 files changed

+189
-23
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import tokenize
2+
import ast
3+
import io
4+
import inspect
5+
from schema_with_comments import MathAPI
6+
7+
8+
def _convert_pythonic_comments_to_annotated_docs(schema_class, debug=True):
9+
10+
def _extract_tokens_between_line_numbers(gen, start_lineno, end_lineno):
11+
# Extract tokens between start_lineno and end_lineno obtained from the tokenize generator
12+
tokens = []
13+
for tok in gen:
14+
if tok.start[0] < start_lineno: # Skip tokens before start_lineno
15+
continue
16+
if tok.start[0] >= start_lineno and tok.end[0] <= end_lineno:
17+
# Add token if it is within the range
18+
tokens.append((tok.type, tok.string))
19+
elif tok.start[0] > end_lineno: # Stop if token is beyond end_lineno
20+
break
21+
22+
return tokens
23+
24+
schema_path = inspect.getfile(schema_class)
25+
26+
with open(schema_path, 'r') as f:
27+
schema_class_source = f.read()
28+
gen = tokenize.tokenize(io.BytesIO(
29+
schema_class_source.encode('utf-8')).readline)
30+
31+
tree = ast.parse(schema_class_source)
32+
33+
if debug:
34+
print("Source code before transformation:")
35+
print("--"*50)
36+
print(schema_class_source)
37+
print("--"*50)
38+
39+
has_comments = False # Flag later used to perform imports of Annotated and Doc if needed
40+
41+
for node in tree.body:
42+
if isinstance(node, ast.ClassDef):
43+
for n in node.body:
44+
if isinstance(n, ast.AnnAssign): # Check if the node is an annotated assignment
45+
assgn_comment = None
46+
tokens = _extract_tokens_between_line_numbers(
47+
# Extract tokens between the line numbers of the annotated assignment
48+
gen, n.lineno, n.end_lineno
49+
)
50+
for toknum, tokval in tokens:
51+
if toknum == tokenize.COMMENT:
52+
# Extract the comment
53+
assgn_comment = tokval
54+
break
55+
56+
if assgn_comment:
57+
# If a comment is found, transform the annotation to include the comment
58+
assgn_subscript = n.annotation
59+
has_comments = True
60+
n.annotation = ast.Subscript(
61+
value=ast.Name(id="Annotated", ctx=ast.Load()),
62+
slice=ast.Tuple(
63+
elts=[
64+
assgn_subscript,
65+
ast.Call(
66+
func=ast.Name(
67+
id="Doc", ctx=ast.Load()
68+
),
69+
args=[
70+
ast.Constant(
71+
value=assgn_comment.strip("#").strip()
72+
)
73+
],
74+
keywords=[]
75+
)
76+
],
77+
ctx=ast.Load()
78+
),
79+
ctx=ast.Load()
80+
)
81+
82+
if has_comments:
83+
for node in tree.body:
84+
if isinstance(node, ast.ImportFrom):
85+
if node.module == "typing_extensions":
86+
if ast.alias(name="Annotated") not in node.names:
87+
node.names.append(ast.alias(name="Annotated"))
88+
if ast.alias(name="Doc") not in node.names:
89+
node.names.append(ast.alias(name="Doc"))
90+
91+
transformed_schema_source = ast.unparse(tree)
92+
93+
if debug:
94+
print("Source code after transformation:")
95+
print("--"*50)
96+
print(transformed_schema_source)
97+
print("--"*50)
98+
99+
namespace = {}
100+
exec(transformed_schema_source, namespace)
101+
return namespace[schema_class.__name__]
102+
103+
104+
if __name__ == "__main__":
105+
print(_convert_pythonic_comments_to_annotated_docs(MathAPI))

python/examples/math/schema_with_comments.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing_extensions import TypedDict, Annotated, Callable, Doc
1+
from typing_extensions import TypedDict, Callable
22

33
class MathAPI(TypedDict):
44
"""
@@ -13,4 +13,6 @@ class MathAPI(TypedDict):
1313
div: Callable[[float, float], float] # Divide two numbers
1414
neg: Callable[[float], float] # Negate a number
1515
id: Callable[[float], float] # Identity function
16-
unknown: Callable[[str], float] # Unknown request
16+
unknown: Callable[
17+
[str], float
18+
] # Unknown request

python/src/typechat/_internal/translator.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing_extensions import Generic, TypeVar
22

33
import pydantic_core
4-
import re
4+
import ast
5+
import io
6+
import tokenize
57
import inspect
68

79
from typechat._internal.model import PromptSection, TypeChatLanguageModel
@@ -128,38 +130,95 @@ def _create_repair_prompt(self, validation_error: str) -> str:
128130

129131
def _convert_pythonic_comments_to_annotated_docs(schema_class, debug=False):
130132

133+
def _extract_tokens_between_line_numbers(gen, start_lineno, end_lineno):
134+
# Extract tokens between start_lineno and end_lineno obtained from the tokenize generator
135+
tokens = []
136+
for tok in gen:
137+
if tok.start[0] < start_lineno: # Skip tokens before start_lineno
138+
continue
139+
if tok.start[0] >= start_lineno and tok.end[0] <= end_lineno:
140+
# Add token if it is within the range
141+
tokens.append((tok.type, tok.string))
142+
elif tok.start[0] > end_lineno: # Stop if token is beyond end_lineno
143+
break
144+
145+
return tokens
146+
131147
schema_path = inspect.getfile(schema_class)
132148

133-
with open(schema_path, 'r') as file:
134-
schema_class_source = file.read()
149+
with open(schema_path, 'r') as f:
150+
schema_class_source = f.read()
151+
gen = tokenize.tokenize(io.BytesIO(
152+
schema_class_source.encode('utf-8')).readline)
153+
154+
tree = ast.parse(schema_class_source)
135155

136156
if debug:
137-
print("File contents before modification:")
157+
print("Source code before transformation:")
138158
print("--"*50)
139159
print(schema_class_source)
140160
print("--"*50)
141161

142-
pattern = r"(\w+\s*:\s*.*?)(?=\s*#\s*(.+?)(?:\n|\Z))"
143-
commented_fields = re.findall(pattern, schema_class_source)
144-
annotated_fields = []
145-
146-
for field, comment in commented_fields:
147-
field_separator = field.split(":")
148-
field_name = field_separator[0].strip()
149-
field_type = field_separator[1].strip()
150-
151-
annotated_fields.append(
152-
f"{field_name}: Annotated[{field_type}, Doc(\"{comment}\")]")
153-
154-
for field, annotation in zip(commented_fields, annotated_fields):
155-
schema_class_source = schema_class_source.replace(field[0], annotation)
162+
has_comments = False # Flag later used to perform imports of Annotated and Doc if needed
163+
164+
for node in tree.body:
165+
if isinstance(node, ast.ClassDef):
166+
for n in node.body:
167+
if isinstance(n, ast.AnnAssign): # Check if the node is an annotated assignment
168+
assgn_comment = None
169+
tokens = _extract_tokens_between_line_numbers(
170+
# Extract tokens between the line numbers of the annotated assignment
171+
gen, n.lineno, n.end_lineno
172+
)
173+
for toknum, tokval in tokens:
174+
if toknum == tokenize.COMMENT:
175+
# Extract the comment
176+
assgn_comment = tokval
177+
break
178+
179+
if assgn_comment:
180+
# If a comment is found, transform the annotation to include the comment
181+
assgn_subscript = n.annotation
182+
has_comments = True
183+
n.annotation = ast.Subscript(
184+
value=ast.Name(id="Annotated", ctx=ast.Load()),
185+
slice=ast.Tuple(
186+
elts=[
187+
assgn_subscript,
188+
ast.Call(
189+
func=ast.Name(
190+
id="Doc", ctx=ast.Load()
191+
),
192+
args=[
193+
ast.Constant(
194+
value=assgn_comment.strip("#").strip()
195+
)
196+
],
197+
keywords=[]
198+
)
199+
],
200+
ctx=ast.Load()
201+
),
202+
ctx=ast.Load()
203+
)
204+
205+
if has_comments:
206+
for node in tree.body:
207+
if isinstance(node, ast.ImportFrom):
208+
if node.module == "typing_extensions":
209+
if ast.alias(name="Annotated") not in node.names:
210+
node.names.append(ast.alias(name="Annotated"))
211+
if ast.alias(name="Doc") not in node.names:
212+
node.names.append(ast.alias(name="Doc"))
213+
214+
transformed_schema_source = ast.unparse(tree)
156215

157216
if debug:
158-
print("File contents after modification:")
217+
print("Source code after transformation:")
159218
print("--"*50)
160-
print(schema_class_source)
219+
print(transformed_schema_source)
161220
print("--"*50)
162221

163222
namespace = {}
164-
exec(schema_class_source, namespace)
223+
exec(transformed_schema_source, namespace)
165224
return namespace[schema_class.__name__]

0 commit comments

Comments
 (0)