Merge branch 'dev' of github.com:MaiM-with-u/MaiBot into dev

This commit is contained in:
UnCLAS-Prommer
2025-09-09 22:36:09 +08:00
66 changed files with 1085 additions and 1038 deletions

1
bot.py
View File

@@ -65,6 +65,7 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
from src.plugin_system.core.events_manager import events_manager from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.base.component_types import EventType from src.plugin_system.base.component_types import EventType
# 触发 ON_STOP 事件 # 触发 ON_STOP 事件
await events_manager.handle_mai_events(event_type=EventType.ON_STOP) await events_manager.handle_mai_events(event_type=EventType.ON_STOP)

View File

@@ -5,12 +5,11 @@ from typing import Dict, List
# Add project root to Python path # Add project root to Python path
from src.common.database.database_model import Expression, ChatStreams from src.common.database.database_model import Expression, ChatStreams
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
def get_chat_name(chat_id: str) -> str: def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly""" """Get chat name from chat_id by querying ChatStreams table directly"""
try: try:
@@ -35,72 +34,61 @@ def calculate_time_distribution(expressions) -> Dict[str, int]:
"""Calculate distribution of last active time in days""" """Calculate distribution of last active time in days"""
now = time.time() now = time.time()
distribution = { distribution = {
'0-1天': 0, "0-1天": 0,
'1-3天': 0, "1-3天": 0,
'3-7天': 0, "3-7天": 0,
'7-14天': 0, "7-14天": 0,
'14-30天': 0, "14-30天": 0,
'30-60天': 0, "30-60天": 0,
'60-90天': 0, "60-90天": 0,
'90+天': 0 "90+天": 0,
} }
for expr in expressions: for expr in expressions:
diff_days = (now - expr.last_active_time) / (24*3600) diff_days = (now - expr.last_active_time) / (24 * 3600)
if diff_days < 1: if diff_days < 1:
distribution['0-1天'] += 1 distribution["0-1天"] += 1
elif diff_days < 3: elif diff_days < 3:
distribution['1-3天'] += 1 distribution["1-3天"] += 1
elif diff_days < 7: elif diff_days < 7:
distribution['3-7天'] += 1 distribution["3-7天"] += 1
elif diff_days < 14: elif diff_days < 14:
distribution['7-14天'] += 1 distribution["7-14天"] += 1
elif diff_days < 30: elif diff_days < 30:
distribution['14-30天'] += 1 distribution["14-30天"] += 1
elif diff_days < 60: elif diff_days < 60:
distribution['30-60天'] += 1 distribution["30-60天"] += 1
elif diff_days < 90: elif diff_days < 90:
distribution['60-90天'] += 1 distribution["60-90天"] += 1
else: else:
distribution['90+天'] += 1 distribution["90+天"] += 1
return distribution return distribution
def calculate_count_distribution(expressions) -> Dict[str, int]: def calculate_count_distribution(expressions) -> Dict[str, int]:
"""Calculate distribution of count values""" """Calculate distribution of count values"""
distribution = { distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
'0-1': 0,
'1-2': 0,
'2-3': 0,
'3-4': 0,
'4-5': 0,
'5-10': 0,
'10+': 0
}
for expr in expressions: for expr in expressions:
cnt = expr.count cnt = expr.count
if cnt < 1: if cnt < 1:
distribution['0-1'] += 1 distribution["0-1"] += 1
elif cnt < 2: elif cnt < 2:
distribution['1-2'] += 1 distribution["1-2"] += 1
elif cnt < 3: elif cnt < 3:
distribution['2-3'] += 1 distribution["2-3"] += 1
elif cnt < 4: elif cnt < 4:
distribution['3-4'] += 1 distribution["3-4"] += 1
elif cnt < 5: elif cnt < 5:
distribution['4-5'] += 1 distribution["4-5"] += 1
elif cnt < 10: elif cnt < 10:
distribution['5-10'] += 1 distribution["5-10"] += 1
else: else:
distribution['10+'] += 1 distribution["10+"] += 1
return distribution return distribution
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]: def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
"""Get top N most used expressions for a specific chat_id""" """Get top N most used expressions for a specific chat_id"""
return (Expression.select() return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)
.where(Expression.chat_id == chat_id)
.order_by(Expression.count.desc())
.limit(top_n))
def show_overall_statistics(expressions, total: int) -> None: def show_overall_statistics(expressions, total: int) -> None:
@@ -113,11 +101,11 @@ def show_overall_statistics(expressions, total: int) -> None:
print("\n上次激活时间分布:") print("\n上次激活时间分布:")
for period, count in time_dist.items(): for period, count in time_dist.items():
print(f"{period}: {count} ({count/total*100:.2f}%)") print(f"{period}: {count} ({count / total * 100:.2f}%)")
print("\ncount分布:") print("\ncount分布:")
for range_, count in count_dist.items(): for range_, count in count_dist.items():
print(f"{range_}: {count} ({count/total*100:.2f}%)") print(f"{range_}: {count} ({count / total * 100:.2f}%)")
def show_chat_statistics(chat_id: str, chat_name: str) -> None: def show_chat_statistics(chat_id: str, chat_name: str) -> None:
@@ -137,14 +125,14 @@ def show_chat_statistics(chat_id: str, chat_name: str) -> None:
print("\n上次激活时间分布:") print("\n上次激活时间分布:")
for period, count in time_dist.items(): for period, count in time_dist.items():
if count > 0: if count > 0:
print(f"{period}: {count} ({count/chat_total*100:.2f}%)") print(f"{period}: {count} ({count / chat_total * 100:.2f}%)")
# Count distribution for this chat # Count distribution for this chat
count_dist = calculate_count_distribution(chat_exprs) count_dist = calculate_count_distribution(chat_exprs)
print("\ncount分布:") print("\ncount分布:")
for range_, count in count_dist.items(): for range_, count in count_dist.items():
if count > 0: if count > 0:
print(f"{range_}: {count} ({count/chat_total*100:.2f}%)") print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)")
# Top expressions # Top expressions
print("\nTop 10使用最多的表达式:") print("\nTop 10使用最多的表达式:")
@@ -172,9 +160,9 @@ def interactive_menu() -> None:
chat_info.sort(key=lambda x: x[1]) # Sort by chat name chat_info.sort(key=lambda x: x[1]) # Sort by chat name
while True: while True:
print("\n" + "="*50) print("\n" + "=" * 50)
print("表达式统计分析") print("表达式统计分析")
print("="*50) print("=" * 50)
print("0. 显示总体统计") print("0. 显示总体统计")
for i, (chat_id, chat_name) in enumerate(chat_info, 1): for i, (chat_id, chat_name) in enumerate(chat_info, 1):
@@ -185,7 +173,7 @@ def interactive_menu() -> None:
choice = input("\n请选择要查看的统计 (输入序号): ").strip() choice = input("\n请选择要查看的统计 (输入序号): ").strip()
if choice.lower() == 'q': if choice.lower() == "q":
print("再见!") print("再见!")
break break

View File

@@ -23,6 +23,7 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
logger = get_logger("OpenIE导入") logger = get_logger("OpenIE导入")
def ensure_openie_dir(): def ensure_openie_dir():
"""确保OpenIE数据目录存在""" """确保OpenIE数据目录存在"""
if not os.path.exists(OPENIE_DIR): if not os.path.exists(OPENIE_DIR):

View File

@@ -12,6 +12,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from rich.progress import Progress # 替换为 rich 进度条 from rich.progress import Progress # 替换为 rich 进度条
from src.common.logger import get_logger from src.common.logger import get_logger
# from src.chat.knowledge.lpmmconfig import global_config # from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.ie_process import info_extract_from_str from src.chat.knowledge.ie_process import info_extract_from_str
from src.chat.knowledge.open_ie import OpenIE from src.chat.knowledge.open_ie import OpenIE
@@ -36,6 +37,7 @@ TEMP_DIR = os.path.join(ROOT_PATH, "temp")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") # IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
def ensure_dirs(): def ensure_dirs():
"""确保临时目录和输出目录存在""" """确保临时目录和输出目录存在"""
if not os.path.exists(TEMP_DIR): if not os.path.exists(TEMP_DIR):
@@ -48,6 +50,7 @@ def ensure_dirs():
os.makedirs(RAW_DATA_PATH) os.makedirs(RAW_DATA_PATH)
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}") logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
# 创建一个线程安全的锁,用于保护文件操作和共享数据 # 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock() file_lock = Lock()
open_ie_doc_lock = Lock() open_ie_doc_lock = Lock()
@@ -56,13 +59,11 @@ open_ie_doc_lock = Lock()
shutdown_event = Event() shutdown_event = Event()
lpmm_entity_extract_llm = LLMRequest( lpmm_entity_extract_llm = LLMRequest(
model_set=model_config.model_task_config.lpmm_entity_extract, model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
request_type="lpmm.entity_extract"
)
lpmm_rdf_build_llm = LLMRequest(
model_set=model_config.model_task_config.lpmm_rdf_build,
request_type="lpmm.rdf_build"
) )
lpmm_rdf_build_llm = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
def process_single_text(pg_hash, raw_data): def process_single_text(pg_hash, raw_data):
"""处理单个文本的函数,用于线程池""" """处理单个文本的函数,用于线程池"""
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json" temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"

View File

@@ -3,12 +3,11 @@ import sys
import os import os
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Tuple, Optional
from datetime import datetime from datetime import datetime
# Add project root to Python path # Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams #noqa from src.common.database.database_model import Messages, ChatStreams # noqa
def get_chat_name(chat_id: str) -> str: def get_chat_name(chat_id: str) -> str:
@@ -39,15 +38,15 @@ def format_timestamp(timestamp: float) -> str:
def calculate_interest_value_distribution(messages) -> Dict[str, int]: def calculate_interest_value_distribution(messages) -> Dict[str, int]:
"""Calculate distribution of interest_value""" """Calculate distribution of interest_value"""
distribution = { distribution = {
'0.000-0.010': 0, "0.000-0.010": 0,
'0.010-0.050': 0, "0.010-0.050": 0,
'0.050-0.100': 0, "0.050-0.100": 0,
'0.100-0.500': 0, "0.100-0.500": 0,
'0.500-1.000': 0, "0.500-1.000": 0,
'1.000-2.000': 0, "1.000-2.000": 0,
'2.000-5.000': 0, "2.000-5.000": 0,
'5.000-10.000': 0, "5.000-10.000": 0,
'10.000+': 0 "10.000+": 0,
} }
for msg in messages: for msg in messages:
@@ -56,49 +55,45 @@ def calculate_interest_value_distribution(messages) -> Dict[str, int]:
value = float(msg.interest_value) value = float(msg.interest_value)
if value < 0.010: if value < 0.010:
distribution['0.000-0.010'] += 1 distribution["0.000-0.010"] += 1
elif value < 0.050: elif value < 0.050:
distribution['0.010-0.050'] += 1 distribution["0.010-0.050"] += 1
elif value < 0.100: elif value < 0.100:
distribution['0.050-0.100'] += 1 distribution["0.050-0.100"] += 1
elif value < 0.500: elif value < 0.500:
distribution['0.100-0.500'] += 1 distribution["0.100-0.500"] += 1
elif value < 1.000: elif value < 1.000:
distribution['0.500-1.000'] += 1 distribution["0.500-1.000"] += 1
elif value < 2.000: elif value < 2.000:
distribution['1.000-2.000'] += 1 distribution["1.000-2.000"] += 1
elif value < 5.000: elif value < 5.000:
distribution['2.000-5.000'] += 1 distribution["2.000-5.000"] += 1
elif value < 10.000: elif value < 10.000:
distribution['5.000-10.000'] += 1 distribution["5.000-10.000"] += 1
else: else:
distribution['10.000+'] += 1 distribution["10.000+"] += 1
return distribution return distribution
def get_interest_value_stats(messages) -> Dict[str, float]: def get_interest_value_stats(messages) -> Dict[str, float]:
"""Calculate basic statistics for interest_value""" """Calculate basic statistics for interest_value"""
values = [float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0] values = [
float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0
]
if not values: if not values:
return { return {"count": 0, "min": 0, "max": 0, "avg": 0, "median": 0}
'count': 0,
'min': 0,
'max': 0,
'avg': 0,
'median': 0
}
values.sort() values.sort()
count = len(values) count = len(values)
return { return {
'count': count, "count": count,
'min': min(values), "min": min(values),
'max': max(values), "max": max(values),
'avg': sum(values) / count, "avg": sum(values) / count,
'median': values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2 "median": values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2,
} }
@@ -109,11 +104,15 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
chat_counts = {} chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct(): for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id chat_id = msg.chat_id
count = Messages.select().where( count = (
(Messages.chat_id == chat_id) & Messages.select()
(Messages.interest_value.is_null(False)) & .where(
(Messages.interest_value != 0.0) (Messages.chat_id == chat_id)
).count() & (Messages.interest_value.is_null(False))
& (Messages.interest_value != 0.0)
)
.count()
)
if count > 0: if count > 0:
chat_counts[chat_id] = count chat_counts[chat_id] = count
@@ -146,13 +145,13 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
now = time.time() now = time.time()
if choice == "1": if choice == "1":
return now - 24*3600, now return now - 24 * 3600, now
elif choice == "2": elif choice == "2":
return now - 3*24*3600, now return now - 3 * 24 * 3600, now
elif choice == "3": elif choice == "3":
return now - 7*24*3600, now return now - 7 * 24 * 3600, now
elif choice == "4": elif choice == "4":
return now - 30*24*3600, now return now - 30 * 24 * 3600, now
elif choice == "5": elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):") print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip() start_str = input().strip()
@@ -170,14 +169,13 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
return None, None return None, None
def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None: def analyze_interest_values(
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
) -> None:
"""Analyze interest values with optional filters""" """Analyze interest values with optional filters"""
# 构建查询条件 # 构建查询条件
query = Messages.select().where( query = Messages.select().where((Messages.interest_value.is_null(False)) & (Messages.interest_value != 0.0))
(Messages.interest_value.is_null(False)) &
(Messages.interest_value != 0.0)
)
if chat_id: if chat_id:
query = query.where(Messages.chat_id == chat_id) query = query.where(Messages.chat_id == chat_id)
@@ -222,7 +220,7 @@ def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[
print(f"中位数: {stats['median']:.3f}") print(f"中位数: {stats['median']:.3f}")
print("\nInterest Value 分布:") print("\nInterest Value 分布:")
total = stats['count'] total = stats["count"]
for range_name, count in distribution.items(): for range_name, count in distribution.items():
if count > 0: if count > 0:
percentage = count / total * 100 percentage = count / total * 100
@@ -233,16 +231,16 @@ def interactive_menu() -> None:
"""Interactive menu for interest value analysis""" """Interactive menu for interest value analysis"""
while True: while True:
print("\n" + "="*50) print("\n" + "=" * 50)
print("Interest Value 分析工具") print("Interest Value 分析工具")
print("="*50) print("=" * 50)
print("1. 分析全部聊天") print("1. 分析全部聊天")
print("2. 选择特定聊天分析") print("2. 选择特定聊天分析")
print("q. 退出") print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip() choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == 'q': if choice.lower() == "q":
print("再见!") print("再见!")
break break

View File

@@ -1191,6 +1191,7 @@ class LogViewer:
# 如果发现了新模块,在主线程中更新模块集合 # 如果发现了新模块,在主线程中更新模块集合
if new_modules: if new_modules:
def update_modules(): def update_modules():
self.modules.update(new_modules) self.modules.update(new_modules)
self.update_module_list() self.update_module_list()
@@ -1424,4 +1425,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -2,6 +2,7 @@ import os
from pathlib import Path from pathlib import Path
import sys # 新增系统模块导入 import sys # 新增系统模块导入
from src.chat.knowledge.utils.hash import get_sha256 from src.chat.knowledge.utils.hash import get_sha256
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -10,6 +11,7 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data") RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") # IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
def _process_text_file(file_path): def _process_text_file(file_path):
"""处理单个文本文件,返回段落列表""" """处理单个文本文件,返回段落列表"""
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
@@ -44,6 +46,7 @@ def _process_multi_files() -> list:
all_paragraphs.extend(paragraphs) all_paragraphs.extend(paragraphs)
return all_paragraphs return all_paragraphs
def load_raw_data() -> tuple[list[str], list[str]]: def load_raw_data() -> tuple[list[str], list[str]]:
"""加载原始数据文件 """加载原始数据文件

View File

@@ -4,10 +4,11 @@ import os
import re import re
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Tuple, Optional
from datetime import datetime from datetime import datetime
# Add project root to Python path # Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams #noqa from src.common.database.database_model import Messages, ChatStreams # noqa
def contains_emoji_or_image_tags(text: str) -> bool: def contains_emoji_or_image_tags(text: str) -> bool:
@@ -16,8 +17,8 @@ def contains_emoji_or_image_tags(text: str) -> bool:
return False return False
# 检查是否包含 [表情包] 或 [图片] 标记 # 检查是否包含 [表情包] 或 [图片] 标记
emoji_pattern = r'\[表情包[^\]]*\]' emoji_pattern = r"\[表情包[^\]]*\]"
image_pattern = r'\[图片[^\]]*\]' image_pattern = r"\[图片[^\]]*\]"
return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text)) return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text))
@@ -29,7 +30,7 @@ def clean_reply_text(text: str) -> str:
# 匹配 [回复 xxxx...] 格式的内容 # 匹配 [回复 xxxx...] 格式的内容
# 使用非贪婪匹配,匹配到第一个 ] 就停止 # 使用非贪婪匹配,匹配到第一个 ] 就停止
cleaned_text = re.sub(r'\[回复[^\]]*\]', '', text) cleaned_text = re.sub(r"\[回复[^\]]*\]", "", text)
# 去除多余的空白字符 # 去除多余的空白字符
cleaned_text = cleaned_text.strip() cleaned_text = cleaned_text.strip()
@@ -65,20 +66,20 @@ def format_timestamp(timestamp: float) -> str:
def calculate_text_length_distribution(messages) -> Dict[str, int]: def calculate_text_length_distribution(messages) -> Dict[str, int]:
"""Calculate distribution of processed_plain_text length""" """Calculate distribution of processed_plain_text length"""
distribution = { distribution = {
'0': 0, # 空文本 "0": 0, # 空文本
'1-5': 0, # 极短文本 "1-5": 0, # 极短文本
'6-10': 0, # 很短文本 "6-10": 0, # 很短文本
'11-20': 0, # 短文本 "11-20": 0, # 短文本
'21-30': 0, # 较短文本 "21-30": 0, # 较短文本
'31-50': 0, # 中短文本 "31-50": 0, # 中短文本
'51-70': 0, # 中等文本 "51-70": 0, # 中等文本
'71-100': 0, # 较长文本 "71-100": 0, # 较长文本
'101-150': 0, # 长文本 "101-150": 0, # 长文本
'151-200': 0, # 很长文本 "151-200": 0, # 很长文本
'201-300': 0, # 超长文本 "201-300": 0, # 超长文本
'301-500': 0, # 极长文本 "301-500": 0, # 极长文本
'501-1000': 0, # 巨长文本 "501-1000": 0, # 巨长文本
'1000+': 0 # 超巨长文本 "1000+": 0, # 超巨长文本
} }
for msg in messages: for msg in messages:
@@ -94,33 +95,33 @@ def calculate_text_length_distribution(messages) -> Dict[str, int]:
length = len(cleaned_text) length = len(cleaned_text)
if length == 0: if length == 0:
distribution['0'] += 1 distribution["0"] += 1
elif length <= 5: elif length <= 5:
distribution['1-5'] += 1 distribution["1-5"] += 1
elif length <= 10: elif length <= 10:
distribution['6-10'] += 1 distribution["6-10"] += 1
elif length <= 20: elif length <= 20:
distribution['11-20'] += 1 distribution["11-20"] += 1
elif length <= 30: elif length <= 30:
distribution['21-30'] += 1 distribution["21-30"] += 1
elif length <= 50: elif length <= 50:
distribution['31-50'] += 1 distribution["31-50"] += 1
elif length <= 70: elif length <= 70:
distribution['51-70'] += 1 distribution["51-70"] += 1
elif length <= 100: elif length <= 100:
distribution['71-100'] += 1 distribution["71-100"] += 1
elif length <= 150: elif length <= 150:
distribution['101-150'] += 1 distribution["101-150"] += 1
elif length <= 200: elif length <= 200:
distribution['151-200'] += 1 distribution["151-200"] += 1
elif length <= 300: elif length <= 300:
distribution['201-300'] += 1 distribution["201-300"] += 1
elif length <= 500: elif length <= 500:
distribution['301-500'] += 1 distribution["301-500"] += 1
elif length <= 1000: elif length <= 1000:
distribution['501-1000'] += 1 distribution["501-1000"] += 1
else: else:
distribution['1000+'] += 1 distribution["1000+"] += 1
return distribution return distribution
@@ -144,26 +145,26 @@ def get_text_length_stats(messages) -> Dict[str, float]:
if not lengths: if not lengths:
return { return {
'count': 0, "count": 0,
'null_count': null_count, "null_count": null_count,
'excluded_count': excluded_count, "excluded_count": excluded_count,
'min': 0, "min": 0,
'max': 0, "max": 0,
'avg': 0, "avg": 0,
'median': 0 "median": 0,
} }
lengths.sort() lengths.sort()
count = len(lengths) count = len(lengths)
return { return {
'count': count, "count": count,
'null_count': null_count, "null_count": null_count,
'excluded_count': excluded_count, "excluded_count": excluded_count,
'min': min(lengths), "min": min(lengths),
'max': max(lengths), "max": max(lengths),
'avg': sum(lengths) / count, "avg": sum(lengths) / count,
'median': lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2 "median": lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2,
} }
@@ -174,12 +175,16 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
chat_counts = {} chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct(): for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id chat_id = msg.chat_id
count = Messages.select().where( count = (
(Messages.chat_id == chat_id) & Messages.select()
(Messages.is_emoji != 1) & .where(
(Messages.is_picid != 1) & (Messages.chat_id == chat_id)
(Messages.is_command != 1) & (Messages.is_emoji != 1)
).count() & (Messages.is_picid != 1)
& (Messages.is_command != 1)
)
.count()
)
if count > 0: if count > 0:
chat_counts[chat_id] = count chat_counts[chat_id] = count
@@ -212,13 +217,13 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
now = time.time() now = time.time()
if choice == "1": if choice == "1":
return now - 24*3600, now return now - 24 * 3600, now
elif choice == "2": elif choice == "2":
return now - 3*24*3600, now return now - 3 * 24 * 3600, now
elif choice == "3": elif choice == "3":
return now - 7*24*3600, now return now - 7 * 24 * 3600, now
elif choice == "4": elif choice == "4":
return now - 30*24*3600, now return now - 30 * 24 * 3600, now
elif choice == "5": elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):") print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip() start_str = input().strip()
@@ -260,15 +265,13 @@ def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int,
return message_lengths[:top_n] return message_lengths[:top_n]
def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None: def analyze_text_lengths(
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
) -> None:
"""Analyze processed_plain_text lengths with optional filters""" """Analyze processed_plain_text lengths with optional filters"""
# 构建查询条件,排除特殊类型的消息 # 构建查询条件,排除特殊类型的消息
query = Messages.select().where( query = Messages.select().where((Messages.is_emoji != 1) & (Messages.is_picid != 1) & (Messages.is_command != 1))
(Messages.is_emoji != 1) &
(Messages.is_picid != 1) &
(Messages.is_command != 1)
)
if chat_id: if chat_id:
query = query.where(Messages.chat_id == chat_id) query = query.where(Messages.chat_id == chat_id)
@@ -312,14 +315,14 @@ def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[flo
print(f"有文本消息数量: {stats['count']}") print(f"有文本消息数量: {stats['count']}")
print(f"空文本消息数量: {stats['null_count']}") print(f"空文本消息数量: {stats['null_count']}")
print(f"被排除的消息数量: {stats['excluded_count']}") print(f"被排除的消息数量: {stats['excluded_count']}")
if stats['count'] > 0: if stats["count"] > 0:
print(f"最短长度: {stats['min']} 字符") print(f"最短长度: {stats['min']} 字符")
print(f"最长长度: {stats['max']} 字符") print(f"最长长度: {stats['max']} 字符")
print(f"平均长度: {stats['avg']:.2f} 字符") print(f"平均长度: {stats['avg']:.2f} 字符")
print(f"中位数长度: {stats['median']:.2f} 字符") print(f"中位数长度: {stats['median']:.2f} 字符")
print("\n文本长度分布:") print("\n文本长度分布:")
total = stats['count'] total = stats["count"]
if total > 0: if total > 0:
for range_name, count in distribution.items(): for range_name, count in distribution.items():
if count > 0: if count > 0:
@@ -340,16 +343,16 @@ def interactive_menu() -> None:
"""Interactive menu for text length analysis""" """Interactive menu for text length analysis"""
while True: while True:
print("\n" + "="*50) print("\n" + "=" * 50)
print("Processed Plain Text 长度分析工具") print("Processed Plain Text 长度分析工具")
print("="*50) print("=" * 50)
print("1. 分析全部聊天") print("1. 分析全部聊天")
print("2. 选择特定聊天分析") print("2. 选择特定聊天分析")
print("q. 退出") print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip() choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == 'q': if choice.lower() == "q":
print("再见!") print("再见!")
break break

View File

@@ -731,7 +731,7 @@ class EmojiManager:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash) emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion: if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...") logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
return emoji_record.emotion.split(',') return emoji_record.emotion.split(",")
except Exception as e: except Exception as e:
logger.error(f"从数据库查询表情包情感标签时出错: {e}") logger.error(f"从数据库查询表情包情感标签时出错: {e}")

View File

@@ -123,9 +123,7 @@ class ExpressionSelector:
return group_chat_ids return group_chat_ids
return [chat_id] return [chat_id]
def get_random_expressions( def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
self, chat_id: str, total_num: int
) -> List[Dict[str, Any]]:
# sourcery skip: extract-duplicate-method, move-assign # sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选 # 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
@@ -248,7 +246,6 @@ class ExpressionSelector:
# 4. 调用LLM # 4. 调用LLM
try: try:
# start_time = time.time() # start_time = time.time()
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt) content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}") # logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
@@ -297,7 +294,6 @@ class ExpressionSelector:
return [], [] return [], []
init_prompt() init_prompt()
try: try:

View File

@@ -119,4 +119,3 @@ def get_global_focus_value() -> Optional[float]:
return get_time_based_focus_value(config_item[1:]) return get_time_based_focus_value(config_item[1:])
return None return None

View File

@@ -124,5 +124,3 @@ def get_global_frequency() -> Optional[float]:
return get_time_based_frequency(config_item[1:]) return get_time_based_frequency(config_item[1:])
return None return None

View File

@@ -261,7 +261,11 @@ class HeartFChatting:
return loop_info, reply_text, cycle_timers return loop_info, reply_text, cycle_timers
async def _observe(self, interest_value: float = 0.0, recent_messages_list: List["DatabaseMessages"] = []) -> bool: async def _observe(
self, interest_value: float = 0.0, recent_messages_list: Optional[List["DatabaseMessages"]] = None
) -> bool:
if recent_messages_list is None:
recent_messages_list = []
reply_text = "" # 初始化reply_text变量避免UnboundLocalError reply_text = "" # 初始化reply_text变量避免UnboundLocalError
# 使用sigmoid函数将interest_value转换为概率 # 使用sigmoid函数将interest_value转换为概率

View File

@@ -3,8 +3,10 @@ from typing import Any, Optional, Dict
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.heart_flow.heartFC_chat import HeartFChatting from src.chat.heart_flow.heartFC_chat import HeartFChatting
logger = get_logger("heartflow") logger = get_logger("heartflow")
class Heartflow: class Heartflow:
"""主心流协调器,负责初始化并协调聊天""" """主心流协调器,负责初始化并协调聊天"""
@@ -18,7 +20,7 @@ class Heartflow:
if chat := self.heartflow_chat_list.get(chat_id): if chat := self.heartflow_chat_list.get(chat_id):
return chat return chat
else: else:
new_chat = HeartFChatting(chat_id = chat_id) new_chat = HeartFChatting(chat_id=chat_id)
await new_chat.start() await new_chat.start()
self.heartflow_chat_list[chat_id] = new_chat self.heartflow_chat_list[chat_id] = new_chat
return new_chat return new_chat
@@ -27,4 +29,5 @@ class Heartflow:
traceback.print_exc() traceback.print_exc()
return None return None
heartflow = Heartflow() heartflow = Heartflow()

View File

@@ -23,6 +23,7 @@ if TYPE_CHECKING:
logger = get_logger("chat") logger = get_logger("chat")
async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]: async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
"""计算消息的兴趣度 """计算消息的兴趣度
@@ -35,13 +36,13 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
if message.is_picid or message.is_emoji: if message.is_picid or message.is_emoji:
return 0.0, [] return 0.0, []
is_mentioned,is_at,reply_probability_boost = is_mentioned_bot_in_message(message) is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
interested_rate = 0.0 interested_rate = 0.0
with Timer("记忆激活"): with Timer("记忆激活"):
interested_rate, keywords,keywords_lite = await hippocampus_manager.get_activate_from_text( interested_rate, keywords, keywords_lite = await hippocampus_manager.get_activate_from_text(
message.processed_plain_text, message.processed_plain_text,
max_depth= 4, max_depth=4,
fast_retrieval=global_config.chat.interest_rate_mode == "fast", fast_retrieval=global_config.chat.interest_rate_mode == "fast",
) )
message.key_words = keywords message.key_words = keywords
@@ -79,7 +80,6 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, list[str]]:
# 确保在范围内 # 确保在范围内
base_interest = min(max(base_interest, 0.01), 0.3) base_interest = min(max(base_interest, 0.01), 0.3)
message.interest_value = base_interest message.interest_value = base_interest
message.is_mentioned = is_mentioned message.is_mentioned = is_mentioned
message.is_at = is_at message.is_at = is_at
@@ -116,7 +116,6 @@ class HeartFCMessageReceiver:
# 2. 兴趣度计算与更新 # 2. 兴趣度计算与更新
interested_rate, keywords = await _calculate_interest(message) interested_rate, keywords = await _calculate_interest(message)
await self.storage.store_message(message, chat) await self.storage.store_message(message, chat)
heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore heartflow_chat: HeartFChatting = await heartflow.get_or_create_heartflow_chat(chat.stream_id) # type: ignore
@@ -145,18 +144,20 @@ class HeartFCMessageReceiver:
# 如果没有找到图片描述,则移除[picid:xxxx]标记 # 如果没有找到图片描述,则移除[picid:xxxx]标记
processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]") processed_text = processed_text.replace(f"[picid:{picid}]", "[图片:网络不好,图片无法加载]")
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式 # 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
processed_plain_text = replace_user_references( processed_plain_text = replace_user_references(
processed_text, processed_text,
message.message_info.platform, # type: ignore message.message_info.platform, # type: ignore
replace_bot_name=True replace_bot_name=True,
) )
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[{interested_rate:.2f}]") # type: ignore logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[{interested_rate:.2f}]") # type: ignore
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=userinfo.user_nickname) # type: ignore _ = Person.register_person(
platform=message.message_info.platform,
user_id=message.message_info.user_info.user_id,
nickname=userinfo.user_nickname,
) # type: ignore
except Exception as e: except Exception as e:
logger.error(f"消息处理失败: {e}") logger.error(f"消息处理失败: {e}")

View File

@@ -124,6 +124,7 @@ async def send_typing():
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
) )
async def stop_typing(): async def stop_typing():
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心") group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")

View File

@@ -30,6 +30,7 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
qa_manager = None qa_manager = None
inspire_manager = None inspire_manager = None
def lpmm_start_up(): # sourcery skip: extract-duplicate-method def lpmm_start_up(): # sourcery skip: extract-duplicate-method
# 检查LPMM知识库是否启用 # 检查LPMM知识库是否启用
if global_config.lpmm_knowledge.enable: if global_config.lpmm_knowledge.enable:

View File

@@ -32,11 +32,11 @@ install(extra_lines=3)
# 多线程embedding配置常量 # 多线程embedding配置常量
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数 DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小 DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
MIN_CHUNK_SIZE = 1 # 最小分块大小 MIN_CHUNK_SIZE = 1 # 最小分块大小
MAX_CHUNK_SIZE = 50 # 最大分块大小 MAX_CHUNK_SIZE = 50 # 最大分块大小
MIN_WORKERS = 1 # 最小线程数 MIN_WORKERS = 1 # 最小线程数
MAX_WORKERS = 20 # 最大线程数 MAX_WORKERS = 20 # 最大线程数
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
@@ -93,7 +93,13 @@ class EmbeddingStoreItem:
class EmbeddingStore: class EmbeddingStore:
def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE): def __init__(
self,
namespace: str,
dir_path: str,
max_workers: int = DEFAULT_MAX_WORKERS,
chunk_size: int = DEFAULT_CHUNK_SIZE,
):
self.namespace = namespace self.namespace = namespace
self.dir = dir_path self.dir = dir_path
self.embedding_file_path = f"{dir_path}/{namespace}.parquet" self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
@@ -106,9 +112,13 @@ class EmbeddingStore:
# 如果配置值被调整,记录日志 # 如果配置值被调整,记录日志
if self.max_workers != max_workers: if self.max_workers != max_workers:
logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})") logger.warning(
f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})"
)
if self.chunk_size != chunk_size: if self.chunk_size != chunk_size:
logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})") logger.warning(
f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})"
)
self.store = {} self.store = {}
@@ -147,7 +157,9 @@ class EmbeddingStore:
except Exception: except Exception:
pass pass
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]: def _get_embeddings_batch_threaded(
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
) -> List[Tuple[str, List[float]]]:
"""使用多线程批量获取嵌入向量 """使用多线程批量获取嵌入向量
Args: Args:
@@ -165,7 +177,7 @@ class EmbeddingStore:
# 分块 # 分块
chunks = [] chunks = []
for i in range(0, len(strs), chunk_size): for i in range(0, len(strs), chunk_size):
chunk = strs[i:i + chunk_size] chunk = strs[i : i + chunk_size]
chunks.append((i, chunk)) # 保存起始索引以维持顺序 chunks.append((i, chunk)) # 保存起始索引以维持顺序
# 结果存储,使用字典按索引存储以保证顺序 # 结果存储,使用字典按索引存储以保证顺序
@@ -264,7 +276,7 @@ class EmbeddingStore:
embedding_results = self._get_embeddings_batch_threaded( embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS, EMBEDDING_TEST_STRINGS,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)) max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
) )
# 构建测试向量字典 # 构建测试向量字典
@@ -311,7 +323,7 @@ class EmbeddingStore:
embedding_results = self._get_embeddings_batch_threaded( embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS, EMBEDDING_TEST_STRINGS,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)) max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
) )
# 检查一致性 # 检查一致性
@@ -370,8 +382,16 @@ class EmbeddingStore:
if new_strs: if new_strs:
# 使用实例配置的参数,智能调整分块和线程数 # 使用实例配置的参数,智能调整分块和线程数
optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size)) optimal_chunk_size = max(
optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1)) MIN_CHUNK_SIZE,
min(
self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size
),
)
optimal_max_workers = min(
self.max_workers,
max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1),
)
logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}") logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}")
@@ -384,7 +404,7 @@ class EmbeddingStore:
new_strs, new_strs,
chunk_size=optimal_chunk_size, chunk_size=optimal_chunk_size,
max_workers=optimal_max_workers, max_workers=optimal_max_workers,
progress_callback=update_progress progress_callback=update_progress,
) )
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了) # 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)

View File

@@ -426,9 +426,7 @@ class KGManager:
# 获取最终结果 # 获取最终结果
# 从搜索结果中提取文段节点的结果 # 从搜索结果中提取文段节点的结果
passage_node_res = [ passage_node_res = [
(node_key, score) (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith("paragraph")
for node_key, score in ppr_res.items()
if node_key.startswith("paragraph")
] ]
del ppr_res del ppr_res

View File

@@ -1,8 +1,8 @@
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it") raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
from .lpmmconfig import global_config from .lpmmconfig import global_config # noqa
from .embedding_store import EmbeddingManager from .embedding_store import EmbeddingManager # noqa
from .llm_client import LLMClient from .llm_client import LLMClient # noqa
from .utils.dyn_topk import dyn_select_top_k from .utils.dyn_topk import dyn_select_top_k # noqa
class MemoryActiveManager: class MemoryActiveManager:

View File

@@ -103,25 +103,23 @@ class ActionModifier:
self.action_manager.remove_action_from_using(action_name) self.action_manager.remove_action_from_using(action_name)
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}") logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
# === 第三阶段:激活类型判定 === # === 第三阶段:激活类型判定 ===
# if chat_content is not None: # if chat_content is not None:
# logger.debug(f"{self.log_prefix}开始激活类型判定阶段") # logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
# 获取当前使用的动作集(经过第一阶段处理) # 获取当前使用的动作集(经过第一阶段处理)
# current_using_actions = self.action_manager.get_using_actions() # current_using_actions = self.action_manager.get_using_actions()
# 获取因激活类型判定而需要移除的动作 # 获取因激活类型判定而需要移除的动作
# removals_s3 = await self._get_deactivated_actions_by_type( # removals_s3 = await self._get_deactivated_actions_by_type(
# current_using_actions, # current_using_actions,
# chat_content, # chat_content,
# ) # )
# 应用第三阶段的移除 # 应用第三阶段的移除
# for action_name, reason in removals_s3: # for action_name, reason in removals_s3:
# self.action_manager.remove_action_from_using(action_name) # self.action_manager.remove_action_from_using(action_name)
# logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}") # logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
# === 统一日志记录 === # === 统一日志记录 ===
all_removals = removals_s1 + removals_s2 all_removals = removals_s1 + removals_s2
@@ -131,9 +129,7 @@ class ActionModifier:
available_actions = list(self.action_manager.get_using_actions().keys()) available_actions = list(self.action_manager.get_using_actions().keys())
available_actions_text = "".join(available_actions) if available_actions else "" available_actions_text = "".join(available_actions) if available_actions else ""
logger.debug( logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}"
)
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext): def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
type_mismatched_actions: List[Tuple[str, str]] = [] type_mismatched_actions: List[Tuple[str, str]] = []

View File

@@ -396,7 +396,7 @@ class StatisticOutputTask(AsyncTask):
# 计算标准差 # 计算标准差
if len(time_costs) > 1: if len(time_costs) > 1:
variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs) variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs)
std_time_cost = variance ** 0.5 std_time_cost = variance**0.5
stats[period_key][std_key][item_name] = round(std_time_cost, 3) stats[period_key][std_key][item_name] = round(std_time_cost, 3)
else: else:
stats[period_key][std_key][item_name] = 0.0 stats[period_key][std_key][item_name] = 0.0
@@ -506,8 +506,6 @@ class StatisticOutputTask(AsyncTask):
break break
return stats return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
""" """
收集各时间段的统计数据 收集各时间段的统计数据
@@ -639,7 +637,9 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODEL][model_name] cost = stats[COST_BY_MODEL][model_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name] avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name] std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
output.append(data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)) output.append(
data_fmt.format(name, count, in_tokens, out_tokens, tokens, cost, avg_time_cost, std_time_cost)
)
output.append("") output.append("")
return "\n".join(output) return "\n".join(output)
@@ -728,7 +728,9 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>" f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
f"</tr>" f"</tr>"
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items()) for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
] if stat_data[REQ_CNT_BY_MODEL] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"] ]
if stat_data[REQ_CNT_BY_MODEL]
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
) )
# 按请求类型分类统计 # 按请求类型分类统计
type_rows = "\n".join( type_rows = "\n".join(
@@ -744,7 +746,9 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>" f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
f"</tr>" f"</tr>"
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items()) for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
] if stat_data[REQ_CNT_BY_TYPE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"] ]
if stat_data[REQ_CNT_BY_TYPE]
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
) )
# 按模块分类统计 # 按模块分类统计
module_rows = "\n".join( module_rows = "\n".join(
@@ -760,7 +764,9 @@ class StatisticOutputTask(AsyncTask):
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>" f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
f"</tr>" f"</tr>"
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items()) for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
] if stat_data[REQ_CNT_BY_MODULE] else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"] ]
if stat_data[REQ_CNT_BY_MODULE]
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
) )
# 聊天消息统计 # 聊天消息统计
@@ -768,7 +774,9 @@ class StatisticOutputTask(AsyncTask):
[ [
f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>" f"<tr><td>{self.name_mapping[chat_id][0]}</td><td>{count}</td></tr>"
for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items()) for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items())
] if stat_data[MSG_CNT_BY_CHAT] else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"] ]
if stat_data[MSG_CNT_BY_CHAT]
else ["<tr><td colspan='2' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
) )
# 生成HTML # 生成HTML
return f""" return f"""

View File

@@ -51,7 +51,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
is_mentioned = False is_mentioned = False
# 这部分怎么处理啊啊啊啊 # 这部分怎么处理啊啊啊啊
#我觉得可以给消息加一个 reply_probability_boost字段 # 我觉得可以给消息加一个 reply_probability_boost字段
if ( if (
message.message_info.additional_config is not None message.message_info.additional_config is not None
and message.message_info.additional_config.get("is_mentioned") is not None and message.message_info.additional_config.get("is_mentioned") is not None
@@ -826,20 +826,48 @@ def parse_keywords_string(keywords_input) -> list[str]:
return [keywords_str] if keywords_str else [] return [keywords_str] if keywords_str else []
def cut_key_words(concept_name: str) -> list[str]: def cut_key_words(concept_name: str) -> list[str]:
"""对概念名称进行jieba分词并过滤掉关键词列表中的关键词""" """对概念名称进行jieba分词并过滤掉关键词列表中的关键词"""
concept_name_tokens = list(jieba.cut(concept_name)) concept_name_tokens = list(jieba.cut(concept_name))
# 定义常见连词、停用词与标点 # 定义常见连词、停用词与标点
conjunctions = { conjunctions = {"", "", "", "", "以及", "并且", "而且", "", "或者", ""}
"", "", "", "", "以及", "并且", "而且", "", "或者", ""
}
stop_words = { stop_words = {
"", "", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "", "", "", "", "", "",
"", "", "", "", "", "", "", "而且", "或者", "", "以及" "",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"而且",
"或者",
"",
"以及",
} }
chinese_punctuations = set(",。!?、;:()【】《》“”‘’—…·-——,.!?;:()[]<>'\"/\\") chinese_punctuations = set(",。!?、;:()【】《》“”‘’—…·-——,.!?;:()[]<>'\"/\\")
@@ -864,11 +892,16 @@ def cut_key_words(concept_name: str) -> list[str]:
left = merged_tokens[-1] left = merged_tokens[-1]
right = cleaned_tokens[i + 1] right = cleaned_tokens[i + 1]
# 左右都需要是有效词 # 左右都需要是有效词
if left and right \ if (
and left not in conjunctions and right not in conjunctions \ left
and left not in stop_words and right not in stop_words \ and right
and not all(ch in chinese_punctuations for ch in left) \ and left not in conjunctions
and not all(ch in chinese_punctuations for ch in right): and right not in conjunctions
and left not in stop_words
and right not in stop_words
and not all(ch in chinese_punctuations for ch in left)
and not all(ch in chinese_punctuations for ch in right)
):
# 合并为一个新词,并替换掉左侧与跳过右侧 # 合并为一个新词,并替换掉左侧与跳过右侧
combined = f"{left}{tok}{right}" combined = f"{left}{tok}{right}"
merged_tokens[-1] = combined merged_tokens[-1] = combined
@@ -889,7 +922,7 @@ def cut_key_words(concept_name: str) -> list[str]:
if tok in stop_words: if tok in stop_words:
continue continue
# if tok in ban_words: # if tok in ban_words:
# continue # continue
if all(ch in chinese_punctuations for ch in tok): if all(ch in chinese_punctuations for ch in tok):
continue continue
if tok.strip() == "": if tok.strip() == "":

View File

@@ -94,6 +94,7 @@ class ImageManager:
async def get_emoji_tag(self, image_base64: str) -> str: async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager() emoji_manager = get_emoji_manager()
if isinstance(image_base64, str): if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
@@ -120,6 +121,7 @@ class ImageManager:
# 优先使用EmojiManager查询已注册表情包的描述 # 优先使用EmojiManager查询已注册表情包的描述
try: try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager() emoji_manager = get_emoji_manager()
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash) tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
if tags: if tags:

View File

@@ -205,6 +205,7 @@ class DatabaseMessages(BaseDataModel):
"chat_info_user_cardname": self.chat_info.user_info.user_cardname, "chat_info_user_cardname": self.chat_info.user_info.user_cardname,
} }
@dataclass(init=False) @dataclass(init=False)
class DatabaseActionRecords(BaseDataModel): class DatabaseActionRecords(BaseDataModel):
def __init__( def __init__(

View File

@@ -2,10 +2,12 @@ from dataclasses import dataclass
from typing import Optional, List, TYPE_CHECKING from typing import Optional, List, TYPE_CHECKING
from . import BaseDataModel from . import BaseDataModel
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.message_data_model import ReplySetModel from src.common.data_models.message_data_model import ReplySetModel
from src.llm_models.payload_content.tool_option import ToolCall from src.llm_models.payload_content.tool_option import ToolCall
@dataclass @dataclass
class LLMGenerationDataModel(BaseDataModel): class LLMGenerationDataModel(BaseDataModel):
content: Optional[str] = None content: Optional[str] = None

View File

@@ -271,9 +271,6 @@ class PersonInfo(BaseModel):
attitude_to_me = TextField(null=True) # 对bot的态度 attitude_to_me = TextField(null=True) # 对bot的态度
attitude_to_me_confidence = FloatField(null=True) # 对bot的态度置信度 attitude_to_me_confidence = FloatField(null=True) # 对bot的态度置信度
class Meta: class Meta:
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
table_name = "person_info" table_name = "person_info"
@@ -299,6 +296,7 @@ class GroupInfo(BaseModel):
# database = db # 继承自 BaseModel # database = db # 继承自 BaseModel
table_name = "group_info" table_name = "group_info"
class Expression(BaseModel): class Expression(BaseModel):
""" """
用于存储表达风格的模型。 用于存储表达风格的模型。
@@ -315,6 +313,7 @@ class Expression(BaseModel):
class Meta: class Meta:
table_name = "expression" table_name = "expression"
class GraphNodes(BaseModel): class GraphNodes(BaseModel):
""" """
用于存储记忆图节点的模型 用于存储记忆图节点的模型
@@ -504,8 +503,9 @@ def sync_field_constraints():
# 获取当前表结构信息 # 获取当前表结构信息
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')") cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]} current_schema = {
for row in cursor.fetchall()} row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall()
}
# 检查每个模型字段的约束 # 检查每个模型字段的约束
constraints_to_fix = [] constraints_to_fix = []
@@ -513,29 +513,33 @@ def sync_field_constraints():
if field_name not in current_schema: if field_name not in current_schema:
continue # 字段不存在,跳过 continue # 字段不存在,跳过
current_notnull = current_schema[field_name]['notnull'] current_notnull = current_schema[field_name]["notnull"]
model_allows_null = field_obj.null model_allows_null = field_obj.null
# 如果模型允许 null 但数据库字段不允许 null需要修复 # 如果模型允许 null 但数据库字段不允许 null需要修复
if model_allows_null and current_notnull: if model_allows_null and current_notnull:
constraints_to_fix.append({ constraints_to_fix.append(
'field_name': field_name, {
'field_obj': field_obj, "field_name": field_name,
'action': 'allow_null', "field_obj": field_obj,
'current_constraint': 'NOT NULL', "action": "allow_null",
'target_constraint': 'NULL' "current_constraint": "NOT NULL",
}) "target_constraint": "NULL",
}
)
logger.warning(f"字段 '{field_name}' 约束不一致: 模型允许NULL但数据库为NOT NULL") logger.warning(f"字段 '{field_name}' 约束不一致: 模型允许NULL但数据库为NOT NULL")
# 如果模型不允许 null 但数据库字段允许 null也需要修复但要小心 # 如果模型不允许 null 但数据库字段允许 null也需要修复但要小心
elif not model_allows_null and not current_notnull: elif not model_allows_null and not current_notnull:
constraints_to_fix.append({ constraints_to_fix.append(
'field_name': field_name, {
'field_obj': field_obj, "field_name": field_name,
'action': 'disallow_null', "field_obj": field_obj,
'current_constraint': 'NULL', "action": "disallow_null",
'target_constraint': 'NOT NULL' "current_constraint": "NULL",
}) "target_constraint": "NOT NULL",
}
)
logger.warning(f"字段 '{field_name}' 约束不一致: 模型不允许NULL但数据库允许NULL") logger.warning(f"字段 '{field_name}' 约束不一致: 模型不允许NULL但数据库允许NULL")
# 修复约束不一致的字段 # 修复约束不一致的字段
@@ -575,7 +579,7 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
# 4. 从备份表恢复数据 # 4. 从备份表恢复数据
# 获取字段列表 # 获取字段列表
fields = list(model._meta.fields.keys()) fields = list(model._meta.fields.keys())
fields_str = ', '.join(fields) fields_str = ", ".join(fields)
# 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据 # 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据
# 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值 # 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值
@@ -583,8 +587,7 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
# 检查是否有字段需要从 NULL 改为 NOT NULL # 检查是否有字段需要从 NULL 改为 NOT NULL
null_to_notnull_fields = [ null_to_notnull_fields = [
constraint['field_name'] for constraint in constraints_to_fix constraint["field_name"] for constraint in constraints_to_fix if constraint["action"] == "disallow_null"
if constraint['action'] == 'disallow_null'
] ]
if null_to_notnull_fields: if null_to_notnull_fields:
@@ -612,7 +615,7 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
else: else:
select_fields.append(field_name) select_fields.append(field_name)
select_str = ', '.join(select_fields) select_str = ", ".join(select_fields)
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}" insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}"
db.execute_sql(insert_sql) db.execute_sql(insert_sql)
@@ -633,8 +636,10 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
# 记录修复的约束 # 记录修复的约束
for constraint in constraints_to_fix: for constraint in constraints_to_fix:
logger.info(f"已修复字段 '{constraint['field_name']}': " logger.info(
f"{constraint['current_constraint']} -> {constraint['target_constraint']}") f"已修复字段 '{constraint['field_name']}': "
f"{constraint['current_constraint']} -> {constraint['target_constraint']}"
)
except Exception as e: except Exception as e:
logger.exception(f"修复表 '{table_name}' 约束时出错: {e}") logger.exception(f"修复表 '{table_name}' 约束时出错: {e}")
@@ -681,8 +686,9 @@ def check_field_constraints():
# 获取当前表结构信息 # 获取当前表结构信息
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')") cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
current_schema = {row[1]: {'type': row[2], 'notnull': bool(row[3]), 'default': row[4]} current_schema = {
for row in cursor.fetchall()} row[1]: {"type": row[2], "notnull": bool(row[3]), "default": row[4]} for row in cursor.fetchall()
}
table_inconsistencies = [] table_inconsistencies = []
@@ -691,25 +697,29 @@ def check_field_constraints():
if field_name not in current_schema: if field_name not in current_schema:
continue continue
current_notnull = current_schema[field_name]['notnull'] current_notnull = current_schema[field_name]["notnull"]
model_allows_null = field_obj.null model_allows_null = field_obj.null
if model_allows_null and current_notnull: if model_allows_null and current_notnull:
table_inconsistencies.append({ table_inconsistencies.append(
'field_name': field_name, {
'issue': 'model_allows_null_but_db_not_null', "field_name": field_name,
'model_constraint': 'NULL', "issue": "model_allows_null_but_db_not_null",
'db_constraint': 'NOT NULL', "model_constraint": "NULL",
'recommended_action': 'allow_null' "db_constraint": "NOT NULL",
}) "recommended_action": "allow_null",
}
)
elif not model_allows_null and not current_notnull: elif not model_allows_null and not current_notnull:
table_inconsistencies.append({ table_inconsistencies.append(
'field_name': field_name, {
'issue': 'model_not_null_but_db_allows_null', "field_name": field_name,
'model_constraint': 'NOT NULL', "issue": "model_not_null_but_db_allows_null",
'db_constraint': 'NULL', "model_constraint": "NOT NULL",
'recommended_action': 'disallow_null' "db_constraint": "NULL",
}) "recommended_action": "disallow_null",
}
)
if table_inconsistencies: if table_inconsistencies:
inconsistencies[table_name] = table_inconsistencies inconsistencies[table_name] = table_inconsistencies
@@ -720,10 +730,5 @@ def check_field_constraints():
return inconsistencies return inconsistencies
# 模块加载时调用初始化函数 # 模块加载时调用初始化函数
initialize_database(sync_constraints=True) initialize_database(sync_constraints=True)

View File

@@ -339,24 +339,18 @@ MODULE_COLORS = {
# 67 具体的颜色编号0-255这里是较暗的蓝色 # 67 具体的颜色编号0-255这里是较暗的蓝色
"sender": "\033[38;5;24m", # 67号色较暗的蓝色适合不显眼的日志 "sender": "\033[38;5;24m", # 67号色较暗的蓝色适合不显眼的日志
"send_api": "\033[38;5;24m", # 208号色橙色适合突出显示 "send_api": "\033[38;5;24m", # 208号色橙色适合突出显示
# 生成 # 生成
"replyer": "\033[38;5;208m", # 橙色 "replyer": "\033[38;5;208m", # 橙色
"llm_api": "\033[38;5;208m", # 橙色 "llm_api": "\033[38;5;208m", # 橙色
# 消息处理 # 消息处理
"chat": "\033[38;5;82m", # 亮蓝色 "chat": "\033[38;5;82m", # 亮蓝色
"chat_image": "\033[38;5;68m", # 浅蓝色 "chat_image": "\033[38;5;68m", # 浅蓝色
# emoji
#emoji
"emoji": "\033[38;5;214m", # 橙黄色,偏向橙色 "emoji": "\033[38;5;214m", # 橙黄色,偏向橙色
"emoji_api": "\033[38;5;214m", # 橙黄色,偏向橙色 "emoji_api": "\033[38;5;214m", # 橙黄色,偏向橙色
# 核心模块 # 核心模块
"main": "\033[1;97m", # 亮白色+粗体 (主程序) "main": "\033[1;97m", # 亮白色+粗体 (主程序)
"memory": "\033[38;5;34m", # 天蓝色 "memory": "\033[38;5;34m", # 天蓝色
"config": "\033[93m", # 亮黄色 "config": "\033[93m", # 亮黄色
"common": "\033[95m", # 亮紫色 "common": "\033[95m", # 亮紫色
"tools": "\033[96m", # 亮青色 "tools": "\033[96m", # 亮青色
@@ -367,9 +361,6 @@ MODULE_COLORS = {
"llm_models": "\033[36m", # 青色 "llm_models": "\033[36m", # 青色
"remote": "\033[38;5;242m", # 深灰色,更不显眼 "remote": "\033[38;5;242m", # 深灰色,更不显眼
"planner": "\033[36m", "planner": "\033[36m",
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼 "relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
# 聊天相关模块 # 聊天相关模块
"normal_chat": "\033[38;5;81m", # 亮蓝绿色 "normal_chat": "\033[38;5;81m", # 亮蓝绿色
@@ -379,11 +370,9 @@ MODULE_COLORS = {
"background_tasks": "\033[38;5;240m", # 灰色 "background_tasks": "\033[38;5;240m", # 灰色
"chat_message": "\033[38;5;45m", # 青色 "chat_message": "\033[38;5;45m", # 青色
"chat_stream": "\033[38;5;51m", # 亮青色 "chat_stream": "\033[38;5;51m", # 亮青色
"message_storage": "\033[38;5;33m", # 深蓝色 "message_storage": "\033[38;5;33m", # 深蓝色
"expressor": "\033[38;5;166m", # 橙色 "expressor": "\033[38;5;166m", # 橙色
# 专注聊天模块 # 专注聊天模块
"memory_activator": "\033[38;5;117m", # 天蓝色 "memory_activator": "\033[38;5;117m", # 天蓝色
# 插件系统 # 插件系统
"plugins": "\033[31m", # 红色 "plugins": "\033[31m", # 红色
@@ -412,7 +401,6 @@ MODULE_COLORS = {
# 工具和实用模块 # 工具和实用模块
"prompt_build": "\033[38;5;105m", # 紫色 "prompt_build": "\033[38;5;105m", # 紫色
"chat_utils": "\033[38;5;111m", # 蓝色 "chat_utils": "\033[38;5;111m", # 蓝色
"maibot_statistic": "\033[38;5;129m", # 紫色 "maibot_statistic": "\033[38;5;129m", # 紫色
# 特殊功能插件 # 特殊功能插件
"mute_plugin": "\033[38;5;240m", # 灰色 "mute_plugin": "\033[38;5;240m", # 灰色
@@ -447,10 +435,8 @@ MODULE_ALIASES = {
"llm_api": "生成API", "llm_api": "生成API",
"emoji": "表情包", "emoji": "表情包",
"emoji_api": "表情包API", "emoji_api": "表情包API",
"chat": "所见", "chat": "所见",
"chat_image": "识图", "chat_image": "识图",
"action_manager": "动作", "action_manager": "动作",
"memory_activator": "记忆", "memory_activator": "记忆",
"tool_use": "工具", "tool_use": "工具",
@@ -460,7 +446,6 @@ MODULE_ALIASES = {
"memory": "记忆", "memory": "记忆",
"tool_executor": "工具", "tool_executor": "工具",
"hfc": "聊天节奏", "hfc": "聊天节奏",
"plugin_manager": "插件", "plugin_manager": "插件",
"relationship_builder": "关系", "relationship_builder": "关系",
"llm_models": "模型", "llm_models": "模型",

View File

@@ -47,6 +47,7 @@ class PersonalityConfig(ConfigBase):
interest: str = "" interest: str = ""
"""兴趣""" """兴趣"""
@dataclass @dataclass
class RelationshipConfig(ConfigBase): class RelationshipConfig(ConfigBase):
"""关系配置类""" """关系配置类"""
@@ -80,7 +81,6 @@ class ChatConfig(ConfigBase):
# 合并后的时段频率配置 # 合并后的时段频率配置
talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: []) talk_frequency_adjust: list[list[str]] = field(default_factory=lambda: [])
focus_value: float = 0.5 focus_value: float = 0.5
"""麦麦的专注思考能力越低越容易专注消耗token也越多""" """麦麦的专注思考能力越低越容易专注消耗token也越多"""
@@ -112,7 +112,6 @@ class ChatConfig(ConfigBase):
""" """
@dataclass @dataclass
class MessageReceiveConfig(ConfigBase): class MessageReceiveConfig(ConfigBase):
"""消息接收配置类""" """消息接收配置类"""
@@ -123,6 +122,7 @@ class MessageReceiveConfig(ConfigBase):
ban_msgs_regex: set[str] = field(default_factory=lambda: set()) ban_msgs_regex: set[str] = field(default_factory=lambda: set())
"""过滤正则表达式列表""" """过滤正则表达式列表"""
@dataclass @dataclass
class ExpressionConfig(ConfigBase): class ExpressionConfig(ConfigBase):
"""表达配置类""" """表达配置类"""

View File

@@ -531,7 +531,7 @@ class OpenaiClient(BaseClient):
# 添加详细的错误信息以便调试 # 添加详细的错误信息以便调试
logger.error(f"OpenAI API连接错误嵌入模型: {str(e)}") logger.error(f"OpenAI API连接错误嵌入模型: {str(e)}")
logger.error(f"错误类型: {type(e)}") logger.error(f"错误类型: {type(e)}")
if hasattr(e, '__cause__') and e.__cause__: if hasattr(e, "__cause__") and e.__cause__:
logger.error(f"底层错误: {str(e.__cause__)}") logger.error(f"底层错误: {str(e.__cause__)}")
raise NetworkConnectionError() from e raise NetworkConnectionError() from e
except APIStatusError as e: except APIStatusError as e:

View File

@@ -48,8 +48,7 @@ def _json_schema_type_check(instance) -> str | None:
elif not isinstance(instance["name"], str) or instance["name"].strip() == "": elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
return "schema的'name'字段必须是非空字符串" return "schema的'name'字段必须是非空字符串"
if "description" in instance and ( if "description" in instance and (
not isinstance(instance["description"], str) not isinstance(instance["description"], str) or instance["description"].strip() == ""
or instance["description"].strip() == ""
): ):
return "schema的'description'字段只能填入非空字符串" return "schema的'description'字段只能填入非空字符串"
if "schema" not in instance: if "schema" not in instance:
@@ -101,9 +100,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
# 如果当前Schema是列表则遍历每个元素 # 如果当前Schema是列表则遍历每个元素
for i in range(len(sub_schema)): for i in range(len(sub_schema)):
if isinstance(sub_schema[i], dict): if isinstance(sub_schema[i], dict):
sub_schema[i] = link_definitions_recursive( sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
f"{path}/{str(i)}", sub_schema[i], defs
)
else: else:
# 否则为字典 # 否则为字典
if "$defs" in sub_schema: if "$defs" in sub_schema:
@@ -125,9 +122,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
for key, value in sub_schema.items(): for key, value in sub_schema.items():
if isinstance(value, (dict, list)): if isinstance(value, (dict, list)):
# 如果当前值是字典或列表,则递归调用 # 如果当前值是字典或列表,则递归调用
sub_schema[key] = link_definitions_recursive( sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs)
f"{path}/{key}", value, defs
)
return sub_schema return sub_schema
@@ -163,9 +158,7 @@ class RespFormat:
def _generate_schema_from_model(schema): def _generate_schema_from_model(schema):
json_schema = { json_schema = {
"name": schema.__name__, "name": schema.__name__,
"schema": _remove_defs( "schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))),
_link_definitions(_remove_title(schema.model_json_schema()))
),
"strict": False, "strict": False,
} }
if schema.__doc__: if schema.__doc__:

View File

@@ -155,7 +155,13 @@ class LLMUsageRecorder:
logger.error(f"创建 LLMUsage 表失败: {str(e)}") logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def record_usage_to_database( def record_usage_to_database(
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0 self,
model_info: ModelInfo,
model_usage: UsageRecord,
user_id: str,
request_type: str,
endpoint: str,
time_cost: float = 0.0,
): ):
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
@@ -173,7 +179,7 @@ class LLMUsageRecorder:
completion_tokens=model_usage.completion_tokens or 0, completion_tokens=model_usage.completion_tokens or 0,
total_tokens=model_usage.total_tokens or 0, total_tokens=model_usage.total_tokens or 0,
cost=total_cost or 0.0, cost=total_cost or 0.0,
time_cost = round(time_cost or 0.0, 3), time_cost=round(time_cost or 0.0, 3),
status="success", status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField timestamp=datetime.now(), # Peewee 会处理 DateTimeField
) )
@@ -186,4 +192,5 @@ class LLMUsageRecorder:
except Exception as e: except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}") logger.error(f"记录token使用情况失败: {str(e)}")
llm_usage_recorder = LLMUsageRecorder() llm_usage_recorder = LLMUsageRecorder()

View File

@@ -23,17 +23,17 @@ class ContextMessage:
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊" self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
# 识别消息类型 # 识别消息类型
self.is_gift = getattr(message, 'is_gift', False) self.is_gift = getattr(message, "is_gift", False)
self.is_superchat = getattr(message, 'is_superchat', False) self.is_superchat = getattr(message, "is_superchat", False)
# 添加礼物和SC相关信息 # 添加礼物和SC相关信息
if self.is_gift: if self.is_gift:
self.gift_name = getattr(message, 'gift_name', '') self.gift_name = getattr(message, "gift_name", "")
self.gift_count = getattr(message, 'gift_count', '1') self.gift_count = getattr(message, "gift_count", "1")
self.content = f"送出了 {self.gift_name} x{self.gift_count}" self.content = f"送出了 {self.gift_name} x{self.gift_count}"
elif self.is_superchat: elif self.is_superchat:
self.superchat_price = getattr(message, 'superchat_price', '0') self.superchat_price = getattr(message, "superchat_price", "0")
self.superchat_message = getattr(message, 'superchat_message_text', '') self.superchat_message = getattr(message, "superchat_message_text", "")
if self.superchat_message: if self.superchat_message:
self.content = f"{self.superchat_price}] {self.superchat_message}" self.content = f"{self.superchat_price}] {self.superchat_message}"
else: else:
@@ -47,7 +47,7 @@ class ContextMessage:
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"), "timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
"group_name": self.group_name, "group_name": self.group_name,
"is_gift": self.is_gift, "is_gift": self.is_gift,
"is_superchat": self.is_superchat "is_superchat": self.is_superchat,
} }
@@ -83,20 +83,20 @@ class ContextWebManager:
self.app = web.Application() self.app = web.Application()
# 设置CORS # 设置CORS
cors = aiohttp_cors.setup(self.app, defaults={ cors = aiohttp_cors.setup(
"*": aiohttp_cors.ResourceOptions( self.app,
allow_credentials=True, defaults={
expose_headers="*", "*": aiohttp_cors.ResourceOptions(
allow_headers="*", allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*"
allow_methods="*" )
) },
}) )
# 添加路由 # 添加路由
self.app.router.add_get('/', self.index_handler) self.app.router.add_get("/", self.index_handler)
self.app.router.add_get('/ws', self.websocket_handler) self.app.router.add_get("/ws", self.websocket_handler)
self.app.router.add_get('/api/contexts', self.get_contexts_handler) self.app.router.add_get("/api/contexts", self.get_contexts_handler)
self.app.router.add_get('/debug', self.debug_handler) self.app.router.add_get("/debug", self.debug_handler)
# 为所有路由添加CORS # 为所有路由添加CORS
for route in list(self.app.router.routes()): for route in list(self.app.router.routes()):
@@ -105,7 +105,7 @@ class ContextWebManager:
self.runner = web.AppRunner(self.app) self.runner = web.AppRunner(self.app)
await self.runner.setup() await self.runner.setup()
self.site = web.TCPSite(self.runner, 'localhost', self.port) self.site = web.TCPSite(self.runner, "localhost", self.port)
await self.site.start() await self.site.start()
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}") logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
@@ -135,7 +135,8 @@ class ContextWebManager:
async def index_handler(self, request): async def index_handler(self, request):
"""主页处理器""" """主页处理器"""
html_content = ''' html_content = (
"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
@@ -286,7 +287,9 @@ class ContextWebManager:
function connectWebSocket() { function connectWebSocket() {
console.log('正在连接WebSocket...'); console.log('正在连接WebSocket...');
ws = new WebSocket('ws://localhost:''' + str(self.port) + '''/ws'); ws = new WebSocket('ws://localhost:"""
+ str(self.port)
+ """/ws');
ws.onopen = function() { ws.onopen = function() {
console.log('WebSocket连接已建立'); console.log('WebSocket连接已建立');
@@ -470,8 +473,9 @@ class ContextWebManager:
</script> </script>
</body> </body>
</html> </html>
''' """
return web.Response(text=html_content, content_type='text/html') )
return web.Response(text=html_content, content_type="text/html")
async def websocket_handler(self, request): async def websocket_handler(self, request):
"""WebSocket处理器""" """WebSocket处理器"""
@@ -486,7 +490,7 @@ class ContextWebManager:
async for msg in ws: async for msg in ws:
if msg.type == WSMsgType.ERROR: if msg.type == WSMsgType.ERROR:
logger.error(f'WebSocket错误: {ws.exception()}') logger.error(f"WebSocket错误: {ws.exception()}")
break break
# 清理断开的连接 # 清理断开的连接
@@ -506,7 +510,7 @@ class ContextWebManager:
all_context_msgs.sort(key=lambda x: x.timestamp) all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式 # 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]] contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息") logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
return web.json_response({"contexts": contexts_data}) return web.json_response({"contexts": contexts_data})
@@ -529,14 +533,14 @@ class ContextWebManager:
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>' messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
chats_html += f''' chats_html += f"""
<div class="chat"> <div class="chat">
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3> <h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
{messages_html} {messages_html}
</div> </div>
''' """
html_content = f''' html_content = f"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
@@ -578,9 +582,9 @@ class ContextWebManager:
</script> </script>
</body> </body>
</html> </html>
''' """
return web.Response(text=html_content, content_type='text/html') return web.Response(text=html_content, content_type="text/html")
async def add_message(self, chat_id: str, message: MessageRecv): async def add_message(self, chat_id: str, message: MessageRecv):
"""添加新消息到上下文""" """添加新消息到上下文"""
@@ -594,14 +598,18 @@ class ContextWebManager:
# 统计当前总消息数 # 统计当前总消息数
total_messages = sum(len(contexts) for contexts in self.contexts.values()) total_messages = sum(len(contexts) for contexts in self.contexts.values())
logger.info(f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}") logger.info(
f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}"
)
# 调试:打印当前所有消息 # 调试:打印当前所有消息
logger.info("📝 当前上下文中的所有消息:") logger.info("📝 当前上下文中的所有消息:")
for cid, contexts in self.contexts.items(): for cid, contexts in self.contexts.items():
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息") logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
for i, msg in enumerate(contexts): for i, msg in enumerate(contexts):
logger.info(f" {i+1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}...") logger.info(
f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..."
)
# 广播更新给所有WebSocket连接 # 广播更新给所有WebSocket连接
await self.broadcast_contexts() await self.broadcast_contexts()
@@ -616,7 +624,7 @@ class ContextWebManager:
all_context_msgs.sort(key=lambda x: x.timestamp) all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式 # 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]] contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
data = {"contexts": contexts_data} data = {"contexts": contexts_data}
await ws.send_str(json.dumps(data, ensure_ascii=False)) await ws.send_str(json.dumps(data, ensure_ascii=False))
@@ -635,7 +643,7 @@ class ContextWebManager:
all_context_msgs.sort(key=lambda x: x.timestamp) all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式 # 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]] contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
data = {"contexts": contexts_data} data = {"contexts": contexts_data}
message = json.dumps(data, ensure_ascii=False) message = json.dumps(data, ensure_ascii=False)
@@ -682,4 +690,3 @@ async def init_context_web_manager():
manager = get_context_web_manager() manager = get_context_web_manager()
await manager.start_server() await manager.start_server()
return manager return manager

View File

@@ -11,6 +11,7 @@ logger = get_logger("gift_manager")
@dataclass @dataclass
class PendingGift: class PendingGift:
"""等待中的礼物消息""" """等待中的礼物消息"""
message: MessageRecvS4U message: MessageRecvS4U
total_count: int total_count: int
timer_task: asyncio.Task timer_task: asyncio.Task
@@ -25,7 +26,9 @@ class GiftManager:
self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {} self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {}
self.debounce_timeout = 5.0 # 3秒防抖时间 self.debounce_timeout = 5.0 # 3秒防抖时间
async def handle_gift(self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None) -> bool: async def handle_gift(
self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None
) -> bool:
"""处理礼物消息,返回是否应该立即处理 """处理礼物消息,返回是否应该立即处理
Args: Args:
@@ -73,17 +76,12 @@ class GiftManager:
# 如果无法解析数量,保持原有数量不变 # 如果无法解析数量,保持原有数量不变
# 重新创建定时器 # 重新创建定时器
pending_gift.timer_task = asyncio.create_task( pending_gift.timer_task = asyncio.create_task(self._gift_timeout(gift_key))
self._gift_timeout(gift_key)
)
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}") logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
async def _create_pending_gift( async def _create_pending_gift(
self, self, gift_key: Tuple[str, str], message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]]
gift_key: Tuple[str, str],
message: MessageRecvS4U,
callback: Optional[Callable[[MessageRecvS4U], None]]
) -> None: ) -> None:
"""创建新的等待礼物""" """创建新的等待礼物"""
try: try:
@@ -96,12 +94,7 @@ class GiftManager:
timer_task = asyncio.create_task(self._gift_timeout(gift_key)) timer_task = asyncio.create_task(self._gift_timeout(gift_key))
# 创建等待礼物对象 # 创建等待礼物对象
pending_gift = PendingGift( pending_gift = PendingGift(message=message, total_count=initial_count, timer_task=timer_task, callback=callback)
message=message,
total_count=initial_count,
timer_task=timer_task,
callback=callback
)
self.pending_gifts[gift_key] = pending_gift self.pending_gifts[gift_key] = pending_gift
@@ -152,4 +145,3 @@ class GiftManager:
# 创建全局礼物管理器实例 # 创建全局礼物管理器实例
gift_manager = GiftManager() gift_manager = GiftManager()

View File

@@ -2,7 +2,7 @@ class InternalManager:
def __init__(self): def __init__(self):
self.now_internal_state = str() self.now_internal_state = str()
def set_internal_state(self,internal_state:str): def set_internal_state(self, internal_state: str):
self.now_internal_state = internal_state self.now_internal_state = internal_state
def get_internal_state(self): def get_internal_state(self):
@@ -11,4 +11,5 @@ class InternalManager:
def get_internal_state_str(self): def get_internal_state_str(self):
return f"你今天的直播内容是直播QQ水群你正在一边回复弹幕一边在QQ群聊天你在QQ群聊天中产生的想法是{self.now_internal_state}" return f"你今天的直播内容是直播QQ水群你正在一边回复弹幕一边在QQ群聊天你在QQ群聊天中产生的想法是{self.now_internal_state}"
internal_manager = InternalManager() internal_manager = InternalManager()

View File

@@ -16,7 +16,6 @@ import json
from .s4u_mood_manager import mood_manager from .s4u_mood_manager import mood_manager
from src.mais4u.s4u_config import s4u_config from src.mais4u.s4u_config import s4u_config
from src.person_info.person_info import get_person_id from src.person_info.person_info import get_person_id
from .super_chat_manager import get_super_chat_manager
from .yes_or_no import yes_or_no_head from .yes_or_no import yes_or_no_head
logger = get_logger("S4U_chat") logger = get_logger("S4U_chat")
@@ -40,9 +39,6 @@ class MessageSenderContainer:
self.voice_done = "" self.voice_done = ""
async def add_message(self, chunk: str): async def add_message(self, chunk: str):
"""向队列中添加一个消息块。""" """向队列中添加一个消息块。"""
await self.queue.put(chunk) await self.queue.put(chunk)
@@ -199,7 +195,7 @@ class S4UChat:
self.gpt.chat_stream = self.chat_stream self.gpt.chat_stream = self.chat_stream
self.interest_dict: Dict[str, float] = {} # 用户兴趣分 self.interest_dict: Dict[str, float] = {} # 用户兴趣分
self.internal_message :List[MessageRecvS4U] = [] self.internal_message: List[MessageRecvS4U] = []
self.msg_id = "" self.msg_id = ""
self.voice_done = "" self.voice_done = ""
@@ -252,14 +248,13 @@ class S4UChat:
else: else:
self.interest_dict[person_id] = 0 self.interest_dict[person_id] = 0
async def add_message(self, message: MessageRecvS4U|MessageRecv) -> None: async def add_message(self, message: MessageRecvS4U | MessageRecv) -> None:
self.decay_interest_score() self.decay_interest_score()
"""根据VIP状态和中断逻辑将消息放入相应队列。""" """根据VIP状态和中断逻辑将消息放入相应队列。"""
user_id = message.message_info.user_info.user_id user_id = message.message_info.user_info.user_id
platform = message.message_info.platform platform = message.message_info.platform
person_id = get_person_id(platform, user_id) _person_id = get_person_id(platform, user_id)
# try: # try:
# is_gift = message.is_gift # is_gift = message.is_gift
@@ -292,8 +287,11 @@ class S4UChat:
new_priority_score = self._calculate_base_priority_score(message, priority_info) new_priority_score = self._calculate_base_priority_score(message, priority_info)
should_interrupt = False should_interrupt = False
if (s4u_config.enable_message_interruption and if (
self._current_generation_task and not self._current_generation_task.done()): s4u_config.enable_message_interruption
and self._current_generation_task
and not self._current_generation_task.done()
):
if self._current_message_being_replied: if self._current_message_being_replied:
current_queue, current_priority, _, current_msg = self._current_message_being_replied current_queue, current_priority, _, current_msg = self._current_message_being_replied
@@ -359,7 +357,9 @@ class S4UChat:
neg_priority, entry_count, timestamp, message = item neg_priority, entry_count, timestamp, message = item
# 如果消息在最近N条消息范围内保留它 # 如果消息在最近N条消息范围内保留它
logger.info(f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}") logger.info(
f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}"
)
if entry_count >= cutoff_counter: if entry_count >= cutoff_counter:
temp_messages.append(item) temp_messages.append(item)
@@ -375,8 +375,12 @@ class S4UChat:
self._normal_queue.put_nowait(item) self._normal_queue.put_nowait(item)
if removed_count > 0: if removed_count > 0:
logger.info(f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}现在counter:{self._entry_counter}被移除") logger.info(
logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range.") f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}现在counter:{self._entry_counter}被移除"
)
logger.info(
f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range."
)
async def _message_processor(self): async def _message_processor(self):
"""调度器优先处理VIP队列然后处理普通队列。""" """调度器优先处理VIP队列然后处理普通队列。"""
@@ -396,7 +400,6 @@ class S4UChat:
queue_name = "vip" queue_name = "vip"
# 其次处理普通队列 # 其次处理普通队列
elif not self._normal_queue.empty(): elif not self._normal_queue.empty():
neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait() neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
priority = -neg_priority priority = -neg_priority
# 检查普通消息是否超时 # 检查普通消息是否超时
@@ -417,7 +420,9 @@ class S4UChat:
entry_count = 0 entry_count = 0
queue_name = "internal" queue_name = "internal"
logger.info(f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}...") logger.info(
f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}..."
)
else: else:
continue # 没有消息了,回去等事件 continue # 没有消息了,回去等事件
@@ -458,12 +463,10 @@ class S4UChat:
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True) logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
await asyncio.sleep(1) await asyncio.sleep(1)
def get_processing_message_id(self): def get_processing_message_id(self):
self.last_msg_id = self.msg_id self.last_msg_id = self.msg_id
self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}" self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}"
async def _generate_and_send(self, message: MessageRecv): async def _generate_and_send(self, message: MessageRecv):
"""为单个消息生成文本回复。整个过程可以被中断。""" """为单个消息生成文本回复。整个过程可以被中断。"""
self._is_replying = True self._is_replying = True
@@ -516,7 +519,12 @@ class S4UChat:
total_chars_sent = len("麦麦不知道哦") total_chars_sent = len("麦麦不知道哦")
mood = mood_manager.get_mood_by_chat_id(self.stream_id) mood = mood_manager.get_mood_by_chat_id(self.stream_id)
await yes_or_no_head(text = total_chars_sent,emotion = mood.mood_state,chat_history=message.processed_plain_text,chat_id=self.stream_id) await yes_or_no_head(
text=total_chars_sent,
emotion=mood.mood_state,
chat_history=message.processed_plain_text,
chat_id=self.stream_id,
)
# 等待所有文本消息发送完成 # 等待所有文本消息发送完成
await sender_container.close() await sender_container.close()
@@ -524,8 +532,6 @@ class S4UChat:
await chat_watching.on_thinking_finished() await chat_watching.on_thinking_finished()
start_time = time.time() start_time = time.time()
logged = False logged = False
while not self.go_processing(): while not self.go_processing():
@@ -576,4 +582,3 @@ class S4UChat:
await self._processing_task await self._processing_task
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info(f"处理任务已成功取消: {self.stream_name}") logger.info(f"处理任务已成功取消: {self.stream_name}")

View File

@@ -40,7 +40,7 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
if global_config.memory.enable_memory: if global_config.memory.enable_memory:
with Timer("记忆激活"): with Timer("记忆激活"):
interested_rate,_ ,_= await hippocampus_manager.get_activate_from_text( interested_rate, _, _ = await hippocampus_manager.get_activate_from_text(
message.processed_plain_text, message.processed_plain_text,
fast_retrieval=True, fast_retrieval=True,
) )
@@ -133,20 +133,16 @@ class S4UMessageProcessor:
if await self.handle_screen_message(message): if await self.handle_screen_message(message):
return return
await self.storage.store_message(message, chat) await self.storage.store_message(message, chat)
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat) s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
await s4u_chat.add_message(message) await s4u_chat.add_message(message)
_interested_rate, _ = await _calculate_interest(message) _interested_rate, _ = await _calculate_interest(message)
await mood_manager.start() await mood_manager.start()
# 一系列llm驱动的前处理 # 一系列llm驱动的前处理
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id) chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
asyncio.create_task(chat_mood.update_mood_by_message(message)) asyncio.create_task(chat_mood.update_mood_by_message(message))
@@ -167,30 +163,25 @@ class S4UMessageProcessor:
async def handle_internal_message(self, message: MessageRecvS4U): async def handle_internal_message(self, message: MessageRecvS4U):
if message.is_internal: if message.is_internal:
group_info = GroupInfo(platform="amaidesu_default", group_id=660154, group_name="内心")
group_info = GroupInfo(platform = "amaidesu_default",group_id = 660154,group_name = "内心") chat = await get_chat_manager().get_or_create_stream(
platform="amaidesu_default", user_info=message.message_info.user_info, group_info=group_info
chat = await get_chat_manager().get_or_create_stream(
platform = "amaidesu_default",
user_info = message.message_info.user_info,
group_info = group_info
) )
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat) s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
message.message_info.group_info = s4u_chat.chat_stream.group_info message.message_info.group_info = s4u_chat.chat_stream.group_info
message.message_info.platform = s4u_chat.chat_stream.platform message.message_info.platform = s4u_chat.chat_stream.platform
s4u_chat.internal_message.append(message) s4u_chat.internal_message.append(message)
s4u_chat._new_message_event.set() s4u_chat._new_message_event.set()
logger.info(
logger.info(f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}") f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}"
)
return True return True
return False return False
async def handle_screen_message(self, message: MessageRecvS4U): async def handle_screen_message(self, message: MessageRecvS4U):
if message.is_screen: if message.is_screen:
screen_manager.set_screen(message.screen_info) screen_manager.set_screen(message.screen_info)
@@ -209,7 +200,7 @@ class S4UMessageProcessor:
if message.is_gift: if message.is_gift:
return False return False
gift_keywords = ["送出了礼物", "礼物", "送出了","投喂"] gift_keywords = ["送出了礼物", "礼物", "送出了", "投喂"]
if any(keyword in message.processed_plain_text for keyword in gift_keywords): if any(keyword in message.processed_plain_text for keyword in gift_keywords):
message.is_fake_gift = True message.is_fake_gift = True
return True return True

View File

@@ -176,7 +176,7 @@ class PromptBuilder:
message_list_before_now = get_raw_msg_before_timestamp_with_chat( message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id, chat_id=chat_stream.stream_id,
timestamp=time.time(), timestamp=time.time(),
# sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if # sourcery skip: lift-duplicated-conditional, merge-duplicate-blocks, remove-redundant-if
limit=300, limit=300,
) )
@@ -228,13 +228,17 @@ class PromptBuilder:
last_speaking_user_id = start_speaking_user_id last_speaking_user_id = start_speaking_user_id
msg_seg_str = "对方的发言:\n" msg_seg_str = "对方的发言:\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n" msg_seg_str += (
f"{time.strftime('%H:%M:%S', time.localtime(first_msg.time))}: {first_msg.processed_plain_text}\n"
)
all_msg_seg_list = [] all_msg_seg_list = []
for msg in core_dialogue_list[1:]: for msg in core_dialogue_list[1:]:
speaker = msg.user_info.user_id speaker = msg.user_info.user_id
if speaker == last_speaking_user_id: if speaker == last_speaking_user_id:
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n" msg_seg_str += (
f"{time.strftime('%H:%M:%S', time.localtime(msg.time))}: {msg.processed_plain_text}\n"
)
else: else:
msg_seg_str = f"{msg_seg_str}\n" msg_seg_str = f"{msg_seg_str}\n"
all_msg_seg_list.append(msg_seg_str) all_msg_seg_list.append(msg_seg_str)

View File

@@ -14,10 +14,7 @@ logger = get_logger("s4u_stream_generator")
class S4UStreamGenerator: class S4UStreamGenerator:
def __init__(self): def __init__(self):
# 使用LLMRequest替代AsyncOpenAIClient # 使用LLMRequest替代AsyncOpenAIClient
self.llm_request = LLMRequest( self.llm_request = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="s4u_replyer")
model_set=model_config.model_task_config.replyer,
request_type="s4u_replyer"
)
self.current_model_name = "unknown model" self.current_model_name = "unknown model"
self.partial_response = "" self.partial_response = ""

View File

@@ -1,4 +1,3 @@
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
@@ -47,6 +46,7 @@ HEAD_CODE = {
"看向正前方": "(0,0,0)", "看向正前方": "(0,0,0)",
} }
class ChatWatching: class ChatWatching:
def __init__(self, chat_id: str): def __init__(self, chat_id: str):
self.chat_id: str = chat_id self.chat_id: str = chat_id
@@ -75,13 +75,13 @@ class ChatWatching:
message_type="state", content="start_viewing", stream_id=self.chat_id, storage_message=False message_type="state", content="start_viewing", stream_id=self.chat_id, storage_message=False
) )
async def on_internal_message_start(self): async def on_internal_message_start(self):
"""收到消息时调用""" """收到消息时调用"""
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="state", content="start_internal_thinking", stream_id=self.chat_id, storage_message=False message_type="state", content="start_internal_thinking", stream_id=self.chat_id, storage_message=False
) )
class WatchingManager: class WatchingManager:
def __init__(self): def __init__(self):
self.watching_list: list[ChatWatching] = [] self.watching_list: list[ChatWatching] = []
@@ -100,6 +100,7 @@ class WatchingManager:
return new_watching return new_watching
# 全局视线管理器实例 # 全局视线管理器实例
watching_manager = WatchingManager() watching_manager = WatchingManager()
"""全局视线管理器""" """全局视线管理器"""

View File

@@ -2,7 +2,7 @@ class ScreenManager:
def __init__(self): def __init__(self):
self.now_screen = str() self.now_screen = str()
def set_screen(self,screen_str:str): def set_screen(self, screen_str: str):
self.now_screen = screen_str self.now_screen = screen_str
def get_screen(self): def get_screen(self):
@@ -11,4 +11,5 @@ class ScreenManager:
def get_screen_str(self): def get_screen_str(self):
return f"你可以看见面前的屏幕,目前屏幕的内容是:现在千石可乐在和你一起直播,这是他正在操作的屏幕内容:{self.now_screen}" return f"你可以看见面前的屏幕,目前屏幕的内容是:现在千石可乐在和你一起直播,这是他正在操作的屏幕内容:{self.now_screen}"
screen_manager = ScreenManager() screen_manager = ScreenManager()

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.message import MessageRecvS4U from src.chat.message_receive.message import MessageRecvS4U
# 全局SuperChat管理器实例 # 全局SuperChat管理器实例
from src.mais4u.s4u_config import s4u_config from src.mais4u.s4u_config import s4u_config
@@ -44,7 +45,7 @@ class SuperChatRecord:
"timestamp": self.timestamp, "timestamp": self.timestamp,
"expire_time": self.expire_time, "expire_time": self.expire_time,
"group_name": self.group_name, "group_name": self.group_name,
"remaining_time": self.remaining_time() "remaining_time": self.remaining_time(),
} }
@@ -82,10 +83,7 @@ class SuperChatManager:
for chat_id in list(self.super_chats.keys()): for chat_id in list(self.super_chats.keys()):
original_count = len(self.super_chats[chat_id]) original_count = len(self.super_chats[chat_id])
# 移除过期的SuperChat # 移除过期的SuperChat
self.super_chats[chat_id] = [ self.super_chats[chat_id] = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
sc for sc in self.super_chats[chat_id]
if not sc.is_expired()
]
removed_count = original_count - len(self.super_chats[chat_id]) removed_count = original_count - len(self.super_chats[chat_id])
total_removed += removed_count total_removed += removed_count
@@ -153,7 +151,7 @@ class SuperChatManager:
user_info = message.message_info.user_info user_info = message.message_info.user_info
group_info = message.message_info.group_info group_info = message.message_info.group_info
chat_id = getattr(message, 'chat_stream', None) chat_id = getattr(message, "chat_stream", None)
if chat_id: if chat_id:
chat_id = chat_id.stream_id chat_id = chat_id.stream_id
else: else:
@@ -173,7 +171,7 @@ class SuperChatManager:
message_text=message.superchat_message_text or "", message_text=message.superchat_message_text or "",
timestamp=message.message_info.time, timestamp=message.message_info.time,
expire_time=expire_time, expire_time=expire_time,
group_name=group_info.group_name if group_info else None group_name=group_info.group_name if group_info else None,
) )
# 添加到对应聊天的SuperChat列表 # 添加到对应聊天的SuperChat列表
@@ -226,7 +224,9 @@ class SuperChatManager:
remaining_minutes = int(sc.remaining_time() / 60) remaining_minutes = int(sc.remaining_time() / 60)
remaining_seconds = int(sc.remaining_time() % 60) remaining_seconds = int(sc.remaining_time() % 60)
time_display = f"{remaining_minutes}{remaining_seconds}" if remaining_minutes > 0 else f"{remaining_seconds}" time_display = (
f"{remaining_minutes}{remaining_seconds}" if remaining_minutes > 0 else f"{remaining_seconds}"
)
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}" line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
if len(line) > 100: # 限制单行长度 if len(line) > 100: # 限制单行长度
@@ -267,13 +267,7 @@ class SuperChatManager:
superchats = self.get_superchats_by_chat(chat_id) superchats = self.get_superchats_by_chat(chat_id)
if not superchats: if not superchats:
return { return {"count": 0, "total_amount": 0, "average_amount": 0, "highest_amount": 0, "lowest_amount": 0}
"count": 0,
"total_amount": 0,
"average_amount": 0,
"highest_amount": 0,
"lowest_amount": 0
}
amounts = [sc.price for sc in superchats] amounts = [sc.price for sc in superchats]
@@ -282,7 +276,7 @@ class SuperChatManager:
"total_amount": sum(amounts), "total_amount": sum(amounts),
"average_amount": sum(amounts) / len(amounts), "average_amount": sum(amounts) / len(amounts),
"highest_amount": max(amounts), "highest_amount": max(amounts),
"lowest_amount": min(amounts) "lowest_amount": min(amounts),
} }
async def shutdown(self): # sourcery skip: use-contextlib-suppress async def shutdown(self): # sourcery skip: use-contextlib-suppress
@@ -296,14 +290,13 @@ class SuperChatManager:
logger.info("SuperChat管理器已关闭") logger.info("SuperChat管理器已关闭")
# sourcery skip: assign-if-exp # sourcery skip: assign-if-exp
if s4u_config.enable_s4u: if s4u_config.enable_s4u:
super_chat_manager = SuperChatManager() super_chat_manager = SuperChatManager()
else: else:
super_chat_manager = None super_chat_manager = None
def get_super_chat_manager() -> SuperChatManager: def get_super_chat_manager() -> SuperChatManager:
"""获取全局SuperChat管理器实例""" """获取全局SuperChat管理器实例"""

View File

@@ -10,10 +10,12 @@ from src.common.logger import get_logger
logger = get_logger("s4u_config") logger = get_logger("s4u_config")
# 新增兼容dict和tomlkit Table # 新增兼容dict和tomlkit Table
def is_dict_like(obj): def is_dict_like(obj):
return isinstance(obj, (dict, Table)) return isinstance(obj, (dict, Table))
# 新增递归将Table转为dict # 新增递归将Table转为dict
def table_to_dict(obj): def table_to_dict(obj):
if isinstance(obj, Table): if isinstance(obj, Table):
@@ -25,6 +27,7 @@ def table_to_dict(obj):
else: else:
return obj return obj
# 获取mais4u模块目录 # 获取mais4u模块目录
MAIS4U_ROOT = os.path.dirname(__file__) MAIS4U_ROOT = os.path.dirname(__file__)
CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config") CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config")
@@ -243,7 +246,6 @@ class S4UConfig(S4UConfigBase):
# 兼容性字段,保持向后兼容 # 兼容性字段,保持向后兼容
@dataclass @dataclass
class S4UGlobalConfig(S4UConfigBase): class S4UGlobalConfig(S4UConfigBase):
"""S4U总配置类""" """S4U总配置类"""
@@ -354,9 +356,9 @@ def load_s4u_config(config_path: str) -> S4UGlobalConfig:
logger.critical("S4U配置文件解析失败") logger.critical("S4U配置文件解析失败")
raise e raise e
# 初始化S4U配置 # 初始化S4U配置
logger.info(f"S4U当前版本: {S4U_VERSION}") logger.info(f"S4U当前版本: {S4U_VERSION}")
update_s4u_config() update_s4u_config()

View File

@@ -21,7 +21,7 @@ async def migrate_memory_items_to_string():
"empty_nodes": 0, "empty_nodes": 0,
"error_nodes": 0, "error_nodes": 0,
"weight_updated_nodes": 0, "weight_updated_nodes": 0,
"truncated_nodes": 0 "truncated_nodes": 0,
} }
try: try:
@@ -35,7 +35,7 @@ async def migrate_memory_items_to_string():
try: try:
concept = node.concept concept = node.concept
memory_items_raw = node.memory_items.strip() if node.memory_items else "" memory_items_raw = node.memory_items.strip() if node.memory_items else ""
original_weight = node.weight if hasattr(node, 'weight') and node.weight is not None else 1.0 original_weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
# 如果为空,跳过 # 如果为空,跳过
if not memory_items_raw: if not memory_items_raw:
@@ -71,7 +71,9 @@ async def migrate_memory_items_to_string():
migration_stats["weight_updated_nodes"] += 1 migration_stats["weight_updated_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else "" length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.info(f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}") logger.info(
f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}"
)
else: else:
# 空list设置为空字符串 # 空list设置为空字符串
node.memory_items = "" node.memory_items = ""
@@ -99,7 +101,9 @@ async def migrate_memory_items_to_string():
update_needed = False update_needed = False
if original_weight == 1.0: if original_weight == 1.0:
# 如果weight还是默认值可以根据内容复杂度估算 # 如果weight还是默认值可以根据内容复杂度估算
content_parts = current_content.split(" | ") if " | " in current_content else [current_content] content_parts = (
current_content.split(" | ") if " | " in current_content else [current_content]
)
estimated_weight = max(1.0, float(len(content_parts))) estimated_weight = max(1.0, float(len(content_parts)))
if estimated_weight != original_weight: if estimated_weight != original_weight:
@@ -195,14 +199,18 @@ async def migrate_memory_items_to_string():
logger.info(f"权重更新节点: {migration_stats['weight_updated_nodes']}") logger.info(f"权重更新节点: {migration_stats['weight_updated_nodes']}")
logger.info(f"内容截断节点: {migration_stats['truncated_nodes']}") logger.info(f"内容截断节点: {migration_stats['truncated_nodes']}")
success_rate = (migration_stats['converted_nodes'] + migration_stats['already_string_nodes']) / migration_stats['total_nodes'] * 100 if migration_stats['total_nodes'] > 0 else 0 success_rate = (
(migration_stats["converted_nodes"] + migration_stats["already_string_nodes"])
/ migration_stats["total_nodes"]
* 100
if migration_stats["total_nodes"] > 0
else 0
)
logger.info(f"迁移成功率: {success_rate:.1f}%") logger.info(f"迁移成功率: {success_rate:.1f}%")
return migration_stats return migration_stats
async def set_all_person_known(): async def set_all_person_known():
""" """
将person_info库中所有记录的is_known字段设置为True 将person_info库中所有记录的is_known字段设置为True
@@ -226,10 +234,10 @@ async def set_all_person_known():
# 删除user_id或platform为空的记录 # 删除user_id或platform为空的记录
deleted_count = 0 deleted_count = 0
invalid_records = PersonInfo.select().where( invalid_records = PersonInfo.select().where(
(PersonInfo.user_id.is_null()) | (PersonInfo.user_id.is_null())
(PersonInfo.user_id == '') | | (PersonInfo.user_id == "")
(PersonInfo.platform.is_null()) | | (PersonInfo.platform.is_null())
(PersonInfo.platform == '') | (PersonInfo.platform == "")
) )
# 记录要删除的记录信息 # 记录要删除的记录信息
@@ -237,15 +245,21 @@ async def set_all_person_known():
user_id_info = f"'{record.user_id}'" if record.user_id else "NULL" user_id_info = f"'{record.user_id}'" if record.user_id else "NULL"
platform_info = f"'{record.platform}'" if record.platform else "NULL" platform_info = f"'{record.platform}'" if record.platform else "NULL"
person_name_info = f"'{record.person_name}'" if record.person_name else "无名称" person_name_info = f"'{record.person_name}'" if record.person_name else "无名称"
logger.debug(f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}") logger.debug(
f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}"
)
# 执行删除操作 # 执行删除操作
deleted_count = PersonInfo.delete().where( deleted_count = (
(PersonInfo.user_id.is_null()) | PersonInfo.delete()
(PersonInfo.user_id == '') | .where(
(PersonInfo.platform.is_null()) | (PersonInfo.user_id.is_null())
(PersonInfo.platform == '') | (PersonInfo.user_id == "")
).execute() | (PersonInfo.platform.is_null())
| (PersonInfo.platform == "")
)
.execute()
)
if deleted_count > 0: if deleted_count > 0:
logger.info(f"删除了 {deleted_count} 个user_id或platform为空的记录") logger.info(f"删除了 {deleted_count} 个user_id或platform为空的记录")
@@ -268,12 +282,7 @@ async def set_all_person_known():
# 验证更新结果 # 验证更新结果
known_count = PersonInfo.select().where(PersonInfo.is_known).count() known_count = PersonInfo.select().where(PersonInfo.is_known).count()
result = { result = {"total": total_count, "deleted": deleted_count, "updated": updated_count, "known_count": known_count}
"total": total_count,
"deleted": deleted_count,
"updated": updated_count,
"known_count": known_count
}
logger.info("=== person_info更新完成 ===") logger.info("=== person_info更新完成 ===")
logger.info(f"原始记录数: {result['total']}") logger.info(f"原始记录数: {result['total']}")
@@ -288,7 +297,6 @@ async def set_all_person_known():
raise raise
async def check_and_run_migrations(): async def check_and_run_migrations():
# 获取根目录 # 获取根目录
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
@@ -309,4 +317,3 @@ async def check_and_run_migrations():
# 创建done.mem文件 # 创建done.mem文件
with open(done_file, "w", encoding="utf-8") as f: with open(done_file, "w", encoding="utf-8") as f:
f.write("done") f.write("done")

View File

@@ -282,7 +282,7 @@ class Person:
memory_category = parts[0].strip() memory_category = parts[0].strip()
memory_text = parts[1].strip() memory_text = parts[1].strip()
memory_weight = parts[2].strip() _memory_weight = parts[2].strip()
# 检查分类是否匹配 # 检查分类是否匹配
if memory_category != category: if memory_category != category:

View File

@@ -1,11 +1,5 @@
import json
from json_repair import repair_json
from datetime import datetime
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest from src.chat.utils.prompt_builder import Prompt
from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from .person_info import Person
logger = get_logger("relation") logger = get_logger("relation")
@@ -43,4 +37,3 @@ def init_prompt():
""", """,
"attitude_to_me_prompt", "attitude_to_me_prompt",
) )

View File

@@ -121,5 +121,5 @@ __all__ = [
"DatabaseChatInfo", "DatabaseChatInfo",
"TargetPersonInfo", "TargetPersonInfo",
"ActionPlannerInfo", "ActionPlannerInfo",
"LLMGenerationDataModel" "LLMGenerationDataModel",
] ]

View File

@@ -7,22 +7,24 @@ logger = get_logger("frequency_api")
def get_current_focus_value(chat_id: str) -> float: def get_current_focus_value(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_focus_value() return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_focus_value()
def get_current_talk_frequency(chat_id: str) -> float: def get_current_talk_frequency(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_talk_frequency() return frequency_control_manager.get_or_create_frequency_control(chat_id).get_final_talk_frequency()
def set_focus_value_adjust(chat_id: str, focus_value_adjust: float) -> None: def set_focus_value_adjust(chat_id: str, focus_value_adjust: float) -> None:
frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust = focus_value_adjust frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust = focus_value_adjust
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None: def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
frequency_control_manager.get_or_create_frequency_control(chat_id).talk_frequency_external_adjust = talk_frequency_adjust frequency_control_manager.get_or_create_frequency_control(
chat_id
).talk_frequency_external_adjust = talk_frequency_adjust
def get_focus_value_adjust(chat_id: str) -> float: def get_focus_value_adjust(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust return frequency_control_manager.get_or_create_frequency_control(chat_id).focus_value_external_adjust
def get_talk_frequency_adjust(chat_id: str) -> float: def get_talk_frequency_adjust(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).talk_frequency_external_adjust return frequency_control_manager.get_or_create_frequency_control(chat_id).talk_frequency_external_adjust

View File

@@ -72,7 +72,9 @@ async def generate_with_model(
llm_request = LLMRequest(model_set=model_config, request_type=request_type) llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens) response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(
prompt, temperature=temperature, max_tokens=max_tokens
)
return True, response, reasoning_content, model_name return True, response, reasoning_content, model_name
except Exception as e: except Exception as e:
@@ -80,6 +82,7 @@ async def generate_with_model(
logger.error(f"[LLMAPI] {error_msg}") logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", "" return False, error_msg, "", ""
async def generate_with_model_with_tools( async def generate_with_model_with_tools(
prompt: str, prompt: str,
model_config: TaskConfig, model_config: TaskConfig,
@@ -109,10 +112,7 @@ async def generate_with_model_with_tools(
llm_request = LLMRequest(model_set=model_config, request_type=request_type) llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async( response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
prompt, prompt, tools=tool_options, temperature=temperature, max_tokens=max_tokens
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens
) )
return True, response, reasoning_content, model_name, tool_call return True, response, reasoning_content, model_name, tool_call

View File

@@ -435,9 +435,7 @@ def build_readable_messages_to_str(
Returns: Returns:
格式化后的可读字符串 格式化后的可读字符串
""" """
return build_readable_messages( return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions)
messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions
)
async def build_readable_messages_with_details( async def build_readable_messages_with_details(
@@ -491,8 +489,6 @@ def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessag
return [msg for msg in messages if msg.user_info.user_id != str(global_config.bot.qq_account)] return [msg for msg in messages if msg.user_info.user_id != str(global_config.bot.qq_account)]
def translate_pid_to_description(pid: str) -> str: def translate_pid_to_description(pid: str) -> str:
image = Images.get_or_none(Images.image_id == pid) image = Images.get_or_none(Images.image_id == pid)
description = "" description = ""

View File

@@ -2,7 +2,7 @@ from pathlib import Path
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("plugin_manager") # 复用plugin_manager名称 logger = get_logger("plugin_manager") # 复用plugin_manager名称
def register_plugin(cls): def register_plugin(cls):

View File

@@ -401,9 +401,7 @@ class PluginManager:
command_components = [ command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
] ]
tool_components = [ tool_components = [c for c in plugin_info.components if c.component_type == ComponentType.TOOL]
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
]
event_handler_components = [ event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
] ]

View File

@@ -195,7 +195,9 @@ class ToolExecutor:
return tool_results, used_tools return tool_results, used_tools
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]: async def execute_tool_call(
self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None
) -> Optional[Dict[str, Any]]:
# sourcery skip: use-assigned-variable # sourcery skip: use-assigned-variable
"""执行单个工具调用 """执行单个工具调用

View File

@@ -63,5 +63,4 @@ class CoreActionsPlugin(BasePlugin):
if self.get_config("components.enable_emoji", True): if self.get_config("components.enable_emoji", True):
components.append((EmojiAction.get_action_info(), EmojiAction)) components.append((EmojiAction.get_action_info(), EmojiAction))
return components return components

View File

@@ -74,7 +74,9 @@ class BuildMemoryAction(BaseAction):
# 动作基本信息 # 动作基本信息
action_name = "build_memory" action_name = "build_memory"
action_description = "了解对于某个概念或者某件事的记忆,并存储下来,在之后的聊天中,你可以根据这条记忆来获取相关信息" action_description = (
"了解对于某个概念或者某件事的记忆,并存储下来,在之后的聊天中,你可以根据这条记忆来获取相关信息"
)
# 动作参数定义 # 动作参数定义
action_parameters = { action_parameters = {
@@ -103,24 +105,28 @@ class BuildMemoryAction(BaseAction):
concept_name = self.action_data.get("concept_name", "") concept_name = self.action_data.get("concept_name", "")
# 2. 获取目标用户信息 # 2. 获取目标用户信息
# 对 concept_name 进行jieba分词 # 对 concept_name 进行jieba分词
concept_name_tokens = cut_key_words(concept_name) concept_name_tokens = cut_key_words(concept_name)
# logger.info(f"{self.log_prefix} 对 concept_name 进行分词结果: {concept_name_tokens}") # logger.info(f"{self.log_prefix} 对 concept_name 进行分词结果: {concept_name_tokens}")
filtered_concept_name_tokens = [ filtered_concept_name_tokens = [
token for token in concept_name_tokens if all(keyword not in token for keyword in global_config.memory.memory_ban_words) token
for token in concept_name_tokens
if all(keyword not in token for keyword in global_config.memory.memory_ban_words)
] ]
if not filtered_concept_name_tokens: if not filtered_concept_name_tokens:
logger.warning(f"{self.log_prefix} 过滤后的概念名称列表为空,跳过添加记忆") logger.warning(f"{self.log_prefix} 过滤后的概念名称列表为空,跳过添加记忆")
return False, "过滤后的概念名称列表为空,跳过添加记忆" return False, "过滤后的概念名称列表为空,跳过添加记忆"
similar_topics_dict = hippocampus_manager.get_hippocampus().parahippocampal_gyrus.get_similar_topics_from_keywords(filtered_concept_name_tokens) similar_topics_dict = (
await hippocampus_manager.get_hippocampus().parahippocampal_gyrus.add_memory_with_similar(concept_description, similar_topics_dict) hippocampus_manager.get_hippocampus().parahippocampal_gyrus.get_similar_topics_from_keywords(
filtered_concept_name_tokens
)
)
await hippocampus_manager.get_hippocampus().parahippocampal_gyrus.add_memory_with_similar(
concept_description, similar_topics_dict
)
return True, f"成功添加记忆: {concept_name}" return True, f"成功添加记忆: {concept_name}"
@@ -129,6 +135,5 @@ class BuildMemoryAction(BaseAction):
return False, f"构建记忆时出错: {e}" return False, f"构建记忆时出错: {e}"
# 还缺一个关系的太多遗忘和对应的提取 # 还缺一个关系的太多遗忘和对应的提取
init_prompt() init_prompt()

View File

@@ -178,7 +178,9 @@ class BuildRelationAction(BaseAction):
chat_model_config = models.get("utils") chat_model_config = models.get("utils")
success, update_memory, _, _ = await llm_api.generate_with_model( success, update_memory, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="relation.category.update" # type: ignore prompt,
model_config=chat_model_config,
request_type="relation.category.update", # type: ignore
) )
update_memory_data = json.loads(repair_json(update_memory)) update_memory_data = json.loads(repair_json(update_memory))
@@ -207,7 +209,9 @@ class BuildRelationAction(BaseAction):
person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}") person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}")
person.sync_to_database() person.sync_to_database()
logger.info(f"{self.log_prefix} 更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}") logger.info(
f"{self.log_prefix} 更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
)
return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}" return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
@@ -215,7 +219,6 @@ class BuildRelationAction(BaseAction):
logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}") logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}")
return False, f"删除{person.person_name}的记忆点失败: {memory_content}" return False, f"删除{person.person_name}的记忆点失败: {memory_content}"
return True, "关系动作执行成功" return True, "关系动作执行成功"
except Exception as e: except Exception as e: