Skip to content

Commit ca1e227

Browse files
committed
fix async queue
Signed-off-by: Richard Chien <[email protected]>
1 parent bf4c832 commit ca1e227

File tree

14 files changed

+110
-25
lines changed

14 files changed

+110
-25
lines changed

src/kimi_cli/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from kimi_cli.soul.agent import Runtime, load_agent
2323
from kimi_cli.soul.context import Context
2424
from kimi_cli.soul.kimisoul import KimiSoul
25+
from kimi_cli.utils.aioqueue import QueueShutDown
2526
from kimi_cli.utils.logging import StreamToLogger, logger
2627
from kimi_cli.utils.path import shorten_home
2728
from kimi_cli.wire import Wire, WireUISide
@@ -201,7 +202,7 @@ async def _ui_loop_fn(wire: Wire) -> None:
201202
while True:
202203
msg = await wire_ui.receive()
203204
yield msg
204-
except asyncio.QueueShutDown:
205+
except QueueShutDown:
205206
pass
206207
finally:
207208
# stop consuming Wire messages

src/kimi_cli/soul/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from kosong.message import ContentPart
1212

13+
from kimi_cli.utils.aioqueue import QueueShutDown
1314
from kimi_cli.utils.logging import logger
1415
from kimi_cli.wire import Wire
1516
from kimi_cli.wire.message import WireMessage
@@ -164,7 +165,7 @@ async def run_soul(
164165
wire.shutdown()
165166
try:
166167
await asyncio.wait_for(ui_task, timeout=0.5)
167-
except asyncio.QueueShutDown:
168+
except QueueShutDown:
168169
logger.debug("UI loop shut down")
169170
pass
170171
except TimeoutError:

src/kimi_cli/soul/approval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Literal
77

88
from kimi_cli.soul.toolset import get_current_tool_call_or_none
9+
from kimi_cli.utils.aioqueue import Queue
910
from kimi_cli.utils.logging import logger
1011
from kimi_cli.wire.display import DisplayBlock
1112

@@ -25,7 +26,7 @@ class Request:
2526

2627
class Approval:
2728
def __init__(self, yolo: bool = False):
28-
self._request_queue = asyncio.Queue[Request]()
29+
self._request_queue = Queue[Request]()
2930
self._requests: dict[str, tuple[Request, asyncio.Future[bool]]] = {}
3031
self._yolo = yolo
3132
self._auto_approve_actions: set[str] = set() # TODO: persist across sessions

src/kimi_cli/ui/acp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from kimi_cli.soul import Soul, run_soul
1818
from kimi_cli.soul.kimisoul import KimiSoul
1919
from kimi_cli.soul.toolset import KimiToolset
20+
from kimi_cli.utils.aioqueue import QueueShutDown
2021
from kimi_cli.utils.logging import logger
2122
from kimi_cli.wire import Wire, WireUISide
2223
from kimi_cli.wire.message import WireMessage
@@ -92,7 +93,7 @@ async def _ui_loop_fn(wire: Wire) -> None:
9293
while True:
9394
msg = await wire_ui.receive()
9495
yield msg
95-
except asyncio.QueueShutDown:
96+
except QueueShutDown:
9697
pass
9798
finally:
9899
# stop consuming Wire messages

src/kimi_cli/ui/print/visualize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
from dataclasses import dataclass
32
from typing import Protocol
43

@@ -8,6 +7,7 @@
87

98
from kimi_cli.cli import OutputFormat
109
from kimi_cli.soul.message import tool_result_to_message
10+
from kimi_cli.utils.aioqueue import QueueShutDown
1111
from kimi_cli.wire import Wire
1212
from kimi_cli.wire.message import StepBegin, StepInterrupted, WireMessage
1313

@@ -162,7 +162,7 @@ async def visualize(output_format: OutputFormat, final_only: bool, wire: Wire) -
162162
while True:
163163
try:
164164
msg = await wire_ui.receive()
165-
except asyncio.QueueShutDown:
165+
except QueueShutDown:
166166
handler.flush()
167167
break
168168

src/kimi_cli/ui/shell/keyboard.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from collections.abc import AsyncGenerator, Callable
88
from enum import Enum, auto
99

10+
from kimi_cli.utils.aioqueue import Queue
11+
1012

1113
class KeyEvent(Enum):
1214
UP = auto()
@@ -20,7 +22,7 @@ class KeyEvent(Enum):
2022

2123
async def listen_for_keyboard() -> AsyncGenerator[KeyEvent]:
2224
loop = asyncio.get_running_loop()
23-
queue = asyncio.Queue[KeyEvent]()
25+
queue = Queue[KeyEvent]()
2426
cancel_event = threading.Event()
2527

2628
def emit(event: KeyEvent) -> None:

src/kimi_cli/ui/shell/replay.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from kimi_cli.ui.shell.console import console
1717
from kimi_cli.ui.shell.prompt import PROMPT_SYMBOL
1818
from kimi_cli.ui.shell.visualize import visualize
19+
from kimi_cli.utils.aioqueue import QueueShutDown
1920
from kimi_cli.utils.logging import logger
2021
from kimi_cli.utils.message import message_stringify
2122
from kimi_cli.wire import Wire
@@ -66,7 +67,7 @@ async def replay_recent_history(
6667
wire.soul_side.send(event)
6768
await asyncio.sleep(0) # yield to UI loop
6869
wire.shutdown()
69-
with contextlib.suppress(asyncio.QueueShutDown):
70+
with contextlib.suppress(QueueShutDown):
7071
await ui_task
7172

7273

src/kimi_cli/ui/shell/visualize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from kimi_cli.tools import extract_key_argument
2020
from kimi_cli.ui.shell.console import console
2121
from kimi_cli.ui.shell.keyboard import KeyEvent, listen_for_keyboard
22+
from kimi_cli.utils.aioqueue import QueueShutDown
2223
from kimi_cli.utils.rich.columns import BulletColumns
2324
from kimi_cli.utils.rich.markdown import Markdown
2425
from kimi_cli.wire import WireUISide
@@ -350,7 +351,7 @@ def keyboard_handler(event: KeyEvent) -> None:
350351
while True:
351352
try:
352353
msg = await wire.receive()
353-
except asyncio.QueueShutDown:
354+
except QueueShutDown:
354355
self.cleanup(is_interrupt=False)
355356
live.update(self.compose())
356357
break

src/kimi_cli/ui/wire/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from kimi_cli.soul import LLMNotSet, LLMNotSupported, MaxStepsReached, RunCancelled, Soul, run_soul
1111
from kimi_cli.soul.kimisoul import KimiSoul
12+
from kimi_cli.utils.aioqueue import Queue, QueueShutDown
1213
from kimi_cli.utils.logging import logger
1314
from kimi_cli.wire import Wire
1415
from kimi_cli.wire.message import ApprovalRequest, Request
@@ -37,7 +38,7 @@ def __init__(self, soul: Soul):
3738

3839
# outward
3940
self._write_task: asyncio.Task[None] | None = None
40-
self._write_queue: asyncio.Queue[JSONRPCOutMessage] = asyncio.Queue()
41+
self._write_queue: Queue[JSONRPCOutMessage] = Queue()
4142

4243
# inward
4344
self._dispatch_tasks: set[asyncio.Task[None]] = set()
@@ -65,7 +66,7 @@ async def _write_loop(self) -> None:
6566
while True:
6667
try:
6768
msg = await self._write_queue.get()
68-
except asyncio.QueueShutDown:
69+
except QueueShutDown:
6970
logger.debug("Send queue shut down, stopping Wire server write loop")
7071
break
7172
self._writer.write(msg.model_dump_json().encode("utf-8") + b"\n")
@@ -141,7 +142,7 @@ async def _dispatch_msg(self, msg: JSONRPCInMessage) -> None:
141142
async def _send_msg(self, msg: JSONRPCOutMessage) -> None:
142143
try:
143144
await self._write_queue.put(msg)
144-
except asyncio.QueueShutDown:
145+
except QueueShutDown:
145146
logger.error("Send queue shut down; dropping message: {msg}", msg=msg)
146147

147148
@property

src/kimi_cli/utils/aioqueue.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import sys
5+
6+
if sys.version_info >= (3, 13):
7+
QueueShutDown = asyncio.QueueShutDown # type: ignore[assignment]
8+
9+
class Queue[T](asyncio.Queue[T]):
10+
"""Asyncio Queue with shutdown support."""
11+
12+
else:
13+
14+
class QueueShutDown(Exception):
15+
"""Raised when operating on a shut down queue."""
16+
17+
class _Shutdown:
18+
"""Sentinel for queue shutdown."""
19+
20+
_SHUTDOWN = _Shutdown()
21+
22+
class Queue[T](asyncio.Queue[T | _Shutdown]):
23+
"""Asyncio Queue with shutdown support for Python < 3.13."""
24+
25+
def __init__(self) -> None:
26+
super().__init__()
27+
self._shutdown = False
28+
29+
def shutdown(self, immediate: bool = False) -> None:
30+
if self._shutdown:
31+
return
32+
self._shutdown = True
33+
if immediate:
34+
self._queue.clear()
35+
36+
getters = list(getattr(self, "_getters", []))
37+
count = max(1, len(getters))
38+
self._enqueue_shutdown(count)
39+
40+
def _enqueue_shutdown(self, count: int) -> None:
41+
for _ in range(count):
42+
try:
43+
super().put_nowait(_SHUTDOWN)
44+
except asyncio.QueueFull:
45+
self._queue.clear()
46+
super().put_nowait(_SHUTDOWN)
47+
48+
async def get(self) -> T:
49+
if self._shutdown and self.empty():
50+
raise QueueShutDown
51+
item = await super().get()
52+
if isinstance(item, _Shutdown):
53+
raise QueueShutDown
54+
return item
55+
56+
def get_nowait(self) -> T:
57+
if self._shutdown and self.empty():
58+
raise QueueShutDown
59+
item = super().get_nowait()
60+
if isinstance(item, _Shutdown):
61+
raise QueueShutDown
62+
return item
63+
64+
async def put(self, item: T) -> None:
65+
if self._shutdown:
66+
raise QueueShutDown
67+
await super().put(item)
68+
69+
def put_nowait(self, item: T) -> None:
70+
if self._shutdown:
71+
raise QueueShutDown
72+
super().put_nowait(item)

0 commit comments

Comments
 (0)