chore: 部分去除print改用logger

Fix/20250309 logger optimize
This commit is contained in:
HYY
2025-03-09 23:55:45 +08:00
committed by GitHub
5 changed files with 184 additions and 150 deletions

28
bot.py
View File

@@ -1,5 +1,7 @@
import os import os
import shutil import shutil
import sys
import nonebot import nonebot
import time import time
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -10,6 +12,7 @@ import platform
# 获取没有加载env时的环境变量 # 获取没有加载env时的环境变量
env_mask = {key: os.getenv(key) for key in os.environ} env_mask = {key: os.getenv(key) for key in os.environ}
def easter_egg(): def easter_egg():
# 彩蛋 # 彩蛋
from colorama import init, Fore from colorama import init, Fore
@@ -22,11 +25,12 @@ def easter_egg():
rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char rainbow_text += rainbow_colors[i % len(rainbow_colors)] + char
print(rainbow_text) print(rainbow_text)
def init_config(): def init_config():
# 初次启动检测 # 初次启动检测
if not os.path.exists("config/bot_config.toml"): if not os.path.exists("config/bot_config.toml"):
logger.warning("检测到bot_config.toml不存在正在从模板复制") logger.warning("检测到bot_config.toml不存在正在从模板复制")
# 检查config目录是否存在 # 检查config目录是否存在
if not os.path.exists("config"): if not os.path.exists("config"):
os.makedirs("config") os.makedirs("config")
@@ -35,6 +39,7 @@ def init_config():
shutil.copy("template/bot_config_template.toml", "config/bot_config.toml") shutil.copy("template/bot_config_template.toml", "config/bot_config.toml")
logger.info("复制完成请修改config/bot_config.toml和.env.prod中的配置后重新启动") logger.info("复制完成请修改config/bot_config.toml和.env.prod中的配置后重新启动")
def init_env(): def init_env():
# 初始化.env 默认ENVIRONMENT=prod # 初始化.env 默认ENVIRONMENT=prod
if not os.path.exists(".env"): if not os.path.exists(".env"):
@@ -46,11 +51,17 @@ def init_env():
logger.error("检测到.env.prod文件不存在") logger.error("检测到.env.prod文件不存在")
shutil.copy("template.env", "./.env.prod") shutil.copy("template.env", "./.env.prod")
# 检测.env.dev文件是否存在不存在的话直接复制生产环境配置
if not os.path.exists(".env.dev"):
logger.error("检测到.env.dev文件不存在")
shutil.copy(".env.prod", "./.env.dev")
# 首先加载基础环境变量.env # 首先加载基础环境变量.env
if os.path.exists(".env"): if os.path.exists(".env"):
load_dotenv(".env") load_dotenv(".env")
logger.success("成功加载基础环境变量配置") logger.success("成功加载基础环境变量配置")
def load_env(): def load_env():
# 使用闭包实现对加载器的横向扩展,避免大量重复判断 # 使用闭包实现对加载器的横向扩展,避免大量重复判断
def prod(): def prod():
@@ -70,7 +81,7 @@ def load_env():
logger.info(f"[load_env] 当前的 ENVIRONMENT 变量值:{env}") logger.info(f"[load_env] 当前的 ENVIRONMENT 变量值:{env}")
if env in fn_map: if env in fn_map:
fn_map[env]() # 根据映射执行闭包函数 fn_map[env]() # 根据映射执行闭包函数
elif os.path.exists(f".env.{env}"): elif os.path.exists(f".env.{env}"):
logger.success(f"加载{env}环境变量配置") logger.success(f"加载{env}环境变量配置")
@@ -81,6 +92,17 @@ def load_env():
RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
def load_logger():
logger.remove() # 移除默认配置
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> <fg #777777>|</> <level>{level: <7}</level> <fg "
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
"#777777>-</> <level>{message}</level>",
colorize=True,
level=os.getenv("LOG_LEVEL", "INFO") # 根据环境设置日志级别默认为INFO
)
def scan_provider(env_config: dict): def scan_provider(env_config: dict):
provider = {} provider = {}
@@ -115,6 +137,7 @@ def scan_provider(env_config: dict):
) )
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
if __name__ == "__main__": if __name__ == "__main__":
# 利用 TZ 环境变量设定程序工作的时区 # 利用 TZ 环境变量设定程序工作的时区
# 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用 # 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用
@@ -122,6 +145,7 @@ if __name__ == "__main__":
time.tzset() time.tzset()
easter_egg() easter_egg()
load_logger()
init_config() init_config()
init_env() init_env()
load_env() load_env()

View File

@@ -1,6 +1,4 @@
import asyncio import asyncio
import os
import random
import time import time
from loguru import logger from loguru import logger
@@ -30,16 +28,15 @@ driver = get_driver()
config = driver.config config = driver.config
Database.initialize( Database.initialize(
host= config.MONGODB_HOST, host=config.MONGODB_HOST,
port= int(config.MONGODB_PORT), port=int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME, db_name=config.DATABASE_NAME,
username= config.MONGODB_USERNAME, username=config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD, password=config.MONGODB_PASSWORD,
auth_source= config.MONGODB_AUTH_SOURCE auth_source=config.MONGODB_AUTH_SOURCE
) )
print("\033[1;32m[初始化数据库完成]\033[0m") print("\033[1;32m[初始化数据库完成]\033[0m")
# 导入其他模块 # 导入其他模块
from ..memory_system.memory import hippocampus, memory_graph from ..memory_system.memory import hippocampus, memory_graph
from .bot import ChatBot from .bot import ChatBot
@@ -59,24 +56,24 @@ group_msg = on_message(priority=5)
scheduler = require("nonebot_plugin_apscheduler").scheduler scheduler = require("nonebot_plugin_apscheduler").scheduler
@driver.on_startup @driver.on_startup
async def start_background_tasks(): async def start_background_tasks():
"""启动后台任务""" """启动后台任务"""
# 启动LLM统计 # 启动LLM统计
llm_stats.start() llm_stats.start()
print("\033[1;32m[初始化]\033[0m LLM统计功能已启动") logger.success("[初始化]LLM统计功能已启动")
# 初始化并启动情绪管理器 # 初始化并启动情绪管理器
mood_manager = MoodManager.get_instance() mood_manager = MoodManager.get_instance()
mood_manager.start_mood_update(update_interval=global_config.mood_update_interval) mood_manager.start_mood_update(update_interval=global_config.mood_update_interval)
print("\033[1;32m[初始化]\033[0m 情绪管理器已启动") logger.success("[初始化]情绪管理器已启动")
# 只启动表情包管理任务 # 只启动表情包管理任务
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL)) asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
await bot_schedule.initialize() await bot_schedule.initialize()
bot_schedule.print_schedule() bot_schedule.print_schedule()
@driver.on_startup @driver.on_startup
async def init_relationships(): async def init_relationships():
"""在 NoneBot2 启动时初始化关系管理器""" """在 NoneBot2 启动时初始化关系管理器"""
@@ -84,45 +81,52 @@ async def init_relationships():
await relationship_manager.load_all_relationships() await relationship_manager.load_all_relationships()
asyncio.create_task(relationship_manager._start_relationship_manager()) asyncio.create_task(relationship_manager._start_relationship_manager())
@driver.on_bot_connect @driver.on_bot_connect
async def _(bot: Bot): async def _(bot: Bot):
"""Bot连接成功时的处理""" """Bot连接成功时的处理"""
global _message_manager_started global _message_manager_started
print(f"\033[1;38;5;208m-----------{global_config.BOT_NICKNAME}成功连接!-----------\033[0m") print(f"\033[1;38;5;208m-----------{global_config.BOT_NICKNAME}成功连接!-----------\033[0m")
await willing_manager.ensure_started() await willing_manager.ensure_started()
message_sender.set_bot(bot) message_sender.set_bot(bot)
print("\033[1;38;5;208m-----------消息发送器已启动!-----------\033[0m") print("\033[1;38;5;208m-----------消息发送器已启动!-----------\033[0m")
if not _message_manager_started: if not _message_manager_started:
asyncio.create_task(message_manager.start_processor()) asyncio.create_task(message_manager.start_processor())
_message_manager_started = True _message_manager_started = True
print("\033[1;38;5;208m-----------消息处理器已启动!-----------\033[0m") print("\033[1;38;5;208m-----------消息处理器已启动!-----------\033[0m")
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL)) asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m") print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
@group_msg.handle() @group_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent, state: T_State): async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
await chat_bot.handle_message(event, bot) await chat_bot.handle_message(event, bot)
# 添加build_memory定时任务 # 添加build_memory定时任务
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory") @scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
async def build_memory_task(): async def build_memory_task():
"""每build_memory_interval秒执行一次记忆构建""" """每build_memory_interval秒执行一次记忆构建"""
print("\033[1;32m[记忆构建]\033[0m -------------------------------------------开始构建记忆-------------------------------------------") print(
"\033[1;32m[记忆构建]\033[0m -------------------------------------------开始构建记忆-------------------------------------------")
start_time = time.time() start_time = time.time()
await hippocampus.operation_build_memory(chat_size=20) await hippocampus.operation_build_memory(chat_size=20)
end_time = time.time() end_time = time.time()
print(f"\033[1;32m[记忆构建]\033[0m -------------------------------------------记忆构建完成:耗时: {end_time - start_time:.2f} 秒-------------------------------------------") print(
f"\033[1;32m[记忆构建]\033[0m -------------------------------------------记忆构建完成:耗时: {end_time - start_time:.2f} 秒-------------------------------------------")
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
async def forget_memory_task(): async def forget_memory_task():
"""每30秒执行一次记忆构建""" """每30秒执行一次记忆构建"""
# print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...") # print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
# await hippocampus.operation_forget_topic(percentage=0.1) # await hippocampus.operation_forget_topic(percentage=0.1)
# print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成") # print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="merge_memory") @scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="merge_memory")
async def merge_memory_task(): async def merge_memory_task():
"""每30秒执行一次记忆构建""" """每30秒执行一次记忆构建"""
@@ -130,9 +134,9 @@ async def merge_memory_task():
# await hippocampus.operation_merge_memory(percentage=0.1) # await hippocampus.operation_merge_memory(percentage=0.1)
# print("\033[1;32m[记忆整合]\033[0m 记忆整合完成") # print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")
@scheduler.scheduled_job("interval", seconds=30, id="print_mood") @scheduler.scheduled_job("interval", seconds=30, id="print_mood")
async def print_mood_task(): async def print_mood_task():
"""每30秒打印一次情绪状态""" """每30秒打印一次情绪状态"""
mood_manager = MoodManager.get_instance() mood_manager = MoodManager.get_instance()
mood_manager.print_mood_status() mood_manager.print_mood_status()

View File

@@ -6,45 +6,46 @@ import tomli
from loguru import logger from loguru import logger
from packaging import version from packaging import version
from packaging.version import Version, InvalidVersion from packaging.version import Version, InvalidVersion
from packaging.specifiers import SpecifierSet,InvalidSpecifier from packaging.specifiers import SpecifierSet, InvalidSpecifier
@dataclass @dataclass
class BotConfig: class BotConfig:
"""机器人配置类""" """机器人配置类"""
INNER_VERSION: Version = None INNER_VERSION: Version = None
BOT_QQ: Optional[int] = 1 BOT_QQ: Optional[int] = 1
BOT_NICKNAME: Optional[str] = None BOT_NICKNAME: Optional[str] = None
# 消息处理相关配置 # 消息处理相关配置
MIN_TEXT_LENGTH: int = 2 # 最小处理文本长度 MIN_TEXT_LENGTH: int = 2 # 最小处理文本长度
MAX_CONTEXT_SIZE: int = 15 # 上下文最大消息数 MAX_CONTEXT_SIZE: int = 15 # 上下文最大消息数
emoji_chance: float = 0.2 # 发送表情包的基础概率 emoji_chance: float = 0.2 # 发送表情包的基础概率
ENABLE_PIC_TRANSLATE: bool = True # 是否启用图片翻译 ENABLE_PIC_TRANSLATE: bool = True # 是否启用图片翻译
talk_allowed_groups = set() talk_allowed_groups = set()
talk_frequency_down_groups = set() talk_frequency_down_groups = set()
thinking_timeout: int = 100 # 思考时间 thinking_timeout: int = 100 # 思考时间
response_willing_amplifier: float = 1.0 # 回复意愿放大系数 response_willing_amplifier: float = 1.0 # 回复意愿放大系数
response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数 response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
down_frequency_rate: float = 3.5 # 降低回复频率的群组回复意愿降低系数 down_frequency_rate: float = 3.5 # 降低回复频率的群组回复意愿降低系数
ban_user_id = set() ban_user_id = set()
build_memory_interval: int = 30 # 记忆构建间隔(秒) build_memory_interval: int = 30 # 记忆构建间隔(秒)
forget_memory_interval: int = 300 # 记忆遗忘间隔(秒) forget_memory_interval: int = 300 # 记忆遗忘间隔(秒)
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
EMOJI_SAVE: bool = True # 偷表情包 EMOJI_SAVE: bool = True # 偷表情包
EMOJI_CHECK: bool = False #是否开启过滤 EMOJI_CHECK: bool = False # 是否开启过滤
EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求 EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求
ban_words = set() ban_words = set()
max_response_length: int = 1024 # 最大回复长度 max_response_length: int = 1024 # 最大回复长度
# 模型配置 # 模型配置
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {}) llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {}) llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {})
@@ -60,28 +61,29 @@ class BotConfig:
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率 MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率 MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率 MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
enable_advance_output: bool = False # 是否启用高级输出 enable_advance_output: bool = False # 是否启用高级输出
enable_kuuki_read: bool = True # 是否启用读空气功能 enable_kuuki_read: bool = True # 是否启用读空气功能
mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate: float = 0.95 # 情绪衰减率
mood_intensity_factor: float = 0.7 # 情绪强度因子
keywords_reaction_rules = [] # 关键词回复规则 mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate: float = 0.95 # 情绪衰减率
mood_intensity_factor: float = 0.7 # 情绪强度因子
chinese_typo_enable=True # 是否启用中文错别字生成器 keywords_reaction_rules = [] # 关键词回复规则
chinese_typo_error_rate=0.03 # 单字替换概率
chinese_typo_min_freq=7 # 最小字频阈值 chinese_typo_enable = True # 是否启用中文错别字生成器
chinese_typo_tone_error_rate=0.2 # 声调错误概率 chinese_typo_error_rate = 0.03 # 单字替换概率
chinese_typo_word_replace_rate=0.02 # 整词替换概率 chinese_typo_min_freq = 7 # 最小字频阈值
chinese_typo_tone_error_rate = 0.2 # 声调错误概率
chinese_typo_word_replace_rate = 0.02 # 整词替换概率
# 默认人设 # 默认人设
PROMPT_PERSONALITY=[ PROMPT_PERSONALITY = [
"曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧", "曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧",
"是一个女大学生,你有黑色头发,你会刷小红书", "是一个女大学生,你有黑色头发,你会刷小红书",
"是一个女大学生你会刷b站对ACG文化感兴趣" "是一个女大学生你会刷b站对ACG文化感兴趣"
] ]
PROMPT_SCHEDULE_GEN="一个曾经学习地质,现在学习心理学和脑科学的女大学生喜欢刷qq贴吧知乎和小红书" PROMPT_SCHEDULE_GEN="一个曾经学习地质,现在学习心理学和脑科学的女大学生喜欢刷qq贴吧知乎和小红书"
PERSONALITY_1: float = 0.6 # 第一种人格概率 PERSONALITY_1: float = 0.6 # 第一种人格概率
@@ -99,7 +101,7 @@ class BotConfig:
if not os.path.exists(config_dir): if not os.path.exists(config_dir):
os.makedirs(config_dir) os.makedirs(config_dir)
return config_dir return config_dir
@classmethod @classmethod
def convert_to_specifierset(cls, value: str) -> SpecifierSet: def convert_to_specifierset(cls, value: str) -> SpecifierSet:
"""将 字符串 版本表达式转换成 SpecifierSet """将 字符串 版本表达式转换成 SpecifierSet
@@ -119,7 +121,7 @@ class BotConfig:
exit(1) exit(1)
return converted return converted
@classmethod @classmethod
def get_config_version(cls, toml: dict) -> Version: def get_config_version(cls, toml: dict) -> Version:
"""提取配置文件的 SpecifierSet 版本数据 """提取配置文件的 SpecifierSet 版本数据
@@ -131,14 +133,14 @@ class BotConfig:
if 'inner' in toml: if 'inner' in toml:
try: try:
config_version : str = toml["inner"]["version"] config_version: str = toml["inner"]["version"]
except KeyError as e: except KeyError as e:
logger.error(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") logger.error(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
else: else:
toml["inner"] = { "version": "0.0.0" } toml["inner"] = {"version": "0.0.0"}
config_version = toml["inner"]["version"] config_version = toml["inner"]["version"]
try: try:
ver = version.parse(config_version) ver = version.parse(config_version)
except InvalidVersion as e: except InvalidVersion as e:
@@ -150,38 +152,38 @@ class BotConfig:
raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n") raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n")
return ver return ver
@classmethod @classmethod
def load_config(cls, config_path: str = None) -> "BotConfig": def load_config(cls, config_path: str = None) -> "BotConfig":
"""从TOML配置文件加载配置""" """从TOML配置文件加载配置"""
config = cls() config = cls()
def personality(parent: dict): def personality(parent: dict):
personality_config=parent['personality'] personality_config = parent['personality']
personality=personality_config.get('prompt_personality') personality = personality_config.get('prompt_personality')
if len(personality) >= 2: if len(personality) >= 2:
logger.info(f"载入自定义人格:{personality}") logger.info(f"载入自定义人格:{personality}")
config.PROMPT_PERSONALITY=personality_config.get('prompt_personality',config.PROMPT_PERSONALITY) config.PROMPT_PERSONALITY = personality_config.get('prompt_personality', config.PROMPT_PERSONALITY)
logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)}") logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN)}")
config.PROMPT_SCHEDULE_GEN=personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN) config.PROMPT_SCHEDULE_GEN = personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN)
if config.INNER_VERSION in SpecifierSet(">=0.0.2"): if config.INNER_VERSION in SpecifierSet(">=0.0.2"):
config.PERSONALITY_1=personality_config.get('personality_1_probability',config.PERSONALITY_1) config.PERSONALITY_1 = personality_config.get('personality_1_probability', config.PERSONALITY_1)
config.PERSONALITY_2=personality_config.get('personality_2_probability',config.PERSONALITY_2) config.PERSONALITY_2 = personality_config.get('personality_2_probability', config.PERSONALITY_2)
config.PERSONALITY_3=personality_config.get('personality_3_probability',config.PERSONALITY_3) config.PERSONALITY_3 = personality_config.get('personality_3_probability', config.PERSONALITY_3)
def emoji(parent: dict): def emoji(parent: dict):
emoji_config = parent["emoji"] emoji_config = parent["emoji"]
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL) config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL) config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL)
config.EMOJI_CHECK_PROMPT = emoji_config.get('check_prompt',config.EMOJI_CHECK_PROMPT) config.EMOJI_CHECK_PROMPT = emoji_config.get('check_prompt', config.EMOJI_CHECK_PROMPT)
config.EMOJI_SAVE = emoji_config.get('auto_save',config.EMOJI_SAVE) config.EMOJI_SAVE = emoji_config.get('auto_save', config.EMOJI_SAVE)
config.EMOJI_CHECK = emoji_config.get('enable_check',config.EMOJI_CHECK) config.EMOJI_CHECK = emoji_config.get('enable_check', config.EMOJI_CHECK)
def cq_code(parent: dict): def cq_code(parent: dict):
cq_code_config = parent["cq_code"] cq_code_config = parent["cq_code"]
config.ENABLE_PIC_TRANSLATE = cq_code_config.get("enable_pic_translate", config.ENABLE_PIC_TRANSLATE) config.ENABLE_PIC_TRANSLATE = cq_code_config.get("enable_pic_translate", config.ENABLE_PIC_TRANSLATE)
def bot(parent: dict): def bot(parent: dict):
# 机器人基础配置 # 机器人基础配置
bot_config = parent["bot"] bot_config = parent["bot"]
@@ -193,12 +195,13 @@ class BotConfig:
response_config = parent["response"] response_config = parent["response"]
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY) config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY) config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY) config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability",
config.MODEL_R1_DISTILL_PROBABILITY)
config.max_response_length = response_config.get("max_response_length", config.max_response_length) config.max_response_length = response_config.get("max_response_length", config.max_response_length)
def model(parent: dict): def model(parent: dict):
# 加载模型配置 # 加载模型配置
model_config:dict = parent["model"] model_config: dict = parent["model"]
config_list = [ config_list = [
"llm_reasoning", "llm_reasoning",
@@ -215,24 +218,24 @@ class BotConfig:
for item in config_list: for item in config_list:
if item in model_config: if item in model_config:
cfg_item:dict = model_config[item] cfg_item: dict = model_config[item]
# base_url 的例子: SILICONFLOW_BASE_URL # base_url 的例子: SILICONFLOW_BASE_URL
# key 的例子: SILICONFLOW_KEY # key 的例子: SILICONFLOW_KEY
cfg_target = { cfg_target = {
"name" : "", "name": "",
"base_url" : "", "base_url": "",
"key" : "", "key": "",
"pri_in" : 0, "pri_in": 0,
"pri_out" : 0 "pri_out": 0
} }
if config.INNER_VERSION in SpecifierSet("<=0.0.0"): if config.INNER_VERSION in SpecifierSet("<=0.0.0"):
cfg_target = cfg_item cfg_target = cfg_item
elif config.INNER_VERSION in SpecifierSet(">=0.0.1"): elif config.INNER_VERSION in SpecifierSet(">=0.0.1"):
stable_item = ["name","pri_in","pri_out"] stable_item = ["name", "pri_in", "pri_out"]
pricing_item = ["pri_in","pri_out"] pricing_item = ["pri_in", "pri_out"]
# 从配置中原始拷贝稳定字段 # 从配置中原始拷贝稳定字段
for i in stable_item: for i in stable_item:
# 如果 字段 属于计费项 且获取不到,那默认值是 0 # 如果 字段 属于计费项 且获取不到,那默认值是 0
@@ -246,18 +249,16 @@ class BotConfig:
logger.error(f"{item} 中的必要字段 {e} 不存在,请检查") logger.error(f"{item} 中的必要字段 {e} 不存在,请检查")
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查") raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查")
provider = cfg_item.get("provider") provider = cfg_item.get("provider")
if provider == None: if provider == None:
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查") logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查") raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")
cfg_target["base_url"] = f"{provider}_BASE_URL" cfg_target["base_url"] = f"{provider}_BASE_URL"
cfg_target["key"] = f"{provider}_KEY" cfg_target["key"] = f"{provider}_KEY"
# 如果 列表中的项目在 model_config 中,利用反射来设置对应项目 # 如果 列表中的项目在 model_config 中,利用反射来设置对应项目
setattr(config,item,cfg_target) setattr(config, item, cfg_target)
else: else:
logger.error(f"模型 {item} 在config中不存在请检查") logger.error(f"模型 {item} 在config中不存在请检查")
raise KeyError(f"模型 {item} 在config中不存在请检查") raise KeyError(f"模型 {item} 在config中不存在请检查")
@@ -267,12 +268,14 @@ class BotConfig:
config.MIN_TEXT_LENGTH = msg_config.get("min_text_length", config.MIN_TEXT_LENGTH) config.MIN_TEXT_LENGTH = msg_config.get("min_text_length", config.MIN_TEXT_LENGTH)
config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE) config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE)
config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance) config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance)
config.ban_words=msg_config.get("ban_words",config.ban_words) config.ban_words = msg_config.get("ban_words", config.ban_words)
if config.INNER_VERSION in SpecifierSet(">=0.0.2"): if config.INNER_VERSION in SpecifierSet(">=0.0.2"):
config.thinking_timeout = msg_config.get("thinking_timeout", config.thinking_timeout) config.thinking_timeout = msg_config.get("thinking_timeout", config.thinking_timeout)
config.response_willing_amplifier = msg_config.get("response_willing_amplifier", config.response_willing_amplifier) config.response_willing_amplifier = msg_config.get("response_willing_amplifier",
config.response_interested_rate_amplifier = msg_config.get("response_interested_rate_amplifier", config.response_interested_rate_amplifier) config.response_willing_amplifier)
config.response_interested_rate_amplifier = msg_config.get("response_interested_rate_amplifier",
config.response_interested_rate_amplifier)
config.down_frequency_rate = msg_config.get("down_frequency_rate", config.down_frequency_rate) config.down_frequency_rate = msg_config.get("down_frequency_rate", config.down_frequency_rate)
def memory(parent: dict): def memory(parent: dict):
@@ -300,8 +303,10 @@ class BotConfig:
config.chinese_typo_enable = chinese_typo_config.get("enable", config.chinese_typo_enable) config.chinese_typo_enable = chinese_typo_config.get("enable", config.chinese_typo_enable)
config.chinese_typo_error_rate = chinese_typo_config.get("error_rate", config.chinese_typo_error_rate) config.chinese_typo_error_rate = chinese_typo_config.get("error_rate", config.chinese_typo_error_rate)
config.chinese_typo_min_freq = chinese_typo_config.get("min_freq", config.chinese_typo_min_freq) config.chinese_typo_min_freq = chinese_typo_config.get("min_freq", config.chinese_typo_min_freq)
config.chinese_typo_tone_error_rate = chinese_typo_config.get("tone_error_rate", config.chinese_typo_tone_error_rate) config.chinese_typo_tone_error_rate = chinese_typo_config.get("tone_error_rate",
config.chinese_typo_word_replace_rate = chinese_typo_config.get("word_replace_rate", config.chinese_typo_word_replace_rate) config.chinese_typo_tone_error_rate)
config.chinese_typo_word_replace_rate = chinese_typo_config.get("word_replace_rate",
config.chinese_typo_word_replace_rate)
def groups(parent: dict): def groups(parent: dict):
groups_config = parent["groups"] groups_config = parent["groups"]
@@ -389,7 +394,7 @@ class BotConfig:
except(tomli.TOMLDecodeError) as e: except(tomli.TOMLDecodeError) as e:
logger.critical(f"配置文件bot_config.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}") logger.critical(f"配置文件bot_config.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}")
exit(1) exit(1)
# 获取配置文件版本 # 获取配置文件版本
config.INNER_VERSION = cls.get_config_version(toml_dict) config.INNER_VERSION = cls.get_config_version(toml_dict)
@@ -413,31 +418,32 @@ class BotConfig:
f"当前程序仅支持以下版本范围: {group_specifierset}" f"当前程序仅支持以下版本范围: {group_specifierset}"
) )
raise InvalidVersion(f"当前程序仅支持以下版本范围: {group_specifierset}") raise InvalidVersion(f"当前程序仅支持以下版本范围: {group_specifierset}")
# 如果 necessary 项目存在,而且显式声明是 False进入特殊处理 # 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
elif "necessary" in include_configs[key] and include_configs[key].get("necessary") == False: elif "necessary" in include_configs[key] and include_configs[key].get("necessary") == False:
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理 # 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
if key == "keywords_reaction": if key == "keywords_reaction":
pass pass
else: else:
# 如果用户根本没有需要的配置项,提示缺少配置 # 如果用户根本没有需要的配置项,提示缺少配置
logger.error(f"配置文件中缺少必需的字段: '{key}'") logger.error(f"配置文件中缺少必需的字段: '{key}'")
raise KeyError(f"配置文件中缺少必需的字段: '{key}'") raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
logger.success(f"成功加载配置文件: {config_path}") logger.success(f"成功加载配置文件: {config_path}")
return config return config
# 获取配置文件路径 # 获取配置文件路径
bot_config_floder_path = BotConfig.get_config_dir() bot_config_floder_path = BotConfig.get_config_dir()
print(f"正在品鉴配置文件目录: {bot_config_floder_path}") logger.debug(f"正在品鉴配置文件目录: {bot_config_floder_path}")
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml") bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
if os.path.exists(bot_config_path): if os.path.exists(bot_config_path):
# 如果开发环境配置文件不存在,则使用默认配置文件 # 如果开发环境配置文件不存在,则使用默认配置文件
print(f"异常的新鲜,异常的美味: {bot_config_path}") logger.debug(f"异常的新鲜,异常的美味: {bot_config_path}")
logger.info("使用bot配置文件") logger.info("使用bot配置文件")
else: else:
# 配置文件不存在 # 配置文件不存在
@@ -446,8 +452,6 @@ else:
global_config = BotConfig.load_config(config_path=bot_config_path) global_config = BotConfig.load_config(config_path=bot_config_path)
if not global_config.enable_advance_output: if not global_config.enable_advance_output:
logger.remove() logger.remove()
pass pass

View File

View File

@@ -13,21 +13,21 @@ from ..models.utils_model import LLM_request
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
Database.initialize( Database.initialize(
host= config.MONGODB_HOST, host=config.MONGODB_HOST,
port= int(config.MONGODB_PORT), port=int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME, db_name=config.DATABASE_NAME,
username= config.MONGODB_USERNAME, username=config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD, password=config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE auth_source=config.MONGODB_AUTH_SOURCE
) )
class ScheduleGenerator: class ScheduleGenerator:
def __init__(self): def __init__(self):
#根据global_config.llm_normal这一字典配置指定模型 # 根据global_config.llm_normal这一字典配置指定模型
# self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9) # self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9)
self.llm_scheduler = LLM_request(model = global_config.llm_normal,temperature=0.9) self.llm_scheduler = LLM_request(model=global_config.llm_normal, temperature=0.9)
self.db = Database.get_instance() self.db = Database.get_instance()
self.today_schedule_text = "" self.today_schedule_text = ""
self.today_schedule = {} self.today_schedule = {}
@@ -35,39 +35,41 @@ class ScheduleGenerator:
self.tomorrow_schedule = {} self.tomorrow_schedule = {}
self.yesterday_schedule_text = "" self.yesterday_schedule_text = ""
self.yesterday_schedule = {} self.yesterday_schedule = {}
async def initialize(self): async def initialize(self):
today = datetime.datetime.now() today = datetime.datetime.now()
tomorrow = datetime.datetime.now() + datetime.timedelta(days=1) tomorrow = datetime.datetime.now() + datetime.timedelta(days=1)
yesterday = datetime.datetime.now() - datetime.timedelta(days=1) yesterday = datetime.datetime.now() - datetime.timedelta(days=1)
self.today_schedule_text, self.today_schedule = await self.generate_daily_schedule(target_date=today) self.today_schedule_text, self.today_schedule = await self.generate_daily_schedule(target_date=today)
self.tomorrow_schedule_text, self.tomorrow_schedule = await self.generate_daily_schedule(target_date=tomorrow,read_only=True) self.tomorrow_schedule_text, self.tomorrow_schedule = await self.generate_daily_schedule(target_date=tomorrow,
self.yesterday_schedule_text, self.yesterday_schedule = await self.generate_daily_schedule(target_date=yesterday,read_only=True) read_only=True)
self.yesterday_schedule_text, self.yesterday_schedule = await self.generate_daily_schedule(
async def generate_daily_schedule(self, target_date: datetime.datetime = None,read_only:bool = False) -> Dict[str, str]: target_date=yesterday, read_only=True)
async def generate_daily_schedule(self, target_date: datetime.datetime = None, read_only: bool = False) -> Dict[
str, str]:
date_str = target_date.strftime("%Y-%m-%d") date_str = target_date.strftime("%Y-%m-%d")
weekday = target_date.strftime("%A") weekday = target_date.strftime("%A")
schedule_text = str schedule_text = str
existing_schedule = self.db.db.schedule.find_one({"date": date_str}) existing_schedule = self.db.db.schedule.find_one({"date": date_str})
if existing_schedule: if existing_schedule:
print(f"{date_str}的日程已存在:") logger.info(f"{date_str}的日程已存在:")
schedule_text = existing_schedule["schedule"] schedule_text = existing_schedule["schedule"]
# print(self.schedule_text) # print(self.schedule_text)
elif read_only == False: elif read_only == False:
print(f"{date_str}的日程不存在,准备生成新的日程。") logger.info(f"{date_str}的日程不存在,准备生成新的日程。")
prompt = f"""我是{global_config.BOT_NICKNAME}{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}{weekday})的日程安排,包括:"""+\ prompt = f"""我是{global_config.BOT_NICKNAME}{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}{weekday})的日程安排,包括:""" + \
""" """
1. 早上的学习和工作安排 1. 早上的学习和工作安排
2. 下午的活动和任务 2. 下午的活动和任务
3. 晚上的计划和休息时间 3. 晚上的计划和休息时间
请按照时间顺序列出具体时间点和对应的活动用一个时间点而不是时间段来表示时间用JSON格式返回日程表仅返回内容不要返回注释时间采用24小时制格式为{"时间": "活动","时间": "活动",...}。""" 请按照时间顺序列出具体时间点和对应的活动用一个时间点而不是时间段来表示时间用JSON格式返回日程表仅返回内容不要返回注释时间采用24小时制格式为{"时间": "活动","时间": "活动",...}。"""
try: try:
schedule_text, _ = await self.llm_scheduler.generate_response(prompt) schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text}) self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
@@ -76,36 +78,35 @@ class ScheduleGenerator:
schedule_text = "生成日程时出错了" schedule_text = "生成日程时出错了"
# print(self.schedule_text) # print(self.schedule_text)
else: else:
print(f"{date_str}的日程不存在。") logger.info(f"{date_str}的日程不存在。")
schedule_text = "忘了" schedule_text = "忘了"
return schedule_text,None return schedule_text, None
schedule_form = self._parse_schedule(schedule_text) schedule_form = self._parse_schedule(schedule_text)
return schedule_text,schedule_form return schedule_text, schedule_form
def _parse_schedule(self, schedule_text: str) -> Union[bool, Dict[str, str]]: def _parse_schedule(self, schedule_text: str) -> Union[bool, Dict[str, str]]:
"""解析日程文本,转换为时间和活动的字典""" """解析日程文本,转换为时间和活动的字典"""
try: try:
schedule_dict = json.loads(schedule_text) schedule_dict = json.loads(schedule_text)
return schedule_dict return schedule_dict
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print(schedule_text) logger.exception("解析日程失败: {}".format(schedule_text))
print(f"解析日程失败: {str(e)}")
return False return False
def _parse_time(self, time_str: str) -> str: def _parse_time(self, time_str: str) -> str:
"""解析时间字符串,转换为时间""" """解析时间字符串,转换为时间"""
return datetime.datetime.strptime(time_str, "%H:%M") return datetime.datetime.strptime(time_str, "%H:%M")
def get_current_task(self) -> str: def get_current_task(self) -> str:
"""获取当前时间应该进行的任务""" """获取当前时间应该进行的任务"""
current_time = datetime.datetime.now().strftime("%H:%M") current_time = datetime.datetime.now().strftime("%H:%M")
# 找到最接近当前时间的任务 # 找到最接近当前时间的任务
closest_time = None closest_time = None
min_diff = float('inf') min_diff = float('inf')
# 检查今天的日程 # 检查今天的日程
if not self.today_schedule: if not self.today_schedule:
return "摸鱼" return "摸鱼"
@@ -114,7 +115,7 @@ class ScheduleGenerator:
if closest_time is None or diff < min_diff: if closest_time is None or diff < min_diff:
closest_time = time_str closest_time = time_str
min_diff = diff min_diff = diff
# 检查昨天的日程中的晚间任务 # 检查昨天的日程中的晚间任务
if self.yesterday_schedule: if self.yesterday_schedule:
for time_str in self.yesterday_schedule.keys(): for time_str in self.yesterday_schedule.keys():
@@ -125,17 +126,17 @@ class ScheduleGenerator:
closest_time = time_str closest_time = time_str
min_diff = diff min_diff = diff
return closest_time, self.yesterday_schedule[closest_time] return closest_time, self.yesterday_schedule[closest_time]
if closest_time: if closest_time:
return closest_time, self.today_schedule[closest_time] return closest_time, self.today_schedule[closest_time]
return "摸鱼" return "摸鱼"
def _time_diff(self, time1: str, time2: str) -> int: def _time_diff(self, time1: str, time2: str) -> int:
"""计算两个时间字符串之间的分钟差""" """计算两个时间字符串之间的分钟差"""
if time1=="24:00": if time1 == "24:00":
time1="23:59" time1 = "23:59"
if time2=="24:00": if time2 == "24:00":
time2="23:59" time2 = "23:59"
t1 = datetime.datetime.strptime(time1, "%H:%M") t1 = datetime.datetime.strptime(time1, "%H:%M")
t2 = datetime.datetime.strptime(time2, "%H:%M") t2 = datetime.datetime.strptime(time2, "%H:%M")
diff = int((t2 - t1).total_seconds() / 60) diff = int((t2 - t1).total_seconds() / 60)
@@ -146,17 +147,18 @@ class ScheduleGenerator:
diff -= 1440 # 减一天的分钟 diff -= 1440 # 减一天的分钟
# print(f"时间1[{time1}]: 时间2[{time2}],差值[{diff}]分钟") # print(f"时间1[{time1}]: 时间2[{time2}],差值[{diff}]分钟")
return diff return diff
def print_schedule(self): def print_schedule(self):
"""打印完整的日程安排""" """打印完整的日程安排"""
if not self._parse_schedule(self.today_schedule_text): if not self._parse_schedule(self.today_schedule_text):
print("今日日程有误,将在下次运行时重新生成") logger.warning("今日日程有误,将在下次运行时重新生成")
self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")}) self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
else: else:
print("\n=== 今日日程安排 ===") logger.info("\n=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items(): for time_str, activity in self.today_schedule.items():
print(f"时间[{time_str}]: 活动[{activity}]") logger.info(f"时间[{time_str}]: 活动[{activity}]")
print("==================\n") logger.info("==================\n")
# def main(): # def main():
# # 使用示例 # # 使用示例
@@ -165,7 +167,7 @@ class ScheduleGenerator:
# scheduler.print_schedule() # scheduler.print_schedule()
# print("\n当前任务") # print("\n当前任务")
# print(scheduler.get_current_task()) # print(scheduler.get_current_task())
# print("昨天日程:") # print("昨天日程:")
# print(scheduler.yesterday_schedule) # print(scheduler.yesterday_schedule)
# print("今天日程:") # print("今天日程:")
@@ -175,5 +177,5 @@ class ScheduleGenerator:
# if __name__ == "__main__": # if __name__ == "__main__":
# main() # main()
bot_schedule = ScheduleGenerator() bot_schedule = ScheduleGenerator()