Ruff Format
This commit is contained in:
@@ -4,7 +4,6 @@ MaiSaka LLM 服务 - 使用主项目 LLM 系统
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Literal
|
||||
|
||||
@@ -34,6 +33,7 @@ MSG_TYPE_FIELD = "_type"
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""工具调用信息"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict
|
||||
@@ -42,6 +42,7 @@ class ToolCall:
|
||||
@dataclass
|
||||
class ChatResponse:
|
||||
"""LLM 对话循环单步响应"""
|
||||
|
||||
content: Optional[str]
|
||||
tool_calls: List[ToolCall]
|
||||
raw_message: dict # 可直接追加到对话历史的消息字典
|
||||
@@ -49,6 +50,7 @@ class ChatResponse:
|
||||
|
||||
# ──────────────────── 工具函数 ────────────────────
|
||||
|
||||
|
||||
def build_message(role: str, content: str, msg_type: MessageType = "user", **kwargs) -> dict:
|
||||
"""构建消息字典,包含消息类型标记。"""
|
||||
msg = {"role": role, "content": content, MSG_TYPE_FIELD: msg_type, **kwargs}
|
||||
@@ -93,23 +95,18 @@ class MaiSakaLLMService:
|
||||
except Exception:
|
||||
# 如果配置加载失败,使用默认配置
|
||||
from src.config.model_configs import ModelTaskConfig
|
||||
|
||||
self._model_configs = ModelTaskConfig()
|
||||
logger.warning("无法加载主项目模型配置,使用默认配置")
|
||||
|
||||
# 初始化 LLMRequest 实例(只使用 tool_use 和 replyer)
|
||||
self._llm_tool_use = LLMRequest(
|
||||
model_set=self._model_configs.tool_use,
|
||||
request_type="maisaka_tool_use"
|
||||
)
|
||||
self._llm_tool_use = LLMRequest(model_set=self._model_configs.tool_use, request_type="maisaka_tool_use")
|
||||
# 主对话也使用 tool_use 模型(因为需要工具调用支持)
|
||||
self._llm_chat = self._llm_tool_use
|
||||
# 分析模块也使用 tool_use 模型
|
||||
self._llm_utils = self._llm_tool_use
|
||||
# 回复生成使用 replyer 模型
|
||||
self._llm_replyer = LLMRequest(
|
||||
model_set=self._model_configs.replyer,
|
||||
request_type="maisaka_replyer"
|
||||
)
|
||||
self._llm_replyer = LLMRequest(model_set=self._model_configs.replyer, request_type="maisaka_replyer")
|
||||
|
||||
# 尝试修复数据库 schema(忽略错误)
|
||||
self._try_fix_database_schema()
|
||||
@@ -133,6 +130,7 @@ class MaiSakaLLMService:
|
||||
|
||||
chat_prompt.add_context("file_tools_section", tools_section if tools_section else "")
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
@@ -147,7 +145,9 @@ class MaiSakaLLMService:
|
||||
self._chat_system_prompt = chat_system_prompt
|
||||
|
||||
# 获取模型名称用于显示
|
||||
self._model_name = self._model_configs.tool_use.model_list[0] if self._model_configs.tool_use.model_list else "未配置"
|
||||
self._model_name = (
|
||||
self._model_configs.tool_use.model_list[0] if self._model_configs.tool_use.model_list else "未配置"
|
||||
)
|
||||
|
||||
# 加载子模块提示词
|
||||
self._emotion_prompt: Optional[str] = None
|
||||
@@ -157,21 +157,22 @@ class MaiSakaLLMService:
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
self._emotion_prompt = loop.run_until_complete(prompt_manager.render_prompt(
|
||||
prompt_manager.get_prompt("maidairy_emotion")
|
||||
))
|
||||
self._cognition_prompt = loop.run_until_complete(prompt_manager.render_prompt(
|
||||
prompt_manager.get_prompt("maidairy_cognition")
|
||||
))
|
||||
self._timing_prompt = loop.run_until_complete(prompt_manager.render_prompt(
|
||||
prompt_manager.get_prompt("maidairy_timing")
|
||||
))
|
||||
self._context_summarize_prompt = loop.run_until_complete(prompt_manager.render_prompt(
|
||||
prompt_manager.get_prompt("maidairy_context_summarize")
|
||||
))
|
||||
self._emotion_prompt = loop.run_until_complete(
|
||||
prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_emotion"))
|
||||
)
|
||||
self._cognition_prompt = loop.run_until_complete(
|
||||
prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_cognition"))
|
||||
)
|
||||
self._timing_prompt = loop.run_until_complete(
|
||||
prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_timing"))
|
||||
)
|
||||
self._context_summarize_prompt = loop.run_until_complete(
|
||||
prompt_manager.render_prompt(prompt_manager.get_prompt("maidairy_context_summarize"))
|
||||
)
|
||||
logger.info("成功加载 MaiSaka 子模块提示词")
|
||||
finally:
|
||||
loop.close()
|
||||
@@ -191,9 +192,7 @@ class MaiSakaLLMService:
|
||||
|
||||
if "model_api_provider_name" not in columns:
|
||||
# 添加缺失的列
|
||||
session.execute(text(
|
||||
"ALTER TABLE llm_usage ADD COLUMN model_api_provider_name VARCHAR(255)"
|
||||
))
|
||||
session.execute(text("ALTER TABLE llm_usage ADD COLUMN model_api_provider_name VARCHAR(255)"))
|
||||
session.commit()
|
||||
logger.info("数据库 schema 已修复:添加 model_api_provider_name 列")
|
||||
except Exception:
|
||||
@@ -205,7 +204,7 @@ class MaiSakaLLMService:
|
||||
self._extra_tools = list(tools)
|
||||
|
||||
@staticmethod
|
||||
def _tool_option_to_dict(tool: 'ToolOption') -> dict:
|
||||
def _tool_option_to_dict(tool: "ToolOption") -> dict:
|
||||
"""将 ToolOption 对象转换为主项目期望的 dict 格式
|
||||
|
||||
主项目的 _build_tool_options() 期望的格式:
|
||||
@@ -218,18 +217,8 @@ class MaiSakaLLMService:
|
||||
params = []
|
||||
if tool.params:
|
||||
for param in tool.params:
|
||||
params.append((
|
||||
param.name,
|
||||
param.param_type,
|
||||
param.description,
|
||||
param.required,
|
||||
param.enum_values
|
||||
))
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": params
|
||||
}
|
||||
params.append((param.name, param.param_type, param.description, param.required, param.enum_values))
|
||||
return {"name": tool.name, "description": tool.description, "parameters": params}
|
||||
|
||||
async def chat_loop_step(self, chat_history: List[dict]) -> ChatResponse:
|
||||
"""执行对话循环的一步 - 使用 tool_use 模型"""
|
||||
@@ -271,11 +260,13 @@ class MaiSakaLLMService:
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_func = tc.get("function", {})
|
||||
# 主项目的 ToolCall: call_id, func_name, args
|
||||
tool_calls_list.append(ToolCallOption(
|
||||
call_id=tc.get("id", ""),
|
||||
func_name=tc_func.get("name", ""),
|
||||
args=json.loads(tc_func.get("arguments", "{}")) if tc_func.get("arguments") else {}
|
||||
))
|
||||
tool_calls_list.append(
|
||||
ToolCallOption(
|
||||
call_id=tc.get("id", ""),
|
||||
func_name=tc_func.get("name", ""),
|
||||
args=json.loads(tc_func.get("arguments", "{}")) if tc_func.get("arguments") else {},
|
||||
)
|
||||
)
|
||||
builder.set_tool_calls(tool_calls_list)
|
||||
elif role == "tool" and "tool_call_id" in msg:
|
||||
builder.add_tool_call(msg["tool_call_id"])
|
||||
@@ -290,15 +281,17 @@ class MaiSakaLLMService:
|
||||
|
||||
# 调用 LLM(使用带消息的接口)
|
||||
# 合并内置工具和额外工具(将 ToolOption 对象转换为 dict)
|
||||
all_tools = [self._tool_option_to_dict(t) for t in get_builtin_tools()] + (self._extra_tools if self._extra_tools else [])
|
||||
all_tools = [self._tool_option_to_dict(t) for t in get_builtin_tools()] + (
|
||||
self._extra_tools if self._extra_tools else []
|
||||
)
|
||||
|
||||
# 打印消息列表
|
||||
built_messages = message_factory(None)
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("MaiSaka LLM Request - chat_loop_step:")
|
||||
for msg in built_messages:
|
||||
print(f" {msg}")
|
||||
print("="*60 + "\n")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
response, (reasoning, model, tool_calls) = await self._llm_chat.generate_response_with_message_async(
|
||||
message_factory=message_factory,
|
||||
@@ -312,15 +305,17 @@ class MaiSakaLLMService:
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
# 主项目的 ToolCall 有 call_id, func_name, args
|
||||
call_id = tc.call_id if hasattr(tc, 'call_id') else ""
|
||||
func_name = tc.func_name if hasattr(tc, 'func_name') else ""
|
||||
args = tc.args if hasattr(tc, 'args') else {}
|
||||
call_id = tc.call_id if hasattr(tc, "call_id") else ""
|
||||
func_name = tc.func_name if hasattr(tc, "func_name") else ""
|
||||
args = tc.args if hasattr(tc, "args") else {}
|
||||
|
||||
converted_tool_calls.append(ToolCall(
|
||||
id=call_id,
|
||||
name=func_name,
|
||||
arguments=args,
|
||||
))
|
||||
converted_tool_calls.append(
|
||||
ToolCall(
|
||||
id=call_id,
|
||||
name=func_name,
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
|
||||
# 构建原始消息格式(MaiSaka 风格)
|
||||
raw_message = {"role": "assistant", "content": response}
|
||||
@@ -394,10 +389,10 @@ class MaiSakaLLMService:
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("MaiSaka LLM Request - analyze_emotion:")
|
||||
print(f" {prompt}")
|
||||
print("="*60 + "\n")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
try:
|
||||
response, _ = await self._llm_utils.generate_response_async(
|
||||
@@ -428,10 +423,10 @@ class MaiSakaLLMService:
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("MaiSaka LLM Request - analyze_cognition:")
|
||||
print(f" {prompt}")
|
||||
print("="*60 + "\n")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
try:
|
||||
response, _ = await self._llm_utils.generate_response_async(
|
||||
@@ -463,10 +458,10 @@ class MaiSakaLLMService:
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("MaiSaka LLM Request - analyze_timing:")
|
||||
print(f" {prompt}")
|
||||
print("="*60 + "\n")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
try:
|
||||
response, _ = await self._llm_utils.generate_response_async(
|
||||
@@ -498,10 +493,10 @@ class MaiSakaLLMService:
|
||||
|
||||
prompt = "\n".join(prompt_parts)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("MaiSaka LLM Request - summarize_context:")
|
||||
print(f" {prompt}")
|
||||
print("="*60 + "\n")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
try:
|
||||
response, _ = await self._llm_utils.generate_response_async(
|
||||
@@ -529,8 +524,7 @@ class MaiSakaLLMService:
|
||||
|
||||
# 格式化对话历史
|
||||
filtered_history = [
|
||||
msg for msg in chat_history
|
||||
if msg.get("role") != "system" and msg.get("_type") != "perception"
|
||||
msg for msg in chat_history if msg.get("role") != "system" and msg.get("_type") != "perception"
|
||||
]
|
||||
formatted_history = format_chat_history(filtered_history)
|
||||
|
||||
@@ -542,18 +536,15 @@ class MaiSakaLLMService:
|
||||
system_prompt = "你是一个友好的 AI 助手,请根据用户的想法生成自然的回复。"
|
||||
|
||||
user_prompt = (
|
||||
f"当前时间:{current_time}\n\n"
|
||||
f"【聊天记录】\n{formatted_history}\n\n"
|
||||
f"【你的想法】\n{reason}\n\n"
|
||||
f"现在,你说:"
|
||||
f"当前时间:{current_time}\n\n【聊天记录】\n{formatted_history}\n\n【你的想法】\n{reason}\n\n现在,你说:"
|
||||
)
|
||||
|
||||
messages = f"System: {system_prompt}\n\nUser: {user_prompt}"
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("\n" + "=" * 60)
|
||||
print("MaiSaka LLM Request - generate_reply:")
|
||||
print(f" {messages}")
|
||||
print("="*60 + "\n")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
try:
|
||||
response, _ = await self._llm_replyer.generate_response_async(
|
||||
|
||||
Reference in New Issue
Block a user