炸 service 层

This commit is contained in:
DrSmoothl
2026-03-14 00:13:35 +08:00
parent 898fab6de9
commit 43c5b34623
13 changed files with 1408 additions and 1736 deletions

View File

@@ -4,7 +4,7 @@
提供聊天信息查询和管理的核心功能。
"""
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
from enum import Enum
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
@@ -23,50 +23,58 @@ class ChatManager:
"""聊天管理器 - 负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
def _validate_platform(platform: Optional[str] | SpecialTypes) -> None:
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
@staticmethod
def _match_platform(chat_stream: BotChatSession, platform: Optional[str] | SpecialTypes) -> bool:
return platform == SpecialTypes.ALL_PLATFORMS or chat_stream.platform == platform
@staticmethod
def _get_streams(
platform: Optional[str] | SpecialTypes = "qq", is_group_session: Optional[bool] = None
) -> List[BotChatSession]:
ChatManager._validate_platform(platform)
try:
for _, stream in _chat_manager.sessions.items():
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的聊天流")
streams = [
stream
for stream in _chat_manager.sessions.values()
if ChatManager._match_platform(stream, platform)
and (is_group_session is None or stream.is_group_session == is_group_session)
]
return streams
except Exception as e:
logger.error(f"[ChatService] 获取聊天流失败: {e}")
return []
@staticmethod
def _find_stream(
predicate: Callable[[BotChatSession], bool],
platform: Optional[str] | SpecialTypes = "qq",
) -> Optional[BotChatSession]:
for stream in ChatManager._get_streams(platform=platform):
if predicate(stream):
return stream
return None
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
streams = ChatManager._get_streams(platform=platform)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的聊天流")
return streams
@staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e:
logger.error(f"[ChatService] 获取群聊流失败: {e}")
streams = ChatManager._get_streams(platform=platform, is_group_session=True)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的群聊流")
return streams
@staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in _chat_manager.sessions.items():
if (
platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
) and not stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e:
logger.error(f"[ChatService] 获取私聊流失败: {e}")
streams = ChatManager._get_streams(platform=platform, is_group_session=False)
logger.debug(f"[ChatService] 获取到 {len(streams)}{platform} 平台的私聊流")
return streams
@staticmethod
@@ -75,19 +83,17 @@ class ChatManager:
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
if not isinstance(group_id, str):
raise TypeError("group_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
ChatManager._validate_platform(platform)
if not group_id:
raise ValueError("group_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
stream.is_group_session
and str(stream.group_id) == str(group_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流")
return stream
stream = ChatManager._find_stream(
lambda item: item.is_group_session and str(item.group_id) == str(group_id),
platform=platform,
)
if stream is not None:
logger.debug(f"[ChatService] 找到群ID {group_id} 的聊天流")
return stream
logger.warning(f"[ChatService] 未找到群ID {group_id} 的聊天流")
except Exception as e:
logger.error(f"[ChatService] 查找群聊流失败: {e}")
@@ -99,19 +105,17 @@ class ChatManager:
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
if not isinstance(user_id, str):
raise TypeError("user_id 必须是字符串类型")
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
ChatManager._validate_platform(platform)
if not user_id:
raise ValueError("user_id 不能为空")
try:
for _, stream in _chat_manager.sessions.items():
if (
not stream.is_group_session
and str(stream.user_id) == str(user_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流")
return stream
stream = ChatManager._find_stream(
lambda item: (not item.is_group_session) and str(item.user_id) == str(user_id),
platform=platform,
)
if stream is not None:
logger.debug(f"[ChatService] 找到用户ID {user_id} 的私聊流")
return stream
logger.warning(f"[ChatService] 未找到用户ID {user_id} 的私聊流")
except Exception as e:
logger.error(f"[ChatService] 查找私聊流失败: {e}")

View File

@@ -1,13 +1,16 @@
"""数据库服务模块
提供数据库操作相关的核心功能。
"""
"""数据库服务模块"""
import json
import time
import traceback
from datetime import datetime
from typing import Any, Optional
from sqlalchemy import delete, func, select
from sqlmodel import SQLModel
from src.common.database.database import get_db_session
from src.common.database.database_model import ActionRecord
from src.common.logger import get_logger
logger = get_logger("database_service")
@@ -25,73 +28,57 @@ def _to_dict(record: Any) -> dict[str, Any]:
return {}
async def db_query(
model_class,
data: Optional[dict[str, Any]] = None,
query_type: str = "get",
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[list[str]] = None,
single_result: bool = False,
def _get_model_field(model_class: type[SQLModel], field_name: str):
field = getattr(model_class, field_name, None)
if field is None:
raise ValueError(f"{model_class.__name__} 不存在字段 {field_name}")
return field
def _build_filters(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> list[Any]:
if not filters:
return []
return [_get_model_field(model_class, field_name) == value for field_name, value in filters.items()]
def _apply_order_by(statement, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None):
if not order_by:
return statement
order_fields = [order_by] if isinstance(order_by, str) else order_by
clauses = []
for item in order_fields:
descending = item.startswith("-")
field_name = item[1:] if descending else item
field = _get_model_field(model_class, field_name)
clauses.append(field.desc() if descending else field.asc())
return statement.order_by(*clauses)
async def db_save(
model_class: type[SQLModel],
data: dict[str, Any],
key_field: Optional[str] = None,
key_value: Optional[Any] = None,
):
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
with get_db_session() as session:
record = None
if key_field and key_value is not None:
key_column = _get_model_field(model_class, key_field)
record = session.exec(select(model_class).where(key_column == key_value)).first()
if query_type == "get":
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if order_by:
query = query.order_by(*order_by)
if limit:
query = query.limit(limit)
results = list(query.dicts())
if single_result:
return results[0] if results else None
return results
if record is None:
record = model_class(**data)
else:
for field_name, value in data.items():
_get_model_field(model_class, field_name)
setattr(record, field_name, value)
if query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
record = model_class.create(**data)
session.add(record)
session.flush()
session.refresh(record)
return _to_dict(record)
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
return query.model_class.update(**data).where(*query.stmt._where_criteria).execute()
if query_type == "delete":
return model_class.delete().where(*query.stmt._where_criteria).execute()
return query.count()
except Exception as e:
logger.error(f"[DatabaseService] 数据库操作出错: {e}")
traceback.print_exc()
if query_type == "get":
return None if single_result else []
return None
async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None):
try:
if key_field and key_value is not None:
record = model_class.get_or_none(getattr(model_class, key_field) == key_value)
if record is not None:
for field, value in data.items():
setattr(record, field, value)
record.save()
return _to_dict(record)
new_record = model_class.create(**data)
return _to_dict(new_record)
except Exception as e:
logger.error(f"[DatabaseService] 保存数据库记录出错: {e}")
traceback.print_exc()
@@ -99,68 +86,108 @@ async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] =
async def db_get(
model_class,
model_class: type[SQLModel],
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
order_by: Optional[str | list[str]] = None,
single_result: bool = False,
):
try:
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if order_by:
query = query.order_by(order_by)
if limit:
query = query.limit(limit)
results = list(query.dicts())
if single_result:
return results[0] if results else None
return results
with get_db_session(auto_commit=False) as session:
statement = select(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
statement = statement.where(*conditions)
statement = _apply_order_by(statement, model_class, order_by)
if limit:
statement = statement.limit(limit)
results = session.exec(statement).all()
data = [_to_dict(item) for item in results]
if single_result:
return data[0] if data else None
return data
except Exception as e:
logger.error(f"[DatabaseService] 获取数据库记录出错: {e}")
traceback.print_exc()
return None if single_result else []
async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters: Optional[dict[str, Any]] = None) -> int:
try:
with get_db_session() as session:
statement = select(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
statement = statement.where(*conditions)
records = session.exec(statement).all()
for record in records:
for field_name, value in data.items():
_get_model_field(model_class, field_name)
setattr(record, field_name, value)
session.add(record)
return len(records)
except Exception as e:
logger.error(f"[DatabaseService] 更新数据库记录出错: {e}")
traceback.print_exc()
return 0
async def db_delete(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> int:
try:
with get_db_session() as session:
statement = delete(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
statement = statement.where(*conditions)
result = session.exec(statement)
return result.rowcount or 0
except Exception as e:
logger.error(f"[DatabaseService] 删除数据库记录出错: {e}")
traceback.print_exc()
return 0
async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> int:
try:
with get_db_session(auto_commit=False) as session:
statement = select(func.count()).select_from(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
statement = statement.where(*conditions)
result = session.exec(statement).one()
return int(result or 0)
except Exception as e:
logger.error(f"[DatabaseService] 统计数据库记录出错: {e}")
traceback.print_exc()
return 0
async def store_action_info(
chat_stream=None,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
builtin_prompt: Optional[str] = None,
display_prompt: str = "",
thinking_id: str = "",
action_data: Optional[dict] = None,
action_name: str = "",
action_reasoning: str = "",
):
try:
from src.common.database.database_model import ActionRecords
if chat_stream is None:
raise ValueError("store_action_info 需要 chat_stream")
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)),
"time": time.time(),
"timestamp": datetime.now(),
"session_id": chat_stream.session_id,
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
"action_done": action_done,
"action_reasoning": action_reasoning,
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
"action_builtin_prompt": builtin_prompt,
"action_display_prompt": display_prompt,
}
if chat_stream:
record_data.update(
{
"chat_id": getattr(chat_stream, "stream_id", ""),
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
"chat_info_platform": getattr(chat_stream, "platform", ""),
}
)
else:
record_data.update({"chat_id": "", "chat_info_stream_id": "", "chat_info_platform": ""})
saved_record = await db_save(
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
ActionRecord, data=record_data, key_field="action_id", key_value=record_data["action_id"]
)
if saved_record:
logger.debug(f"[DatabaseService] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")

View File

@@ -1,21 +0,0 @@
"""频率控制服务模块
提供聊天频率控制的核心功能。
"""
from src.chat.heart_flow.frequency_control import frequency_control_manager
from src.config.config import global_config
def get_current_talk_value(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(
chat_id
).get_talk_frequency_adjust() * global_config.chat.get_talk_value(chat_id)
def set_talk_frequency_adjust(chat_id: str, talk_frequency_adjust: float) -> None:
frequency_control_manager.get_or_create_frequency_control(chat_id).set_talk_frequency_adjust(talk_frequency_adjust)
def get_talk_frequency_adjust(chat_id: str) -> float:
return frequency_control_manager.get_or_create_frequency_control(chat_id).get_talk_frequency_adjust()

View File

@@ -16,6 +16,38 @@ from src.llm_models.utils_model import LLMRequest
logger = get_logger("llm_service")
async def _generate_response(
model_config: TaskConfig,
request_type: str,
prompt: Optional[str] = None,
message_factory: Optional[Callable[[BaseClient], List[Message]]] = None,
tool_options: Optional[List[Dict[str, Any]]] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[str, str, str, List[ToolCall] | None]:
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
if message_factory is not None:
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async(
message_factory=message_factory,
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens,
)
return response, reasoning_content, model_name, tool_call
if prompt is None:
raise ValueError("prompt 与 message_factory 不能同时为空")
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
prompt,
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens,
)
return response, reasoning_content, model_name, tool_call
def get_available_models() -> Dict[str, TaskConfig]:
"""获取所有可用的模型配置
@@ -61,11 +93,12 @@ async def generate_with_model(
"""
try:
logger.debug(f"[LLMService] 完整提示词: {prompt}")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(
prompt, temperature=temperature, max_tokens=max_tokens
response, reasoning_content, model_name, _ = await _generate_response(
model_config=model_config,
request_type=request_type,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
)
return True, response, reasoning_content, model_name
@@ -101,10 +134,13 @@ async def generate_with_model_with_tools(
logger.info(f"使用模型{model_name_list}生成内容")
logger.debug(f"完整提示词: {prompt}")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
prompt, tools=tool_options, temperature=temperature, max_tokens=max_tokens
response, reasoning_content, model_name, tool_call = await _generate_response(
model_config=model_config,
request_type=request_type,
prompt=prompt,
tool_options=tool_options,
temperature=temperature,
max_tokens=max_tokens,
)
return True, response, reasoning_content, model_name, tool_call
@@ -139,11 +175,11 @@ async def generate_with_model_with_tools_by_message_factory(
model_name_list = model_config.model_list
logger.info(f"使用模型 {model_name_list} 生成内容")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_with_message_async(
response, reasoning_content, model_name, tool_call = await _generate_response(
model_config=model_config,
request_type=request_type,
message_factory=message_factory,
tools=tool_options,
tool_options=tool_options,
temperature=temperature,
max_tokens=max_tokens,
)

View File

@@ -1,65 +0,0 @@
"""个人信息服务模块
提供个人信息查询的核心功能。
"""
from typing import Any
from src.common.logger import get_logger
from src.person_info.person_info import Person
logger = get_logger("person_service")
def get_person_id(platform: str, user_id: int | str) -> str:
"""根据平台和用户ID获取person_id
Args:
platform: 平台名称,如 "qq", "telegram"
user_id: 用户ID
Returns:
str: 唯一的person_idMD5哈希值
"""
try:
return Person(platform=platform, user_id=str(user_id)).person_id
except Exception as e:
logger.error(f"[PersonService] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
return ""
async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any:
"""根据person_id和字段名获取某个值
Args:
person_id: 用户的唯一标识ID
field_name: 要获取的字段名,如 "nickname", "impression"
default: 当字段不存在或获取失败时返回的默认值
Returns:
Any: 字段值或默认值
"""
try:
person = Person(person_id=person_id)
value = getattr(person, field_name)
return value if value is not None else default
except Exception as e:
logger.error(f"[PersonService] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
return default
def get_person_id_by_name(person_name: str) -> str:
"""根据用户名获取person_id
Args:
person_name: 用户名
Returns:
str: person_id如果未找到返回空字符串
"""
try:
person = Person(person_name=person_name)
return person.person_id
except Exception as e:
logger.error(f"[PersonService] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
return ""