🤖 自动格式化代码 [skip ci]
This commit is contained in:
@@ -43,8 +43,8 @@ class S4UStreamGenerator:
|
||||
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
|
||||
self.sentence_split_pattern = re.compile(
|
||||
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
|
||||
r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))' # 匹配直到句子结束符
|
||||
, re.UNICODE | re.DOTALL
|
||||
r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))', # 匹配直到句子结束符
|
||||
re.UNICODE | re.DOTALL,
|
||||
)
|
||||
|
||||
async def generate_response(
|
||||
@@ -68,7 +68,7 @@ class S4UStreamGenerator:
|
||||
|
||||
# 构建prompt
|
||||
if previous_reply_context:
|
||||
message_txt = f"""
|
||||
message_txt = f"""
|
||||
你正在回复用户的消息,但中途被打断了。这是已有的对话上下文:
|
||||
[你已经对上一条消息说的话]: {previous_reply_context}
|
||||
---
|
||||
@@ -78,9 +78,8 @@ class S4UStreamGenerator:
|
||||
else:
|
||||
message_txt = message.processed_plain_text
|
||||
|
||||
|
||||
prompt = await prompt_builder.build_prompt_normal(
|
||||
message = message,
|
||||
message=message,
|
||||
message_txt=message_txt,
|
||||
sender_name=sender_name,
|
||||
chat_stream=message.chat_stream,
|
||||
@@ -109,16 +108,16 @@ class S4UStreamGenerator:
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
print(prompt)
|
||||
|
||||
|
||||
buffer = ""
|
||||
delimiters = ",。!?,.!?\n\r" # For final trimming
|
||||
punctuation_buffer = ""
|
||||
|
||||
|
||||
async for content in client.get_stream_content(
|
||||
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
|
||||
):
|
||||
buffer += content
|
||||
|
||||
|
||||
# 使用正则表达式匹配句子
|
||||
last_match_end = 0
|
||||
for match in self.sentence_split_pattern.finditer(buffer):
|
||||
@@ -132,24 +131,23 @@ class S4UStreamGenerator:
|
||||
else:
|
||||
# 发送之前累积的标点和当前句子
|
||||
to_yield = punctuation_buffer + sentence
|
||||
if to_yield.endswith((',', ',')):
|
||||
to_yield = to_yield.rstrip(',,')
|
||||
|
||||
if to_yield.endswith((",", ",")):
|
||||
to_yield = to_yield.rstrip(",,")
|
||||
|
||||
yield to_yield
|
||||
punctuation_buffer = "" # 清空标点符号缓冲区
|
||||
await asyncio.sleep(0) # 允许其他任务运行
|
||||
|
||||
punctuation_buffer = "" # 清空标点符号缓冲区
|
||||
await asyncio.sleep(0) # 允许其他任务运行
|
||||
|
||||
last_match_end = match.end(0)
|
||||
|
||||
|
||||
# 从缓冲区移除已发送的部分
|
||||
if last_match_end > 0:
|
||||
buffer = buffer[last_match_end:]
|
||||
|
||||
|
||||
# 发送缓冲区中剩余的任何内容
|
||||
to_yield = (punctuation_buffer + buffer).strip()
|
||||
if to_yield:
|
||||
if to_yield.endswith((',', ',')):
|
||||
to_yield = to_yield.rstrip(',,')
|
||||
if to_yield.endswith((",", ",")):
|
||||
to_yield = to_yield.rstrip(",,")
|
||||
if to_yield:
|
||||
yield to_yield
|
||||
|
||||
|
||||
Reference in New Issue
Block a user