1- from asyncio import create_task , Future , Task
1+ from asyncio import create_task , Future , Task , wait
22from collections .abc import Awaitable , Callable
33from dataclasses import dataclass
44from typing import Optional , Union
@@ -30,13 +30,12 @@ class ASGIHTTPConnection(HTTPConnection):
3030 This provides the API for sending the response.
3131 """
3232
33- def __init__ (
34- self , send_cb : SendCallable , context : ASGIHTTPRequestContext , task_holder : set
35- ):
33+ def __init__ (self , send_cb : SendCallable , context : ASGIHTTPRequestContext ):
3634 self .send_cb = send_cb
3735 self .context = context
38- self .task_holder = task_holder
36+ self .task_holder : set [ Task ] = set ()
3937 self ._close_callback : Callable [[], None ] | None = None
38+ self ._request_finished : Future [None ] = Future ()
4039
4140 # Various tornado APIs (e.g. RequestHandler.flush()) return a Future which
4241 # application code does not need to await. The operations these represent
@@ -90,10 +89,12 @@ def finish(self) -> None:
9089 self .send_cb (
9190 {
9291 "type" : "http.response.body" ,
92+ "body" : b"" ,
9393 "more_body" : False ,
9494 }
9595 )
9696 )
97+ self ._request_finished .set_result (None )
9798
9899 def set_close_callback (self , callback : Optional [Callable [[], None ]]) -> None :
99100 self ._close_callback = callback
@@ -103,14 +104,19 @@ def _on_connection_close(self) -> None:
103104 callback = self ._close_callback
104105 self ._close_callback = None
105106 callback ()
107+ self ._request_finished .set_result (None )
108+
109+ async def wait_finish (self ) -> None :
110+ """For the ASGI interface: wait for all input & output to finish"""
111+ await self ._request_finished
112+ await wait (self .task_holder )
106113
107114
108115class ASGIAdapter :
109116 """Wrap a tornado application object to use with an ASGI server"""
110117
111118 def __init__ (self , application : Application ):
112119 self .application = application
113- self .task_holder : set [Task ] = set ()
114120
115121 async def __call__ (
116122 self , scope : dict , receive : ReceiveCallable , send : SendCallable
@@ -128,7 +134,7 @@ async def http_scope(
128134 ctx .address = tuple (client_addr )
129135 ctx .remote_ip = client_addr [0 ]
130136
131- conn = ASGIHTTPConnection (send , ctx , self . task_holder )
137+ conn = ASGIHTTPConnection (send , ctx )
132138 req_target = scope ["path" ]
133139 if qs := scope ["query_string" ]:
134140 req_target += "?" + qs .decode ("latin1" )
@@ -156,3 +162,5 @@ async def http_scope(
156162 msg_delegate .on_connection_close ()
157163 conn ._on_connection_close ()
158164 break
165+
166+ await conn .wait_finish ()
0 commit comments