feat:修改maisaka
This commit is contained in:
@@ -1,62 +1,56 @@
|
||||
"""
|
||||
MaiSaka - 异步输入读取器
|
||||
基于后台线程的异步标准输入读取,通过 asyncio.Queue 传递给异步代码。
|
||||
将阻塞的标准输入读取放到后台线程中,供 asyncio 循环安全消费。
|
||||
"""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InputReader:
|
||||
"""
|
||||
基于后台线程的异步标准输入读取器。
|
||||
"""后台读取标准输入,并通过 asyncio.Queue 向主循环投递结果。"""
|
||||
|
||||
使用单一守护线程持续读取 stdin,通过 asyncio.Queue 传递给异步代码。
|
||||
保证整个应用只有一个线程读 stdin,避免多线程竞争。
|
||||
支持带超时的读取,用于 LLM wait 工具。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
def __init__(self) -> None:
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._queue: asyncio.Queue[Optional[str]] = asyncio.Queue()
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
def start(self, loop: asyncio.AbstractEventLoop):
|
||||
"""启动后台读取线程(仅首次调用生效)"""
|
||||
if self._thread is not None:
|
||||
def start(self, loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""启动后台输入线程。重复调用时忽略。"""
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._loop = loop
|
||||
self._thread = threading.Thread(target=self._read_loop, daemon=True)
|
||||
self._stop_event.clear()
|
||||
self._thread = threading.Thread(target=self._read_loop, name="maisaka-input-reader", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def _read_loop(self):
|
||||
"""后台线程:持续从 stdin 读取行"""
|
||||
try:
|
||||
while True:
|
||||
line = sys.stdin.readline()
|
||||
if not line: # EOF
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
|
||||
break
|
||||
stripped = line.rstrip("\n").rstrip("\r")
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, stripped)
|
||||
except Exception:
|
||||
pass
|
||||
def _read_loop(self) -> None:
|
||||
"""在后台线程中阻塞读取 stdin。"""
|
||||
while not self._stop_event.is_set():
|
||||
line = sys.stdin.readline()
|
||||
if self._loop is None:
|
||||
return
|
||||
|
||||
async def get_line(self, timeout: Optional[float] = None) -> Optional[str]:
|
||||
"""
|
||||
异步获取下一行输入。
|
||||
if line == "":
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
|
||||
return
|
||||
|
||||
Args:
|
||||
timeout: 超时秒数,None 表示无限等待
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, line.rstrip("\r\n"))
|
||||
|
||||
Returns:
|
||||
输入的字符串,超时或 EOF 返回 None
|
||||
"""
|
||||
try:
|
||||
if timeout is not None:
|
||||
return await asyncio.wait_for(self._queue.get(), timeout=timeout)
|
||||
async def get_line(self, timeout: Optional[int] = None) -> Optional[str]:
|
||||
"""异步获取一行输入;设置 timeout 时支持超时返回。"""
|
||||
if timeout is None:
|
||||
return await self._queue.get()
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(self._queue.get(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
"""请求后台线程停止。"""
|
||||
self._stop_event.set()
|
||||
|
||||
Reference in New Issue
Block a user