diff --git a/Dockerfile b/Dockerfile index f8c0570..2e8fe2b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,8 +5,9 @@ ENV PYTHONUNBUFFERED=1 RUN apt update && apt -y upgrade RUN apt install -y python3-numpy python3-pip python3-requests RUN pip3 install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html -RUN pip3 install torch-scatter==2.0.4+cpu -f https://pytorch-geometric.com/whl/torch-1.5.0.html -RUN pip3 install dpu-utils typed-ast ptgnn +# https://stackoverflow.com/questions/67074684/pip-has-problems-with-metadata +RUN pip3 install --use-deprecated=legacy-resolver torch-scatter==2.0.4+cpu -f https://pytorch-geometric.com/whl/torch-1.5.0.html +RUN pip3 install dpu-utils typed-ast ptgnn==0.8.5 ENV PYTHONPATH=/usr/src/ ADD https://github.com/typilus/typilus-action/releases/download/v0.1/typilus20200507.pkl.gz /usr/src/model.pkl.gz diff --git a/README.md b/README.md index b7fde4c..2db21b7 100644 --- a/README.md +++ b/README.md @@ -27,17 +27,19 @@ suggestions with only a partial context, at the cost of suggesting some false positives. -### Install Action in your Repository +### How to use the Action in your Repository To use the GitHub action, create a workflow file. For example, ```yaml -name: Typilus Type Annotation Suggestions +name: Annotation Suggestions # Controls when the action will run. Triggers the workflow on push or pull request -# events but only for the master branch +# events but only for the main branch on: pull_request: - branches: [ master ] + paths: + - '**.py' + branches: [master, main] jobs: suggest: @@ -46,11 +48,10 @@ jobs: steps: # Checks-out your repository under $GITHUB_WORKSPACE, so that typilus can access it. - - uses: actions/checkout@v2 - - uses: typilus/typilus-action@master + - uses: actions/checkout@v3 + - uses: Karim-53/typilus-action@master env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - MODEL_PATH: path/to/model.pkl.gz # Optional: provide the path of a custom model instead of the pre-trained model. SUGGESTION_CONFIDENCE_THRESHOLD: 0.8 # Configure this to limit the confidence of suggestions on un-annotated locations. A float in [0, 1]. Default 0.8 DISAGREEMENT_CONFIDENCE_THRESHOLD: 0.95 # Configure this to limit the confidence of suggestions on annotated locations. A float in [0, 1]. Default 0.95 ``` diff --git a/action.yml b/action.yml index f27702f..012234b 100644 --- a/action.yml +++ b/action.yml @@ -1,6 +1,6 @@ # action.yml -name: 'Typilus: Suggest Python Type Annotations' -description: 'Suggest Likely Python Type Annotations' +name: 'Typilus: Suggest Python Type Annotations with AI (fix)' +description: 'Suggest Likely Python Type Annotations using AI' branding: icon: box color: gray-dark diff --git a/entrypoint.py b/entrypoint.py index 77eef64..79d4bda 100644 --- a/entrypoint.py +++ b/entrypoint.py @@ -77,140 +77,146 @@ def __repr__(self) -> str: ) print("Diff GET Status Code: ", diff_rq.status_code) +try: + changed_files = get_changed_files(diff_rq.text) + if len(changed_files) == 0: + print("No relevant changes found.") + sys.exit(0) -changed_files = get_changed_files(diff_rq.text) -if len(changed_files) == 0: - print("No relevant changes found.") - sys.exit(0) + monitoring = Monitoring() + suggestion_confidence_threshold = float(os.getenv("SUGGESTION_CONFIDENCE_THRESHOLD", 0.5)) + diagreement_confidence_threshold = float(os.getenv("DISAGREEMENT_CONFIDENCE_THRESHOLD", 0.95)) -monitoring = Monitoring() -suggestion_confidence_threshold = float(os.getenv("SUGGESTION_CONFIDENCE_THRESHOLD", 0.5)) -diagreement_confidence_threshold = float(os.getenv("DISAGREEMENT_CONFIDENCE_THRESHOLD", 0.95)) - -if debug: - print( - f"Confidence thresholds {suggestion_confidence_threshold:.2f} and {diagreement_confidence_threshold:.2f}." - ) - - -with TemporaryDirectory() as out_dir: - typing_rules_path = os.path.join(os.path.dirname(__file__), "metadata", "typingRules.json") - extract_graphs( - repo_path, typing_rules_path, files_to_extract=set(changed_files), target_folder=out_dir, - ) - - def data_iter(): - for datafile_path in iglob(os.path.join(out_dir, "*.jsonl.gz")): - print(f"Looking into {datafile_path}...") - for graph in load_jsonl_gz(datafile_path): - yield graph + if debug: + print( + f"Confidence thresholds {suggestion_confidence_threshold:.2f} and {diagreement_confidence_threshold:.2f}." + ) - model_path = os.getenv("MODEL_PATH", "/usr/src/model.pkl.gz") - model, nn = Graph2Class.restore_model(model_path, "cpu") - type_suggestions: List[TypeSuggestion] = [] - for graph, predictions in model.predict(data_iter(), nn, "cpu"): - # predictions has the type: Dict[int, Tuple[str, float]] - filepath = graph["filename"] + with TemporaryDirectory() as out_dir: + typing_rules_path = os.path.join(os.path.dirname(__file__), "metadata", "typingRules.json") + extract_graphs( + repo_path, typing_rules_path, files_to_extract=set(changed_files), target_folder=out_dir, + ) + def data_iter(): + for datafile_path in iglob(os.path.join(out_dir, "*.jsonl.gz")): + print(f"Looking into {datafile_path}...") + yield from load_jsonl_gz(datafile_path) + + model_path = os.getenv("MODEL_PATH", "/usr/src/model.pkl.gz") + model, nn = Graph2Class.restore_model(model_path, "cpu") + + type_suggestions: List[TypeSuggestion] = [] + for graph, predictions in model.predict(data_iter(), nn, "cpu"): + # predictions has the type: Dict[int, Tuple[str, float]] + filepath = graph["filename"] + + if debug: + print("Predictions:", predictions) + print("SuperNodes:", graph["supernodes"]) + + for supernode_idx, (predicted_type, predicted_prob) in predictions.items(): + supernode_data = graph["supernodes"][str(supernode_idx)] + if supernode_data["type"] == "variable": + continue # Do not suggest annotations on variables for now. + lineno, colno = supernode_data["location"] + suggestion = TypeSuggestion( + filepath, + supernode_data["name"], + (lineno, colno), + annotation_rewrite(predicted_type), + supernode_data["type"], + predicted_prob, + is_disagreement=supernode_data["annotation"] != "??" + and supernode_data["annotation"] != predicted_type, + ) + + print("Suggestion: ", suggestion) + + if lineno not in changed_files[filepath]: + continue + elif suggestion.name == "%UNK%": + continue + + if ( + supernode_data["annotation"] == "??" + and suggestion.confidence > suggestion_confidence_threshold + ): + type_suggestions.append(suggestion) + elif ( + suggestion.is_disagreement + # and suggestion.confidence > diagreement_confidence_threshold + ): + pass # TODO: Disabled for now: type_suggestions.append(suggestion) + + # Add PR comments if debug: - print("Predictions:", predictions) - print("SuperNodes:", graph["supernodes"]) - - for supernode_idx, (predicted_type, predicted_prob) in predictions.items(): - supernode_data = graph["supernodes"][str(supernode_idx)] - if supernode_data["type"] == "variable": - continue # Do not suggest annotations on variables for now. - lineno, colno = supernode_data["location"] - suggestion = TypeSuggestion( - filepath, - supernode_data["name"], - (lineno, colno), - annotation_rewrite(predicted_type), - supernode_data["type"], - predicted_prob, - is_disagreement=supernode_data["annotation"] != "??" - and supernode_data["annotation"] != predicted_type, - ) - - print("Suggestion: ", suggestion) - - if lineno not in changed_files[filepath]: - continue - elif suggestion.name == "%UNK%": - continue - - if ( - supernode_data["annotation"] == "??" - and suggestion.confidence > suggestion_confidence_threshold - ): - type_suggestions.append(suggestion) - elif ( - suggestion.is_disagreement - # and suggestion.confidence > diagreement_confidence_threshold - ): - pass # TODO: Disabled for now: type_suggestions.append(suggestion) - - # Add PR comments - if debug: - print("# Suggestions:", len(type_suggestions)) - for suggestion in type_suggestions: - print(suggestion) + print("# Suggestions:", len(type_suggestions)) + for suggestion in type_suggestions: + print(suggestion) - comment_url = event_data["pull_request"]["review_comments_url"] - commit_id = event_data["pull_request"]["head"]["sha"] + comment_url = event_data["pull_request"]["review_comments_url"] + commit_id = event_data["pull_request"]["head"]["sha"] - for suggestion in type_suggestions: - if suggestion.symbol_kind == "class-or-function": - suggestion.annotation_lineno = find_annotation_line( - suggestion.filepath[1:], suggestion.file_location, suggestion.name + for suggestion in type_suggestions: + if suggestion.symbol_kind == "class-or-function": + suggestion.annotation_lineno = find_annotation_line( + suggestion.filepath[1:], suggestion.file_location, suggestion.name + ) + else: # when the underlying symbol is a parameter + suggestion.annotation_lineno = suggestion.file_location[0] + + # Group type suggestions by (filepath + lineno) + grouped_suggestions = group_suggestions(type_suggestions) + + def bucket_confidences(confidence: float) -> str: + if confidence >= 0.95: + return ":fire:" + if confidence >= 0.85: + return ":bell:" + return ":confused:" if confidence >= 0.7 else ":question:" + + def report_confidence(suggestions): + suggestions = sorted(suggestions, key=lambda s: -s.confidence) + return "".join( + f"| `{s.name}` | `{s.suggestion}` | {s.confidence:.1%} {bucket_confidences(s.confidence)} | \n" + for s in suggestions ) - else: # when the underlying symbol is a parameter - suggestion.annotation_lineno = suggestion.file_location[0] - - # Group type suggestions by (filepath + lineno) - grouped_suggestions = group_suggestions(type_suggestions) - - def bucket_confidences(confidence: float) -> str: - if confidence >= 0.95: - return ":fire:" - if confidence >= 0.85: - return ":bell:" - if confidence >= 0.7: - return ":confused:" - return ":question:" - - def report_confidence(suggestions): - suggestions = sorted(suggestions, key=lambda s: -s.confidence) - return "".join( - f"| `{s.name}` | `{s.suggestion}` | {s.confidence:.1%} {bucket_confidences(s.confidence)} | \n" - for s in suggestions - ) - for same_line_suggestions in grouped_suggestions: - suggestion = same_line_suggestions[0] - path = suggestion.filepath[1:] # No slash in the beginning - annotation_lineno = suggestion.annotation_lineno - with open(path) as file: - target_line = file.readlines()[annotation_lineno - 1] - data = { - "path": path, - "line": annotation_lineno, - "side": "RIGHT", - "commit_id": commit_id, - "body": "The following type annotation(s) might be useful:\n ```suggestion\n" - f"{annotate_line(target_line, same_line_suggestions)}```\n" - f"### :chart_with_upwards_trend: Prediction Stats\n" - f"| Symbol | Annotation | Confidence |\n" - f"| -- | -- | --: |\n" - f"{report_confidence(same_line_suggestions)}", - } - headers = { - "authorization": f"Bearer {github_token}", - "Accept": "application/vnd.github.v3.raw+json", - } - r = requests.post(comment_url, data=json.dumps(data), headers=headers) - if debug: - print("URL: ", comment_url) - print(f"Data: {data}. Status Code: {r.status_code}. Text: {r.text}") + for same_line_suggestions in grouped_suggestions: + suggestion = same_line_suggestions[0] + path = suggestion.filepath[1:] # No slash in the beginning + annotation_lineno = suggestion.annotation_lineno + with open(path) as file: + target_line = file.readlines()[annotation_lineno - 1] + data = { + "path": path, + "line": annotation_lineno, + "side": "RIGHT", + "commit_id": commit_id, + "body": "The following type annotation(s) might be useful:\n ```suggestion\n" + f"{annotate_line(target_line, same_line_suggestions)}```\n" + f"### :chart_with_upwards_trend: Prediction Stats\n" + f"| Symbol | Annotation | Confidence |\n" + f"| -- | -- | --: |\n" + f"{report_confidence(same_line_suggestions)}", + } + headers = { + "authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github.v3.raw+json", + } + r = requests.post(comment_url, data=json.dumps(data), headers=headers) + if debug: + print("URL: ", comment_url) + print(f"Data: {data}. Status Code: {r.status_code}. Text: {r.text}") +except AssertionError: + import traceback + _, _, tb = sys.exc_info() + traceback.print_tb(tb) # Fixed format + tb_info = traceback.extract_tb(tb) + filename, line, func, text = tb_info[-1] + + print('An error occurred on line {} in statement {}'.format(line, text)) + exit() diff --git a/src/annotationutils.py b/src/annotationutils.py index 2473364..8bec549 100644 --- a/src/annotationutils.py +++ b/src/annotationutils.py @@ -5,11 +5,9 @@ def find_suggestion_for_return(suggestions): - for s in suggestions: - if s.symbol_kind == "class-or-function": - return s - else: - return None + return next( + (s for s in suggestions if s.symbol_kind == "class-or-function"), None + ) def annotate_line(line, suggestions): @@ -33,7 +31,7 @@ def annotate_parameters(line, suggestions): """ Annotate the parameters of a function on a particular line """ - annotated_line = " " + line + annotated_line = f" {line}" length_increase = 0 for s in suggestions: assert line[s.file_location[1] :].startswith(s.name) @@ -48,7 +46,7 @@ def annotate_return(line, suggestion): Annotate the return of a function """ assert line.rstrip().endswith(":") - return line.rstrip()[:-1] + f" -> {suggestion.suggestion}" + ":\n" + return f"{line.rstrip()[:-1]} -> {suggestion.suggestion}" + ":\n" def find_annotation_line(filepath, location, func_name): diff --git a/src/changeutils.py b/src/changeutils.py index eef66ac..2e7b318 100644 --- a/src/changeutils.py +++ b/src/changeutils.py @@ -9,8 +9,7 @@ def get_line_ranges_of_interest(diff_lines: List[str]) -> Set[int]: lines_of_interest = set() current_line = 0 for line in diff_lines: - hunk_start_match = HUNK_MATCH.match(line) - if hunk_start_match: + if hunk_start_match := HUNK_MATCH.match(line): current_line = int(hunk_start_match.group(1)) elif line.startswith("+"): lines_of_interest.add(current_line) @@ -47,6 +46,8 @@ def get_changed_files(diff: str, suffix=".py") -> Dict[str, Set[int]]: elif file_diff_lines[1].startswith("similarity"): assert file_diff_lines[2].startswith("rename") assert file_diff_lines[3].startswith("rename") + if len(file_diff_lines) == 4: + continue # skip file renames \wo any changes assert file_diff_lines[4].startswith("index") assert file_diff_lines[5].startswith("--- a/") assert file_diff_lines[6].startswith("+++ b/") diff --git a/src/graph_generator/dataflowpass.py b/src/graph_generator/dataflowpass.py index 572f088..f3816e3 100644 --- a/src/graph_generator/dataflowpass.py +++ b/src/graph_generator/dataflowpass.py @@ -246,7 +246,7 @@ def visit_Try(self, node: Try): before_exec_handlers = self.__clone_last_uses() after_exec_handlers = self.__clone_last_uses() - for i, exc_handler in enumerate(node.handlers): + for exc_handler in node.handlers: self.visit(exc_handler) after_exec_handlers = self.__merge_uses(after_exec_handlers, self.__last_use) self.__last_use = before_exec_handlers @@ -297,7 +297,7 @@ def visit_AsyncWith(self, node: AsyncWith): self.__visit_with(node) def __visit_with(self, node: Union[With, AsyncWith]): - for i, w_item in enumerate(node.items): + for w_item in node.items: self.visit(w_item) self.__visit_statement_block(node.body) @@ -325,14 +325,14 @@ def visit_Assign(self, node: Assign): self.visit(node.value) for target in node.targets: - if isinstance(target, Attribute) or isinstance(target, Name): + if isinstance(target, (Attribute, Name)): self.__visit_variable_like(target, node) else: self.visit(target) def visit_AugAssign(self, node: AugAssign): self.visit(node.value) - if isinstance(node.target, Name) or isinstance(node.target, Attribute): + if isinstance(node.target, (Name, Attribute)): self.__visit_variable_like(node.target, node) else: self.visit(node.target) @@ -430,16 +430,16 @@ def visit_BinOp(self, node): self.visit(node.right) def visit_BoolOp(self, node): - for idx, value in enumerate(node.values): + for value in node.values: self.visit(value) def visit_Compare(self, node: Compare): self.visit(node.left) - for i, (op, right) in enumerate(zip(node.ops, node.comparators)): + for op, right in zip(node.ops, node.comparators): self.visit(right) def visit_Delete(self, node: Delete): - for i, target in enumerate(node.targets): + for target in node.targets: self.visit(target) def visit_Global(self, node: Global): @@ -492,7 +492,7 @@ def visit_Tuple(self, node): self.__sequence_datastruct_visit(node) def __sequence_datastruct_visit(self, node): - for idx, element in enumerate(node.elts): + for element in node.elts: self.visit(element) # endregion diff --git a/src/graph_generator/extract_graphs.py b/src/graph_generator/extract_graphs.py index c8620ad..f5f63e4 100644 --- a/src/graph_generator/extract_graphs.py +++ b/src/graph_generator/extract_graphs.py @@ -65,13 +65,14 @@ def explore_files( if not os.path.isfile(file_path): continue with open(file_path, encoding="utf-8", errors="ignore") as f: - monitoring.increment_count() - monitoring.enter_file(file_path) # import pdb; pdb.set_trace() if file_path[len(root_dir) :] not in files_to_extract: continue + monitoring.increment_count() + monitoring.enter_file(file_path) + graph = build_graph(f.read(), monitoring, type_lattice) if graph is None or len(graph["supernodes"]) == 0: continue diff --git a/src/graph_generator/graphgenerator.py b/src/graph_generator/graphgenerator.py index 855cdc9..ef82c9c 100644 --- a/src/graph_generator/graphgenerator.py +++ b/src/graph_generator/graphgenerator.py @@ -181,7 +181,7 @@ def parse_symbol_info(sinfo: SymbolInformation) -> Dict[str, Any]: return { "name": sinfo.name, - "annotation": None if not has_annotation else annotation_str, + "annotation": annotation_str if has_annotation else None, "location": first_annotatable_location, "type": sinfo.symbol_type, } @@ -222,9 +222,7 @@ def is_identifier_node(n): return False if keyword.iskeyword(str(n)): return False - if n == self.INDENT or n == self.DEDENT or n == self.NLINE: - return False - return True + return n not in [self.INDENT, self.DEDENT, self.NLINE] all_identifier_like_nodes: Set[TokenNode] = { n for n in self.__node_to_id if is_identifier_node(n) @@ -275,7 +273,7 @@ def visit(self, node: AST): parent = self.__current_parent_node self.__current_parent_node = node try: - method = "visit_" + node.__class__.__name__ + method = f"visit_{node.__class__.__name__}" visitor = getattr(self, method, self.generic_visit) if visitor == self.generic_visit: logging.warning("Unvisited AST type: %s", node.__class__.__name__) @@ -373,7 +371,7 @@ def __get_symbol_for_name(self, name, lineno, col_offset): and name.startswith("__") and not name.endswith("__") ): - name = "_" + self.__scope_symtable[-1].get_name() + name + name = f"_{self.__scope_symtable[-1].get_name()}{name}" current_idx = len(self.__scope_symtable) - 1 while current_idx >= 0: @@ -658,7 +656,6 @@ def visit_Assign(self, node: Assign): and hasattr(node.value, "func") and hasattr(node.value.func, "id") and node.value.func.id == "NewType" - and hasattr(node, "value") and hasattr(node.value, "args") and len(node.value.args) == 2 ): @@ -673,7 +670,7 @@ def visit_Assign(self, node: Assign): assert False for i, target in enumerate(node.targets): - if isinstance(target, Attribute) or isinstance(target, Name): + if isinstance(target, (Attribute, Name)): self.__visit_variable_like( target, target.lineno, @@ -696,7 +693,7 @@ def visit_Assign(self, node: Assign): self.visit(node.value) def visit_AugAssign(self, node: AugAssign): - if isinstance(node.target, Name) or isinstance(node.target, Attribute): + if isinstance(node.target, (Name, Attribute)): self.__visit_variable_like( node.target, node.lineno, node.col_offset, can_annotate_here=False ) @@ -704,7 +701,7 @@ def visit_AugAssign(self, node: AugAssign): self.visit(node.target) self._add_edge(node.target, node.value, EdgeType.COMPUTED_FROM) - self.add_terminal(TokenNode(self.BINOP_SYMBOLS[type(node.op)] + "=")) + self.add_terminal(TokenNode(f"{self.BINOP_SYMBOLS[type(node.op)]}=")) self.visit(node.value) def visit_AnnAssign(self, node: AnnAssign): @@ -729,7 +726,7 @@ def visit_Import(self, node): def visit_ImportFrom(self, node: ImportFrom): for alias in node.names: if node.module is not None: - name = parse_type_annotation_node(node.module + "." + alias.name) + name = parse_type_annotation_node(f"{node.module}.{alias.name}") else: name = parse_type_annotation_node(alias.name) if alias.asname: @@ -965,7 +962,7 @@ def visit_BoolOp(self, node): def visit_Compare(self, node: Compare): self.visit(node.left) - for i, (op, right) in enumerate(zip(node.ops, node.comparators)): + for op, right in zip(node.ops, node.comparators): self.add_terminal(TokenNode(self.CMPOP_SYMBOLS[type(op)])) self.visit(right) @@ -1048,12 +1045,12 @@ def visit_Dict(self, node): self.add_terminal(TokenNode("}")) def visit_FormattedValue(self, node: FormattedValue): - self.add_terminal(TokenNode(str('f"'))) + self.add_terminal(TokenNode('f"')) self.visit(node.value) if node.format_spec is not None: - self.add_terminal(TokenNode(str(":"))) + self.add_terminal(TokenNode(":")) self.visit(node.format_spec) - self.add_terminal(TokenNode(str('"'))) + self.add_terminal(TokenNode('"')) def visit_List(self, node): self.__sequence_datastruct_visit(node, "[", "]") @@ -1066,7 +1063,7 @@ def visit_Tuple(self, node): def __sequence_datastruct_visit(self, node, open_brace: str, close_brace: str): self.add_terminal(TokenNode(open_brace)) - for idx, element in enumerate(node.elts): + for element in node.elts: self.visit(element) self.add_terminal( TokenNode(",") @@ -1090,9 +1087,7 @@ def visit_Num(self, node): self.add_terminal(TokenNode(repr(node.n))) def visit_Str(self, node: Str): - self.add_terminal( - TokenNode('"' + node.s + '"') - ) # Approximate quote addition, but should be good enough. + self.add_terminal(TokenNode(f'"{node.s}"')) # endregion @@ -1111,7 +1106,7 @@ def node_to_label(self, node: Any) -> str: elif node is None: return "None" else: - raise Exception("Unrecognized node type %s" % type(node)) + raise Exception(f"Unrecognized node type {type(node)}") def to_dot( self, @@ -1130,8 +1125,8 @@ def to_dot( nodes_to_be_drawn.add(to_idx) with open(filename, "w") as f: - if len(initial_comment) > 0: - f.write("#" + initial_comment) + if initial_comment != "": + f.write(f"#{initial_comment}") f.write("\n") f.write("digraph program {\n") for node, node_idx in self.__node_to_id.items(): diff --git a/src/graph_generator/graphgenutils.py b/src/graph_generator/graphgenutils.py index 856844f..9f3dfa2 100644 --- a/src/graph_generator/graphgenutils.py +++ b/src/graph_generator/graphgenutils.py @@ -37,12 +37,10 @@ def __hash__(self): return hash(self.name) def __eq__(self, other): - if not isinstance(other, StrSymbol): - return False - return self.name == other.name + return self.name == other.name if isinstance(other, StrSymbol) else False def __str__(self): - return "Symbol: " + self.name + return f"Symbol: {self.name}" class SymbolInformation(NamedTuple): diff --git a/src/graph_generator/type_lattice_generator.py b/src/graph_generator/type_lattice_generator.py index bac2b32..d8c0a12 100644 --- a/src/graph_generator/type_lattice_generator.py +++ b/src/graph_generator/type_lattice_generator.py @@ -78,7 +78,7 @@ def __init__(self, typing_rules_path: str, max_depth_size: int = 2, max_list_siz RemoveGenericWithAnys(), ] ) - assert len(self.__ids_to_nodes) == len(set(repr(r) for r in self.__ids_to_nodes)) + assert len(self.__ids_to_nodes) == len({repr(r) for r in self.__ids_to_nodes}) def create_alias_replacement( self, imported_symbols: Dict[TypeAnnotationNode, TypeAnnotationNode] @@ -100,7 +100,7 @@ def __compute_non_generic_types(self): self.__all_types[parse_type_annotation_node("typing.Tuple")], self.__all_types[parse_type_annotation_node("typing.Callable")], ] - while len(to_visit) > 0: + while to_visit: next_node = to_visit.pop() generic_transitive_closure.add(next_node) to_visit.extend( @@ -129,7 +129,7 @@ def __annotation_to_id(self, annotation: TypeAnnotationNode) -> int: def __all_reachable_from(self, type_idx: int) -> FrozenSet[int]: reachable = set() to_visit = [type_idx] # type: List[int] - while len(to_visit) > 0: + while to_visit: next_type_idx = to_visit.pop() reachable.add(next_type_idx) to_visit.extend( @@ -195,7 +195,7 @@ def __rewrite_verbose(self, type_annotation: TypeAnnotationNode) -> TypeAnnotati def build_graph(self): print( - "Building type graph for project... (%s elements to process)" % len(self.__to_process) + f"Building type graph for project... ({len(self.__to_process)} elements to process)" ) i = 0 @@ -208,8 +208,7 @@ def build_graph(self): i += 1 if i > 500: print( - "Building type graph for project... (%s elements to process)" - % len(self.__to_process) + f"Building type graph for project... ({len(self.__to_process)} elements to process)" ) i = 0 if len(self.__to_process) > 3000: @@ -230,22 +229,21 @@ def build_graph(self): ) was_rewritten = len(all_inherited_types_and_self) > 1 - if was_rewritten: - if ( - not erasure_happened - or len(self.__to_process) < 5000 - or len(all_inherited_types_and_self) < 5 - ): - for type_annotation in all_inherited_types_and_self: - type_annotation = type_annotation.accept_visitor( - self.__max_depth_pruning_visitor, self.__max_annotation_depth, - ) - type_annotation = self.__rewrite_verbose(type_annotation) - type_has_been_seen = type_annotation in self.__all_types - - if not type_has_been_seen: - self.__add_is_a_relationship(next_type, type_annotation) - self.__to_process.append(type_annotation) + if was_rewritten and ( + not erasure_happened + or len(self.__to_process) < 5000 + or len(all_inherited_types_and_self) < 5 + ): + for type_annotation in all_inherited_types_and_self: + type_annotation = type_annotation.accept_visitor( + self.__max_depth_pruning_visitor, self.__max_annotation_depth, + ) + type_annotation = self.__rewrite_verbose(type_annotation) + type_has_been_seen = type_annotation in self.__all_types + + if not type_has_been_seen: + self.__add_is_a_relationship(next_type, type_annotation) + self.__to_process.append(type_annotation) if not was_rewritten and not erasure_happened: # Add a rule to Any @@ -281,11 +279,11 @@ def canonicalize_annotation( def return_json(self) -> Dict[str, Any]: edges = [] for from_type_idx, to_type_idxs in self.is_a_edges.items(): - for to_type_idx in to_type_idxs: - edges.append((from_type_idx, to_type_idx)) - - assert len(self.__ids_to_nodes) == len(set(repr(r) for r in self.__ids_to_nodes)) + edges.extend((from_type_idx, to_type_idx) for to_type_idx in to_type_idxs) + assert len(self.__ids_to_nodes) == len({repr(r) for r in self.__ids_to_nodes}) return { - "nodes": list((repr(type_annotation) for type_annotation in self.__ids_to_nodes)), + "nodes": [ + repr(type_annotation) for type_annotation in self.__ids_to_nodes + ], "edges": edges, } diff --git a/src/graph_generator/typeparsing/erasure.py b/src/graph_generator/typeparsing/erasure.py index 643034a..70b32c9 100644 --- a/src/graph_generator/typeparsing/erasure.py +++ b/src/graph_generator/typeparsing/erasure.py @@ -25,12 +25,16 @@ def visit_subscript_annotation(self, node: SubscriptAnnotationNode): else: next_slices, erasure_happened_at_a_slice = node.slice.accept_visitor(self) - if not erasure_happened_at_a_slice: - return [node, node.value], True # Erase type parameters - return ( - [SubscriptAnnotationNode(value=node.value, slice=s) for s in next_slices], - True, + ( + [ + SubscriptAnnotationNode(value=node.value, slice=s) + for s in next_slices + ], + True, + ) + if erasure_happened_at_a_slice + else ([node, node.value], True) ) def visit_tuple_annotation(self, node: TupleAnnotationNode): diff --git a/src/graph_generator/typeparsing/inheritancerewrite.py b/src/graph_generator/typeparsing/inheritancerewrite.py index 76570da..580a18b 100644 --- a/src/graph_generator/typeparsing/inheritancerewrite.py +++ b/src/graph_generator/typeparsing/inheritancerewrite.py @@ -41,9 +41,7 @@ def visit_subscript_annotation(self, node: SubscriptAnnotationNode): if v in self.__non_generic_types: all_children.append(v) continue - for s in slice_node_options: - all_children.append(SubscriptAnnotationNode(v, s)) - + all_children.extend(SubscriptAnnotationNode(v, s) for s in slice_node_options) return all_children def visit_tuple_annotation(self, node: TupleAnnotationNode): @@ -67,8 +65,7 @@ def visit_list_annotation(self, node: ListAnnotationNode): return r def visit_attribute_annotation(self, node: AttributeAnnotationNode): - v = [node] + list(self.__is_a(node)) - return v + return [node] + list(self.__is_a(node)) def visit_index_annotation(self, node: IndexAnnotationNode): next_values = node.value.accept_visitor(self) diff --git a/src/graph_generator/typeparsing/nodes.py b/src/graph_generator/typeparsing/nodes.py index 250178b..716f558 100644 --- a/src/graph_generator/typeparsing/nodes.py +++ b/src/graph_generator/typeparsing/nodes.py @@ -62,16 +62,17 @@ def accept_visitor(self, visitor: TypeAnnotationVisitor, *args) -> Any: return visitor.visit_subscript_annotation(self, *args) def __repr__(self): - return repr(self.value) + "[" + repr(self.slice) + "]" + return f"{repr(self.value)}[{repr(self.slice)}]" def __hash__(self): return hash(self.value) ^ (hash(self.slice) + 13) def __eq__(self, other): - if not isinstance(other, SubscriptAnnotationNode): - return False - else: - return self.value == other.value and self.slice == other.slice + return ( + self.value == other.value and self.slice == other.slice + if isinstance(other, SubscriptAnnotationNode) + else False + ) @staticmethod def parse(node) -> "SubscriptAnnotationNode": @@ -98,16 +99,14 @@ def __repr__(self): return ", ".join(repr(e) for e in self.elements) def __hash__(self): - if len(self.elements) > 0: - return hash(self.elements) - else: - return 1 + return hash(self.elements) if len(self.elements) > 0 else 1 def __eq__(self, other): - if not isinstance(other, TupleAnnotationNode): - return False - else: - return self.elements == other.elements + return ( + self.elements == other.elements + if isinstance(other, TupleAnnotationNode) + else False + ) @staticmethod def parse(node) -> "TupleAnnotationNode": @@ -132,9 +131,11 @@ def __hash__(self): return hash(self.identifier) def __eq__(self, other): - if not isinstance(other, NameAnnotationNode): - return False - return self.identifier == other.identifier + return ( + self.identifier == other.identifier + if isinstance(other, NameAnnotationNode) + else False + ) @staticmethod def parse(node) -> "NameAnnotationNode": @@ -156,15 +157,14 @@ def __repr__(self): return "[" + ", ".join(repr(e) for e in self.elements) + "]" def __hash__(self): - if len(self.elements) > 0: - return hash(self.elements) - else: - return 2 + return hash(self.elements) if len(self.elements) > 0 else 2 def __eq__(self, other): - if not isinstance(other, ListAnnotationNode): - return False - return self.elements == other.elements + return ( + self.elements == other.elements + if isinstance(other, ListAnnotationNode) + else False + ) @staticmethod def parse(node) -> "ListAnnotationNode": @@ -185,16 +185,17 @@ def accept_visitor(self, visitor: TypeAnnotationVisitor, *args) -> Any: return visitor.visit_attribute_annotation(self, *args) def __repr__(self): - return repr(self.value) + "." + self.attribute + return f"{repr(self.value)}.{self.attribute}" def __hash__(self): return hash(self.attribute) ^ (hash(self.value) + 13) def __eq__(self, other): - if not isinstance(other, AttributeAnnotationNode): - return False - else: - return self.attribute == other.attribute and self.value == other.value + return ( + self.attribute == other.attribute and self.value == other.value + if isinstance(other, AttributeAnnotationNode) + else False + ) @staticmethod def parse(node) -> "AttributeAnnotationNode": @@ -220,9 +221,11 @@ def __hash__(self): return hash(self.value) def __eq__(self, other): - if not isinstance(other, IndexAnnotationNode): - return False - return self.value == other.value + return ( + self.value == other.value + if isinstance(other, IndexAnnotationNode) + else False + ) @staticmethod def parse(node) -> "IndexAnnotationNode": @@ -271,9 +274,11 @@ def __hash__(self): return hash(self.value) def __eq__(self, other): - if not isinstance(other, NameConstantAnnotationNode): - return False - return self.value == other.value + return ( + self.value == other.value + if isinstance(other, NameConstantAnnotationNode) + else False + ) @staticmethod def parse(node) -> "NameConstantAnnotationNode": @@ -345,11 +350,11 @@ def parse_type_annotation_node(node) -> Optional[TypeAnnotationNode]: Processes the node containing the type annotation and return the object corresponding to the node type. """ try: - if isinstance(node, str): - r = parse_type_comment(node) - else: - r = _parse_recursive(node) - return r + return ( + parse_type_comment(node) + if isinstance(node, str) + else _parse_recursive(node) + ) except Exception as e: pass return None diff --git a/src/graph_generator/typeparsing/rewriterules/removestandalones.py b/src/graph_generator/typeparsing/rewriterules/removestandalones.py index d9a66a6..4a61fc4 100644 --- a/src/graph_generator/typeparsing/rewriterules/removestandalones.py +++ b/src/graph_generator/typeparsing/rewriterules/removestandalones.py @@ -18,11 +18,7 @@ class RemoveStandAlones(RewriteRule): ANY_NODE = parse_type_annotation_node("typing.Any") def matches(self, node: TypeAnnotationNode, parent: Optional[TypeAnnotationNode]) -> bool: - if ( - not node == self.UNION_NODE - and not node == self.OPTIONAL_NODE - and not node == self.GENERIC_NODE - ): + if node not in [self.UNION_NODE, self.OPTIONAL_NODE, self.GENERIC_NODE]: return False return not isinstance(parent, SubscriptAnnotationNode) diff --git a/src/graph_generator/typeparsing/rewriterulevisitor.py b/src/graph_generator/typeparsing/rewriterulevisitor.py index c914734..5ceda3e 100644 --- a/src/graph_generator/typeparsing/rewriterulevisitor.py +++ b/src/graph_generator/typeparsing/rewriterulevisitor.py @@ -22,10 +22,14 @@ def __init__(self, rules: List[RewriteRule]): def __apply_on_match( self, original_node: TypeAnnotationNode, parent: TypeAnnotationNode ) -> TypeAnnotationNode: - for rule in self.__rules: - if rule.matches(original_node, parent): - return rule.apply(original_node) - return original_node + return next( + ( + rule.apply(original_node) + for rule in self.__rules + if rule.matches(original_node, parent) + ), + original_node, + ) def visit_subscript_annotation( self, node: SubscriptAnnotationNode, parent: TypeAnnotationNode