@@ -1467,7 +1467,7 @@ def with_ctx(
14671467 self ,
14681468 / ,
14691469 * ,
1470- disposables : Disposables | Collection [Disposable ],
1470+ disposables : Collection [Disposable ],
14711471 ) -> Self : ...
14721472
14731473 @overload
@@ -1478,12 +1478,12 @@ def with_ctx(
14781478 * state : State ,
14791479 ) -> Self : ...
14801480
1481- def with_ctx ( # noqa: C901
1481+ def with_ctx (
14821482 self ,
14831483 state : State | None = None ,
14841484 / ,
14851485 * states : State ,
1486- disposables : Disposables | Collection [Disposable ] | None = None ,
1486+ disposables : Collection [Disposable ] = () ,
14871487 ) -> Self :
14881488 """
14891489 Apply the specified state context to this stage.
@@ -1502,7 +1502,7 @@ def with_ctx( # noqa: C901
15021502 Optional additional `State` objects to include in the state context.
15031503 Only applicable when `state_context` is a `State`.
15041504
1505- disposables: Disposables | Collection[Disposable] | None
1505+ disposables: Collection[Disposable]
15061506 Optional Disposables which will be used for execution of this stage.
15071507 State produced by disposables will be used within the context state.
15081508
@@ -1513,47 +1513,17 @@ def with_ctx( # noqa: C901
15131513 """
15141514 execution : StageExecution = self ._execution
15151515
1516- resolved_disposables : Disposables | None
1517- if disposables is None :
1518- resolved_disposables = None
1519-
1520- elif isinstance (disposables , Disposables ):
1521- resolved_disposables = disposables
1522-
1523- else :
1524- resolved_disposables = Disposables (disposables )
1525-
1526- match (state , resolved_disposables ):
1527- case (None , None ):
1528- assert not states # nosec: B101
1529- return self # nothing to change...
1530-
1531- case (None , ctx_disposables ):
1532- assert not states # nosec: B101
1516+ if ctx_state := state :
1517+ if disposables :
15331518
15341519 async def stage (
15351520 * ,
15361521 state : StageState ,
15371522 ) -> StageState :
1538- result_state : StageState
1539- try :
1540- with ctx .updated (* await ctx_disposables .prepare ()):
1541- result_state = await execution (state = state )
1542-
1543- except BaseException as exc :
1544- await ctx_disposables .dispose (
1545- exc_type = type (exc ),
1546- exc_val = exc ,
1547- exc_tb = exc .__traceback__ ,
1548- )
1549- raise exc
1550-
1551- else :
1552- await ctx_disposables .dispose ()
1553-
1554- return result_state
1555-
1556- case (ctx_state , None ):
1523+ async with Disposables (disposables ) as disposable_state :
1524+ with ctx .updated (* disposable_state , ctx_state , * states ):
1525+ return await execution (state = state )
1526+ else :
15571527
15581528 async def stage (
15591529 * ,
@@ -1562,30 +1532,18 @@ async def stage(
15621532 with ctx .updated (ctx_state , * states ):
15631533 return await execution (state = state )
15641534
1565- case (ctx_state , ctx_disposables ):
1566-
1567- async def stage (
1568- * ,
1569- state : StageState ,
1570- ) -> StageState :
1571- disposables_state : Iterable [State ] = await ctx_disposables .prepare ()
1572- result_state : StageState
1573- try :
1574- with ctx .updated (ctx_state , * disposables_state , * states ):
1575- result_state = await execution (state = state )
1576-
1577- except BaseException as exc :
1578- await ctx_disposables .dispose (
1579- exc_type = type (exc ),
1580- exc_val = exc ,
1581- exc_tb = exc .__traceback__ ,
1582- )
1583- raise exc
1584-
1585- else :
1586- await ctx_disposables .dispose ()
1535+ elif disposables :
1536+ assert not states # nosec: B101
15871537
1588- return result_state
1538+ async def stage (
1539+ * ,
1540+ state : StageState ,
1541+ ) -> StageState :
1542+ async with ctx .disposables (* disposables ):
1543+ return await execution (state = state )
1544+ else :
1545+ assert not states # nosec: B101
1546+ return self # nothing to change...
15891547
15901548 return self .__class__ (
15911549 stage ,
0 commit comments