@@ -56,6 +56,9 @@ class LoopState:
5656 nodes_needing_checkpoints : list [cst .Return | cst .Yield ] = field (
5757 default_factory = list
5858 )
59+ possibly_redundant_lowlevel_checkpoints : list [cst .BaseExpression ] = field (
60+ default_factory = list
61+ )
5962
6063 def copy (self ):
6164 return LoopState (
@@ -66,6 +69,7 @@ def copy(self):
6669 uncheckpointed_before_break = self .uncheckpointed_before_break .copy (),
6770 artificial_errors = self .artificial_errors .copy (),
6871 nodes_needing_checkpoints = self .nodes_needing_checkpoints .copy (),
72+ possibly_redundant_lowlevel_checkpoints = self .possibly_redundant_lowlevel_checkpoints .copy (),
6973 )
7074
7175
@@ -214,6 +218,22 @@ def leave_Yield(
214218 leave_Return = leave_Yield # type: ignore
215219
216220
221+ # class RemoveLowlevelCheckpoints(cst.CSTTransformer):
222+ # def __init__(self, stmts_to_remove: set[cst.Await]):
223+ # self.stmts_to_remove = stmts_to_remove
224+ #
225+ # def leave_Await(self, original_node: cst.Await, updated_node: cst.Await) -> cst.Await:
226+ # # return original node to preserve identity
227+ # return original_node
228+ #
229+ # # for some reason you can't just return RemovalSentinel from Await, so we have to
230+ # # visit the possible wrappers and modify their bodies instead
231+ #
232+ # def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
233+ # new_body = [stmt for stmt in updated_node.body.body if stmt not in self.stmts_to_remove]
234+ # return updated_node.with_changes(body=updated_node.body.with_changes(body=new_body))
235+
236+
217237@error_class_cst
218238@disabled_by_default
219239class Visitor91X (Flake8TrioVisitor_cst , CommonVisitors ):
@@ -226,16 +246,27 @@ class Visitor91X(Flake8TrioVisitor_cst, CommonVisitors):
226246 "{0} from async iterable with no guaranteed checkpoint since {1.name} "
227247 "on line {1.lineno}."
228248 ),
249+ "TRIO912" : "Redundant checkpoint with no effect on program execution." ,
229250 }
230251
231252 def __init__ (self , * args : Any , ** kwargs : Any ):
232253 super ().__init__ (* args , ** kwargs )
233254 self .has_yield = False
234255 self .safe_decorator = False
235256 self .async_function = False
236- self .uncheckpointed_statements : set [Statement ] = set ()
237257 self .comp_unknown = False
238258
259+ self .uncheckpointed_statements : set [Statement ] = set ()
260+ self .checkpointed_by_lowlevel = False
261+
262+ # value == False, not redundant (or not determined to be redundant yet)
263+ # value == True, there were no uncheckpointed statements when we encountered it
264+ # value = expr/stmt, made redundant by the given expr/stmt
265+ self .lowlevel_checkpoints : dict [
266+ cst .Await , cst .BaseStatement | cst .BaseExpression | bool
267+ ] = {}
268+ self .lowlevel_checkpoint_updated_nodes : dict [cst .Await , cst .Await ] = {}
269+
239270 self .loop_state = LoopState ()
240271 self .try_state = TryState ()
241272
@@ -258,6 +289,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
258289 "safe_decorator" ,
259290 "async_function" ,
260291 "uncheckpointed_statements" ,
292+ "lowlevel_checkpoints" ,
261293 "loop_state" ,
262294 "try_state" ,
263295 copy = True ,
@@ -299,8 +331,31 @@ def leave_FunctionDef(
299331 indentedblock = updated_node .body .with_changes (body = new_body )
300332 updated_node = updated_node .with_changes (body = indentedblock )
301333
334+ res : cst .FunctionDef = updated_node
335+ to_remove : set [cst .Await ] = set ()
336+ for expr , value in self .lowlevel_checkpoints .items ():
337+ if value != False :
338+ self .error (expr , error_code = "TRIO912" )
339+ if self .should_autofix ():
340+ to_remove .add (self .lowlevel_checkpoint_updated_nodes .pop (expr ))
341+
342+ if to_remove :
343+ new_body = []
344+ for stmt in updated_node .body .body :
345+ if not m .matches (
346+ stmt ,
347+ m .SimpleStatementLine (
348+ [m .Expr (m .MatchIfTrue (lambda x : x in to_remove ))]
349+ ),
350+ ):
351+ new_body .append (stmt ) # type: ignore
352+ assert new_body != updated_node .body .body
353+ res = updated_node .with_changes (
354+ body = updated_node .body .with_changes (body = new_body )
355+ )
356+
302357 self .restore_state (original_node )
303- return updated_node # noqa: R504
358+ return res
304359
305360 # error if function exit/return/yields with uncheckpointed statements
306361 # returns a bool indicating if any real (i.e. not artificial) errors were raised
@@ -372,12 +427,48 @@ def error_91x(
372427 error_code = "TRIO911" if self .has_yield else "TRIO910" ,
373428 )
374429
430+ def is_lowlevel_checkpoint (self , node : cst .BaseExpression ) -> bool :
431+ # TODO: match against both libraries if both are imported
432+ return m .matches (
433+ node ,
434+ m .Call (
435+ m .Attribute (
436+ m .Attribute (m .Name (self .library [0 ]), m .Name ("lowlevel" )),
437+ m .Name ("checkpoint" ),
438+ )
439+ ),
440+ )
441+
442+ def visit_Await (self , node : cst .Await ) -> None :
443+ # do a match against the awaited expr
444+ # if that is trio.lowlevel.checkpoint, and uncheckpointed statements
445+ # are empty, raise TRIO912.
446+ if self .is_lowlevel_checkpoint (node .expression ):
447+ if not self .uncheckpointed_statements :
448+ self .lowlevel_checkpoints [node ] = True
449+ elif self .uncheckpointed_statements == {ARTIFICIAL_STATEMENT }:
450+ self .loop_state .possibly_redundant_lowlevel_checkpoints .append (node )
451+ else :
452+ self .lowlevel_checkpoints [node ] = False
453+ # if trio.lowlevel.checkpoint and *not* empty, take note of it in a special list.
454+ elif not self .uncheckpointed_statements :
455+ for expr , value in self .lowlevel_checkpoints .items ():
456+ if value == False :
457+ self .lowlevel_checkpoints [expr ] = node
458+
459+ # if this is not a trio.lowlevel.checkpoint, and there are no uncheckpointed statements, check if there is a lowlevel checkpoint in the special list. If so, raise a TRIO912 for it and remove it.
460+
375461 def leave_Await (
376462 self , original_node : cst .Await , updated_node : cst .Await
377463 ) -> cst .Await :
378464 # the expression being awaited is not checkpointed
379465 # so only set checkpoint after the await node
380466
467+ # TODO: dirty hack to get identity right, the logic in visit should maybe be
468+ # moved/split into the leave
469+ if original_node in self .lowlevel_checkpoints :
470+ self .lowlevel_checkpoint_updated_nodes [original_node ] = updated_node
471+
381472 # all nodes are now checkpointed
382473 self .uncheckpointed_statements = set ()
383474 return updated_node
@@ -494,6 +585,10 @@ def leave_Try(self, original_node: cst.Try, updated_node: cst.Try) -> cst.Try:
494585 self .restore_state (original_node )
495586 return updated_node
496587
588+ # if a previous lowlevel checkpoint is marked as redundant after all bodies, then
589+ # it's redundant.
590+ # If any body marks it as necessary, then it's necessary.
591+ # Otherwise, it keeps it's state from before.
497592 def leave_If_test (self , node : cst .If | cst .IfExp ) -> None :
498593 if not self .async_function :
499594 return
@@ -604,6 +699,11 @@ def leave_While_body(self, node: cst.For | cst.While):
604699 if not any_error :
605700 self .loop_state .nodes_needing_checkpoints = []
606701
702+ # but lowlevel checkpoints are redundant
703+ for expr in self .loop_state .possibly_redundant_lowlevel_checkpoints :
704+ self .error (expr , error_code = "TRIO912" )
705+ # self.possibly_redundant_lowlevel_checkpoints.clear()
706+
607707 # replace artificial statements in else with prebody uncheckpointed statements
608708 # non-artificial stmts before continue/break/at body end will already be in them
609709 for stmts in (
@@ -654,6 +754,12 @@ def leave_While_orelse(self, node: cst.For | cst.While):
654754 # reset break & continue in case of nested loops
655755 self .outer [node ]["uncheckpointed_statements" ] = self .uncheckpointed_statements
656756
757+ # TODO: if this loop always checkpoints
758+ # e.g. from being an async for, or being guaranteed to run once, or other stuff.
759+ # then we can warn about redundant checkpoints before the loop.
760+ # ... except if the reason we always checkpoint is due to redundant checkpoints
761+ # we're about to remove.... :thinking:
762+
657763 leave_For_orelse = leave_While_orelse
658764
659765 def leave_While (
0 commit comments