77 contextmanager ,
88)
99from dataclasses import dataclass , field
10- from datetime import timedelta
10+ from datetime import datetime , timedelta
11+ from typing import Any
1112
12- from anyio import fail_after , sleep
13+ from anyio import create_task_group , fail_after , sleep
1314from anyio .from_thread import BlockingPortal
1415from grpc .aio import Channel
1516from jumpstarter_protocol import jumpstarter_pb2 , jumpstarter_pb2_grpc
@@ -39,6 +40,7 @@ class Lease(AbstractContextManager, AbstractAsyncContextManager):
3940 release : bool = True # release on contexts exit
4041 controller : jumpstarter_pb2_grpc .ControllerServiceStub = field (init = False )
4142 tls_config : TLSConfigV1Alpha1 = field (default_factory = TLSConfigV1Alpha1 )
43+ grpc_options : dict [str , Any ] = field (default_factory = dict )
4244
4345 def __post_init__ (self ):
4446 if hasattr (super (), "__post_init__" ):
@@ -62,6 +64,11 @@ async def _create(self):
6264 ).name
6365 logger .info ("Created lease request for selector %s for duration %s" , selector , duration )
6466
67+ async def get (self ):
68+ with translate_grpc_exceptions ():
69+ svc = ClientService (channel = self .channel , namespace = self .namespace )
70+ return await svc .GetLease (name = self .name )
71+
6572 def request (self ):
6673 """Request a lease, or verifies a lease which was already created.
6774
@@ -96,11 +103,7 @@ async def _acquire(self):
96103 with fail_after (300 ): # TODO: configurable timeout
97104 while True :
98105 logger .debug ("Polling Lease %s" , self .name )
99- with translate_grpc_exceptions ():
100- result = await self .svc .GetLease (
101- name = self .name ,
102- )
103-
106+ result = await self .get ()
104107 # lease ready
105108 if condition_true (result .conditions , "Ready" ):
106109 logger .debug ("Lease %s acquired" , self .name )
@@ -148,14 +151,39 @@ def __exit__(self, exc_type, exc_value, traceback):
148151 async def handle_async (self , stream ):
149152 logger .debug ("Connecting to Lease with name %s" , self .name )
150153 response = await self .controller .Dial (jumpstarter_pb2 .DialRequest (lease_name = self .name ))
151- async with connect_router_stream (response .router_endpoint , response .router_token , stream , self .tls_config ):
154+ async with connect_router_stream (
155+ response .router_endpoint , response .router_token , stream , self .tls_config , self .grpc_options
156+ ):
152157 pass
153158
154159 @asynccontextmanager
155160 async def serve_unix_async (self ):
156161 async with TemporaryUnixListener (self .handle_async ) as path :
157162 yield path
158163
164+ @asynccontextmanager
165+ async def monitor_async (self , threshold : timedelta = timedelta (minutes = 5 )):
166+ async def _monitor ():
167+ while True :
168+ lease = await self .get ()
169+ if lease .effective_begin_time :
170+ end_time = lease .effective_begin_time + lease .duration
171+ remain = end_time - datetime .now (tz = datetime .now ().astimezone ().tzinfo )
172+ if remain < threshold :
173+ logger .info ("Lease {} ending soon in {} at {}" .format (self .name , remain , end_time ))
174+ await sleep (threshold .total_seconds ())
175+ else :
176+ await sleep (5 )
177+ else :
178+ await sleep (1 )
179+
180+ async with create_task_group () as tg :
181+ tg .start_soon (_monitor )
182+ try :
183+ yield
184+ finally :
185+ tg .cancel_scope .cancel ()
186+
159187 @asynccontextmanager
160188 async def connect_async (self , stack ):
161189 async with self .serve_unix_async () as path :
@@ -172,3 +200,8 @@ def connect(self):
172200 def serve_unix (self ):
173201 with self .portal .wrap_async_context_manager (self .serve_unix_async ()) as path :
174202 yield path
203+
204+ @contextmanager
205+ def monitor (self , threshold : timedelta = timedelta (minutes = 5 )):
206+ with self .portal .wrap_async_context_manager (self .monitor_async (threshold )):
207+ yield
0 commit comments