超级Ruff
This commit is contained in:
3
bot.py
3
bot.py
@@ -62,9 +62,10 @@ def easter_egg():
|
||||
async def graceful_shutdown(): # sourcery skip: use-named-expression
|
||||
try:
|
||||
logger.info("正在优雅关闭麦麦...")
|
||||
|
||||
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
|
||||
# 触发 ON_STOP 事件
|
||||
await events_manager.handle_mai_events(event_type=EventType.ON_STOP)
|
||||
|
||||
|
||||
@@ -5,12 +5,11 @@ from typing import Dict, List
|
||||
|
||||
# Add project root to Python path
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
@@ -18,7 +17,7 @@ def get_chat_name(chat_id: str) -> str:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
|
||||
# 如果有群组信息,显示群组名称
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
@@ -35,117 +34,106 @@ def calculate_time_distribution(expressions) -> Dict[str, int]:
|
||||
"""Calculate distribution of last active time in days"""
|
||||
now = time.time()
|
||||
distribution = {
|
||||
'0-1天': 0,
|
||||
'1-3天': 0,
|
||||
'3-7天': 0,
|
||||
'7-14天': 0,
|
||||
'14-30天': 0,
|
||||
'30-60天': 0,
|
||||
'60-90天': 0,
|
||||
'90+天': 0
|
||||
"0-1天": 0,
|
||||
"1-3天": 0,
|
||||
"3-7天": 0,
|
||||
"7-14天": 0,
|
||||
"14-30天": 0,
|
||||
"30-60天": 0,
|
||||
"60-90天": 0,
|
||||
"90+天": 0,
|
||||
}
|
||||
for expr in expressions:
|
||||
diff_days = (now - expr.last_active_time) / (24*3600)
|
||||
diff_days = (now - expr.last_active_time) / (24 * 3600)
|
||||
if diff_days < 1:
|
||||
distribution['0-1天'] += 1
|
||||
distribution["0-1天"] += 1
|
||||
elif diff_days < 3:
|
||||
distribution['1-3天'] += 1
|
||||
distribution["1-3天"] += 1
|
||||
elif diff_days < 7:
|
||||
distribution['3-7天'] += 1
|
||||
distribution["3-7天"] += 1
|
||||
elif diff_days < 14:
|
||||
distribution['7-14天'] += 1
|
||||
distribution["7-14天"] += 1
|
||||
elif diff_days < 30:
|
||||
distribution['14-30天'] += 1
|
||||
distribution["14-30天"] += 1
|
||||
elif diff_days < 60:
|
||||
distribution['30-60天'] += 1
|
||||
distribution["30-60天"] += 1
|
||||
elif diff_days < 90:
|
||||
distribution['60-90天'] += 1
|
||||
distribution["60-90天"] += 1
|
||||
else:
|
||||
distribution['90+天'] += 1
|
||||
distribution["90+天"] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def calculate_count_distribution(expressions) -> Dict[str, int]:
|
||||
"""Calculate distribution of count values"""
|
||||
distribution = {
|
||||
'0-1': 0,
|
||||
'1-2': 0,
|
||||
'2-3': 0,
|
||||
'3-4': 0,
|
||||
'4-5': 0,
|
||||
'5-10': 0,
|
||||
'10+': 0
|
||||
}
|
||||
distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
|
||||
for expr in expressions:
|
||||
cnt = expr.count
|
||||
if cnt < 1:
|
||||
distribution['0-1'] += 1
|
||||
distribution["0-1"] += 1
|
||||
elif cnt < 2:
|
||||
distribution['1-2'] += 1
|
||||
distribution["1-2"] += 1
|
||||
elif cnt < 3:
|
||||
distribution['2-3'] += 1
|
||||
distribution["2-3"] += 1
|
||||
elif cnt < 4:
|
||||
distribution['3-4'] += 1
|
||||
distribution["3-4"] += 1
|
||||
elif cnt < 5:
|
||||
distribution['4-5'] += 1
|
||||
distribution["4-5"] += 1
|
||||
elif cnt < 10:
|
||||
distribution['5-10'] += 1
|
||||
distribution["5-10"] += 1
|
||||
else:
|
||||
distribution['10+'] += 1
|
||||
distribution["10+"] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
|
||||
"""Get top N most used expressions for a specific chat_id"""
|
||||
return (Expression.select()
|
||||
.where(Expression.chat_id == chat_id)
|
||||
.order_by(Expression.count.desc())
|
||||
.limit(top_n))
|
||||
return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)
|
||||
|
||||
|
||||
def show_overall_statistics(expressions, total: int) -> None:
|
||||
"""Show overall statistics"""
|
||||
time_dist = calculate_time_distribution(expressions)
|
||||
count_dist = calculate_count_distribution(expressions)
|
||||
|
||||
|
||||
print("\n=== 总体统计 ===")
|
||||
print(f"总表达式数量: {total}")
|
||||
|
||||
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
print(f"{period}: {count} ({count/total*100:.2f}%)")
|
||||
|
||||
print(f"{period}: {count} ({count / total * 100:.2f}%)")
|
||||
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
print(f"{range_}: {count} ({count/total*100:.2f}%)")
|
||||
print(f"{range_}: {count} ({count / total * 100:.2f}%)")
|
||||
|
||||
|
||||
def show_chat_statistics(chat_id: str, chat_name: str) -> None:
|
||||
"""Show statistics for a specific chat"""
|
||||
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
|
||||
chat_total = len(chat_exprs)
|
||||
|
||||
|
||||
print(f"\n=== {chat_name} ===")
|
||||
print(f"表达式数量: {chat_total}")
|
||||
|
||||
|
||||
if chat_total == 0:
|
||||
print("该聊天没有表达式数据")
|
||||
return
|
||||
|
||||
|
||||
# Time distribution for this chat
|
||||
time_dist = calculate_time_distribution(chat_exprs)
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
if count > 0:
|
||||
print(f"{period}: {count} ({count/chat_total*100:.2f}%)")
|
||||
|
||||
print(f"{period}: {count} ({count / chat_total * 100:.2f}%)")
|
||||
|
||||
# Count distribution for this chat
|
||||
count_dist = calculate_count_distribution(chat_exprs)
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
if count > 0:
|
||||
print(f"{range_}: {count} ({count/chat_total*100:.2f}%)")
|
||||
|
||||
print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)")
|
||||
|
||||
# Top expressions
|
||||
print("\nTop 10使用最多的表达式:")
|
||||
top_exprs = get_top_expressions_by_chat(chat_id, 10)
|
||||
@@ -163,32 +151,32 @@ def interactive_menu() -> None:
|
||||
if not expressions:
|
||||
print("数据库中没有找到表达式")
|
||||
return
|
||||
|
||||
|
||||
total = len(expressions)
|
||||
|
||||
|
||||
# Get unique chat_ids and their names
|
||||
chat_ids = list(set(expr.chat_id for expr in expressions))
|
||||
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
|
||||
chat_info.sort(key=lambda x: x[1]) # Sort by chat name
|
||||
|
||||
|
||||
while True:
|
||||
print("\n" + "="*50)
|
||||
print("\n" + "=" * 50)
|
||||
print("表达式统计分析")
|
||||
print("="*50)
|
||||
print("=" * 50)
|
||||
print("0. 显示总体统计")
|
||||
|
||||
|
||||
for i, (chat_id, chat_name) in enumerate(chat_info, 1):
|
||||
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
|
||||
print(f"{i}. {chat_name} ({chat_count}个表达式)")
|
||||
|
||||
|
||||
print("q. 退出")
|
||||
|
||||
|
||||
choice = input("\n请选择要查看的统计 (输入序号): ").strip()
|
||||
|
||||
if choice.lower() == 'q':
|
||||
|
||||
if choice.lower() == "q":
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
choice_num = int(choice)
|
||||
if choice_num == 0:
|
||||
@@ -200,9 +188,9 @@ def interactive_menu() -> None:
|
||||
print("无效的选择,请重新输入")
|
||||
except ValueError:
|
||||
print("请输入有效的数字")
|
||||
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
interactive_menu()
|
||||
|
||||
@@ -23,6 +23,7 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
logger = get_logger("OpenIE导入")
|
||||
|
||||
|
||||
def ensure_openie_dir():
|
||||
"""确保OpenIE数据目录存在"""
|
||||
if not os.path.exists(OPENIE_DIR):
|
||||
@@ -253,7 +254,7 @@ def main():
|
||||
# 没有运行的事件循环,创建新的
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步主函数
|
||||
loop.run_until_complete(main_async())
|
||||
|
||||
@@ -12,6 +12,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from rich.progress import Progress # 替换为 rich 进度条
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# from src.chat.knowledge.lpmmconfig import global_config
|
||||
from src.chat.knowledge.ie_process import info_extract_from_str
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
@@ -36,6 +37,7 @@ TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
|
||||
def ensure_dirs():
|
||||
"""确保临时目录和输出目录存在"""
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
@@ -48,6 +50,7 @@ def ensure_dirs():
|
||||
os.makedirs(RAW_DATA_PATH)
|
||||
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
|
||||
|
||||
|
||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||
file_lock = Lock()
|
||||
open_ie_doc_lock = Lock()
|
||||
@@ -56,13 +59,11 @@ open_ie_doc_lock = Lock()
|
||||
shutdown_event = Event()
|
||||
|
||||
lpmm_entity_extract_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract,
|
||||
request_type="lpmm.entity_extract"
|
||||
)
|
||||
lpmm_rdf_build_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_rdf_build,
|
||||
request_type="lpmm.rdf_build"
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
|
||||
)
|
||||
lpmm_rdf_build_llm = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
||||
|
||||
|
||||
def process_single_text(pg_hash, raw_data):
|
||||
"""处理单个文本的函数,用于线程池"""
|
||||
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
|
||||
|
||||
@@ -3,12 +3,11 @@ import sys
|
||||
import os
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
from src.common.database.database_model import Messages, ChatStreams #noqa
|
||||
|
||||
|
||||
from src.common.database.database_model import Messages, ChatStreams # noqa
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
@@ -17,7 +16,7 @@ def get_chat_name(chat_id: str) -> str:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
elif chat_stream.user_nickname:
|
||||
@@ -39,66 +38,62 @@ def format_timestamp(timestamp: float) -> str:
|
||||
def calculate_interest_value_distribution(messages) -> Dict[str, int]:
|
||||
"""Calculate distribution of interest_value"""
|
||||
distribution = {
|
||||
'0.000-0.010': 0,
|
||||
'0.010-0.050': 0,
|
||||
'0.050-0.100': 0,
|
||||
'0.100-0.500': 0,
|
||||
'0.500-1.000': 0,
|
||||
'1.000-2.000': 0,
|
||||
'2.000-5.000': 0,
|
||||
'5.000-10.000': 0,
|
||||
'10.000+': 0
|
||||
"0.000-0.010": 0,
|
||||
"0.010-0.050": 0,
|
||||
"0.050-0.100": 0,
|
||||
"0.100-0.500": 0,
|
||||
"0.500-1.000": 0,
|
||||
"1.000-2.000": 0,
|
||||
"2.000-5.000": 0,
|
||||
"5.000-10.000": 0,
|
||||
"10.000+": 0,
|
||||
}
|
||||
|
||||
|
||||
for msg in messages:
|
||||
if msg.interest_value is None or msg.interest_value == 0.0:
|
||||
continue
|
||||
|
||||
|
||||
value = float(msg.interest_value)
|
||||
if value < 0.010:
|
||||
distribution['0.000-0.010'] += 1
|
||||
distribution["0.000-0.010"] += 1
|
||||
elif value < 0.050:
|
||||
distribution['0.010-0.050'] += 1
|
||||
distribution["0.010-0.050"] += 1
|
||||
elif value < 0.100:
|
||||
distribution['0.050-0.100'] += 1
|
||||
distribution["0.050-0.100"] += 1
|
||||
elif value < 0.500:
|
||||
distribution['0.100-0.500'] += 1
|
||||
distribution["0.100-0.500"] += 1
|
||||
elif value < 1.000:
|
||||
distribution['0.500-1.000'] += 1
|
||||
distribution["0.500-1.000"] += 1
|
||||
elif value < 2.000:
|
||||
distribution['1.000-2.000'] += 1
|
||||
distribution["1.000-2.000"] += 1
|
||||
elif value < 5.000:
|
||||
distribution['2.000-5.000'] += 1
|
||||
distribution["2.000-5.000"] += 1
|
||||
elif value < 10.000:
|
||||
distribution['5.000-10.000'] += 1
|
||||
distribution["5.000-10.000"] += 1
|
||||
else:
|
||||
distribution['10.000+'] += 1
|
||||
|
||||
distribution["10.000+"] += 1
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
def get_interest_value_stats(messages) -> Dict[str, float]:
|
||||
"""Calculate basic statistics for interest_value"""
|
||||
values = [float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0]
|
||||
|
||||
values = [
|
||||
float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0
|
||||
]
|
||||
|
||||
if not values:
|
||||
return {
|
||||
'count': 0,
|
||||
'min': 0,
|
||||
'max': 0,
|
||||
'avg': 0,
|
||||
'median': 0
|
||||
}
|
||||
|
||||
return {"count": 0, "min": 0, "max": 0, "avg": 0, "median": 0}
|
||||
|
||||
values.sort()
|
||||
count = len(values)
|
||||
|
||||
|
||||
return {
|
||||
'count': count,
|
||||
'min': min(values),
|
||||
'max': max(values),
|
||||
'avg': sum(values) / count,
|
||||
'median': values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2
|
||||
"count": count,
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
"avg": sum(values) / count,
|
||||
"median": values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2,
|
||||
}
|
||||
|
||||
|
||||
@@ -109,20 +104,24 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
|
||||
chat_counts = {}
|
||||
for msg in Messages.select(Messages.chat_id).distinct():
|
||||
chat_id = msg.chat_id
|
||||
count = Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.interest_value.is_null(False)) &
|
||||
(Messages.interest_value != 0.0)
|
||||
).count()
|
||||
count = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.chat_id == chat_id)
|
||||
& (Messages.interest_value.is_null(False))
|
||||
& (Messages.interest_value != 0.0)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
if count > 0:
|
||||
chat_counts[chat_id] = count
|
||||
|
||||
|
||||
# 获取聊天名称
|
||||
result = []
|
||||
for chat_id, count in chat_counts.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
result.append((chat_id, chat_name, count))
|
||||
|
||||
|
||||
# 按消息数量排序
|
||||
result.sort(key=lambda x: x[2], reverse=True)
|
||||
return result
|
||||
@@ -135,30 +134,30 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Get time range input from user"""
|
||||
print("\n时间范围选择:")
|
||||
print("1. 最近1天")
|
||||
print("2. 最近3天")
|
||||
print("2. 最近3天")
|
||||
print("3. 最近7天")
|
||||
print("4. 最近30天")
|
||||
print("5. 自定义时间范围")
|
||||
print("6. 不限制时间")
|
||||
|
||||
|
||||
choice = input("请选择时间范围 (1-6): ").strip()
|
||||
|
||||
|
||||
now = time.time()
|
||||
|
||||
|
||||
if choice == "1":
|
||||
return now - 24*3600, now
|
||||
return now - 24 * 3600, now
|
||||
elif choice == "2":
|
||||
return now - 3*24*3600, now
|
||||
return now - 3 * 24 * 3600, now
|
||||
elif choice == "3":
|
||||
return now - 7*24*3600, now
|
||||
return now - 7 * 24 * 3600, now
|
||||
elif choice == "4":
|
||||
return now - 30*24*3600, now
|
||||
return now - 30 * 24 * 3600, now
|
||||
elif choice == "5":
|
||||
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
start_str = input().strip()
|
||||
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
end_str = input().strip()
|
||||
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
@@ -170,41 +169,40 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
return None, None
|
||||
|
||||
|
||||
def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None:
|
||||
def analyze_interest_values(
|
||||
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
|
||||
) -> None:
|
||||
"""Analyze interest values with optional filters"""
|
||||
|
||||
|
||||
# 构建查询条件
|
||||
query = Messages.select().where(
|
||||
(Messages.interest_value.is_null(False)) &
|
||||
(Messages.interest_value != 0.0)
|
||||
)
|
||||
|
||||
query = Messages.select().where((Messages.interest_value.is_null(False)) & (Messages.interest_value != 0.0))
|
||||
|
||||
if chat_id:
|
||||
query = query.where(Messages.chat_id == chat_id)
|
||||
|
||||
|
||||
if start_time:
|
||||
query = query.where(Messages.time >= start_time)
|
||||
|
||||
|
||||
if end_time:
|
||||
query = query.where(Messages.time <= end_time)
|
||||
|
||||
|
||||
messages = list(query)
|
||||
|
||||
|
||||
if not messages:
|
||||
print("没有找到符合条件的消息")
|
||||
return
|
||||
|
||||
|
||||
# 计算统计信息
|
||||
distribution = calculate_interest_value_distribution(messages)
|
||||
stats = get_interest_value_stats(messages)
|
||||
|
||||
|
||||
# 显示结果
|
||||
print("\n=== Interest Value 分析结果 ===")
|
||||
if chat_id:
|
||||
print(f"聊天: {get_chat_name(chat_id)}")
|
||||
else:
|
||||
print("聊天: 全部聊天")
|
||||
|
||||
|
||||
if start_time and end_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}")
|
||||
elif start_time:
|
||||
@@ -213,16 +211,16 @@ def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[
|
||||
print(f"时间范围: {format_timestamp(end_time)} 之前")
|
||||
else:
|
||||
print("时间范围: 不限制")
|
||||
|
||||
|
||||
print("\n基本统计:")
|
||||
print(f"有效消息数量: {stats['count']} (排除null和0值)")
|
||||
print(f"最小值: {stats['min']:.3f}")
|
||||
print(f"最大值: {stats['max']:.3f}")
|
||||
print(f"平均值: {stats['avg']:.3f}")
|
||||
print(f"中位数: {stats['median']:.3f}")
|
||||
|
||||
|
||||
print("\nInterest Value 分布:")
|
||||
total = stats['count']
|
||||
total = stats["count"]
|
||||
for range_name, count in distribution.items():
|
||||
if count > 0:
|
||||
percentage = count / total * 100
|
||||
@@ -231,34 +229,34 @@ def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[
|
||||
|
||||
def interactive_menu() -> None:
|
||||
"""Interactive menu for interest value analysis"""
|
||||
|
||||
|
||||
while True:
|
||||
print("\n" + "="*50)
|
||||
print("\n" + "=" * 50)
|
||||
print("Interest Value 分析工具")
|
||||
print("="*50)
|
||||
print("=" * 50)
|
||||
print("1. 分析全部聊天")
|
||||
print("2. 选择特定聊天分析")
|
||||
print("q. 退出")
|
||||
|
||||
|
||||
choice = input("\n请选择分析模式 (1-2, q): ").strip()
|
||||
|
||||
if choice.lower() == 'q':
|
||||
|
||||
if choice.lower() == "q":
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
|
||||
chat_id = None
|
||||
|
||||
|
||||
if choice == "2":
|
||||
# 显示可用的聊天列表
|
||||
chats = get_available_chats()
|
||||
if not chats:
|
||||
print("没有找到有interest_value数据的聊天")
|
||||
continue
|
||||
|
||||
|
||||
print(f"\n可用的聊天 (共{len(chats)}个):")
|
||||
for i, (_cid, name, count) in enumerate(chats, 1):
|
||||
print(f"{i}. {name} ({count}条有效消息)")
|
||||
|
||||
|
||||
try:
|
||||
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
|
||||
if 1 <= chat_choice <= len(chats):
|
||||
@@ -269,19 +267,19 @@ def interactive_menu() -> None:
|
||||
except ValueError:
|
||||
print("请输入有效数字")
|
||||
continue
|
||||
|
||||
|
||||
elif choice != "1":
|
||||
print("无效选择")
|
||||
continue
|
||||
|
||||
|
||||
# 获取时间范围
|
||||
start_time, end_time = get_time_range_input()
|
||||
|
||||
|
||||
# 执行分析
|
||||
analyze_interest_values(chat_id, start_time, end_time)
|
||||
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
interactive_menu()
|
||||
|
||||
@@ -828,7 +828,7 @@ class LogViewer:
|
||||
parts, tags = self.formatter.format_log_entry(log_entry)
|
||||
line_text = " ".join(parts)
|
||||
log_lines.append(line_text)
|
||||
|
||||
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_lines))
|
||||
messagebox.showinfo("导出成功", f"日志已导出到: {filename}")
|
||||
@@ -1188,15 +1188,16 @@ class LogViewer:
|
||||
line_count += 1
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
|
||||
# 如果发现了新模块,在主线程中更新模块集合
|
||||
if new_modules:
|
||||
|
||||
def update_modules():
|
||||
self.modules.update(new_modules)
|
||||
self.update_module_list()
|
||||
|
||||
|
||||
self.root.after(0, update_modules)
|
||||
|
||||
|
||||
return new_entries
|
||||
|
||||
def append_new_logs(self, new_entries):
|
||||
@@ -1424,4 +1425,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
from pathlib import Path
|
||||
import sys # 新增系统模块导入
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -10,6 +11,7 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||
|
||||
|
||||
def _process_text_file(file_path):
|
||||
"""处理单个文本文件,返回段落列表"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
@@ -44,6 +46,7 @@ def _process_multi_files() -> list:
|
||||
all_paragraphs.extend(paragraphs)
|
||||
return all_paragraphs
|
||||
|
||||
|
||||
def load_raw_data() -> tuple[list[str], list[str]]:
|
||||
"""加载原始数据文件
|
||||
|
||||
@@ -72,4 +75,4 @@ def load_raw_data() -> tuple[list[str], list[str]]:
|
||||
raw_data.append(item)
|
||||
logger.info(f"共读取到{len(raw_data)}条数据")
|
||||
|
||||
return sha256_list, raw_data
|
||||
return sha256_list, raw_data
|
||||
|
||||
@@ -4,21 +4,22 @@ import os
|
||||
import re
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
from src.common.database.database_model import Messages, ChatStreams #noqa
|
||||
from src.common.database.database_model import Messages, ChatStreams # noqa
|
||||
|
||||
|
||||
def contains_emoji_or_image_tags(text: str) -> bool:
|
||||
"""Check if text contains [表情包xxxxx] or [图片xxxxx] tags"""
|
||||
if not text:
|
||||
return False
|
||||
|
||||
|
||||
# 检查是否包含 [表情包] 或 [图片] 标记
|
||||
emoji_pattern = r'\[表情包[^\]]*\]'
|
||||
image_pattern = r'\[图片[^\]]*\]'
|
||||
|
||||
emoji_pattern = r"\[表情包[^\]]*\]"
|
||||
image_pattern = r"\[图片[^\]]*\]"
|
||||
|
||||
return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text))
|
||||
|
||||
|
||||
@@ -26,14 +27,14 @@ def clean_reply_text(text: str) -> str:
|
||||
"""Remove reply references like [回复 xxxx...] from text"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
|
||||
# 匹配 [回复 xxxx...] 格式的内容
|
||||
# 使用非贪婪匹配,匹配到第一个 ] 就停止
|
||||
cleaned_text = re.sub(r'\[回复[^\]]*\]', '', text)
|
||||
|
||||
cleaned_text = re.sub(r"\[回复[^\]]*\]", "", text)
|
||||
|
||||
# 去除多余的空白字符
|
||||
cleaned_text = cleaned_text.strip()
|
||||
|
||||
|
||||
return cleaned_text
|
||||
|
||||
|
||||
@@ -43,7 +44,7 @@ def get_chat_name(chat_id: str) -> str:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
elif chat_stream.user_nickname:
|
||||
@@ -65,63 +66,63 @@ def format_timestamp(timestamp: float) -> str:
|
||||
def calculate_text_length_distribution(messages) -> Dict[str, int]:
|
||||
"""Calculate distribution of processed_plain_text length"""
|
||||
distribution = {
|
||||
'0': 0, # 空文本
|
||||
'1-5': 0, # 极短文本
|
||||
'6-10': 0, # 很短文本
|
||||
'11-20': 0, # 短文本
|
||||
'21-30': 0, # 较短文本
|
||||
'31-50': 0, # 中短文本
|
||||
'51-70': 0, # 中等文本
|
||||
'71-100': 0, # 较长文本
|
||||
'101-150': 0, # 长文本
|
||||
'151-200': 0, # 很长文本
|
||||
'201-300': 0, # 超长文本
|
||||
'301-500': 0, # 极长文本
|
||||
'501-1000': 0, # 巨长文本
|
||||
'1000+': 0 # 超巨长文本
|
||||
"0": 0, # 空文本
|
||||
"1-5": 0, # 极短文本
|
||||
"6-10": 0, # 很短文本
|
||||
"11-20": 0, # 短文本
|
||||
"21-30": 0, # 较短文本
|
||||
"31-50": 0, # 中短文本
|
||||
"51-70": 0, # 中等文本
|
||||
"71-100": 0, # 较长文本
|
||||
"101-150": 0, # 长文本
|
||||
"151-200": 0, # 很长文本
|
||||
"201-300": 0, # 超长文本
|
||||
"301-500": 0, # 极长文本
|
||||
"501-1000": 0, # 巨长文本
|
||||
"1000+": 0, # 超巨长文本
|
||||
}
|
||||
|
||||
|
||||
for msg in messages:
|
||||
if msg.processed_plain_text is None:
|
||||
continue
|
||||
|
||||
|
||||
# 排除包含表情包或图片标记的消息
|
||||
if contains_emoji_or_image_tags(msg.processed_plain_text):
|
||||
continue
|
||||
|
||||
|
||||
# 清理文本中的回复引用
|
||||
cleaned_text = clean_reply_text(msg.processed_plain_text)
|
||||
length = len(cleaned_text)
|
||||
|
||||
|
||||
if length == 0:
|
||||
distribution['0'] += 1
|
||||
distribution["0"] += 1
|
||||
elif length <= 5:
|
||||
distribution['1-5'] += 1
|
||||
distribution["1-5"] += 1
|
||||
elif length <= 10:
|
||||
distribution['6-10'] += 1
|
||||
distribution["6-10"] += 1
|
||||
elif length <= 20:
|
||||
distribution['11-20'] += 1
|
||||
distribution["11-20"] += 1
|
||||
elif length <= 30:
|
||||
distribution['21-30'] += 1
|
||||
distribution["21-30"] += 1
|
||||
elif length <= 50:
|
||||
distribution['31-50'] += 1
|
||||
distribution["31-50"] += 1
|
||||
elif length <= 70:
|
||||
distribution['51-70'] += 1
|
||||
distribution["51-70"] += 1
|
||||
elif length <= 100:
|
||||
distribution['71-100'] += 1
|
||||
distribution["71-100"] += 1
|
||||
elif length <= 150:
|
||||
distribution['101-150'] += 1
|
||||
distribution["101-150"] += 1
|
||||
elif length <= 200:
|
||||
distribution['151-200'] += 1
|
||||
distribution["151-200"] += 1
|
||||
elif length <= 300:
|
||||
distribution['201-300'] += 1
|
||||
distribution["201-300"] += 1
|
||||
elif length <= 500:
|
||||
distribution['301-500'] += 1
|
||||
distribution["301-500"] += 1
|
||||
elif length <= 1000:
|
||||
distribution['501-1000'] += 1
|
||||
distribution["501-1000"] += 1
|
||||
else:
|
||||
distribution['1000+'] += 1
|
||||
|
||||
distribution["1000+"] += 1
|
||||
|
||||
return distribution
|
||||
|
||||
|
||||
@@ -130,7 +131,7 @@ def get_text_length_stats(messages) -> Dict[str, float]:
|
||||
lengths = []
|
||||
null_count = 0
|
||||
excluded_count = 0 # 被排除的消息数量
|
||||
|
||||
|
||||
for msg in messages:
|
||||
if msg.processed_plain_text is None:
|
||||
null_count += 1
|
||||
@@ -141,29 +142,29 @@ def get_text_length_stats(messages) -> Dict[str, float]:
|
||||
# 清理文本中的回复引用
|
||||
cleaned_text = clean_reply_text(msg.processed_plain_text)
|
||||
lengths.append(len(cleaned_text))
|
||||
|
||||
|
||||
if not lengths:
|
||||
return {
|
||||
'count': 0,
|
||||
'null_count': null_count,
|
||||
'excluded_count': excluded_count,
|
||||
'min': 0,
|
||||
'max': 0,
|
||||
'avg': 0,
|
||||
'median': 0
|
||||
"count": 0,
|
||||
"null_count": null_count,
|
||||
"excluded_count": excluded_count,
|
||||
"min": 0,
|
||||
"max": 0,
|
||||
"avg": 0,
|
||||
"median": 0,
|
||||
}
|
||||
|
||||
|
||||
lengths.sort()
|
||||
count = len(lengths)
|
||||
|
||||
|
||||
return {
|
||||
'count': count,
|
||||
'null_count': null_count,
|
||||
'excluded_count': excluded_count,
|
||||
'min': min(lengths),
|
||||
'max': max(lengths),
|
||||
'avg': sum(lengths) / count,
|
||||
'median': lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2
|
||||
"count": count,
|
||||
"null_count": null_count,
|
||||
"excluded_count": excluded_count,
|
||||
"min": min(lengths),
|
||||
"max": max(lengths),
|
||||
"avg": sum(lengths) / count,
|
||||
"median": lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2,
|
||||
}
|
||||
|
||||
|
||||
@@ -174,21 +175,25 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
|
||||
chat_counts = {}
|
||||
for msg in Messages.select(Messages.chat_id).distinct():
|
||||
chat_id = msg.chat_id
|
||||
count = Messages.select().where(
|
||||
(Messages.chat_id == chat_id) &
|
||||
(Messages.is_emoji != 1) &
|
||||
(Messages.is_picid != 1) &
|
||||
(Messages.is_command != 1)
|
||||
).count()
|
||||
count = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.chat_id == chat_id)
|
||||
& (Messages.is_emoji != 1)
|
||||
& (Messages.is_picid != 1)
|
||||
& (Messages.is_command != 1)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
if count > 0:
|
||||
chat_counts[chat_id] = count
|
||||
|
||||
|
||||
# 获取聊天名称
|
||||
result = []
|
||||
for chat_id, count in chat_counts.items():
|
||||
chat_name = get_chat_name(chat_id)
|
||||
result.append((chat_id, chat_name, count))
|
||||
|
||||
|
||||
# 按消息数量排序
|
||||
result.sort(key=lambda x: x[2], reverse=True)
|
||||
return result
|
||||
@@ -201,30 +206,30 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Get time range input from user"""
|
||||
print("\n时间范围选择:")
|
||||
print("1. 最近1天")
|
||||
print("2. 最近3天")
|
||||
print("2. 最近3天")
|
||||
print("3. 最近7天")
|
||||
print("4. 最近30天")
|
||||
print("5. 自定义时间范围")
|
||||
print("6. 不限制时间")
|
||||
|
||||
|
||||
choice = input("请选择时间范围 (1-6): ").strip()
|
||||
|
||||
|
||||
now = time.time()
|
||||
|
||||
|
||||
if choice == "1":
|
||||
return now - 24*3600, now
|
||||
return now - 24 * 3600, now
|
||||
elif choice == "2":
|
||||
return now - 3*24*3600, now
|
||||
return now - 3 * 24 * 3600, now
|
||||
elif choice == "3":
|
||||
return now - 7*24*3600, now
|
||||
return now - 7 * 24 * 3600, now
|
||||
elif choice == "4":
|
||||
return now - 30*24*3600, now
|
||||
return now - 30 * 24 * 3600, now
|
||||
elif choice == "5":
|
||||
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
start_str = input().strip()
|
||||
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
||||
end_str = input().strip()
|
||||
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
||||
@@ -239,13 +244,13 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]:
|
||||
"""Get top N longest messages"""
|
||||
message_lengths = []
|
||||
|
||||
|
||||
for msg in messages:
|
||||
if msg.processed_plain_text is not None:
|
||||
# 排除包含表情包或图片标记的消息
|
||||
if contains_emoji_or_image_tags(msg.processed_plain_text):
|
||||
continue
|
||||
|
||||
|
||||
# 清理文本中的回复引用
|
||||
cleaned_text = clean_reply_text(msg.processed_plain_text)
|
||||
length = len(cleaned_text)
|
||||
@@ -254,42 +259,40 @@ def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int,
|
||||
# 截取前100个字符作为预览
|
||||
preview = cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text
|
||||
message_lengths.append((chat_name, length, time_str, preview))
|
||||
|
||||
|
||||
# 按长度排序,取前N个
|
||||
message_lengths.sort(key=lambda x: x[1], reverse=True)
|
||||
return message_lengths[:top_n]
|
||||
|
||||
|
||||
def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None:
|
||||
def analyze_text_lengths(
|
||||
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
|
||||
) -> None:
|
||||
"""Analyze processed_plain_text lengths with optional filters"""
|
||||
|
||||
|
||||
# 构建查询条件,排除特殊类型的消息
|
||||
query = Messages.select().where(
|
||||
(Messages.is_emoji != 1) &
|
||||
(Messages.is_picid != 1) &
|
||||
(Messages.is_command != 1)
|
||||
)
|
||||
|
||||
query = Messages.select().where((Messages.is_emoji != 1) & (Messages.is_picid != 1) & (Messages.is_command != 1))
|
||||
|
||||
if chat_id:
|
||||
query = query.where(Messages.chat_id == chat_id)
|
||||
|
||||
|
||||
if start_time:
|
||||
query = query.where(Messages.time >= start_time)
|
||||
|
||||
|
||||
if end_time:
|
||||
query = query.where(Messages.time <= end_time)
|
||||
|
||||
|
||||
messages = list(query)
|
||||
|
||||
|
||||
if not messages:
|
||||
print("没有找到符合条件的消息")
|
||||
return
|
||||
|
||||
|
||||
# 计算统计信息
|
||||
distribution = calculate_text_length_distribution(messages)
|
||||
stats = get_text_length_stats(messages)
|
||||
top_longest = get_top_longest_messages(messages, 10)
|
||||
|
||||
|
||||
# 显示结果
|
||||
print("\n=== Processed Plain Text 长度分析结果 ===")
|
||||
print("(已排除表情、图片ID、命令类型消息,已排除[表情包]和[图片]标记消息,已清理回复引用)")
|
||||
@@ -297,7 +300,7 @@ def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[flo
|
||||
print(f"聊天: {get_chat_name(chat_id)}")
|
||||
else:
|
||||
print("聊天: 全部聊天")
|
||||
|
||||
|
||||
if start_time and end_time:
|
||||
print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}")
|
||||
elif start_time:
|
||||
@@ -306,26 +309,26 @@ def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[flo
|
||||
print(f"时间范围: {format_timestamp(end_time)} 之前")
|
||||
else:
|
||||
print("时间范围: 不限制")
|
||||
|
||||
|
||||
print("\n基本统计:")
|
||||
print(f"总消息数量: {len(messages)}")
|
||||
print(f"有文本消息数量: {stats['count']}")
|
||||
print(f"空文本消息数量: {stats['null_count']}")
|
||||
print(f"被排除的消息数量: {stats['excluded_count']}")
|
||||
if stats['count'] > 0:
|
||||
if stats["count"] > 0:
|
||||
print(f"最短长度: {stats['min']} 字符")
|
||||
print(f"最长长度: {stats['max']} 字符")
|
||||
print(f"平均长度: {stats['avg']:.2f} 字符")
|
||||
print(f"中位数长度: {stats['median']:.2f} 字符")
|
||||
|
||||
|
||||
print("\n文本长度分布:")
|
||||
total = stats['count']
|
||||
total = stats["count"]
|
||||
if total > 0:
|
||||
for range_name, count in distribution.items():
|
||||
if count > 0:
|
||||
percentage = count / total * 100
|
||||
print(f"{range_name} 字符: {count} ({percentage:.2f}%)")
|
||||
|
||||
|
||||
# 显示最长的消息
|
||||
if top_longest:
|
||||
print(f"\n最长的 {len(top_longest)} 条消息:")
|
||||
@@ -338,34 +341,34 @@ def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[flo
|
||||
|
||||
def interactive_menu() -> None:
|
||||
"""Interactive menu for text length analysis"""
|
||||
|
||||
|
||||
while True:
|
||||
print("\n" + "="*50)
|
||||
print("\n" + "=" * 50)
|
||||
print("Processed Plain Text 长度分析工具")
|
||||
print("="*50)
|
||||
print("=" * 50)
|
||||
print("1. 分析全部聊天")
|
||||
print("2. 选择特定聊天分析")
|
||||
print("q. 退出")
|
||||
|
||||
|
||||
choice = input("\n请选择分析模式 (1-2, q): ").strip()
|
||||
|
||||
if choice.lower() == 'q':
|
||||
|
||||
if choice.lower() == "q":
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
|
||||
chat_id = None
|
||||
|
||||
|
||||
if choice == "2":
|
||||
# 显示可用的聊天列表
|
||||
chats = get_available_chats()
|
||||
if not chats:
|
||||
print("没有找到聊天数据")
|
||||
continue
|
||||
|
||||
|
||||
print(f"\n可用的聊天 (共{len(chats)}个):")
|
||||
for i, (_cid, name, count) in enumerate(chats, 1):
|
||||
print(f"{i}. {name} ({count}条消息)")
|
||||
|
||||
|
||||
try:
|
||||
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
|
||||
if 1 <= chat_choice <= len(chats):
|
||||
@@ -376,19 +379,19 @@ def interactive_menu() -> None:
|
||||
except ValueError:
|
||||
print("请输入有效数字")
|
||||
continue
|
||||
|
||||
|
||||
elif choice != "1":
|
||||
print("无效选择")
|
||||
continue
|
||||
|
||||
|
||||
# 获取时间范围
|
||||
start_time, end_time = get_time_range_input()
|
||||
|
||||
|
||||
# 执行分析
|
||||
analyze_text_lengths(chat_id, start_time, end_time)
|
||||
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
interactive_menu()
|
||||
|
||||
@@ -708,7 +708,7 @@ class EmojiManager:
|
||||
if not emoji.is_deleted and emoji.hash == emoji_hash:
|
||||
return emoji
|
||||
return None # 如果循环结束还没找到,则返回 None
|
||||
|
||||
|
||||
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[List[str]]:
|
||||
"""根据哈希值获取已注册表情包的情感标签列表
|
||||
|
||||
@@ -731,7 +731,7 @@ class EmojiManager:
|
||||
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
if emoji_record and emoji_record.emotion:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
|
||||
return emoji_record.emotion.split(',')
|
||||
return emoji_record.emotion.split(",")
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
|
||||
|
||||
|
||||
@@ -77,10 +77,10 @@ class ExpressionSelector:
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否允许使用表达
|
||||
"""
|
||||
@@ -123,9 +123,7 @@ class ExpressionSelector:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def get_random_expressions(
|
||||
self, chat_id: str, total_num: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
@@ -200,7 +198,7 @@ class ExpressionSelector:
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
# sourcery skip: inline-variable, list-comprehension
|
||||
"""使用LLM选择适合的表达方式"""
|
||||
|
||||
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
@@ -208,7 +206,7 @@ class ExpressionSelector:
|
||||
|
||||
# 1. 获取20个随机表达方式(现在按权重抽取)
|
||||
style_exprs = self.get_random_expressions(chat_id, 10)
|
||||
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
return [], []
|
||||
@@ -248,7 +246,6 @@ class ExpressionSelector:
|
||||
|
||||
# 4. 调用LLM
|
||||
try:
|
||||
|
||||
# start_time = time.time()
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
|
||||
@@ -295,7 +292,6 @@ class ExpressionSelector:
|
||||
except Exception as e:
|
||||
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
@@ -119,4 +119,3 @@ def get_global_focus_value() -> Optional[float]:
|
||||
return get_time_based_focus_value(config_item[1:])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -124,5 +124,3 @@ def get_global_frequency() -> Optional[float]:
|
||||
return get_time_based_frequency(config_item[1:])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -34,4 +34,4 @@ def parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -155,21 +155,21 @@ class HeartFChatting:
|
||||
timer_strings.append(f"{name}: {formatted_time}")
|
||||
|
||||
# 获取动作类型,兼容新旧格式
|
||||
action_type = "未知动作"
|
||||
_action_type = "未知动作"
|
||||
if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail:
|
||||
loop_plan_info = self._current_cycle_detail.loop_plan_info
|
||||
if isinstance(loop_plan_info, dict):
|
||||
action_result = loop_plan_info.get("action_result", {})
|
||||
if isinstance(action_result, dict):
|
||||
# 旧格式:action_result是字典
|
||||
action_type = action_result.get("action_type", "未知动作")
|
||||
_action_type = action_result.get("action_type", "未知动作")
|
||||
elif isinstance(action_result, list) and action_result:
|
||||
# 新格式:action_result是actions列表
|
||||
# TODO: 把这里写明白
|
||||
action_type = action_result[0].action_type or "未知动作"
|
||||
_action_type = action_result[0].action_type or "未知动作"
|
||||
elif isinstance(loop_plan_info, list) and loop_plan_info:
|
||||
# 直接是actions列表的情况
|
||||
action_type = loop_plan_info[0].get("action_type", "未知动作")
|
||||
_action_type = loop_plan_info[0].get("action_type", "未知动作")
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考,"
|
||||
@@ -258,7 +258,11 @@ class HeartFChatting:
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _observe(self, interest_value: float = 0.0, recent_messages_list: List["DatabaseMessages"] = []) -> bool:
|
||||
async def _observe(
|
||||
self, interest_value: float = 0.0, recent_messages_list: Optional[List["DatabaseMessages"]] = None
|
||||
) -> bool:
|
||||
if recent_messages_list is None:
|
||||
recent_messages_list = []
|
||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
# 使用sigmoid函数将interest_value转换为概率
|
||||
|
||||
@@ -3,14 +3,16 @@ from typing import Any, Optional, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
||||
|
||||
logger = get_logger("heartflow")
|
||||
|
||||
|
||||
class Heartflow:
|
||||
"""主心流协调器,负责初始化并协调聊天"""
|
||||
|
||||
def __init__(self):
|
||||
self.heartflow_chat_list: Dict[Any, HeartFChatting] = {}
|
||||
|
||||
|
||||
async def get_or_create_heartflow_chat(self, chat_id: Any) -> Optional[HeartFChatting]:
|
||||
"""获取或创建一个新的HeartFChatting实例"""
|
||||
try:
|
||||
@@ -18,7 +20,7 @@ class Heartflow:
|
||||
if chat := self.heartflow_chat_list.get(chat_id):
|
||||
return chat
|
||||
else:
|
||||
new_chat = HeartFChatting(chat_id = chat_id)
|
||||
new_chat = HeartFChatting(chat_id=chat_id)
|
||||
await new_chat.start()
|
||||
self.heartflow_chat_list[chat_id] = new_chat
|
||||
return new_chat
|
||||
@@ -27,4 +29,5 @@ class Heartflow:
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
heartflow = Heartflow()
|
||||
|
||||
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
@@ -34,14 +35,14 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
|
||||
"""
|
||||
if message.is_picid or message.is_emoji:
|
||||
return 0.0, []
|
||||
|
||||
is_mentioned,is_at,reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||
|
||||
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||
interested_rate = 0.0
|
||||
|
||||
with Timer("记忆激活"):
|
||||
interested_rate, keywords,keywords_lite = await hippocampus_manager.get_activate_from_text(
|
||||
interested_rate, keywords, keywords_lite = await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
max_depth= 4,
|
||||
max_depth=4,
|
||||
fast_retrieval=global_config.chat.interest_rate_mode == "fast",
|
||||
)
|
||||
message.key_words = keywords
|
||||
@@ -51,7 +52,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
||||
# 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
|
||||
|
||||
|
||||
if text_len == 0:
|
||||
base_interest = 0.01 # 空消息最低兴趣度
|
||||
elif text_len <= 5:
|
||||
@@ -75,16 +76,15 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
|
||||
else:
|
||||
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
|
||||
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
|
||||
|
||||
|
||||
# 确保在范围内
|
||||
base_interest = min(max(base_interest, 0.01), 0.3)
|
||||
|
||||
|
||||
message.interest_value = base_interest
|
||||
message.is_mentioned = is_mentioned
|
||||
message.is_at = is_at
|
||||
message.reply_probability_boost = reply_probability_boost
|
||||
|
||||
|
||||
return base_interest, keywords
|
||||
|
||||
|
||||
@@ -115,14 +115,13 @@ class HeartFCMessageReceiver:
|
||||
|
||||
# 2. 兴趣度计算与更新
|
||||
interested_rate, keywords = await _calculate_interest(message)
|
||||
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
|
||||
|
||||
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
|
||||
if global_config.mood.enable_mood:
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(heartflow_chat.stream_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
||||
|
||||
@@ -132,7 +131,7 @@ class HeartFCMessageReceiver:
|
||||
# 用这个pattern截取出id部分,picid是一个list,并替换成对应的图片描述
|
||||
picid_pattern = r"\[picid:([^\]]+)\]"
|
||||
picid_list = re.findall(picid_pattern, message.processed_plain_text)
|
||||
|
||||
|
||||
# 创建替换后的文本
|
||||
processed_text = message.processed_plain_text
|
||||
if picid_list:
|
||||
@@ -145,18 +144,20 @@ class HeartFCMessageReceiver:
|
||||
# 如果没有找到图片描述,则移除[picid:xxxx]标记
|
||||
processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]")
|
||||
|
||||
|
||||
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
||||
processed_plain_text = replace_user_references(
|
||||
processed_text,
|
||||
message.message_info.platform, # type: ignore
|
||||
replace_bot_name=True
|
||||
message.message_info.platform, # type: ignore
|
||||
replace_bot_name=True,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[{interested_rate:.2f}]") # type: ignore
|
||||
|
||||
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=userinfo.user_nickname) # type: ignore
|
||||
_ = Person.register_person(
|
||||
platform=message.message_info.platform,
|
||||
user_id=message.message_info.user_info.user_id,
|
||||
nickname=userinfo.user_nickname,
|
||||
) # type: ignore
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"消息处理失败: {e}")
|
||||
|
||||
@@ -124,6 +124,7 @@ async def send_typing():
|
||||
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
|
||||
|
||||
async def stop_typing():
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
@@ -135,4 +136,4 @@ async def stop_typing():
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
)
|
||||
|
||||
@@ -30,6 +30,7 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||
qa_manager = None
|
||||
inspire_manager = None
|
||||
|
||||
|
||||
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||
# 检查LPMM知识库是否启用
|
||||
if global_config.lpmm_knowledge.enable:
|
||||
|
||||
@@ -32,11 +32,11 @@ install(extra_lines=3)
|
||||
|
||||
# 多线程embedding配置常量
|
||||
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
|
||||
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
|
||||
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
||||
MIN_WORKERS = 1 # 最小线程数
|
||||
MAX_WORKERS = 20 # 最大线程数
|
||||
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
|
||||
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
||||
MIN_WORKERS = 1 # 最小线程数
|
||||
MAX_WORKERS = 20 # 最大线程数
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
||||
@@ -93,7 +93,13 @@ class EmbeddingStoreItem:
|
||||
|
||||
|
||||
class EmbeddingStore:
|
||||
def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||
def __init__(
|
||||
self,
|
||||
namespace: str,
|
||||
dir_path: str,
|
||||
max_workers: int = DEFAULT_MAX_WORKERS,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
):
|
||||
self.namespace = namespace
|
||||
self.dir = dir_path
|
||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||
@@ -103,12 +109,16 @@ class EmbeddingStore:
|
||||
# 多线程配置参数验证和设置
|
||||
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
|
||||
self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size))
|
||||
|
||||
|
||||
# 如果配置值被调整,记录日志
|
||||
if self.max_workers != max_workers:
|
||||
logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})")
|
||||
logger.warning(
|
||||
f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})"
|
||||
)
|
||||
if self.chunk_size != chunk_size:
|
||||
logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})")
|
||||
logger.warning(
|
||||
f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})"
|
||||
)
|
||||
|
||||
self.store = {}
|
||||
|
||||
@@ -120,23 +130,23 @@ class EmbeddingStore:
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
|
||||
try:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
|
||||
# 使用新的事件循环运行异步方法
|
||||
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
|
||||
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
return embedding
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
return []
|
||||
@@ -147,43 +157,45 @@ class EmbeddingStore:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
|
||||
def _get_embeddings_batch_threaded(
|
||||
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> List[Tuple[str, List[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
|
||||
|
||||
Args:
|
||||
strs: 要获取嵌入的字符串列表
|
||||
chunk_size: 每个线程处理的数据块大小
|
||||
max_workers: 最大线程数
|
||||
progress_callback: 进度回调函数,接收一个参数表示完成的数量
|
||||
|
||||
|
||||
Returns:
|
||||
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
|
||||
"""
|
||||
if not strs:
|
||||
return []
|
||||
|
||||
|
||||
# 分块
|
||||
chunks = []
|
||||
for i in range(0, len(strs), chunk_size):
|
||||
chunk = strs[i:i + chunk_size]
|
||||
chunk = strs[i : i + chunk_size]
|
||||
chunks.append((i, chunk)) # 保存起始索引以维持顺序
|
||||
|
||||
|
||||
# 结果存储,使用字典按索引存储以保证顺序
|
||||
results = {}
|
||||
|
||||
|
||||
def process_chunk(chunk_data):
|
||||
"""处理单个数据块的函数"""
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results = []
|
||||
|
||||
|
||||
# 为每个线程创建独立的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
|
||||
try:
|
||||
# 创建线程专用的LLM实例
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
# 在线程中创建独立的事件循环
|
||||
@@ -193,25 +205,25 @@ class EmbeddingStore:
|
||||
embedding = loop.run_until_complete(llm.get_embedding(s))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
|
||||
# 每完成一个嵌入立即更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
# 如果创建LLM实例失败,返回空结果
|
||||
@@ -220,14 +232,14 @@ class EmbeddingStore:
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
return chunk_results
|
||||
|
||||
|
||||
# 使用线程池处理
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务
|
||||
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
||||
|
||||
|
||||
# 收集结果(进度已在process_chunk中实时更新)
|
||||
for future in as_completed(future_to_chunk):
|
||||
try:
|
||||
@@ -241,7 +253,7 @@ class EmbeddingStore:
|
||||
start_idx, chunk_strs = chunk
|
||||
for i, s in enumerate(chunk_strs):
|
||||
results[start_idx + i] = (s, [])
|
||||
|
||||
|
||||
# 按原始顺序返回结果
|
||||
ordered_results = []
|
||||
for i in range(len(strs)):
|
||||
@@ -250,7 +262,7 @@ class EmbeddingStore:
|
||||
else:
|
||||
# 防止遗漏
|
||||
ordered_results.append((strs[i], []))
|
||||
|
||||
|
||||
return ordered_results
|
||||
|
||||
def get_test_file_path(self):
|
||||
@@ -259,14 +271,14 @@ class EmbeddingStore:
|
||||
def save_embedding_test_vectors(self):
|
||||
"""保存测试字符串的嵌入到本地(使用多线程优化)"""
|
||||
logger.info("开始保存测试字符串的嵌入向量...")
|
||||
|
||||
|
||||
# 使用多线程批量获取测试字符串的嵌入
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
EMBEDDING_TEST_STRINGS,
|
||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||
)
|
||||
|
||||
|
||||
# 构建测试向量字典
|
||||
test_vectors = {}
|
||||
for idx, (s, embedding) in enumerate(embedding_results):
|
||||
@@ -276,10 +288,10 @@ class EmbeddingStore:
|
||||
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||
# 使用原始单线程方法作为后备
|
||||
test_vectors[str(idx)] = self._get_embedding(s)
|
||||
|
||||
|
||||
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
logger.info("测试字符串嵌入向量保存完成")
|
||||
|
||||
def load_embedding_test_vectors(self):
|
||||
@@ -297,35 +309,35 @@ class EmbeddingStore:
|
||||
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
|
||||
self.save_embedding_test_vectors()
|
||||
return True
|
||||
|
||||
|
||||
# 检查本地向量完整性
|
||||
for idx in range(len(EMBEDDING_TEST_STRINGS)):
|
||||
if local_vectors.get(str(idx)) is None:
|
||||
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
|
||||
self.save_embedding_test_vectors()
|
||||
return True
|
||||
|
||||
|
||||
logger.info("开始检验嵌入模型一致性...")
|
||||
|
||||
|
||||
# 使用多线程批量获取当前模型的嵌入
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
EMBEDDING_TEST_STRINGS,
|
||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||
)
|
||||
|
||||
|
||||
# 检查一致性
|
||||
for idx, (s, new_emb) in enumerate(embedding_results):
|
||||
local_emb = local_vectors.get(str(idx))
|
||||
if not new_emb:
|
||||
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||
return False
|
||||
|
||||
|
||||
sim = cosine_similarity(local_emb, new_emb)
|
||||
if sim < EMBEDDING_SIM_THRESHOLD:
|
||||
logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}")
|
||||
return False
|
||||
|
||||
|
||||
logger.info("嵌入模型一致性校验通过。")
|
||||
return True
|
||||
|
||||
@@ -333,22 +345,22 @@ class EmbeddingStore:
|
||||
"""向库中存入字符串(使用多线程优化)"""
|
||||
if not strs:
|
||||
return
|
||||
|
||||
|
||||
total = len(strs)
|
||||
|
||||
|
||||
# 过滤已存在的字符串
|
||||
new_strs = []
|
||||
for s in strs:
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
if item_hash not in self.store:
|
||||
new_strs.append(s)
|
||||
|
||||
|
||||
if not new_strs:
|
||||
logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理")
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串")
|
||||
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
@@ -362,31 +374,39 @@ class EmbeddingStore:
|
||||
transient=False,
|
||||
) as progress:
|
||||
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
|
||||
|
||||
|
||||
# 首先更新已存在项的进度
|
||||
already_processed = total - len(new_strs)
|
||||
if already_processed > 0:
|
||||
progress.update(task, advance=already_processed)
|
||||
|
||||
|
||||
if new_strs:
|
||||
# 使用实例配置的参数,智能调整分块和线程数
|
||||
optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size))
|
||||
optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1))
|
||||
|
||||
optimal_chunk_size = max(
|
||||
MIN_CHUNK_SIZE,
|
||||
min(
|
||||
self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size
|
||||
),
|
||||
)
|
||||
optimal_max_workers = min(
|
||||
self.max_workers,
|
||||
max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1),
|
||||
)
|
||||
|
||||
logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}")
|
||||
|
||||
|
||||
# 定义进度更新回调函数
|
||||
def update_progress(count):
|
||||
progress.update(task, advance=count)
|
||||
|
||||
|
||||
# 批量获取嵌入,并实时更新进度
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
new_strs,
|
||||
chunk_size=optimal_chunk_size,
|
||||
new_strs,
|
||||
chunk_size=optimal_chunk_size,
|
||||
max_workers=optimal_max_workers,
|
||||
progress_callback=update_progress
|
||||
progress_callback=update_progress,
|
||||
)
|
||||
|
||||
|
||||
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
|
||||
for s, embedding in embedding_results:
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
@@ -519,7 +539,7 @@ class EmbeddingManager:
|
||||
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||
"""
|
||||
初始化EmbeddingManager
|
||||
|
||||
|
||||
Args:
|
||||
max_workers: 最大线程数
|
||||
chunk_size: 每个线程处理的数据块大小
|
||||
|
||||
@@ -426,9 +426,7 @@ class KGManager:
|
||||
# 获取最终结果
|
||||
# 从搜索结果中提取文段节点的结果
|
||||
passage_node_res = [
|
||||
(node_key, score)
|
||||
for node_key, score in ppr_res.items()
|
||||
if node_key.startswith("paragraph")
|
||||
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith("paragraph")
|
||||
]
|
||||
del ppr_res
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
|
||||
from .lpmmconfig import global_config
|
||||
from .embedding_store import EmbeddingManager
|
||||
from .llm_client import LLMClient
|
||||
from .utils.dyn_topk import dyn_select_top_k
|
||||
from .lpmmconfig import global_config # noqa
|
||||
from .embedding_store import EmbeddingManager # noqa
|
||||
from .llm_client import LLMClient # noqa
|
||||
from .utils.dyn_topk import dyn_select_top_k # noqa
|
||||
|
||||
|
||||
class MemoryActiveManager:
|
||||
|
||||
@@ -8,7 +8,7 @@ def dyn_select_top_k(
|
||||
# 检查输入列表是否为空
|
||||
if not score:
|
||||
return []
|
||||
|
||||
|
||||
# 按照分数排序(降序)
|
||||
sorted_score = sorted(score, key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class MessageStorage:
|
||||
if isinstance(keywords, list):
|
||||
return json.dumps(keywords, ensure_ascii=False)
|
||||
return "[]"
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_keywords(keywords_str: str) -> list:
|
||||
"""将JSON字符串反序列化为关键词列表"""
|
||||
@@ -85,7 +85,7 @@ class MessageStorage:
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
selected_expressions = ""
|
||||
|
||||
|
||||
chat_info_dict = chat_stream.to_dict()
|
||||
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
||||
|
||||
|
||||
@@ -124,4 +124,4 @@ class ActionManager:
|
||||
"""恢复到默认动作集"""
|
||||
actions_to_restore = list(self._using_actions.keys())
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
|
||||
@@ -103,25 +103,23 @@ class ActionModifier:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
|
||||
|
||||
|
||||
|
||||
# === 第三阶段:激活类型判定 ===
|
||||
# if chat_content is not None:
|
||||
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
|
||||
# 获取当前使用的动作集(经过第一阶段处理)
|
||||
# current_using_actions = self.action_manager.get_using_actions()
|
||||
# 获取当前使用的动作集(经过第一阶段处理)
|
||||
# current_using_actions = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取因激活类型判定而需要移除的动作
|
||||
# removals_s3 = await self._get_deactivated_actions_by_type(
|
||||
# current_using_actions,
|
||||
# chat_content,
|
||||
# )
|
||||
# 获取因激活类型判定而需要移除的动作
|
||||
# removals_s3 = await self._get_deactivated_actions_by_type(
|
||||
# current_using_actions,
|
||||
# chat_content,
|
||||
# )
|
||||
|
||||
# 应用第三阶段的移除
|
||||
# for action_name, reason in removals_s3:
|
||||
# self.action_manager.remove_action_from_using(action_name)
|
||||
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
|
||||
# 应用第三阶段的移除
|
||||
# for action_name, reason in removals_s3:
|
||||
# self.action_manager.remove_action_from_using(action_name)
|
||||
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
|
||||
|
||||
# === 统一日志记录 ===
|
||||
all_removals = removals_s1 + removals_s2
|
||||
@@ -131,9 +129,7 @@ class ActionModifier:
|
||||
|
||||
available_actions = list(self.action_manager.get_using_actions().keys())
|
||||
available_actions_text = "、".join(available_actions) if available_actions else "无"
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}"
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
|
||||
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||
|
||||
@@ -385,18 +385,18 @@ class StatisticOutputTask(AsyncTask):
|
||||
time_cost_key = f"time_costs_by_{category.split('_')[-1]}"
|
||||
avg_key = f"avg_time_costs_by_{category.split('_')[-1]}"
|
||||
std_key = f"std_time_costs_by_{category.split('_')[-1]}"
|
||||
|
||||
|
||||
for item_name in stats[period_key][category]:
|
||||
time_costs = stats[period_key][time_cost_key].get(item_name, [])
|
||||
if time_costs:
|
||||
# 计算平均耗时
|
||||
avg_time_cost = sum(time_costs) / len(time_costs)
|
||||
stats[period_key][avg_key][item_name] = round(avg_time_cost, 3)
|
||||
|
||||
|
||||
# 计算标准差
|
||||
if len(time_costs) > 1:
|
||||
variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs)
|
||||
std_time_cost = variance ** 0.5
|
||||
std_time_cost = variance**0.5
|
||||
stats[period_key][std_key][item_name] = round(std_time_cost, 3)
|
||||
else:
|
||||
stats[period_key][std_key][item_name] = 0.0
|
||||
@@ -506,8 +506,6 @@ class StatisticOutputTask(AsyncTask):
|
||||
break
|
||||
return stats
|
||||
|
||||
|
||||
|
||||
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
收集各时间段的统计数据
|
||||
@@ -639,7 +637,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = stats[COST_BY_MODEL][model_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
|
||||
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost))
|
||||
output.append(
|
||||
data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)
|
||||
)
|
||||
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
@@ -728,7 +728,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
|
||||
f"</tr>"
|
||||
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
|
||||
] if stat_data[REQ_CNT_BY_MODEL] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_MODEL]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 按请求类型分类统计
|
||||
type_rows = "\n".join(
|
||||
@@ -744,7 +746,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
|
||||
f"</tr>"
|
||||
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
|
||||
] if stat_data[REQ_CNT_BY_TYPE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_TYPE]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 按模块分类统计
|
||||
module_rows = "\n".join(
|
||||
@@ -760,7 +764,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
|
||||
f"</tr>"
|
||||
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
|
||||
] if stat_data[REQ_CNT_BY_MODULE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_MODULE]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
|
||||
# 聊天消息统计
|
||||
@@ -768,7 +774,9 @@ class StatisticOutputTask(AsyncTask):
|
||||
[
|
||||
f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>"
|
||||
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items())
|
||||
] if stat_data[MSG_CNT_BY_CHAT] else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
]
|
||||
if stat_data[MSG_CNT_BY_CHAT]
|
||||
else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 生成HTML
|
||||
return f"""
|
||||
|
||||
@@ -49,9 +49,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
|
||||
reply_probability = 0.0
|
||||
is_at = False
|
||||
is_mentioned = False
|
||||
|
||||
|
||||
# 这部分怎么处理啊啊啊啊
|
||||
#我觉得可以给消息加一个 reply_probability_boost字段
|
||||
# 我觉得可以给消息加一个 reply_probability_boost字段
|
||||
if (
|
||||
message.message_info.additional_config is not None
|
||||
and message.message_info.additional_config.get("is_mentioned") is not None
|
||||
@@ -826,20 +826,48 @@ def parse_keywords_string(keywords_input) -> list[str]:
|
||||
return [keywords_str] if keywords_str else []
|
||||
|
||||
|
||||
|
||||
|
||||
def cut_key_words(concept_name: str) -> list[str]:
|
||||
"""对概念名称进行jieba分词,并过滤掉关键词列表中的关键词"""
|
||||
concept_name_tokens = list(jieba.cut(concept_name))
|
||||
|
||||
# 定义常见连词、停用词与标点
|
||||
conjunctions = {
|
||||
"和", "与", "及", "跟", "以及", "并且", "而且", "或", "或者", "并"
|
||||
}
|
||||
conjunctions = {"和", "与", "及", "跟", "以及", "并且", "而且", "或", "或者", "并"}
|
||||
stop_words = {
|
||||
"的", "了", "呢", "吗", "吧", "啊", "哦", "恩", "嗯", "呀", "嘛", "哇",
|
||||
"在", "是", "很", "也", "又", "就", "都", "还", "更", "最", "被", "把",
|
||||
"给", "对", "和", "与", "及", "跟", "并", "而且", "或者", "或", "以及"
|
||||
"的",
|
||||
"了",
|
||||
"呢",
|
||||
"吗",
|
||||
"吧",
|
||||
"啊",
|
||||
"哦",
|
||||
"恩",
|
||||
"嗯",
|
||||
"呀",
|
||||
"嘛",
|
||||
"哇",
|
||||
"在",
|
||||
"是",
|
||||
"很",
|
||||
"也",
|
||||
"又",
|
||||
"就",
|
||||
"都",
|
||||
"还",
|
||||
"更",
|
||||
"最",
|
||||
"被",
|
||||
"把",
|
||||
"给",
|
||||
"对",
|
||||
"和",
|
||||
"与",
|
||||
"及",
|
||||
"跟",
|
||||
"并",
|
||||
"而且",
|
||||
"或者",
|
||||
"或",
|
||||
"以及",
|
||||
}
|
||||
chinese_punctuations = set(",。!?、;:()【】《》“”‘’—…·-——,.!?;:()[]<>'\"/\\")
|
||||
|
||||
@@ -864,11 +892,16 @@ def cut_key_words(concept_name: str) -> list[str]:
|
||||
left = merged_tokens[-1]
|
||||
right = cleaned_tokens[i + 1]
|
||||
# 左右都需要是有效词
|
||||
if left and right \
|
||||
and left not in conjunctions and right not in conjunctions \
|
||||
and left not in stop_words and right not in stop_words \
|
||||
and not all(ch in chinese_punctuations for ch in left) \
|
||||
and not all(ch in chinese_punctuations for ch in right):
|
||||
if (
|
||||
left
|
||||
and right
|
||||
and left not in conjunctions
|
||||
and right not in conjunctions
|
||||
and left not in stop_words
|
||||
and right not in stop_words
|
||||
and not all(ch in chinese_punctuations for ch in left)
|
||||
and not all(ch in chinese_punctuations for ch in right)
|
||||
):
|
||||
# 合并为一个新词,并替换掉左侧与跳过右侧
|
||||
combined = f"{left}{tok}{right}"
|
||||
merged_tokens[-1] = combined
|
||||
@@ -889,7 +922,7 @@ def cut_key_words(concept_name: str) -> list[str]:
|
||||
if tok in stop_words:
|
||||
continue
|
||||
# if tok in ban_words:
|
||||
# continue
|
||||
# continue
|
||||
if all(ch in chinese_punctuations for ch in tok):
|
||||
continue
|
||||
if tok.strip() == "":
|
||||
@@ -899,4 +932,4 @@ def cut_key_words(concept_name: str) -> list[str]:
|
||||
result_tokens.append(tok)
|
||||
|
||||
filtered_concept_name_tokens = result_tokens
|
||||
return filtered_concept_name_tokens
|
||||
return filtered_concept_name_tokens
|
||||
|
||||
@@ -91,9 +91,10 @@ class ImageManager:
|
||||
desc_obj.save()
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
|
||||
|
||||
async def get_emoji_tag(self, image_base64: str) -> str:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
@@ -120,6 +121,7 @@ class ImageManager:
|
||||
# 优先使用EmojiManager查询已注册表情包的描述
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
|
||||
if tags:
|
||||
|
||||
@@ -6,6 +6,7 @@ class BaseDataModel:
|
||||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
|
||||
def temporarily_transform_class_to_dict(obj: Any) -> Any:
|
||||
# sourcery skip: assign-if-exp, reintroduce-else
|
||||
"""
|
||||
|
||||
@@ -205,6 +205,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class DatabaseActionRecords(BaseDataModel):
|
||||
def __init__(
|
||||
@@ -232,4 +233,4 @@ class DatabaseActionRecords(BaseDataModel):
|
||||
self.action_prompt_display = action_prompt_display
|
||||
self.chat_id = chat_id
|
||||
self.chat_info_stream_id = chat_info_stream_id
|
||||
self.chat_info_platform = chat_info_platform
|
||||
self.chat_info_platform = chat_info_platform
|
||||
|
||||
@@ -2,9 +2,11 @@ from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING, Any
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMGenerationDataModel(BaseDataModel):
|
||||
content: Optional[str] = None
|
||||
@@ -13,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel):
|
||||
tool_calls: Optional[List["ToolCall"]] = None
|
||||
prompt: Optional[str] = None
|
||||
selected_expressions: Optional[List[int]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
|
||||
@@ -135,7 +135,7 @@ class Messages(BaseModel):
|
||||
interest_value = DoubleField(null=True)
|
||||
key_words = TextField(null=True)
|
||||
key_words_lite = TextField(null=True)
|
||||
|
||||
|
||||
is_mentioned = BooleanField(null=True)
|
||||
is_at = BooleanField(null=True)
|
||||
reply_probability_boost = DoubleField(null=True)
|
||||
@@ -169,7 +169,7 @@ class Messages(BaseModel):
|
||||
is_picid = BooleanField(default=False)
|
||||
is_command = BooleanField(default=False)
|
||||
is_notify = BooleanField(default=False)
|
||||
|
||||
|
||||
selected_expressions = TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
@@ -267,13 +267,10 @@ class PersonInfo(BaseModel):
|
||||
know_times = FloatField(null=True) # 认识时间 (时间戳)
|
||||
know_since = FloatField(null=True) # 首次印象总结时间
|
||||
last_know = FloatField(null=True) # 最后一次印象总结时间
|
||||
|
||||
|
||||
attitude_to_me = TextField(null=True) # 对bot的态度
|
||||
attitude_to_me_confidence = FloatField(null=True) # 对bot的态度置信度
|
||||
|
||||
|
||||
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "person_info"
|
||||
@@ -299,6 +296,7 @@ class GroupInfo(BaseModel):
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "group_info"
|
||||
|
||||
|
||||
class Expression(BaseModel):
|
||||
"""
|
||||
用于存储表达风格的模型。
|
||||
@@ -315,6 +313,7 @@ class Expression(BaseModel):
|
||||
class Meta:
|
||||
table_name = "expression"
|
||||
|
||||
|
||||
class GraphNodes(BaseModel):
|
||||
"""
|
||||
用于存储记忆图节点的模型
|
||||
@@ -374,7 +373,7 @@ def initialize_database(sync_constraints=False):
|
||||
"""
|
||||
检查所有定义的表是否存在,如果不存在则创建它们。
|
||||
检查所有表的所有字段是否存在,如果缺失则自动添加。
|
||||
|
||||
|
||||
Args:
|
||||
sync_constraints (bool): 是否同步字段约束。默认为 False。
|
||||
如果为 True,会检查并修复字段的 NULL 约束不一致问题。
|
||||
@@ -456,13 +455,13 @@ def initialize_database(sync_constraints=False):
|
||||
logger.info(f"字段 '{field_name}' 删除成功")
|
||||
except Exception as e:
|
||||
logger.error(f"删除字段 '{field_name}' 失败: {e}")
|
||||
|
||||
|
||||
# 如果启用了约束同步,执行约束检查和修复
|
||||
if sync_constraints:
|
||||
logger.debug("开始同步数据库字段约束...")
|
||||
sync_field_constraints()
|
||||
logger.debug("数据库字段约束同步完成")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"检查表或字段是否存在时出错: {e}")
|
||||
# 如果检查失败(例如数据库不可用),则退出
|
||||
@@ -476,7 +475,7 @@ def sync_field_constraints():
|
||||
同步数据库字段约束,确保现有数据库字段的 NULL 约束与模型定义一致。
|
||||
如果发现不一致,会自动修复字段约束。
|
||||
"""
|
||||
|
||||
|
||||
models = [
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
@@ -501,50 +500,55 @@ def sync_field_constraints():
|
||||
continue
|
||||
|
||||
logger.debug(f"检查表 '{table_name}' 的字段约束...")
|
||||
|
||||
|
||||
# 获取当前表结构信息
|
||||
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
|
||||
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]}
|
||||
for row in cursor.fetchall()}
|
||||
|
||||
current_schema = {
|
||||
row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall()
|
||||
}
|
||||
|
||||
# 检查每个模型字段的约束
|
||||
constraints_to_fix = []
|
||||
for field_name, field_obj in model._meta.fields.items():
|
||||
if field_name not in current_schema:
|
||||
continue # 字段不存在,跳过
|
||||
|
||||
current_notnull = current_schema[field_name]['notnull']
|
||||
|
||||
current_notnull = current_schema[field_name]["notnull"]
|
||||
model_allows_null = field_obj.null
|
||||
|
||||
|
||||
# 如果模型允许 null 但数据库字段不允许 null,需要修复
|
||||
if model_allows_null and current_notnull:
|
||||
constraints_to_fix.append({
|
||||
'field_name': field_name,
|
||||
'field_obj': field_obj,
|
||||
'action': 'allow_null',
|
||||
'current_constraint': 'NOT NULL',
|
||||
'target_constraint': 'NULL'
|
||||
})
|
||||
constraints_to_fix.append(
|
||||
{
|
||||
"field_name": field_name,
|
||||
"field_obj": field_obj,
|
||||
"action": "allow_null",
|
||||
"current_constraint": "NOT NULL",
|
||||
"target_constraint": "NULL",
|
||||
}
|
||||
)
|
||||
logger.warning(f"字段 '{field_name}' 约束不一致: 模型允许NULL,但数据库为NOT NULL")
|
||||
|
||||
|
||||
# 如果模型不允许 null 但数据库字段允许 null,也需要修复(但要小心)
|
||||
elif not model_allows_null and not current_notnull:
|
||||
constraints_to_fix.append({
|
||||
'field_name': field_name,
|
||||
'field_obj': field_obj,
|
||||
'action': 'disallow_null',
|
||||
'current_constraint': 'NULL',
|
||||
'target_constraint': 'NOT NULL'
|
||||
})
|
||||
constraints_to_fix.append(
|
||||
{
|
||||
"field_name": field_name,
|
||||
"field_obj": field_obj,
|
||||
"action": "disallow_null",
|
||||
"current_constraint": "NULL",
|
||||
"target_constraint": "NOT NULL",
|
||||
}
|
||||
)
|
||||
logger.warning(f"字段 '{field_name}' 约束不一致: 模型不允许NULL,但数据库允许NULL")
|
||||
|
||||
|
||||
# 修复约束不一致的字段
|
||||
if constraints_to_fix:
|
||||
logger.info(f"表 '{table_name}' 需要修复 {len(constraints_to_fix)} 个字段约束")
|
||||
_fix_table_constraints(table_name, model, constraints_to_fix)
|
||||
else:
|
||||
logger.debug(f"表 '{table_name}' 的字段约束已同步")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"同步字段约束时出错: {e}")
|
||||
|
||||
@@ -557,40 +561,39 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||
try:
|
||||
# 备份表名
|
||||
backup_table = f"{table_name}_backup_{int(datetime.datetime.now().timestamp())}"
|
||||
|
||||
|
||||
logger.info(f"开始修复表 '{table_name}' 的字段约束...")
|
||||
|
||||
|
||||
# 1. 创建备份表
|
||||
db.execute_sql(f"CREATE TABLE {backup_table} AS SELECT * FROM {table_name}")
|
||||
logger.info(f"已创建备份表 '{backup_table}'")
|
||||
|
||||
|
||||
# 2. 删除原表
|
||||
db.execute_sql(f"DROP TABLE {table_name}")
|
||||
logger.info(f"已删除原表 '{table_name}'")
|
||||
|
||||
|
||||
# 3. 重新创建表(使用当前模型定义)
|
||||
db.create_tables([model])
|
||||
logger.info(f"已重新创建表 '{table_name}' 使用新的约束")
|
||||
|
||||
|
||||
# 4. 从备份表恢复数据
|
||||
# 获取字段列表
|
||||
fields = list(model._meta.fields.keys())
|
||||
fields_str = ', '.join(fields)
|
||||
|
||||
fields_str = ", ".join(fields)
|
||||
|
||||
# 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据
|
||||
# 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值
|
||||
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}"
|
||||
|
||||
|
||||
# 检查是否有字段需要从 NULL 改为 NOT NULL
|
||||
null_to_notnull_fields = [
|
||||
constraint['field_name'] for constraint in constraints_to_fix
|
||||
if constraint['action'] == 'disallow_null'
|
||||
constraint["field_name"] for constraint in constraints_to_fix if constraint["action"] == "disallow_null"
|
||||
]
|
||||
|
||||
|
||||
if null_to_notnull_fields:
|
||||
# 需要处理 NULL 值,为这些字段设置默认值
|
||||
logger.warning(f"字段 {null_to_notnull_fields} 将从允许NULL改为不允许NULL,需要处理现有的NULL值")
|
||||
|
||||
|
||||
# 构建更复杂的 SELECT 语句来处理 NULL 值
|
||||
select_fields = []
|
||||
for field_name in fields:
|
||||
@@ -607,21 +610,21 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||
default_value = f"'{datetime.datetime.now()}'"
|
||||
else:
|
||||
default_value = "''"
|
||||
|
||||
|
||||
select_fields.append(f"COALESCE({field_name}, {default_value}) as {field_name}")
|
||||
else:
|
||||
select_fields.append(field_name)
|
||||
|
||||
select_str = ', '.join(select_fields)
|
||||
|
||||
select_str = ", ".join(select_fields)
|
||||
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}"
|
||||
|
||||
|
||||
db.execute_sql(insert_sql)
|
||||
logger.info(f"已从备份表恢复数据到 '{table_name}'")
|
||||
|
||||
|
||||
# 5. 验证数据完整性
|
||||
original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0]
|
||||
new_count = db.execute_sql(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
|
||||
|
||||
|
||||
if original_count == new_count:
|
||||
logger.info(f"数据完整性验证通过: {original_count} 行数据")
|
||||
# 删除备份表
|
||||
@@ -630,12 +633,14 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
|
||||
else:
|
||||
logger.error(f"数据完整性验证失败: 原始 {original_count} 行,新表 {new_count} 行")
|
||||
logger.error(f"备份表 '{backup_table}' 已保留,请手动检查")
|
||||
|
||||
|
||||
# 记录修复的约束
|
||||
for constraint in constraints_to_fix:
|
||||
logger.info(f"已修复字段 '{constraint['field_name']}': "
|
||||
f"{constraint['current_constraint']} -> {constraint['target_constraint']}")
|
||||
|
||||
logger.info(
|
||||
f"已修复字段 '{constraint['field_name']}': "
|
||||
f"{constraint['current_constraint']} -> {constraint['target_constraint']}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"修复表 '{table_name}' 约束时出错: {e}")
|
||||
# 尝试恢复
|
||||
@@ -654,7 +659,7 @@ def check_field_constraints():
|
||||
检查但不修复字段约束,返回不一致的字段信息。
|
||||
用于在修复前预览需要修复的内容。
|
||||
"""
|
||||
|
||||
|
||||
models = [
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
@@ -669,9 +674,9 @@ def check_field_constraints():
|
||||
GraphEdges,
|
||||
ActionRecords,
|
||||
]
|
||||
|
||||
|
||||
inconsistencies = {}
|
||||
|
||||
|
||||
try:
|
||||
with db:
|
||||
for model in models:
|
||||
@@ -681,49 +686,49 @@ def check_field_constraints():
|
||||
|
||||
# 获取当前表结构信息
|
||||
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
|
||||
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]}
|
||||
for row in cursor.fetchall()}
|
||||
|
||||
current_schema = {
|
||||
row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall()
|
||||
}
|
||||
|
||||
table_inconsistencies = []
|
||||
|
||||
|
||||
# 检查每个模型字段的约束
|
||||
for field_name, field_obj in model._meta.fields.items():
|
||||
if field_name not in current_schema:
|
||||
continue
|
||||
|
||||
current_notnull = current_schema[field_name]['notnull']
|
||||
|
||||
current_notnull = current_schema[field_name]["notnull"]
|
||||
model_allows_null = field_obj.null
|
||||
|
||||
|
||||
if model_allows_null and current_notnull:
|
||||
table_inconsistencies.append({
|
||||
'field_name': field_name,
|
||||
'issue': 'model_allows_null_but_db_not_null',
|
||||
'model_constraint': 'NULL',
|
||||
'db_constraint': 'NOT NULL',
|
||||
'recommended_action': 'allow_null'
|
||||
})
|
||||
table_inconsistencies.append(
|
||||
{
|
||||
"field_name": field_name,
|
||||
"issue": "model_allows_null_but_db_not_null",
|
||||
"model_constraint": "NULL",
|
||||
"db_constraint": "NOT NULL",
|
||||
"recommended_action": "allow_null",
|
||||
}
|
||||
)
|
||||
elif not model_allows_null and not current_notnull:
|
||||
table_inconsistencies.append({
|
||||
'field_name': field_name,
|
||||
'issue': 'model_not_null_but_db_allows_null',
|
||||
'model_constraint': 'NOT NULL',
|
||||
'db_constraint': 'NULL',
|
||||
'recommended_action': 'disallow_null'
|
||||
})
|
||||
|
||||
table_inconsistencies.append(
|
||||
{
|
||||
"field_name": field_name,
|
||||
"issue": "model_not_null_but_db_allows_null",
|
||||
"model_constraint": "NOT NULL",
|
||||
"db_constraint": "NULL",
|
||||
"recommended_action": "disallow_null",
|
||||
}
|
||||
)
|
||||
|
||||
if table_inconsistencies:
|
||||
inconsistencies[table_name] = table_inconsistencies
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"检查字段约束时出错: {e}")
|
||||
|
||||
return inconsistencies
|
||||
|
||||
return inconsistencies
|
||||
|
||||
|
||||
# 模块加载时调用初始化函数
|
||||
initialize_database(sync_constraints=True)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -339,24 +339,18 @@ MODULE_COLORS = {
|
||||
# 67 :具体的颜色编号(0-255),这里是较暗的蓝色
|
||||
"sender": "\033[38;5;24m", # 67号色,较暗的蓝色,适合不显眼的日志
|
||||
"send_api": "\033[38;5;24m", # 208号色,橙色,适合突出显示
|
||||
|
||||
# 生成
|
||||
"replyer": "\033[38;5;208m", # 橙色
|
||||
"llm_api": "\033[38;5;208m", # 橙色
|
||||
|
||||
# 消息处理
|
||||
"chat": "\033[38;5;82m", # 亮蓝色
|
||||
"chat_image": "\033[38;5;68m", # 浅蓝色
|
||||
|
||||
#emoji
|
||||
# emoji
|
||||
"emoji": "\033[38;5;214m", # 橙黄色,偏向橙色
|
||||
"emoji_api": "\033[38;5;214m", # 橙黄色,偏向橙色
|
||||
|
||||
# 核心模块
|
||||
"main": "\033[1;97m", # 亮白色+粗体 (主程序)
|
||||
|
||||
"memory": "\033[38;5;34m", # 天蓝色
|
||||
|
||||
"config": "\033[93m", # 亮黄色
|
||||
"common": "\033[95m", # 亮紫色
|
||||
"tools": "\033[96m", # 亮青色
|
||||
@@ -367,9 +361,6 @@ MODULE_COLORS = {
|
||||
"llm_models": "\033[36m", # 青色
|
||||
"remote": "\033[38;5;242m", # 深灰色,更不显眼
|
||||
"planner": "\033[36m",
|
||||
|
||||
|
||||
|
||||
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
|
||||
# 聊天相关模块
|
||||
"normal_chat": "\033[38;5;81m", # 亮蓝绿色
|
||||
@@ -379,11 +370,9 @@ MODULE_COLORS = {
|
||||
"background_tasks": "\033[38;5;240m", # 灰色
|
||||
"chat_message": "\033[38;5;45m", # 青色
|
||||
"chat_stream": "\033[38;5;51m", # 亮青色
|
||||
|
||||
"message_storage": "\033[38;5;33m", # 深蓝色
|
||||
"expressor": "\033[38;5;166m", # 橙色
|
||||
# 专注聊天模块
|
||||
|
||||
"memory_activator": "\033[38;5;117m", # 天蓝色
|
||||
# 插件系统
|
||||
"plugins": "\033[31m", # 红色
|
||||
@@ -412,7 +401,6 @@ MODULE_COLORS = {
|
||||
# 工具和实用模块
|
||||
"prompt_build": "\033[38;5;105m", # 紫色
|
||||
"chat_utils": "\033[38;5;111m", # 蓝色
|
||||
|
||||
"maibot_statistic": "\033[38;5;129m", # 紫色
|
||||
# 特殊功能插件
|
||||
"mute_plugin": "\033[38;5;240m", # 灰色
|
||||
@@ -447,10 +435,8 @@ MODULE_ALIASES = {
|
||||
"llm_api": "生成API",
|
||||
"emoji": "表情包",
|
||||
"emoji_api": "表情包API",
|
||||
|
||||
"chat": "所见",
|
||||
"chat_image": "识图",
|
||||
|
||||
"action_manager": "动作",
|
||||
"memory_activator": "记忆",
|
||||
"tool_use": "工具",
|
||||
@@ -460,7 +446,6 @@ MODULE_ALIASES = {
|
||||
"memory": "记忆",
|
||||
"tool_executor": "工具",
|
||||
"hfc": "聊天节奏",
|
||||
|
||||
"plugin_manager": "插件",
|
||||
"relationship_builder": "关系",
|
||||
"llm_models": "模型",
|
||||
|
||||
@@ -114,7 +114,7 @@ def set_value_by_path(d, path, value):
|
||||
if k not in d or not isinstance(d[k], dict):
|
||||
d[k] = {}
|
||||
d = d[k]
|
||||
|
||||
|
||||
# 使用 tomlkit.item 来保持 TOML 格式
|
||||
try:
|
||||
d[path[-1]] = tomlkit.item(value)
|
||||
@@ -253,7 +253,7 @@ def _update_config_generic(config_name: str, template_name: str):
|
||||
f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
|
||||
)
|
||||
config_updated = True
|
||||
|
||||
|
||||
# 如果配置有更新,立即保存到文件
|
||||
if config_updated:
|
||||
with open(old_config_path, "w", encoding="utf-8") as f:
|
||||
|
||||
@@ -43,10 +43,11 @@ class PersonalityConfig(ConfigBase):
|
||||
|
||||
reply_style: str = ""
|
||||
"""表达风格"""
|
||||
|
||||
|
||||
interest: str = ""
|
||||
"""兴趣"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelationshipConfig(ConfigBase):
|
||||
"""关系配置类"""
|
||||
@@ -61,31 +62,30 @@ class ChatConfig(ConfigBase):
|
||||
|
||||
max_context_size: int = 18
|
||||
"""上下文长度"""
|
||||
|
||||
|
||||
interest_rate_mode: Literal["fast", "accurate"] = "fast"
|
||||
"""兴趣值计算模式,fast为快速计算,accurate为精确计算"""
|
||||
|
||||
mentioned_bot_reply: float = 1
|
||||
"""提及 bot 必然回复,1为100%回复,0为不额外增幅"""
|
||||
|
||||
|
||||
planner_size: float = 1.5
|
||||
"""副规划器大小,越小,麦麦的动作执行能力越精细,但是消耗更多token,调大可以缓解429类错误"""
|
||||
|
||||
at_bot_inevitable_reply: float = 1
|
||||
"""@bot 必然回复,1为100%回复,0为不额外增幅"""
|
||||
|
||||
|
||||
talk_frequency: float = 0.5
|
||||
"""回复频率阈值"""
|
||||
|
||||
# 合并后的时段频率配置
|
||||
talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
focus_value: float = 0.5
|
||||
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
||||
|
||||
|
||||
focus_value_adjust: list[list[str]] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
"""
|
||||
统一的活跃度和专注度配置
|
||||
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
|
||||
@@ -110,7 +110,6 @@ class ChatConfig(ConfigBase):
|
||||
- talk_frequency_adjust 控制回复频率,数值越高回复越频繁
|
||||
- focus_value_adjust 控制专注思考能力,数值越低越容易专注,消耗token也越多
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -123,6 +122,7 @@ class MessageReceiveConfig(ConfigBase):
|
||||
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
|
||||
"""过滤正则表达式列表"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpressionConfig(ConfigBase):
|
||||
"""表达配置类"""
|
||||
|
||||
@@ -174,7 +174,7 @@ class ClientRegistry:
|
||||
return client_class(api_provider)
|
||||
else:
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
|
||||
|
||||
# 正常的缓存逻辑
|
||||
if api_provider.name not in self.client_instance_cache:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
|
||||
@@ -531,7 +531,7 @@ class OpenaiClient(BaseClient):
|
||||
# 添加详细的错误信息以便调试
|
||||
logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}")
|
||||
logger.error(f"错误类型: {type(e)}")
|
||||
if hasattr(e, '__cause__') and e.__cause__:
|
||||
if hasattr(e, "__cause__") and e.__cause__:
|
||||
logger.error(f"底层错误: {str(e.__cause__)}")
|
||||
raise NetworkConnectionError() from e
|
||||
except APIStatusError as e:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .tool_option import ToolCall
|
||||
|
||||
__all__ = ["ToolCall"]
|
||||
__all__ = ["ToolCall"]
|
||||
|
||||
@@ -48,8 +48,7 @@ def _json_schema_type_check(instance) -> str | None:
|
||||
elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
|
||||
return "schema的'name'字段必须是非空字符串"
|
||||
if "description" in instance and (
|
||||
not isinstance(instance["description"], str)
|
||||
or instance["description"].strip() == ""
|
||||
not isinstance(instance["description"], str) or instance["description"].strip() == ""
|
||||
):
|
||||
return "schema的'description'字段只能填入非空字符串"
|
||||
if "schema" not in instance:
|
||||
@@ -101,9 +100,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
# 如果当前Schema是列表,则遍历每个元素
|
||||
for i in range(len(sub_schema)):
|
||||
if isinstance(sub_schema[i], dict):
|
||||
sub_schema[i] = link_definitions_recursive(
|
||||
f"{path}/{str(i)}", sub_schema[i], defs
|
||||
)
|
||||
sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
|
||||
else:
|
||||
# 否则为字典
|
||||
if "$defs" in sub_schema:
|
||||
@@ -125,9 +122,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
for key, value in sub_schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
# 如果当前值是字典或列表,则递归调用
|
||||
sub_schema[key] = link_definitions_recursive(
|
||||
f"{path}/{key}", value, defs
|
||||
)
|
||||
sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs)
|
||||
|
||||
return sub_schema
|
||||
|
||||
@@ -163,9 +158,7 @@ class RespFormat:
|
||||
def _generate_schema_from_model(schema):
|
||||
json_schema = {
|
||||
"name": schema.__name__,
|
||||
"schema": _remove_defs(
|
||||
_link_definitions(_remove_title(schema.model_json_schema()))
|
||||
),
|
||||
"schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))),
|
||||
"strict": False,
|
||||
}
|
||||
if schema.__doc__:
|
||||
|
||||
@@ -155,7 +155,13 @@ class LLMUsageRecorder:
|
||||
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||
|
||||
def record_usage_to_database(
|
||||
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
model_usage: UsageRecord,
|
||||
user_id: str,
|
||||
request_type: str,
|
||||
endpoint: str,
|
||||
time_cost: float = 0.0,
|
||||
):
|
||||
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
|
||||
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
||||
@@ -173,7 +179,7 @@ class LLMUsageRecorder:
|
||||
completion_tokens=model_usage.completion_tokens or 0,
|
||||
total_tokens=model_usage.total_tokens or 0,
|
||||
cost=total_cost or 0.0,
|
||||
time_cost = round(time_cost or 0.0, 3),
|
||||
time_cost=round(time_cost or 0.0, 3),
|
||||
status="success",
|
||||
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||
)
|
||||
@@ -186,4 +192,5 @@ class LLMUsageRecorder:
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
|
||||
llm_usage_recorder = LLMUsageRecorder()
|
||||
|
||||
llm_usage_recorder = LLMUsageRecorder()
|
||||
|
||||
@@ -14,31 +14,31 @@ logger = get_logger("context_web")
|
||||
|
||||
class ContextMessage:
|
||||
"""上下文消息类"""
|
||||
|
||||
|
||||
def __init__(self, message: MessageRecv):
|
||||
self.user_name = message.message_info.user_info.user_nickname
|
||||
self.user_id = message.message_info.user_info.user_id
|
||||
self.content = message.processed_plain_text
|
||||
self.timestamp = datetime.now()
|
||||
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
|
||||
|
||||
|
||||
# 识别消息类型
|
||||
self.is_gift = getattr(message, 'is_gift', False)
|
||||
self.is_superchat = getattr(message, 'is_superchat', False)
|
||||
|
||||
self.is_gift = getattr(message, "is_gift", False)
|
||||
self.is_superchat = getattr(message, "is_superchat", False)
|
||||
|
||||
# 添加礼物和SC相关信息
|
||||
if self.is_gift:
|
||||
self.gift_name = getattr(message, 'gift_name', '')
|
||||
self.gift_count = getattr(message, 'gift_count', '1')
|
||||
self.gift_name = getattr(message, "gift_name", "")
|
||||
self.gift_count = getattr(message, "gift_count", "1")
|
||||
self.content = f"送出了 {self.gift_name} x{self.gift_count}"
|
||||
elif self.is_superchat:
|
||||
self.superchat_price = getattr(message, 'superchat_price', '0')
|
||||
self.superchat_message = getattr(message, 'superchat_message_text', '')
|
||||
self.superchat_price = getattr(message, "superchat_price", "0")
|
||||
self.superchat_message = getattr(message, "superchat_message_text", "")
|
||||
if self.superchat_message:
|
||||
self.content = f"[¥{self.superchat_price}] {self.superchat_message}"
|
||||
else:
|
||||
self.content = f"[¥{self.superchat_price}] {self.content}"
|
||||
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"user_name": self.user_name,
|
||||
@@ -47,13 +47,13 @@ class ContextMessage:
|
||||
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
|
||||
"group_name": self.group_name,
|
||||
"is_gift": self.is_gift,
|
||||
"is_superchat": self.is_superchat
|
||||
"is_superchat": self.is_superchat,
|
||||
}
|
||||
|
||||
|
||||
class ContextWebManager:
|
||||
"""上下文网页管理器"""
|
||||
|
||||
|
||||
def __init__(self, max_messages: int = 10, port: int = 8765):
|
||||
self.max_messages = max_messages
|
||||
self.port = port
|
||||
@@ -63,53 +63,53 @@ class ContextWebManager:
|
||||
self.runner = None
|
||||
self.site = None
|
||||
self._server_starting = False # 添加启动标志防止并发
|
||||
|
||||
|
||||
async def start_server(self):
|
||||
"""启动web服务器"""
|
||||
if self.site is not None:
|
||||
logger.debug("Web服务器已经启动,跳过重复启动")
|
||||
return
|
||||
|
||||
|
||||
if self._server_starting:
|
||||
logger.debug("Web服务器正在启动中,等待启动完成...")
|
||||
# 等待启动完成
|
||||
while self._server_starting and self.site is None:
|
||||
await asyncio.sleep(0.1)
|
||||
return
|
||||
|
||||
|
||||
self._server_starting = True
|
||||
|
||||
|
||||
try:
|
||||
self.app = web.Application()
|
||||
|
||||
|
||||
# 设置CORS
|
||||
cors = aiohttp_cors.setup(self.app, defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True,
|
||||
expose_headers="*",
|
||||
allow_headers="*",
|
||||
allow_methods="*"
|
||||
)
|
||||
})
|
||||
|
||||
cors = aiohttp_cors.setup(
|
||||
self.app,
|
||||
defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# 添加路由
|
||||
self.app.router.add_get('/', self.index_handler)
|
||||
self.app.router.add_get('/ws', self.websocket_handler)
|
||||
self.app.router.add_get('/api/contexts', self.get_contexts_handler)
|
||||
self.app.router.add_get('/debug', self.debug_handler)
|
||||
|
||||
self.app.router.add_get("/", self.index_handler)
|
||||
self.app.router.add_get("/ws", self.websocket_handler)
|
||||
self.app.router.add_get("/api/contexts", self.get_contexts_handler)
|
||||
self.app.router.add_get("/debug", self.debug_handler)
|
||||
|
||||
# 为所有路由添加CORS
|
||||
for route in list(self.app.router.routes()):
|
||||
cors.add(route)
|
||||
|
||||
|
||||
self.runner = web.AppRunner(self.app)
|
||||
await self.runner.setup()
|
||||
|
||||
self.site = web.TCPSite(self.runner, 'localhost', self.port)
|
||||
|
||||
self.site = web.TCPSite(self.runner, "localhost", self.port)
|
||||
await self.site.start()
|
||||
|
||||
|
||||
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 启动Web服务器失败: {e}")
|
||||
# 清理部分启动的资源
|
||||
@@ -121,7 +121,7 @@ class ContextWebManager:
|
||||
raise
|
||||
finally:
|
||||
self._server_starting = False
|
||||
|
||||
|
||||
async def stop_server(self):
|
||||
"""停止web服务器"""
|
||||
if self.site:
|
||||
@@ -132,10 +132,11 @@ class ContextWebManager:
|
||||
self.runner = None
|
||||
self.site = None
|
||||
self._server_starting = False
|
||||
|
||||
|
||||
async def index_handler(self, request):
|
||||
"""主页处理器"""
|
||||
html_content = '''
|
||||
html_content = (
|
||||
"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
@@ -286,7 +287,9 @@ class ContextWebManager:
|
||||
|
||||
function connectWebSocket() {
|
||||
console.log('正在连接WebSocket...');
|
||||
ws = new WebSocket('ws://localhost:''' + str(self.port) + '''/ws');
|
||||
ws = new WebSocket('ws://localhost:"""
|
||||
+ str(self.port)
|
||||
+ """/ws');
|
||||
|
||||
ws.onopen = function() {
|
||||
console.log('WebSocket连接已建立');
|
||||
@@ -470,47 +473,48 @@ class ContextWebManager:
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
return web.Response(text=html_content, content_type='text/html')
|
||||
|
||||
"""
|
||||
)
|
||||
return web.Response(text=html_content, content_type="text/html")
|
||||
|
||||
async def websocket_handler(self, request):
|
||||
"""WebSocket处理器"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
|
||||
self.websockets.append(ws)
|
||||
logger.debug(f"WebSocket连接建立,当前连接数: {len(self.websockets)}")
|
||||
|
||||
|
||||
# 发送初始数据
|
||||
await self.send_contexts_to_websocket(ws)
|
||||
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
logger.error(f'WebSocket错误: {ws.exception()}')
|
||||
logger.error(f"WebSocket错误: {ws.exception()}")
|
||||
break
|
||||
|
||||
|
||||
# 清理断开的连接
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
logger.debug(f"WebSocket连接断开,当前连接数: {len(self.websockets)}")
|
||||
|
||||
|
||||
return ws
|
||||
|
||||
|
||||
async def get_contexts_handler(self, request):
|
||||
"""获取上下文API"""
|
||||
all_context_msgs = []
|
||||
for _chat_id, contexts in self.contexts.items():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
|
||||
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
|
||||
return web.json_response({"contexts": contexts_data})
|
||||
|
||||
|
||||
async def debug_handler(self, request):
|
||||
"""调试信息处理器"""
|
||||
debug_info = {
|
||||
@@ -519,7 +523,7 @@ class ContextWebManager:
|
||||
"total_chats": len(self.contexts),
|
||||
"total_messages": sum(len(contexts) for contexts in self.contexts.values()),
|
||||
}
|
||||
|
||||
|
||||
# 构建聊天详情HTML
|
||||
chats_html = ""
|
||||
for chat_id, contexts in self.contexts.items():
|
||||
@@ -528,15 +532,15 @@ class ContextWebManager:
|
||||
timestamp = msg.timestamp.strftime("%H:%M:%S")
|
||||
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
|
||||
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
|
||||
|
||||
chats_html += f'''
|
||||
|
||||
chats_html += f"""
|
||||
<div class="chat">
|
||||
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
|
||||
{messages_html}
|
||||
</div>
|
||||
'''
|
||||
|
||||
html_content = f'''
|
||||
"""
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
@@ -578,74 +582,78 @@ class ContextWebManager:
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
|
||||
return web.Response(text=html_content, content_type='text/html')
|
||||
|
||||
"""
|
||||
|
||||
return web.Response(text=html_content, content_type="text/html")
|
||||
|
||||
async def add_message(self, chat_id: str, message: MessageRecv):
|
||||
"""添加新消息到上下文"""
|
||||
if chat_id not in self.contexts:
|
||||
self.contexts[chat_id] = deque(maxlen=self.max_messages)
|
||||
logger.debug(f"为聊天 {chat_id} 创建新的上下文队列")
|
||||
|
||||
|
||||
context_msg = ContextMessage(message)
|
||||
self.contexts[chat_id].append(context_msg)
|
||||
|
||||
|
||||
# 统计当前总消息数
|
||||
total_messages = sum(len(contexts) for contexts in self.contexts.values())
|
||||
|
||||
logger.info(f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}"
|
||||
)
|
||||
|
||||
# 调试:打印当前所有消息
|
||||
logger.info("📝 当前上下文中的所有消息:")
|
||||
for cid, contexts in self.contexts.items():
|
||||
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
|
||||
for i, msg in enumerate(contexts):
|
||||
logger.info(f" {i+1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}...")
|
||||
|
||||
logger.info(
|
||||
f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..."
|
||||
)
|
||||
|
||||
# 广播更新给所有WebSocket连接
|
||||
await self.broadcast_contexts()
|
||||
|
||||
|
||||
async def send_contexts_to_websocket(self, ws: web.WebSocketResponse):
|
||||
"""向单个WebSocket发送上下文数据"""
|
||||
all_context_msgs = []
|
||||
for _chat_id, contexts in self.contexts.items():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
|
||||
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
data = {"contexts": contexts_data}
|
||||
await ws.send_str(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
|
||||
async def broadcast_contexts(self):
|
||||
"""向所有WebSocket连接广播上下文更新"""
|
||||
if not self.websockets:
|
||||
logger.debug("没有WebSocket连接,跳过广播")
|
||||
return
|
||||
|
||||
|
||||
all_context_msgs = []
|
||||
for _chat_id, contexts in self.contexts.items():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
|
||||
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
data = {"contexts": contexts_data}
|
||||
message = json.dumps(data, ensure_ascii=False)
|
||||
|
||||
|
||||
logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接")
|
||||
|
||||
|
||||
# 创建WebSocket列表的副本,避免在遍历时修改
|
||||
websockets_copy = self.websockets.copy()
|
||||
removed_count = 0
|
||||
|
||||
|
||||
for ws in websockets_copy:
|
||||
if ws.closed:
|
||||
if ws in self.websockets:
|
||||
@@ -660,7 +668,7 @@ class ContextWebManager:
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
removed_count += 1
|
||||
|
||||
|
||||
if removed_count > 0:
|
||||
logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接")
|
||||
|
||||
@@ -681,5 +689,4 @@ async def init_context_web_manager():
|
||||
"""初始化上下文网页管理器"""
|
||||
manager = get_context_web_manager()
|
||||
await manager.start_server()
|
||||
return manager
|
||||
|
||||
return manager
|
||||
|
||||
@@ -11,6 +11,7 @@ logger = get_logger("gift_manager")
|
||||
@dataclass
|
||||
class PendingGift:
|
||||
"""等待中的礼物消息"""
|
||||
|
||||
message: MessageRecvS4U
|
||||
total_count: int
|
||||
timer_task: asyncio.Task
|
||||
@@ -19,71 +20,68 @@ class PendingGift:
|
||||
|
||||
class GiftManager:
|
||||
"""礼物管理器,提供防抖功能"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""初始化礼物管理器"""
|
||||
self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {}
|
||||
self.debounce_timeout = 5.0 # 3秒防抖时间
|
||||
|
||||
async def handle_gift(self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None) -> bool:
|
||||
|
||||
async def handle_gift(
|
||||
self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None
|
||||
) -> bool:
|
||||
"""处理礼物消息,返回是否应该立即处理
|
||||
|
||||
|
||||
Args:
|
||||
message: 礼物消息
|
||||
callback: 防抖完成后的回调函数
|
||||
|
||||
|
||||
Returns:
|
||||
bool: False表示消息被暂存等待防抖,True表示应该立即处理
|
||||
"""
|
||||
if not message.is_gift:
|
||||
return True
|
||||
|
||||
|
||||
# 构建礼物的唯一键:(发送人ID, 礼物名称)
|
||||
gift_key = (message.message_info.user_info.user_id, message.gift_name)
|
||||
|
||||
|
||||
# 如果已经有相同的礼物在等待中,则合并
|
||||
if gift_key in self.pending_gifts:
|
||||
await self._merge_gift(gift_key, message)
|
||||
return False
|
||||
|
||||
|
||||
# 创建新的等待礼物
|
||||
await self._create_pending_gift(gift_key, message, callback)
|
||||
return False
|
||||
|
||||
|
||||
async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None:
|
||||
"""合并礼物消息"""
|
||||
pending_gift = self.pending_gifts[gift_key]
|
||||
|
||||
|
||||
# 取消之前的定时器
|
||||
if not pending_gift.timer_task.cancelled():
|
||||
pending_gift.timer_task.cancel()
|
||||
|
||||
|
||||
# 累加礼物数量
|
||||
try:
|
||||
new_count = int(new_message.gift_count)
|
||||
pending_gift.total_count += new_count
|
||||
|
||||
|
||||
# 更新消息为最新的(保留最新的消息,但累加数量)
|
||||
pending_gift.message = new_message
|
||||
pending_gift.message.gift_count = str(pending_gift.total_count)
|
||||
pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}"
|
||||
|
||||
|
||||
except ValueError:
|
||||
logger.warning(f"无法解析礼物数量: {new_message.gift_count}")
|
||||
# 如果无法解析数量,保持原有数量不变
|
||||
|
||||
|
||||
# 重新创建定时器
|
||||
pending_gift.timer_task = asyncio.create_task(
|
||||
self._gift_timeout(gift_key)
|
||||
)
|
||||
|
||||
pending_gift.timer_task = asyncio.create_task(self._gift_timeout(gift_key))
|
||||
|
||||
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
|
||||
|
||||
|
||||
async def _create_pending_gift(
|
||||
self,
|
||||
gift_key: Tuple[str, str],
|
||||
message: MessageRecvS4U,
|
||||
callback: Optional[Callable[[MessageRecvS4U], None]]
|
||||
self, gift_key: Tuple[str, str], message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]]
|
||||
) -> None:
|
||||
"""创建新的等待礼物"""
|
||||
try:
|
||||
@@ -91,56 +89,51 @@ class GiftManager:
|
||||
except ValueError:
|
||||
initial_count = 1
|
||||
logger.warning(f"无法解析礼物数量: {message.gift_count},默认设为1")
|
||||
|
||||
|
||||
# 创建定时器任务
|
||||
timer_task = asyncio.create_task(self._gift_timeout(gift_key))
|
||||
|
||||
|
||||
# 创建等待礼物对象
|
||||
pending_gift = PendingGift(
|
||||
message=message,
|
||||
total_count=initial_count,
|
||||
timer_task=timer_task,
|
||||
callback=callback
|
||||
)
|
||||
|
||||
pending_gift = PendingGift(message=message, total_count=initial_count, timer_task=timer_task, callback=callback)
|
||||
|
||||
self.pending_gifts[gift_key] = pending_gift
|
||||
|
||||
|
||||
logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}")
|
||||
|
||||
|
||||
async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None:
|
||||
"""礼物防抖超时处理"""
|
||||
try:
|
||||
# 等待防抖时间
|
||||
await asyncio.sleep(self.debounce_timeout)
|
||||
|
||||
|
||||
# 获取等待中的礼物
|
||||
if gift_key not in self.pending_gifts:
|
||||
return
|
||||
|
||||
|
||||
pending_gift = self.pending_gifts.pop(gift_key)
|
||||
|
||||
|
||||
logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}")
|
||||
|
||||
|
||||
message = pending_gift.message
|
||||
message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}"
|
||||
|
||||
|
||||
# 执行回调
|
||||
if pending_gift.callback:
|
||||
try:
|
||||
pending_gift.callback(message)
|
||||
except Exception as e:
|
||||
logger.error(f"礼物回调执行失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 定时器被取消,不需要处理
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"礼物防抖处理异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
def get_pending_count(self) -> int:
|
||||
"""获取当前等待中的礼物数量"""
|
||||
return len(self.pending_gifts)
|
||||
|
||||
|
||||
async def flush_all(self) -> None:
|
||||
"""立即处理所有等待中的礼物"""
|
||||
for gift_key in list(self.pending_gifts.keys()):
|
||||
@@ -152,4 +145,3 @@ class GiftManager:
|
||||
|
||||
# 创建全局礼物管理器实例
|
||||
gift_manager = GiftManager()
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
class InternalManager:
|
||||
def __init__(self):
|
||||
self.now_internal_state = str()
|
||||
|
||||
def set_internal_state(self,internal_state:str):
|
||||
|
||||
def set_internal_state(self, internal_state: str):
|
||||
self.now_internal_state = internal_state
|
||||
|
||||
|
||||
def get_internal_state(self):
|
||||
return self.now_internal_state
|
||||
|
||||
|
||||
def get_internal_state_str(self):
|
||||
return f"你今天的直播内容是直播QQ水群,你正在一边回复弹幕,一边在QQ群聊天,你在QQ群聊天中产生的想法是:{self.now_internal_state}"
|
||||
|
||||
internal_manager = InternalManager()
|
||||
|
||||
internal_manager = InternalManager()
|
||||
|
||||
@@ -16,7 +16,6 @@ import json
|
||||
from .s4u_mood_manager import mood_manager
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.person_info.person_info import get_person_id
|
||||
from .super_chat_manager import get_super_chat_manager
|
||||
from .yes_or_no import yes_or_no_head
|
||||
|
||||
logger = get_logger("S4U_chat")
|
||||
@@ -33,15 +32,12 @@ class MessageSenderContainer:
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._paused_event = asyncio.Event()
|
||||
self._paused_event.set() # 默认设置为非暂停状态
|
||||
|
||||
self.msg_id = ""
|
||||
|
||||
self.last_msg_id = ""
|
||||
|
||||
self.voice_done = ""
|
||||
|
||||
|
||||
|
||||
self.msg_id = ""
|
||||
|
||||
self.last_msg_id = ""
|
||||
|
||||
self.voice_done = ""
|
||||
|
||||
async def add_message(self, chunk: str):
|
||||
"""向队列中添加一个消息块。"""
|
||||
@@ -131,7 +127,7 @@ class MessageSenderContainer:
|
||||
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
|
||||
)
|
||||
await bot_message.process()
|
||||
|
||||
|
||||
await self.storage.store_message(bot_message, self.chat_stream)
|
||||
|
||||
except Exception as e:
|
||||
@@ -198,12 +194,12 @@ class S4UChat:
|
||||
self.gpt = S4UStreamGenerator()
|
||||
self.gpt.chat_stream = self.chat_stream
|
||||
self.interest_dict: Dict[str, float] = {} # 用户兴趣分
|
||||
|
||||
self.internal_message :List[MessageRecvS4U] = []
|
||||
|
||||
|
||||
self.internal_message: List[MessageRecvS4U] = []
|
||||
|
||||
self.msg_id = ""
|
||||
self.voice_done = ""
|
||||
|
||||
|
||||
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
|
||||
|
||||
def _get_priority_info(self, message: MessageRecv) -> dict:
|
||||
@@ -226,7 +222,7 @@ class S4UChat:
|
||||
def _get_interest_score(self, user_id: str) -> float:
|
||||
"""获取用户的兴趣分,默认为1.0"""
|
||||
return self.interest_dict.get(user_id, 1.0)
|
||||
|
||||
|
||||
def go_processing(self):
|
||||
if self.voice_done == self.last_msg_id:
|
||||
return True
|
||||
@@ -237,14 +233,14 @@ class S4UChat:
|
||||
为消息计算基础优先级分数。分数越高,优先级越高。
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
|
||||
# 加上消息自带的优先级
|
||||
score += priority_info.get("message_priority", 0.0)
|
||||
|
||||
# 加上用户的固有兴趣分
|
||||
score += self._get_interest_score(message.message_info.user_info.user_id)
|
||||
return score
|
||||
|
||||
|
||||
def decay_interest_score(self):
|
||||
for person_id, score in self.interest_dict.items():
|
||||
if score > 0:
|
||||
@@ -252,15 +248,14 @@ class S4UChat:
|
||||
else:
|
||||
self.interest_dict[person_id] = 0
|
||||
|
||||
async def add_message(self, message: MessageRecvS4U|MessageRecv) -> None:
|
||||
|
||||
async def add_message(self, message: MessageRecvS4U | MessageRecv) -> None:
|
||||
self.decay_interest_score()
|
||||
|
||||
|
||||
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
|
||||
user_id = message.message_info.user_info.user_id
|
||||
platform = message.message_info.platform
|
||||
person_id = get_person_id(platform, user_id)
|
||||
|
||||
_person_id = get_person_id(platform, user_id)
|
||||
|
||||
# try:
|
||||
# is_gift = message.is_gift
|
||||
# is_superchat = message.is_superchat
|
||||
@@ -276,7 +271,7 @@ class S4UChat:
|
||||
# # 安全地增加兴趣分,如果person_id不存在则先初始化为1.0
|
||||
# current_score = self.interest_dict.get(person_id, 1.0)
|
||||
# self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price)
|
||||
|
||||
|
||||
# # 添加SuperChat到管理器
|
||||
# super_chat_manager = get_super_chat_manager()
|
||||
# await super_chat_manager.add_superchat(message)
|
||||
@@ -284,16 +279,19 @@ class S4UChat:
|
||||
# await self.relationship_builder.build_relation(20)
|
||||
# except Exception:
|
||||
# traceback.print_exc()
|
||||
|
||||
|
||||
logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}")
|
||||
|
||||
|
||||
priority_info = self._get_priority_info(message)
|
||||
is_vip = self._is_vip(priority_info)
|
||||
new_priority_score = self._calculate_base_priority_score(message, priority_info)
|
||||
|
||||
should_interrupt = False
|
||||
if (s4u_config.enable_message_interruption and
|
||||
self._current_generation_task and not self._current_generation_task.done()):
|
||||
if (
|
||||
s4u_config.enable_message_interruption
|
||||
and self._current_generation_task
|
||||
and not self._current_generation_task.done()
|
||||
):
|
||||
if self._current_message_being_replied:
|
||||
current_queue, current_priority, _, current_msg = self._current_message_being_replied
|
||||
|
||||
@@ -344,39 +342,45 @@ class S4UChat:
|
||||
"""清理普通队列中不在最近N条消息范围内的消息"""
|
||||
if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty():
|
||||
return
|
||||
|
||||
|
||||
# 计算阈值:保留最近 recent_message_keep_count 条消息
|
||||
cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count)
|
||||
|
||||
|
||||
# 临时存储需要保留的消息
|
||||
temp_messages = []
|
||||
removed_count = 0
|
||||
|
||||
|
||||
# 取出所有普通队列中的消息
|
||||
while not self._normal_queue.empty():
|
||||
try:
|
||||
item = self._normal_queue.get_nowait()
|
||||
neg_priority, entry_count, timestamp, message = item
|
||||
|
||||
|
||||
# 如果消息在最近N条消息范围内,保留它
|
||||
logger.info(f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}")
|
||||
|
||||
logger.info(
|
||||
f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}"
|
||||
)
|
||||
|
||||
if entry_count >= cutoff_counter:
|
||||
temp_messages.append(item)
|
||||
else:
|
||||
removed_count += 1
|
||||
self._normal_queue.task_done() # 标记被移除的任务为完成
|
||||
|
||||
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
|
||||
# 将保留的消息重新放入队列
|
||||
for item in temp_messages:
|
||||
self._normal_queue.put_nowait(item)
|
||||
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}条,现在counter:{self._entry_counter}被移除")
|
||||
logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range.")
|
||||
logger.info(
|
||||
f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}条,现在counter:{self._entry_counter}被移除"
|
||||
)
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range."
|
||||
)
|
||||
|
||||
async def _message_processor(self):
|
||||
"""调度器:优先处理VIP队列,然后处理普通队列。"""
|
||||
@@ -385,7 +389,7 @@ class S4UChat:
|
||||
# 等待有新消息的信号,避免空转
|
||||
await self._new_message_event.wait()
|
||||
self._new_message_event.clear()
|
||||
|
||||
|
||||
# 清理普通队列中的过旧消息
|
||||
self._cleanup_old_normal_messages()
|
||||
|
||||
@@ -396,7 +400,6 @@ class S4UChat:
|
||||
queue_name = "vip"
|
||||
# 其次处理普通队列
|
||||
elif not self._normal_queue.empty():
|
||||
|
||||
neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
|
||||
priority = -neg_priority
|
||||
# 检查普通消息是否超时
|
||||
@@ -411,13 +414,15 @@ class S4UChat:
|
||||
if self.internal_message:
|
||||
message = self.internal_message[-1]
|
||||
self.internal_message = []
|
||||
|
||||
|
||||
priority = 0
|
||||
neg_priority = 0
|
||||
entry_count = 0
|
||||
queue_name = "internal"
|
||||
|
||||
logger.info(f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}...")
|
||||
logger.info(
|
||||
f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}..."
|
||||
)
|
||||
else:
|
||||
continue # 没有消息了,回去等事件
|
||||
|
||||
@@ -457,23 +462,21 @@ class S4UChat:
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
|
||||
def get_processing_message_id(self):
|
||||
self.last_msg_id = self.msg_id
|
||||
self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}"
|
||||
|
||||
|
||||
async def _generate_and_send(self, message: MessageRecv):
|
||||
"""为单个消息生成文本回复。整个过程可以被中断。"""
|
||||
self._is_replying = True
|
||||
total_chars_sent = 0 # 跟踪发送的总字符数
|
||||
|
||||
|
||||
self.get_processing_message_id()
|
||||
|
||||
|
||||
# 视线管理:开始生成回复时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
|
||||
|
||||
|
||||
if message.is_internal:
|
||||
await chat_watching.on_internal_message_start()
|
||||
else:
|
||||
@@ -516,16 +519,19 @@ class S4UChat:
|
||||
total_chars_sent = len("麦麦不知道哦")
|
||||
|
||||
mood = mood_manager.get_mood_by_chat_id(self.stream_id)
|
||||
await yes_or_no_head(text = total_chars_sent,emotion = mood.mood_state,chat_history=message.processed_plain_text,chat_id=self.stream_id)
|
||||
await yes_or_no_head(
|
||||
text=total_chars_sent,
|
||||
emotion=mood.mood_state,
|
||||
chat_history=message.processed_plain_text,
|
||||
chat_id=self.stream_id,
|
||||
)
|
||||
|
||||
# 等待所有文本消息发送完成
|
||||
await sender_container.close()
|
||||
await sender_container.join()
|
||||
|
||||
|
||||
await chat_watching.on_thinking_finished()
|
||||
|
||||
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
logged = False
|
||||
while not self.go_processing():
|
||||
@@ -536,7 +542,7 @@ class S4UChat:
|
||||
logger.info(f"[{self.stream_name}] 等待消息发送完成...")
|
||||
logged = True
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
|
||||
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
@@ -548,11 +554,11 @@ class S4UChat:
|
||||
# 回复生成实时展示:清空内容(出错时)
|
||||
finally:
|
||||
self._is_replying = False
|
||||
|
||||
|
||||
# 视线管理:回复结束时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
|
||||
await chat_watching.on_reply_finished()
|
||||
|
||||
|
||||
# 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的)
|
||||
sender_container.resume()
|
||||
if not sender_container._task.done():
|
||||
@@ -576,4 +582,3 @@ class S4UChat:
|
||||
await self._processing_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"处理任务已成功取消: {self.stream_name}")
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
|
||||
if global_config.memory.enable_memory:
|
||||
with Timer("记忆激活"):
|
||||
interested_rate,_ ,_= await hippocampus_manager.get_activate_from_text(
|
||||
interested_rate, _, _ = await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
fast_retrieval=True,
|
||||
)
|
||||
@@ -49,7 +49,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
||||
# 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
|
||||
|
||||
|
||||
if text_len == 0:
|
||||
base_interest = 0.01 # 空消息最低兴趣度
|
||||
elif text_len <= 5:
|
||||
@@ -73,7 +73,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
else:
|
||||
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
|
||||
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
|
||||
|
||||
|
||||
# 确保在范围内
|
||||
base_interest = min(max(base_interest, 0.01), 0.3)
|
||||
|
||||
@@ -117,36 +117,32 @@ class S4UMessageProcessor:
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo,
|
||||
)
|
||||
|
||||
|
||||
if await self.handle_internal_message(message):
|
||||
return
|
||||
|
||||
|
||||
if await self.hadle_if_voice_done(message):
|
||||
return
|
||||
|
||||
|
||||
# 处理礼物消息,如果消息被暂存则停止当前处理流程
|
||||
if not skip_gift_debounce and not await self.handle_if_gift(message):
|
||||
return
|
||||
await self.check_if_fake_gift(message)
|
||||
|
||||
|
||||
# 处理屏幕消息
|
||||
if await self.handle_screen_message(message):
|
||||
return
|
||||
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
||||
|
||||
|
||||
await s4u_chat.add_message(message)
|
||||
|
||||
_interested_rate, _ = await _calculate_interest(message)
|
||||
|
||||
|
||||
await mood_manager.start()
|
||||
|
||||
|
||||
|
||||
# 一系列llm驱动的前处理
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message))
|
||||
@@ -164,61 +160,56 @@ class S4UMessageProcessor:
|
||||
logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}")
|
||||
else:
|
||||
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||
|
||||
|
||||
async def handle_internal_message(self, message: MessageRecvS4U):
|
||||
if message.is_internal:
|
||||
|
||||
group_info = GroupInfo(platform = "amaidesu_default",group_id = 660154,group_name = "内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform = "amaidesu_default",
|
||||
user_info = message.message_info.user_info,
|
||||
group_info = group_info
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id=660154, group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default", user_info=message.message_info.user_info, group_info=group_info
|
||||
)
|
||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
||||
message.message_info.group_info = s4u_chat.chat_stream.group_info
|
||||
message.message_info.platform = s4u_chat.chat_stream.platform
|
||||
|
||||
|
||||
|
||||
s4u_chat.internal_message.append(message)
|
||||
s4u_chat._new_message_event.set()
|
||||
|
||||
|
||||
logger.info(f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}")
|
||||
|
||||
|
||||
|
||||
logger.info(
|
||||
f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}"
|
||||
)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
||||
async def handle_screen_message(self, message: MessageRecvS4U):
|
||||
if message.is_screen:
|
||||
screen_manager.set_screen(message.screen_info)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def hadle_if_voice_done(self, message: MessageRecvS4U):
|
||||
if message.voice_done:
|
||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(message.chat_stream)
|
||||
s4u_chat.voice_done = message.voice_done
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def check_if_fake_gift(self, message: MessageRecvS4U) -> bool:
|
||||
"""检查消息是否为假礼物"""
|
||||
if message.is_gift:
|
||||
return False
|
||||
|
||||
gift_keywords = ["送出了礼物", "礼物", "送出了","投喂"]
|
||||
|
||||
gift_keywords = ["送出了礼物", "礼物", "送出了", "投喂"]
|
||||
if any(keyword in message.processed_plain_text for keyword in gift_keywords):
|
||||
message.is_fake_gift = True
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def handle_if_gift(self, message: MessageRecvS4U) -> bool:
|
||||
"""处理礼物消息
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True表示应该继续处理消息,False表示消息已被暂存不需要继续处理
|
||||
"""
|
||||
@@ -228,37 +219,37 @@ class S4UMessageProcessor:
|
||||
"""礼物防抖完成后的回调"""
|
||||
# 创建异步任务来处理合并后的礼物消息,跳过防抖处理
|
||||
asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True))
|
||||
|
||||
|
||||
# 交给礼物管理器处理,并传入回调函数
|
||||
# 对于礼物消息,handle_gift 总是返回 False(消息被暂存)
|
||||
await gift_manager.handle_gift(message, gift_callback)
|
||||
return False # 消息被暂存,不继续处理
|
||||
|
||||
|
||||
return True # 非礼物消息,继续正常处理
|
||||
|
||||
async def _handle_context_web_update(self, chat_id: str, message: MessageRecv):
|
||||
"""处理上下文网页更新的独立task
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
message: 消息对象
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"🔄 开始处理上下文网页更新: {message.message_info.user_info.user_nickname}")
|
||||
|
||||
|
||||
context_manager = get_context_web_manager()
|
||||
|
||||
|
||||
# 只在服务器未启动时启动(避免重复启动)
|
||||
if context_manager.site is None:
|
||||
logger.info("🚀 首次启动上下文网页服务器...")
|
||||
await context_manager.start_server()
|
||||
|
||||
|
||||
# 添加消息到上下文并更新网页
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
|
||||
await context_manager.add_message(chat_id, message)
|
||||
|
||||
|
||||
logger.debug(f"✅ 上下文网页更新完成: {message.message_info.user_info.user_nickname}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理上下文网页更新失败: {e}", exc_info=True)
|
||||
|
||||
@@ -176,7 +176,7 @@ class PromptBuilder:
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
|
||||
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
|
||||
limit=300,
|
||||
)
|
||||
|
||||
@@ -228,13 +228,17 @@ class PromptBuilder:
|
||||
last_speaking_user_id = start_speaking_user_id
|
||||
msg_seg_str = "对方的发言:\n"
|
||||
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
|
||||
msg_seg_str += (
|
||||
f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
|
||||
)
|
||||
|
||||
all_msg_seg_list = []
|
||||
for msg in core_dialogue_list[1:]:
|
||||
speaker = msg.user_info.user_id
|
||||
if speaker == last_speaking_user_id:
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
|
||||
msg_seg_str += (
|
||||
f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
|
||||
)
|
||||
else:
|
||||
msg_seg_str = f"{msg_seg_str}\n"
|
||||
all_msg_seg_list.append(msg_seg_str)
|
||||
|
||||
@@ -14,11 +14,8 @@ logger = get_logger("s4u_stream_generator")
|
||||
class S4UStreamGenerator:
|
||||
def __init__(self):
|
||||
# 使用LLMRequest替代AsyncOpenAIClient
|
||||
self.llm_request = LLMRequest(
|
||||
model_set=model_config.model_task_config.replyer,
|
||||
request_type="s4u_replyer"
|
||||
)
|
||||
|
||||
self.llm_request = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="s4u_replyer")
|
||||
|
||||
self.current_model_name = "unknown model"
|
||||
self.partial_response = ""
|
||||
|
||||
@@ -89,16 +86,16 @@ class S4UStreamGenerator:
|
||||
|
||||
async def _generate_response_with_llm_request(self, prompt: str) -> AsyncGenerator[str, None]:
|
||||
"""使用LLMRequest进行流式响应生成"""
|
||||
|
||||
|
||||
# 构建消息
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
messages = [message_builder.build()]
|
||||
|
||||
|
||||
# 选择模型
|
||||
model_info, api_provider, client = self.llm_request._select_model()
|
||||
self.current_model_name = model_info.name
|
||||
|
||||
|
||||
# 如果模型支持强制流式模式,使用真正的流式处理
|
||||
if model_info.force_stream_mode:
|
||||
# 简化流式处理:直接使用LLMRequest的流式功能
|
||||
@@ -111,14 +108,14 @@ class S4UStreamGenerator:
|
||||
model_info=model_info,
|
||||
message_list=messages,
|
||||
)
|
||||
|
||||
|
||||
# 处理响应内容
|
||||
content = response.content or ""
|
||||
if content:
|
||||
# 将内容按句子分割并输出
|
||||
async for chunk in self._process_content_streaming(content):
|
||||
yield chunk
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式请求执行失败: {e}")
|
||||
# 如果流式请求失败,回退到普通模式
|
||||
@@ -132,7 +129,7 @@ class S4UStreamGenerator:
|
||||
content = response.content or ""
|
||||
async for chunk in self._process_content_streaming(content):
|
||||
yield chunk
|
||||
|
||||
|
||||
else:
|
||||
# 如果不支持流式,使用普通方式然后模拟流式输出
|
||||
response = await self.llm_request._execute_request(
|
||||
@@ -142,7 +139,7 @@ class S4UStreamGenerator:
|
||||
model_info=model_info,
|
||||
message_list=messages,
|
||||
)
|
||||
|
||||
|
||||
content = response.content or ""
|
||||
async for chunk in self._process_content_streaming(content):
|
||||
yield chunk
|
||||
@@ -163,7 +160,7 @@ class S4UStreamGenerator:
|
||||
"""处理内容进行流式输出(用于非流式模型的模拟流式输出)"""
|
||||
buffer = content
|
||||
punctuation_buffer = ""
|
||||
|
||||
|
||||
# 使用正则表达式匹配句子
|
||||
last_match_end = 0
|
||||
for match in self.sentence_split_pattern.finditer(buffer):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
@@ -47,6 +46,7 @@ HEAD_CODE = {
|
||||
"看向正前方": "(0,0,0)",
|
||||
}
|
||||
|
||||
|
||||
class ChatWatching:
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id: str = chat_id
|
||||
@@ -56,13 +56,13 @@ class ChatWatching:
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="start_thinking", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
|
||||
async def on_reply_finished(self):
|
||||
"""生成回复完毕时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="finish_reply", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
|
||||
async def on_thinking_finished(self):
|
||||
"""思考完毕时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
@@ -74,14 +74,14 @@ class ChatWatching:
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="start_viewing", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def on_internal_message_start(self):
|
||||
"""收到消息时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="start_internal_thinking", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
|
||||
class WatchingManager:
|
||||
def __init__(self):
|
||||
self.watching_list: list[ChatWatching] = []
|
||||
@@ -100,6 +100,7 @@ class WatchingManager:
|
||||
|
||||
return new_watching
|
||||
|
||||
|
||||
# 全局视线管理器实例
|
||||
watching_manager = WatchingManager()
|
||||
"""全局视线管理器"""
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
class ScreenManager:
|
||||
def __init__(self):
|
||||
self.now_screen = str()
|
||||
|
||||
def set_screen(self,screen_str:str):
|
||||
|
||||
def set_screen(self, screen_str: str):
|
||||
self.now_screen = screen_str
|
||||
|
||||
|
||||
def get_screen(self):
|
||||
return self.now_screen
|
||||
|
||||
|
||||
def get_screen_str(self):
|
||||
return f"你可以看见面前的屏幕,目前屏幕的内容是:现在千石可乐在和你一起直播,这是他正在操作的屏幕内容:{self.now_screen}"
|
||||
|
||||
screen_manager = ScreenManager()
|
||||
|
||||
screen_manager = ScreenManager()
|
||||
|
||||
@@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
|
||||
# 全局SuperChat管理器实例
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
|
||||
@@ -13,7 +14,7 @@ logger = get_logger("super_chat_manager")
|
||||
@dataclass
|
||||
class SuperChatRecord:
|
||||
"""SuperChat记录数据类"""
|
||||
|
||||
|
||||
user_id: str
|
||||
user_nickname: str
|
||||
platform: str
|
||||
@@ -23,15 +24,15 @@ class SuperChatRecord:
|
||||
timestamp: float
|
||||
expire_time: float
|
||||
group_name: Optional[str] = None
|
||||
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""检查SuperChat是否已过期"""
|
||||
return time.time() > self.expire_time
|
||||
|
||||
|
||||
def remaining_time(self) -> float:
|
||||
"""获取剩余时间(秒)"""
|
||||
return max(0, self.expire_time - time.time())
|
||||
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
@@ -44,19 +45,19 @@ class SuperChatRecord:
|
||||
"timestamp": self.timestamp,
|
||||
"expire_time": self.expire_time,
|
||||
"group_name": self.group_name,
|
||||
"remaining_time": self.remaining_time()
|
||||
"remaining_time": self.remaining_time(),
|
||||
}
|
||||
|
||||
|
||||
class SuperChatManager:
|
||||
"""SuperChat管理器,负责管理和跟踪SuperChat消息"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.super_chats: Dict[str, List[SuperChatRecord]] = {} # chat_id -> SuperChat列表
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._is_initialized = False
|
||||
logger.info("SuperChat管理器已初始化")
|
||||
|
||||
|
||||
def _ensure_cleanup_task_started(self):
|
||||
"""确保清理任务已启动(延迟启动)"""
|
||||
if self._cleanup_task is None or self._cleanup_task.done():
|
||||
@@ -68,7 +69,7 @@ class SuperChatManager:
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,稍后再启动
|
||||
logger.debug("当前没有运行的事件循环,将在需要时启动清理任务")
|
||||
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""启动清理任务(已弃用,保留向后兼容)"""
|
||||
self._ensure_cleanup_task_started()
|
||||
@@ -78,39 +79,36 @@ class SuperChatManager:
|
||||
while True:
|
||||
try:
|
||||
total_removed = 0
|
||||
|
||||
|
||||
for chat_id in list(self.super_chats.keys()):
|
||||
original_count = len(self.super_chats[chat_id])
|
||||
# 移除过期的SuperChat
|
||||
self.super_chats[chat_id] = [
|
||||
sc for sc in self.super_chats[chat_id]
|
||||
if not sc.is_expired()
|
||||
]
|
||||
|
||||
self.super_chats[chat_id] = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
|
||||
|
||||
removed_count = original_count - len(self.super_chats[chat_id])
|
||||
total_removed += removed_count
|
||||
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"从聊天 {chat_id} 中清理了 {removed_count} 个过期的SuperChat")
|
||||
|
||||
|
||||
# 如果列表为空,删除该聊天的记录
|
||||
if not self.super_chats[chat_id]:
|
||||
del self.super_chats[chat_id]
|
||||
|
||||
|
||||
if total_removed > 0:
|
||||
logger.info(f"总共清理了 {total_removed} 个过期的SuperChat")
|
||||
|
||||
|
||||
# 每30秒检查一次
|
||||
await asyncio.sleep(30)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True)
|
||||
await asyncio.sleep(60) # 出错时等待更长时间
|
||||
|
||||
|
||||
def _calculate_expire_time(self, price: float) -> float:
|
||||
"""根据SuperChat金额计算过期时间"""
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# 根据金额阶梯设置不同的存活时间
|
||||
if price >= 500:
|
||||
# 500元以上:保持4小时
|
||||
@@ -133,27 +131,27 @@ class SuperChatManager:
|
||||
else:
|
||||
# 10元以下:保持5分钟
|
||||
duration = 5 * 60
|
||||
|
||||
|
||||
return current_time + duration
|
||||
|
||||
|
||||
async def add_superchat(self, message: MessageRecvS4U) -> None:
|
||||
"""添加新的SuperChat记录"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
|
||||
if not message.is_superchat or not message.superchat_price:
|
||||
logger.warning("尝试添加非SuperChat消息到SuperChat管理器")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
price = float(message.superchat_price)
|
||||
except (ValueError, TypeError):
|
||||
logger.error(f"无效的SuperChat价格: {message.superchat_price}")
|
||||
return
|
||||
|
||||
|
||||
user_info = message.message_info.user_info
|
||||
group_info = message.message_info.group_info
|
||||
chat_id = getattr(message, 'chat_stream', None)
|
||||
chat_id = getattr(message, "chat_stream", None)
|
||||
if chat_id:
|
||||
chat_id = chat_id.stream_id
|
||||
else:
|
||||
@@ -161,9 +159,9 @@ class SuperChatManager:
|
||||
chat_id = f"{message.message_info.platform}_{user_info.user_id}"
|
||||
if group_info:
|
||||
chat_id = f"{message.message_info.platform}_{group_info.group_id}"
|
||||
|
||||
|
||||
expire_time = self._calculate_expire_time(price)
|
||||
|
||||
|
||||
record = SuperChatRecord(
|
||||
user_id=user_info.user_id,
|
||||
user_nickname=user_info.user_nickname,
|
||||
@@ -173,44 +171,44 @@ class SuperChatManager:
|
||||
message_text=message.superchat_message_text or "",
|
||||
timestamp=message.message_info.time,
|
||||
expire_time=expire_time,
|
||||
group_name=group_info.group_name if group_info else None
|
||||
group_name=group_info.group_name if group_info else None,
|
||||
)
|
||||
|
||||
|
||||
# 添加到对应聊天的SuperChat列表
|
||||
if chat_id not in self.super_chats:
|
||||
self.super_chats[chat_id] = []
|
||||
|
||||
|
||||
self.super_chats[chat_id].append(record)
|
||||
|
||||
|
||||
# 按价格降序排序(价格高的在前)
|
||||
self.super_chats[chat_id].sort(key=lambda x: x.price, reverse=True)
|
||||
|
||||
|
||||
logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}")
|
||||
|
||||
|
||||
def get_superchats_by_chat(self, chat_id: str) -> List[SuperChatRecord]:
|
||||
"""获取指定聊天的所有有效SuperChat"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
|
||||
if chat_id not in self.super_chats:
|
||||
return []
|
||||
|
||||
|
||||
# 过滤掉过期的SuperChat
|
||||
valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
|
||||
return valid_superchats
|
||||
|
||||
|
||||
def get_all_valid_superchats(self) -> Dict[str, List[SuperChatRecord]]:
|
||||
"""获取所有有效的SuperChat"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
|
||||
result = {}
|
||||
for chat_id, superchats in self.super_chats.items():
|
||||
valid_superchats = [sc for sc in superchats if not sc.is_expired()]
|
||||
if valid_superchats:
|
||||
result[chat_id] = valid_superchats
|
||||
return result
|
||||
|
||||
|
||||
def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str:
|
||||
"""构建SuperChat显示字符串"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
@@ -226,7 +224,9 @@ class SuperChatManager:
|
||||
remaining_minutes = int(sc.remaining_time() / 60)
|
||||
remaining_seconds = int(sc.remaining_time() % 60)
|
||||
|
||||
time_display = f"{remaining_minutes}分{remaining_seconds}秒" if remaining_minutes > 0 else f"{remaining_seconds}秒"
|
||||
time_display = (
|
||||
f"{remaining_minutes}分{remaining_seconds}秒" if remaining_minutes > 0 else f"{remaining_seconds}秒"
|
||||
)
|
||||
|
||||
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
|
||||
if len(line) > 100: # 限制单行长度
|
||||
@@ -238,7 +238,7 @@ class SuperChatManager:
|
||||
lines.append(f"... 还有{len(superchats) - max_count}条SuperChat")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def build_superchat_summary_string(self, chat_id: str) -> str:
|
||||
"""构建SuperChat摘要字符串"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
@@ -261,30 +261,24 @@ class SuperChatManager:
|
||||
if lines:
|
||||
final_str += "\n" + "\n".join(lines)
|
||||
return final_str
|
||||
|
||||
|
||||
def get_superchat_statistics(self, chat_id: str) -> dict:
|
||||
"""获取SuperChat统计信息"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
|
||||
|
||||
if not superchats:
|
||||
return {
|
||||
"count": 0,
|
||||
"total_amount": 0,
|
||||
"average_amount": 0,
|
||||
"highest_amount": 0,
|
||||
"lowest_amount": 0
|
||||
}
|
||||
|
||||
return {"count": 0, "total_amount": 0, "average_amount": 0, "highest_amount": 0, "lowest_amount": 0}
|
||||
|
||||
amounts = [sc.price for sc in superchats]
|
||||
|
||||
|
||||
return {
|
||||
"count": len(superchats),
|
||||
"total_amount": sum(amounts),
|
||||
"average_amount": sum(amounts) / len(amounts),
|
||||
"highest_amount": max(amounts),
|
||||
"lowest_amount": min(amounts)
|
||||
"lowest_amount": min(amounts),
|
||||
}
|
||||
|
||||
|
||||
async def shutdown(self): # sourcery skip: use-contextlib-suppress
|
||||
"""关闭管理器,清理资源"""
|
||||
if self._cleanup_task and not self._cleanup_task.done():
|
||||
@@ -296,15 +290,14 @@ class SuperChatManager:
|
||||
logger.info("SuperChat管理器已关闭")
|
||||
|
||||
|
||||
|
||||
|
||||
# sourcery skip: assign-if-exp
|
||||
if s4u_config.enable_s4u:
|
||||
super_chat_manager = SuperChatManager()
|
||||
else:
|
||||
super_chat_manager = None
|
||||
|
||||
|
||||
def get_super_chat_manager() -> SuperChatManager:
|
||||
"""获取全局SuperChat管理器实例"""
|
||||
|
||||
return super_chat_manager
|
||||
return super_chat_manager
|
||||
|
||||
@@ -10,10 +10,12 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("s4u_config")
|
||||
|
||||
|
||||
# 新增:兼容dict和tomlkit Table
|
||||
def is_dict_like(obj):
|
||||
return isinstance(obj, (dict, Table))
|
||||
|
||||
|
||||
# 新增:递归将Table转为dict
|
||||
def table_to_dict(obj):
|
||||
if isinstance(obj, Table):
|
||||
@@ -25,6 +27,7 @@ def table_to_dict(obj):
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
# 获取mais4u模块目录
|
||||
MAIS4U_ROOT = os.path.dirname(__file__)
|
||||
CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config")
|
||||
@@ -190,7 +193,7 @@ class S4UModelConfig(S4UConfigBase):
|
||||
@dataclass
|
||||
class S4UConfig(S4UConfigBase):
|
||||
"""S4U聊天系统配置类"""
|
||||
|
||||
|
||||
enable_s4u: bool = False
|
||||
"""是否启用S4U聊天系统"""
|
||||
|
||||
@@ -229,12 +232,12 @@ class S4UConfig(S4UConfigBase):
|
||||
|
||||
enable_streaming_output: bool = True
|
||||
"""是否启用流式输出,false时全部生成后一次性发送"""
|
||||
|
||||
|
||||
max_context_message_length: int = 20
|
||||
"""上下文消息最大长度"""
|
||||
|
||||
|
||||
max_core_message_length: int = 30
|
||||
"""核心消息最大长度"""
|
||||
"""核心消息最大长度"""
|
||||
|
||||
# 模型配置
|
||||
models: S4UModelConfig = field(default_factory=S4UModelConfig)
|
||||
@@ -243,7 +246,6 @@ class S4UConfig(S4UConfigBase):
|
||||
# 兼容性字段,保持向后兼容
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UGlobalConfig(S4UConfigBase):
|
||||
"""S4U总配置类"""
|
||||
@@ -256,7 +258,7 @@ def update_s4u_config():
|
||||
"""更新S4U配置文件"""
|
||||
# 创建配置目录(如果不存在)
|
||||
os.makedirs(CONFIG_DIR, exist_ok=True)
|
||||
|
||||
|
||||
# 检查模板文件是否存在
|
||||
if not os.path.exists(TEMPLATE_PATH):
|
||||
logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
|
||||
@@ -354,13 +356,13 @@ def load_s4u_config(config_path: str) -> S4UGlobalConfig:
|
||||
logger.critical("S4U配置文件解析失败")
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
# 初始化S4U配置
|
||||
|
||||
|
||||
logger.info(f"S4U当前版本: {S4U_VERSION}")
|
||||
update_s4u_config()
|
||||
|
||||
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
|
||||
logger.info("S4U配置文件加载完成!")
|
||||
|
||||
s4u_config: S4UConfig = s4u_config_main.s4u
|
||||
s4u_config: S4UConfig = s4u_config_main.s4u
|
||||
|
||||
@@ -13,7 +13,7 @@ async def migrate_memory_items_to_string():
|
||||
并根据原始list的项目数量设置weight值
|
||||
"""
|
||||
logger.info("开始迁移记忆节点格式...")
|
||||
|
||||
|
||||
migration_stats = {
|
||||
"total_nodes": 0,
|
||||
"converted_nodes": 0,
|
||||
@@ -21,72 +21,74 @@ async def migrate_memory_items_to_string():
|
||||
"empty_nodes": 0,
|
||||
"error_nodes": 0,
|
||||
"weight_updated_nodes": 0,
|
||||
"truncated_nodes": 0
|
||||
"truncated_nodes": 0,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# 获取所有图节点
|
||||
all_nodes = GraphNodes.select()
|
||||
migration_stats["total_nodes"] = all_nodes.count()
|
||||
|
||||
|
||||
logger.info(f"找到 {migration_stats['total_nodes']} 个记忆节点")
|
||||
|
||||
|
||||
for node in all_nodes:
|
||||
try:
|
||||
concept = node.concept
|
||||
memory_items_raw = node.memory_items.strip() if node.memory_items else ""
|
||||
original_weight = node.weight if hasattr(node, 'weight') and node.weight is not None else 1.0
|
||||
|
||||
original_weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
|
||||
|
||||
# 如果为空,跳过
|
||||
if not memory_items_raw:
|
||||
migration_stats["empty_nodes"] += 1
|
||||
logger.debug(f"跳过空节点: {concept}")
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# 尝试解析JSON
|
||||
parsed_data = json.loads(memory_items_raw)
|
||||
|
||||
|
||||
if isinstance(parsed_data, list):
|
||||
# 如果是list格式,需要转换
|
||||
if parsed_data:
|
||||
# 转换为字符串格式
|
||||
new_memory_items = " | ".join(str(item) for item in parsed_data)
|
||||
original_length = len(new_memory_items)
|
||||
|
||||
|
||||
# 检查长度并截断
|
||||
if len(new_memory_items) > 100:
|
||||
new_memory_items = new_memory_items[:100]
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
logger.debug(f"节点 '{concept}' 内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
|
||||
|
||||
new_weight = float(len(parsed_data)) # weight = list项目数量
|
||||
|
||||
|
||||
# 更新数据库
|
||||
node.memory_items = new_memory_items
|
||||
node.weight = new_weight
|
||||
node.save()
|
||||
|
||||
|
||||
migration_stats["converted_nodes"] += 1
|
||||
migration_stats["weight_updated_nodes"] += 1
|
||||
|
||||
|
||||
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
|
||||
logger.info(f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}")
|
||||
logger.info(
|
||||
f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}"
|
||||
)
|
||||
else:
|
||||
# 空list,设置为空字符串
|
||||
node.memory_items = ""
|
||||
node.weight = 1.0
|
||||
node.save()
|
||||
|
||||
|
||||
migration_stats["converted_nodes"] += 1
|
||||
logger.debug(f"转换空list节点: {concept}")
|
||||
|
||||
|
||||
elif isinstance(parsed_data, str):
|
||||
# 已经是字符串格式,检查长度和weight
|
||||
current_content = parsed_data
|
||||
original_length = len(current_content)
|
||||
content_truncated = False
|
||||
|
||||
|
||||
# 检查长度并截断
|
||||
if len(current_content) > 100:
|
||||
current_content = current_content[:100]
|
||||
@@ -94,19 +96,21 @@ async def migrate_memory_items_to_string():
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
node.memory_items = current_content
|
||||
logger.debug(f"节点 '{concept}' 字符串内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
|
||||
|
||||
# 检查weight是否需要更新
|
||||
update_needed = False
|
||||
if original_weight == 1.0:
|
||||
# 如果weight还是默认值,可以根据内容复杂度估算
|
||||
content_parts = current_content.split(" | ") if " | " in current_content else [current_content]
|
||||
content_parts = (
|
||||
current_content.split(" | ") if " | " in current_content else [current_content]
|
||||
)
|
||||
estimated_weight = max(1.0, float(len(content_parts)))
|
||||
|
||||
|
||||
if estimated_weight != original_weight:
|
||||
node.weight = estimated_weight
|
||||
update_needed = True
|
||||
logger.debug(f"更新字符串节点权重 '{concept}': {original_weight} -> {estimated_weight}")
|
||||
|
||||
|
||||
# 如果内容被截断或权重需要更新,保存到数据库
|
||||
if content_truncated or update_needed:
|
||||
node.save()
|
||||
@@ -118,26 +122,26 @@ async def migrate_memory_items_to_string():
|
||||
migration_stats["already_string_nodes"] += 1
|
||||
else:
|
||||
migration_stats["already_string_nodes"] += 1
|
||||
|
||||
|
||||
else:
|
||||
# 其他JSON类型,转换为字符串
|
||||
new_memory_items = str(parsed_data) if parsed_data else ""
|
||||
original_length = len(new_memory_items)
|
||||
|
||||
|
||||
# 检查长度并截断
|
||||
if len(new_memory_items) > 100:
|
||||
new_memory_items = new_memory_items[:100]
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
logger.debug(f"节点 '{concept}' 其他类型内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
|
||||
|
||||
node.memory_items = new_memory_items
|
||||
node.weight = 1.0
|
||||
node.save()
|
||||
|
||||
|
||||
migration_stats["converted_nodes"] += 1
|
||||
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
|
||||
logger.debug(f"转换其他类型节点: {concept}{length_info}")
|
||||
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# 不是JSON格式,假设已经是纯字符串
|
||||
# 检查是否是带引号的字符串
|
||||
@@ -145,16 +149,16 @@ async def migrate_memory_items_to_string():
|
||||
# 去掉引号
|
||||
clean_content = memory_items_raw[1:-1]
|
||||
original_length = len(clean_content)
|
||||
|
||||
|
||||
# 检查长度并截断
|
||||
if len(clean_content) > 100:
|
||||
clean_content = clean_content[:100]
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
logger.debug(f"节点 '{concept}' 去引号内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
|
||||
|
||||
node.memory_items = clean_content
|
||||
node.save()
|
||||
|
||||
|
||||
migration_stats["converted_nodes"] += 1
|
||||
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
|
||||
logger.debug(f"去除引号节点: {concept}{length_info}")
|
||||
@@ -162,29 +166,29 @@ async def migrate_memory_items_to_string():
|
||||
# 已经是纯字符串格式,检查长度
|
||||
current_content = memory_items_raw
|
||||
original_length = len(current_content)
|
||||
|
||||
|
||||
# 检查长度并截断
|
||||
if len(current_content) > 100:
|
||||
current_content = current_content[:100]
|
||||
node.memory_items = current_content
|
||||
node.save()
|
||||
|
||||
|
||||
migration_stats["converted_nodes"] += 1 # 算作转换节点
|
||||
migration_stats["truncated_nodes"] += 1
|
||||
logger.debug(f"节点 '{concept}' 纯字符串内容过长,从 {original_length} 字符截断到 100 字符")
|
||||
else:
|
||||
migration_stats["already_string_nodes"] += 1
|
||||
logger.debug(f"已是字符串格式节点: {concept}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
migration_stats["error_nodes"] += 1
|
||||
logger.error(f"处理节点 {concept} 时发生错误: {e}")
|
||||
continue
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移过程中发生严重错误: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# 输出迁移统计
|
||||
logger.info("=== 记忆节点迁移完成 ===")
|
||||
logger.info(f"总节点数: {migration_stats['total_nodes']}")
|
||||
@@ -194,101 +198,105 @@ async def migrate_memory_items_to_string():
|
||||
logger.info(f"错误节点: {migration_stats['error_nodes']}")
|
||||
logger.info(f"权重更新节点: {migration_stats['weight_updated_nodes']}")
|
||||
logger.info(f"内容截断节点: {migration_stats['truncated_nodes']}")
|
||||
|
||||
success_rate = (migration_stats['converted_nodes'] + migration_stats['already_string_nodes']) / migration_stats['total_nodes'] * 100 if migration_stats['total_nodes'] > 0 else 0
|
||||
|
||||
success_rate = (
|
||||
(migration_stats["converted_nodes"] + migration_stats["already_string_nodes"])
|
||||
/ migration_stats["total_nodes"]
|
||||
* 100
|
||||
if migration_stats["total_nodes"] > 0
|
||||
else 0
|
||||
)
|
||||
logger.info(f"迁移成功率: {success_rate:.1f}%")
|
||||
|
||||
|
||||
return migration_stats
|
||||
|
||||
|
||||
|
||||
|
||||
async def set_all_person_known():
|
||||
"""
|
||||
将person_info库中所有记录的is_known字段设置为True
|
||||
在设置之前,先清理掉user_id或platform为空的记录
|
||||
"""
|
||||
logger.info("开始设置所有person_info记录为已认识...")
|
||||
|
||||
|
||||
try:
|
||||
from src.common.database.database_model import PersonInfo
|
||||
|
||||
|
||||
# 获取所有PersonInfo记录
|
||||
all_persons = PersonInfo.select()
|
||||
total_count = all_persons.count()
|
||||
|
||||
|
||||
logger.info(f"找到 {total_count} 个人员记录")
|
||||
|
||||
|
||||
if total_count == 0:
|
||||
logger.info("没有找到任何人员记录")
|
||||
return {"total": 0, "deleted": 0, "updated": 0, "known_count": 0}
|
||||
|
||||
|
||||
# 删除user_id或platform为空的记录
|
||||
deleted_count = 0
|
||||
invalid_records = PersonInfo.select().where(
|
||||
(PersonInfo.user_id.is_null()) |
|
||||
(PersonInfo.user_id == '') |
|
||||
(PersonInfo.platform.is_null()) |
|
||||
(PersonInfo.platform == '')
|
||||
(PersonInfo.user_id.is_null())
|
||||
| (PersonInfo.user_id == "")
|
||||
| (PersonInfo.platform.is_null())
|
||||
| (PersonInfo.platform == "")
|
||||
)
|
||||
|
||||
|
||||
# 记录要删除的记录信息
|
||||
for record in invalid_records:
|
||||
user_id_info = f"'{record.user_id}'" if record.user_id else "NULL"
|
||||
platform_info = f"'{record.platform}'" if record.platform else "NULL"
|
||||
person_name_info = f"'{record.person_name}'" if record.person_name else "无名称"
|
||||
logger.debug(f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}")
|
||||
|
||||
logger.debug(
|
||||
f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}"
|
||||
)
|
||||
|
||||
# 执行删除操作
|
||||
deleted_count = PersonInfo.delete().where(
|
||||
(PersonInfo.user_id.is_null()) |
|
||||
(PersonInfo.user_id == '') |
|
||||
(PersonInfo.platform.is_null()) |
|
||||
(PersonInfo.platform == '')
|
||||
).execute()
|
||||
|
||||
deleted_count = (
|
||||
PersonInfo.delete()
|
||||
.where(
|
||||
(PersonInfo.user_id.is_null())
|
||||
| (PersonInfo.user_id == "")
|
||||
| (PersonInfo.platform.is_null())
|
||||
| (PersonInfo.platform == "")
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"删除了 {deleted_count} 个user_id或platform为空的记录")
|
||||
else:
|
||||
logger.info("没有发现user_id或platform为空的记录")
|
||||
|
||||
|
||||
# 重新获取剩余记录数量
|
||||
remaining_count = PersonInfo.select().count()
|
||||
logger.info(f"清理后剩余 {remaining_count} 个有效记录")
|
||||
|
||||
|
||||
if remaining_count == 0:
|
||||
logger.info("清理后没有剩余记录")
|
||||
return {"total": total_count, "deleted": deleted_count, "updated": 0, "known_count": 0}
|
||||
|
||||
|
||||
# 批量更新剩余记录的is_known字段为True
|
||||
updated_count = PersonInfo.update(is_known=True).execute()
|
||||
|
||||
|
||||
logger.info(f"成功更新 {updated_count} 个人员记录的is_known字段为True")
|
||||
|
||||
|
||||
# 验证更新结果
|
||||
known_count = PersonInfo.select().where(PersonInfo.is_known).count()
|
||||
|
||||
result = {
|
||||
"total": total_count,
|
||||
"deleted": deleted_count,
|
||||
"updated": updated_count,
|
||||
"known_count": known_count
|
||||
}
|
||||
|
||||
|
||||
result = {"total": total_count, "deleted": deleted_count, "updated": updated_count, "known_count": known_count}
|
||||
|
||||
logger.info("=== person_info更新完成 ===")
|
||||
logger.info(f"原始记录数: {result['total']}")
|
||||
logger.info(f"删除记录数: {result['deleted']}")
|
||||
logger.info(f"更新记录数: {result['updated']}")
|
||||
logger.info(f"已认识记录数: {result['known_count']}")
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新person_info过程中发生错误: {e}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
async def check_and_run_migrations():
|
||||
# 获取根目录
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
@@ -309,4 +317,3 @@ async def check_and_run_migrations():
|
||||
# 创建done.mem文件
|
||||
with open(done_file, "w", encoding="utf-8") as f:
|
||||
f.write("done")
|
||||
|
||||
@@ -282,7 +282,7 @@ class Person:
|
||||
|
||||
memory_category = parts[0].strip()
|
||||
memory_text = parts[1].strip()
|
||||
memory_weight = parts[2].strip()
|
||||
_memory_weight = parts[2].strip()
|
||||
|
||||
# 检查分类是否匹配
|
||||
if memory_category != category:
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
import json
|
||||
from json_repair import repair_json
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from .person_info import Person
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
|
||||
|
||||
logger = get_logger("relation")
|
||||
@@ -43,4 +37,3 @@ def init_prompt():
|
||||
""",
|
||||
"attitude_to_me_prompt",
|
||||
)
|
||||
|
||||
|
||||
@@ -121,5 +121,5 @@ __all__ = [
|
||||
"DatabaseChatInfo",
|
||||
"TargetPersonInfo",
|
||||
"ActionPlannerInfo",
|
||||
"LLMGenerationDataModel"
|
||||
"LLMGenerationDataModel",
|
||||
]
|
||||
|
||||
@@ -7,22 +7,24 @@ logger = get_logger("frequency_api")
|
||||
def get_current_focus_value(chat_id: str) -> float:
|
||||
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_focus_value()
|
||||
|
||||
|
||||
def get_current_talk_frequency(chat_id: str) -> float:
|
||||
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_talk_frequency()
|
||||
|
||||
|
||||
def set_focus_value_adjust(chat_id: str, focus_value_adjust: float) -> None:
|
||||
frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust = focus_value_adjust
|
||||
|
||||
|
||||
|
||||
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
|
||||
frequency_control_manager.get_or_create_frequency_control(chat_id).talk_frequency_external_adjust = talk_frequency_adjust
|
||||
frequency_control_manager.get_or_create_frequency_control(
|
||||
chat_id
|
||||
).talk_frequency_external_adjust = talk_frequency_adjust
|
||||
|
||||
|
||||
def get_focus_value_adjust(chat_id: str) -> float:
|
||||
return frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust
|
||||
|
||||
|
||||
|
||||
def get_talk_frequency_adjust(chat_id: str) -> float:
|
||||
return frequency_control_manager.get_or_create_frequency_control(chat_id).talk_frequency_external_adjust
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -159,6 +159,7 @@ async def generate_reply(
|
||||
logger.error(traceback.format_exc())
|
||||
return False, None
|
||||
|
||||
|
||||
async def rewrite_reply(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
reply_data: Optional[Dict[str, Any]] = None,
|
||||
|
||||
@@ -72,7 +72,9 @@ async def generate_with_model(
|
||||
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens)
|
||||
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(
|
||||
prompt, temperature=temperature, max_tokens=max_tokens
|
||||
)
|
||||
return True, response, reasoning_content, model_name
|
||||
|
||||
except Exception as e:
|
||||
@@ -80,6 +82,7 @@ async def generate_with_model(
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", ""
|
||||
|
||||
|
||||
async def generate_with_model_with_tools(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
@@ -109,10 +112,7 @@ async def generate_with_model_with_tools(
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
|
||||
prompt,
|
||||
tools=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
prompt, tools=tool_options, temperature=temperature, max_tokens=max_tokens
|
||||
)
|
||||
return True, response, reasoning_content, model_name, tool_call
|
||||
|
||||
|
||||
@@ -435,9 +435,7 @@ def build_readable_messages_to_str(
|
||||
Returns:
|
||||
格式化后的可读字符串
|
||||
"""
|
||||
return build_readable_messages(
|
||||
messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions
|
||||
)
|
||||
return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
|
||||
|
||||
|
||||
async def build_readable_messages_with_details(
|
||||
@@ -491,8 +489,6 @@ def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessag
|
||||
return [msg for msg in messages if msg.user_info.user_id != str(global_config.bot.qq_account)]
|
||||
|
||||
|
||||
|
||||
|
||||
def translate_pid_to_description(pid: str) -> str:
|
||||
image = Images.get_or_none(Images.image_id == pid)
|
||||
description = ""
|
||||
@@ -500,4 +496,4 @@ def translate_pid_to_description(pid: str) -> str:
|
||||
description = image.description
|
||||
else:
|
||||
description = "[图片]"
|
||||
return description
|
||||
return description
|
||||
|
||||
@@ -34,7 +34,7 @@ def get_plugin_path(plugin_name: str) -> str:
|
||||
|
||||
Returns:
|
||||
str: 插件目录的绝对路径。
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 如果插件不存在。
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plugin_manager") # 复用plugin_manager名称
|
||||
logger = get_logger("plugin_manager") # 复用plugin_manager名称
|
||||
|
||||
|
||||
def register_plugin(cls):
|
||||
|
||||
@@ -88,7 +88,7 @@ class GlobalAnnouncementManager:
|
||||
return False
|
||||
self._user_disabled_tools[chat_id].append(tool_name)
|
||||
return True
|
||||
|
||||
|
||||
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
|
||||
"""启用特定聊天的某个工具"""
|
||||
if chat_id in self._user_disabled_tools:
|
||||
@@ -111,7 +111,7 @@ class GlobalAnnouncementManager:
|
||||
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有事件处理器"""
|
||||
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||
|
||||
|
||||
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有工具"""
|
||||
return self._user_disabled_tools.get(chat_id, []).copy()
|
||||
|
||||
@@ -224,7 +224,7 @@ class PluginManager:
|
||||
list: 已注册的插件类名称列表。
|
||||
"""
|
||||
return list(self.plugin_classes.keys())
|
||||
|
||||
|
||||
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
|
||||
"""
|
||||
获取指定插件的路径。
|
||||
@@ -401,9 +401,7 @@ class PluginManager:
|
||||
command_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
|
||||
]
|
||||
tool_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
|
||||
]
|
||||
tool_components = [c for c in plugin_info.components if c.component_type == ComponentType.TOOL]
|
||||
event_handler_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
|
||||
]
|
||||
|
||||
@@ -149,10 +149,10 @@ class ToolExecutor:
|
||||
if not tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无需执行工具")
|
||||
return [], []
|
||||
|
||||
|
||||
# 提取tool_calls中的函数名称
|
||||
func_names = [call.func_name for call in tool_calls if call.func_name]
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
|
||||
|
||||
# 执行每个工具调用
|
||||
@@ -195,7 +195,9 @@ class ToolExecutor:
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
|
||||
async def execute_tool_call(
|
||||
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: use-assigned-variable
|
||||
"""执行单个工具调用
|
||||
|
||||
|
||||
@@ -63,5 +63,4 @@ class CoreActionsPlugin(BasePlugin):
|
||||
if self.get_config("components.enable_emoji", True):
|
||||
components.append((EmojiAction.get_action_info(), EmojiAction))
|
||||
|
||||
|
||||
return components
|
||||
|
||||
@@ -74,7 +74,9 @@ class BuildMemoryAction(BaseAction):
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "build_memory"
|
||||
action_description = "了解对于某个概念或者某件事的记忆,并存储下来,在之后的聊天中,你可以根据这条记忆来获取相关信息"
|
||||
action_description = (
|
||||
"了解对于某个概念或者某件事的记忆,并存储下来,在之后的聊天中,你可以根据这条记忆来获取相关信息"
|
||||
)
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {
|
||||
@@ -103,31 +105,34 @@ class BuildMemoryAction(BaseAction):
|
||||
concept_name = self.action_data.get("concept_name", "")
|
||||
# 2. 获取目标用户信息
|
||||
|
||||
|
||||
|
||||
# 对 concept_name 进行jieba分词
|
||||
concept_name_tokens = cut_key_words(concept_name)
|
||||
# logger.info(f"{self.log_prefix} 对 concept_name 进行分词结果: {concept_name_tokens}")
|
||||
|
||||
|
||||
filtered_concept_name_tokens = [
|
||||
token for token in concept_name_tokens if all(keyword not in token for keyword in global_config.memory.memory_ban_words)
|
||||
token
|
||||
for token in concept_name_tokens
|
||||
if all(keyword not in token for keyword in global_config.memory.memory_ban_words)
|
||||
]
|
||||
|
||||
|
||||
if not filtered_concept_name_tokens:
|
||||
logger.warning(f"{self.log_prefix} 过滤后的概念名称列表为空,跳过添加记忆")
|
||||
return False, "过滤后的概念名称列表为空,跳过添加记忆"
|
||||
|
||||
similar_topics_dict = hippocampus_manager.get_hippocampus().parahippocampal_gyrus.get_similar_topics_from_keywords(filtered_concept_name_tokens)
|
||||
await hippocampus_manager.get_hippocampus().parahippocampal_gyrus.add_memory_with_similar(concept_description, similar_topics_dict)
|
||||
|
||||
|
||||
|
||||
|
||||
similar_topics_dict = (
|
||||
hippocampus_manager.get_hippocampus().parahippocampal_gyrus.get_similar_topics_from_keywords(
|
||||
filtered_concept_name_tokens
|
||||
)
|
||||
)
|
||||
await hippocampus_manager.get_hippocampus().parahippocampal_gyrus.add_memory_with_similar(
|
||||
concept_description, similar_topics_dict
|
||||
)
|
||||
|
||||
return True, f"成功添加记忆: {concept_name}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 构建记忆时出错: {e}")
|
||||
return False, f"构建记忆时出错: {e}"
|
||||
|
||||
|
||||
|
||||
# 还缺一个关系的太多遗忘和对应的提取
|
||||
|
||||
@@ -425,7 +425,7 @@ class ManagementCommand(BaseCommand):
|
||||
await self._send_message(f"本地禁用组件成功: {component_name}")
|
||||
else:
|
||||
await self._send_message(f"本地禁用组件失败: {component_name}")
|
||||
|
||||
|
||||
async def _send_message(self, message: str):
|
||||
await send_api.text_to_stream(message, self.stream_id, typing=False, storage_message=False)
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ class BuildRelationAction(BaseAction):
|
||||
if not person.is_known:
|
||||
logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆")
|
||||
return False, f"用户 {person_name} 不存在,跳过添加记忆"
|
||||
|
||||
|
||||
person.last_know = time.time()
|
||||
person.know_times += 1
|
||||
person.sync_to_database()
|
||||
@@ -178,7 +178,9 @@ class BuildRelationAction(BaseAction):
|
||||
|
||||
chat_model_config = models.get("utils")
|
||||
success, update_memory, _, _ = await llm_api.generate_with_model(
|
||||
prompt, model_config=chat_model_config, request_type="relation.category.update" # type: ignore
|
||||
prompt,
|
||||
model_config=chat_model_config,
|
||||
request_type="relation.category.update", # type: ignore
|
||||
)
|
||||
|
||||
update_memory_data = json.loads(repair_json(update_memory))
|
||||
@@ -190,7 +192,7 @@ class BuildRelationAction(BaseAction):
|
||||
# 新记忆
|
||||
person.memory_points.append(f"{category}:{new_memory}:1.0")
|
||||
person.sync_to_database()
|
||||
|
||||
|
||||
logger.info(f"{self.log_prefix} 为{person.person_name}新增记忆点: {new_memory}")
|
||||
|
||||
return True, f"为{person.person_name}新增记忆点: {new_memory}"
|
||||
@@ -207,14 +209,15 @@ class BuildRelationAction(BaseAction):
|
||||
person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}")
|
||||
person.sync_to_database()
|
||||
|
||||
logger.info(f"{self.log_prefix} 更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
|
||||
)
|
||||
|
||||
return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
|
||||
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}")
|
||||
return False, f"删除{person.person_name}的记忆点失败: {memory_content}"
|
||||
|
||||
|
||||
return True, "关系动作执行成功"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user