|
1 | 1 | from typing_extensions import Generic, TypeVar |
2 | 2 |
|
3 | 3 | import pydantic_core |
4 | | -import re |
| 4 | +import ast |
| 5 | +import io |
| 6 | +import tokenize |
5 | 7 | import inspect |
6 | 8 |
|
7 | 9 | from typechat._internal.model import PromptSection, TypeChatLanguageModel |
@@ -128,38 +130,95 @@ def _create_repair_prompt(self, validation_error: str) -> str: |
128 | 130 |
|
129 | 131 | def _convert_pythonic_comments_to_annotated_docs(schema_class, debug=False): |
130 | 132 |
|
| 133 | + def _extract_tokens_between_line_numbers(gen, start_lineno, end_lineno): |
| 134 | + # Extract tokens between start_lineno and end_lineno obtained from the tokenize generator |
| 135 | + tokens = [] |
| 136 | + for tok in gen: |
| 137 | + if tok.start[0] < start_lineno: # Skip tokens before start_lineno |
| 138 | + continue |
| 139 | + if tok.start[0] >= start_lineno and tok.end[0] <= end_lineno: |
| 140 | + # Add token if it is within the range |
| 141 | + tokens.append((tok.type, tok.string)) |
| 142 | + elif tok.start[0] > end_lineno: # Stop if token is beyond end_lineno |
| 143 | + break |
| 144 | + |
| 145 | + return tokens |
| 146 | + |
131 | 147 | schema_path = inspect.getfile(schema_class) |
132 | 148 |
|
133 | | - with open(schema_path, 'r') as file: |
134 | | - schema_class_source = file.read() |
| 149 | + with open(schema_path, 'r') as f: |
| 150 | + schema_class_source = f.read() |
| 151 | + gen = tokenize.tokenize(io.BytesIO( |
| 152 | + schema_class_source.encode('utf-8')).readline) |
| 153 | + |
| 154 | + tree = ast.parse(schema_class_source) |
135 | 155 |
|
136 | 156 | if debug: |
137 | | - print("File contents before modification:") |
| 157 | + print("Source code before transformation:") |
138 | 158 | print("--"*50) |
139 | 159 | print(schema_class_source) |
140 | 160 | print("--"*50) |
141 | 161 |
|
142 | | - pattern = r"(\w+\s*:\s*.*?)(?=\s*#\s*(.+?)(?:\n|\Z))" |
143 | | - commented_fields = re.findall(pattern, schema_class_source) |
144 | | - annotated_fields = [] |
145 | | - |
146 | | - for field, comment in commented_fields: |
147 | | - field_separator = field.split(":") |
148 | | - field_name = field_separator[0].strip() |
149 | | - field_type = field_separator[1].strip() |
150 | | - |
151 | | - annotated_fields.append( |
152 | | - f"{field_name}: Annotated[{field_type}, Doc(\"{comment}\")]") |
153 | | - |
154 | | - for field, annotation in zip(commented_fields, annotated_fields): |
155 | | - schema_class_source = schema_class_source.replace(field[0], annotation) |
| 162 | + has_comments = False # Flag later used to perform imports of Annotated and Doc if needed |
| 163 | + |
| 164 | + for node in tree.body: |
| 165 | + if isinstance(node, ast.ClassDef): |
| 166 | + for n in node.body: |
| 167 | + if isinstance(n, ast.AnnAssign): # Check if the node is an annotated assignment |
| 168 | + assgn_comment = None |
| 169 | + tokens = _extract_tokens_between_line_numbers( |
| 170 | + # Extract tokens between the line numbers of the annotated assignment |
| 171 | + gen, n.lineno, n.end_lineno |
| 172 | + ) |
| 173 | + for toknum, tokval in tokens: |
| 174 | + if toknum == tokenize.COMMENT: |
| 175 | + # Extract the comment |
| 176 | + assgn_comment = tokval |
| 177 | + break |
| 178 | + |
| 179 | + if assgn_comment: |
| 180 | + # If a comment is found, transform the annotation to include the comment |
| 181 | + assgn_subscript = n.annotation |
| 182 | + has_comments = True |
| 183 | + n.annotation = ast.Subscript( |
| 184 | + value=ast.Name(id="Annotated", ctx=ast.Load()), |
| 185 | + slice=ast.Tuple( |
| 186 | + elts=[ |
| 187 | + assgn_subscript, |
| 188 | + ast.Call( |
| 189 | + func=ast.Name( |
| 190 | + id="Doc", ctx=ast.Load() |
| 191 | + ), |
| 192 | + args=[ |
| 193 | + ast.Constant( |
| 194 | + value=assgn_comment.strip("#").strip() |
| 195 | + ) |
| 196 | + ], |
| 197 | + keywords=[] |
| 198 | + ) |
| 199 | + ], |
| 200 | + ctx=ast.Load() |
| 201 | + ), |
| 202 | + ctx=ast.Load() |
| 203 | + ) |
| 204 | + |
| 205 | + if has_comments: |
| 206 | + for node in tree.body: |
| 207 | + if isinstance(node, ast.ImportFrom): |
| 208 | + if node.module == "typing_extensions": |
| 209 | + if ast.alias(name="Annotated") not in node.names: |
| 210 | + node.names.append(ast.alias(name="Annotated")) |
| 211 | + if ast.alias(name="Doc") not in node.names: |
| 212 | + node.names.append(ast.alias(name="Doc")) |
| 213 | + |
| 214 | + transformed_schema_source = ast.unparse(tree) |
156 | 215 |
|
157 | 216 | if debug: |
158 | | - print("File contents after modification:") |
| 217 | + print("Source code after transformation:") |
159 | 218 | print("--"*50) |
160 | | - print(schema_class_source) |
| 219 | + print(transformed_schema_source) |
161 | 220 | print("--"*50) |
162 | 221 |
|
163 | 222 | namespace = {} |
164 | | - exec(schema_class_source, namespace) |
| 223 | + exec(transformed_schema_source, namespace) |
165 | 224 | return namespace[schema_class.__name__] |
0 commit comments