移除残留的KnowU系统,修复gemini请求的思考签名问题
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,7 +6,6 @@ data/
|
|||||||
!pytests/A_memorix_test/data/real_dialogues/private_alice_weekend.json
|
!pytests/A_memorix_test/data/real_dialogues/private_alice_weekend.json
|
||||||
pytests/A_memorix_test/data/benchmarks/results/
|
pytests/A_memorix_test/data/benchmarks/results/
|
||||||
data1/
|
data1/
|
||||||
mai_knowledge/knowledge.json
|
|
||||||
mongodb/
|
mongodb/
|
||||||
NapCat.Framework.Windows.Once/
|
NapCat.Framework.Windows.Once/
|
||||||
NapCat.Framework.Windows.OneKey/
|
NapCat.Framework.Windows.OneKey/
|
||||||
|
|||||||
72
pytests/test_gemini_thought_signatures.py
Normal file
72
pytests/test_gemini_thought_signatures.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import base64
|
||||||
|
import sys
|
||||||
|
from types import ModuleType, SimpleNamespace
|
||||||
|
|
||||||
|
|
||||||
|
config_module = ModuleType("src.config.config")
|
||||||
|
|
||||||
|
|
||||||
|
class _ConfigManagerStub:
|
||||||
|
def get_model_config(self) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(api_providers=[])
|
||||||
|
|
||||||
|
def register_reload_callback(self, _: object) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
config_module.config_manager = _ConfigManagerStub()
|
||||||
|
sys.modules.setdefault("src.config.config", config_module)
|
||||||
|
|
||||||
|
from src.llm_models.model_client import gemini_client
|
||||||
|
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||||
|
from src.llm_models.payload_content.tool_option import ToolCall
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_signature(value: bytes) -> str:
|
||||||
|
return base64.b64encode(value).decode("ascii")
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_messages_preserves_gemini_function_call_signature_and_tool_result_id() -> None:
|
||||||
|
thought_signature = b"gemini-signature"
|
||||||
|
tool_call = ToolCall(
|
||||||
|
call_id="call-1",
|
||||||
|
func_name="reply",
|
||||||
|
args={"msg_id": "42"},
|
||||||
|
extra_content={"google": {"thought_signature": _encode_signature(thought_signature)}},
|
||||||
|
)
|
||||||
|
assistant_message = MessageBuilder().set_role(RoleType.Assistant).set_tool_calls([tool_call]).build()
|
||||||
|
tool_message = (
|
||||||
|
MessageBuilder()
|
||||||
|
.set_role(RoleType.Tool)
|
||||||
|
.set_tool_call_id("call-1")
|
||||||
|
.set_tool_name("reply")
|
||||||
|
.add_text_content('{"ok": true}')
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
contents, _ = gemini_client._convert_messages([assistant_message, tool_message])
|
||||||
|
|
||||||
|
assistant_part = contents[0].parts[0]
|
||||||
|
assert assistant_part.function_call is not None
|
||||||
|
assert assistant_part.function_call.id == "call-1"
|
||||||
|
assert assistant_part.function_call.name == "reply"
|
||||||
|
assert assistant_part.thought_signature == thought_signature
|
||||||
|
|
||||||
|
tool_part = contents[1].parts[0]
|
||||||
|
assert tool_part.function_response is not None
|
||||||
|
assert tool_part.function_response.id == "call-1"
|
||||||
|
assert tool_part.function_response.name == "reply"
|
||||||
|
assert tool_part.function_response.response == {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_messages_injects_dummy_signature_for_first_historical_tool_call() -> None:
|
||||||
|
tool_calls = [
|
||||||
|
ToolCall(call_id="call-1", func_name="reply", args={"msg_id": "1"}),
|
||||||
|
ToolCall(call_id="call-2", func_name="reply", args={"msg_id": "2"}),
|
||||||
|
]
|
||||||
|
assistant_message = MessageBuilder().set_role(RoleType.Assistant).set_tool_calls(tool_calls).build()
|
||||||
|
|
||||||
|
contents, _ = gemini_client._convert_messages([assistant_message])
|
||||||
|
|
||||||
|
assert contents[0].parts[0].thought_signature == gemini_client.GEMINI_FALLBACK_THOUGHT_SIGNATURE
|
||||||
|
assert contents[0].parts[1].thought_signature is None
|
||||||
@@ -135,6 +135,7 @@ class LLMServiceResult(BaseDataModel):
|
|||||||
"name": tool_call.func_name,
|
"name": tool_call.func_name,
|
||||||
"arguments": tool_call.args or {},
|
"arguments": tool_call.args or {},
|
||||||
},
|
},
|
||||||
|
**({"extra_content": tool_call.extra_content} if tool_call.extra_content else {}),
|
||||||
}
|
}
|
||||||
for tool_call in self.completion.tool_calls
|
for tool_call in self.completion.tool_calls
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -221,22 +221,6 @@ class Jargon(SQLModel, table=True):
|
|||||||
inference_with_content_only: Optional[str] = Field(
|
inference_with_content_only: Optional[str] = Field(
|
||||||
default=None, sa_column=Column(Text, nullable=True)
|
default=None, sa_column=Column(Text, nullable=True)
|
||||||
) # 只基于词条的推断结果,JSON格式
|
) # 只基于词条的推断结果,JSON格式
|
||||||
|
|
||||||
|
|
||||||
class MaiKnowledge(SQLModel, table=True):
|
|
||||||
"""存储 Maisaka 的用户画像知识。"""
|
|
||||||
|
|
||||||
__tablename__ = "mai_knowledge" # type: ignore
|
|
||||||
|
|
||||||
id: Optional[int] = Field(default=None, primary_key=True)
|
|
||||||
knowledge_id: str = Field(index=True, max_length=255)
|
|
||||||
category_id: str = Field(index=True, max_length=32)
|
|
||||||
content: str
|
|
||||||
normalized_content: str = Field(index=True)
|
|
||||||
metadata_json: Optional[str] = Field(default=None, nullable=True)
|
|
||||||
created_at: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True))
|
|
||||||
|
|
||||||
|
|
||||||
class ChatHistory(SQLModel, table=True):
|
class ChatHistory(SQLModel, table=True):
|
||||||
"""存储聊天历史记录的模型"""
|
"""存储聊天历史记录的模型"""
|
||||||
|
|
||||||
|
|||||||
@@ -138,18 +138,6 @@ _V2_TABLE_STATEMENTS = (
|
|||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
"""
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS mai_knowledge (
|
|
||||||
id INTEGER NOT NULL,
|
|
||||||
knowledge_id VARCHAR(255) NOT NULL,
|
|
||||||
category_id VARCHAR(32) NOT NULL,
|
|
||||||
content VARCHAR NOT NULL,
|
|
||||||
normalized_content VARCHAR NOT NULL,
|
|
||||||
metadata_json VARCHAR,
|
|
||||||
created_at DATETIME,
|
|
||||||
PRIMARY KEY (id)
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS mai_messages (
|
CREATE TABLE IF NOT EXISTS mai_messages (
|
||||||
id INTEGER NOT NULL,
|
id INTEGER NOT NULL,
|
||||||
message_id VARCHAR(255) NOT NULL,
|
message_id VARCHAR(255) NOT NULL,
|
||||||
@@ -260,10 +248,6 @@ _V2_INDEX_STATEMENTS = (
|
|||||||
"CREATE INDEX IF NOT EXISTS ix_llm_usage_model_assign_name ON llm_usage (model_assign_name)",
|
"CREATE INDEX IF NOT EXISTS ix_llm_usage_model_assign_name ON llm_usage (model_assign_name)",
|
||||||
"CREATE INDEX IF NOT EXISTS ix_llm_usage_model_name ON llm_usage (model_name)",
|
"CREATE INDEX IF NOT EXISTS ix_llm_usage_model_name ON llm_usage (model_name)",
|
||||||
"CREATE INDEX IF NOT EXISTS ix_llm_usage_timestamp ON llm_usage (timestamp)",
|
"CREATE INDEX IF NOT EXISTS ix_llm_usage_timestamp ON llm_usage (timestamp)",
|
||||||
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_category_id ON mai_knowledge (category_id)",
|
|
||||||
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_created_at ON mai_knowledge (created_at)",
|
|
||||||
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_knowledge_id ON mai_knowledge (knowledge_id)",
|
|
||||||
"CREATE INDEX IF NOT EXISTS ix_mai_knowledge_normalized_content ON mai_knowledge (normalized_content)",
|
|
||||||
"CREATE INDEX IF NOT EXISTS ix_mai_messages_group_id ON mai_messages (group_id)",
|
"CREATE INDEX IF NOT EXISTS ix_mai_messages_group_id ON mai_messages (group_id)",
|
||||||
"CREATE INDEX IF NOT EXISTS ix_mai_messages_message_id ON mai_messages (message_id)",
|
"CREATE INDEX IF NOT EXISTS ix_mai_messages_message_id ON mai_messages (message_id)",
|
||||||
"CREATE INDEX IF NOT EXISTS ix_mai_messages_platform ON mai_messages (platform)",
|
"CREATE INDEX IF NOT EXISTS ix_mai_messages_platform ON mai_messages (platform)",
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
|||||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||||
MMC_VERSION: str = "1.0.0"
|
MMC_VERSION: str = "1.0.0"
|
||||||
CONFIG_VERSION: str = "8.4.1"
|
CONFIG_VERSION: str = "8.5.0"
|
||||||
MODEL_CONFIG_VERSION: str = "1.13.1"
|
MODEL_CONFIG_VERSION: str = "1.13.1"
|
||||||
|
|
||||||
logger = get_logger("config")
|
logger = get_logger("config")
|
||||||
|
|||||||
@@ -1456,15 +1456,6 @@ class MaiSakaConfig(ConfigBase):
|
|||||||
|
|
||||||
__ui_label__ = "MaiSaka"
|
__ui_label__ = "MaiSaka"
|
||||||
__ui_icon__ = "message-circle"
|
__ui_icon__ = "message-circle"
|
||||||
|
|
||||||
enable_knowledge_module: bool = Field(
|
|
||||||
default=True,
|
|
||||||
json_schema_extra={
|
|
||||||
"x-widget": "switch",
|
|
||||||
"x-icon": "book",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
"""启用知识库模块"""
|
|
||||||
cli_user_name: str = Field(
|
cli_user_name: str = Field(
|
||||||
default="用户",
|
default="用户",
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -315,6 +315,9 @@ class MaiMessages:
|
|||||||
call_id=str(tool_call.get("call_id", "")),
|
call_id=str(tool_call.get("call_id", "")),
|
||||||
func_name=str(tool_call.get("func_name", "")),
|
func_name=str(tool_call.get("func_name", "")),
|
||||||
args=tool_call.get("args"),
|
args=tool_call.get("args"),
|
||||||
|
extra_content=tool_call.get("extra_content")
|
||||||
|
if isinstance(tool_call.get("extra_content"), dict)
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return deserialized_tool_calls
|
return deserialized_tool_calls
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
"""
|
|
||||||
Knowledge utilities package for Maisaka.
|
|
||||||
"""
|
|
||||||
@@ -1,363 +0,0 @@
|
|||||||
"""
|
|
||||||
Maisaka knowledge retrieval and learning helpers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
|
|
||||||
from src.chat.message_receive.message import SessionMessage
|
|
||||||
from src.chat.utils.utils import is_bot_self
|
|
||||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.know_u.knowledge_store import KNOWLEDGE_CATEGORIES, get_knowledge_store
|
|
||||||
from src.maisaka.context_messages import AssistantMessage, LLMContextMessage, SessionBackedMessage, ToolResultMessage
|
|
||||||
from src.maisaka.message_adapter import parse_speaker_content
|
|
||||||
from src.person_info.person_info import Person
|
|
||||||
from src.services.llm_service import LLMServiceClient
|
|
||||||
|
|
||||||
logger = get_logger("maisaka_knowledge")
|
|
||||||
|
|
||||||
NO_RESULT_KEYWORDS = [
|
|
||||||
"无",
|
|
||||||
"没有",
|
|
||||||
"不适用",
|
|
||||||
"无需",
|
|
||||||
"无相关",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def extract_category_ids_from_result(result: str) -> List[str]:
|
|
||||||
"""Extract valid category ids from an LLM result string."""
|
|
||||||
if not result:
|
|
||||||
return []
|
|
||||||
|
|
||||||
normalized = result.strip()
|
|
||||||
if not normalized:
|
|
||||||
return []
|
|
||||||
|
|
||||||
lowered = normalized.lower()
|
|
||||||
if any(keyword in lowered for keyword in ["none", "no relevant", "no_need", "no need"]):
|
|
||||||
return []
|
|
||||||
if any(keyword in normalized for keyword in NO_RESULT_KEYWORDS):
|
|
||||||
return []
|
|
||||||
|
|
||||||
category_ids: List[str] = []
|
|
||||||
for part in normalized.replace(",", " ").replace(",", " ").replace("\n", " ").split():
|
|
||||||
candidate = part.strip()
|
|
||||||
if candidate in KNOWLEDGE_CATEGORIES and candidate not in category_ids:
|
|
||||||
category_ids.append(candidate)
|
|
||||||
|
|
||||||
return category_ids
|
|
||||||
|
|
||||||
|
|
||||||
async def retrieve_relevant_knowledge(
|
|
||||||
knowledge_analyzer: Any,
|
|
||||||
chat_history: List[LLMContextMessage],
|
|
||||||
) -> str:
|
|
||||||
"""Retrieve formatted knowledge snippets relevant to the current chat history."""
|
|
||||||
store = get_knowledge_store()
|
|
||||||
categories_summary = store.get_categories_summary()
|
|
||||||
|
|
||||||
try:
|
|
||||||
category_ids = await knowledge_analyzer.analyze_knowledge_need(chat_history, categories_summary)
|
|
||||||
if not category_ids:
|
|
||||||
return ""
|
|
||||||
return store.get_formatted_knowledge(category_ids)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("检索相关知识失败")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeLearner:
|
|
||||||
"""
|
|
||||||
从最近对话中提取用户画像类知识并写入知识库。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, session_id: str) -> None:
|
|
||||||
self._session_id = session_id
|
|
||||||
self._store = get_knowledge_store()
|
|
||||||
self._llm = LLMServiceClient(task_name="utils", request_type="maisaka.knowledge.learn")
|
|
||||||
self._learning_lock = asyncio.Lock()
|
|
||||||
self._last_processed_index = 0
|
|
||||||
self.min_messages_for_extraction = 10
|
|
||||||
|
|
||||||
def get_pending_count(self, message_cache: List[SessionMessage]) -> int:
|
|
||||||
"""??????????????"""
|
|
||||||
return max(0, len(message_cache) - self._last_processed_index)
|
|
||||||
|
|
||||||
async def learn(self, message_cache: List[SessionMessage]) -> int:
|
|
||||||
"""?????????????????????"""
|
|
||||||
pending_messages = message_cache[self._last_processed_index :]
|
|
||||||
if not pending_messages:
|
|
||||||
return 0
|
|
||||||
if len(pending_messages) < self.min_messages_for_extraction:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
async with self._learning_lock:
|
|
||||||
chat_excerpt = self._build_chat_excerpt(pending_messages)
|
|
||||||
if not chat_excerpt:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
prompt = self._build_learning_prompt(chat_excerpt)
|
|
||||||
try:
|
|
||||||
result = await self._llm.generate_response(
|
|
||||||
prompt=prompt,
|
|
||||||
options=LLMGenerationOptions(
|
|
||||||
temperature=0.1,
|
|
||||||
max_tokens=512,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("??????????")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
knowledge_items = self._parse_learning_result(result.response or "")
|
|
||||||
if not knowledge_items:
|
|
||||||
logger.debug("?????????????????")
|
|
||||||
self._last_processed_index = len(message_cache)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
added_count = 0
|
|
||||||
for item in knowledge_items:
|
|
||||||
category_id = str(item.get("category_id", "")).strip()
|
|
||||||
content = str(item.get("content", "")).strip()
|
|
||||||
if not category_id or not content:
|
|
||||||
continue
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"session_id": self._session_id,
|
|
||||||
"source": "maisaka_learning",
|
|
||||||
}
|
|
||||||
for field_name in ("platform", "user_id", "user_nickname", "person_name"):
|
|
||||||
field_value = str(item.get(field_name, "")).strip()
|
|
||||||
if field_value:
|
|
||||||
metadata[field_name] = field_value
|
|
||||||
|
|
||||||
if self._store.add_knowledge(
|
|
||||||
category_id=category_id,
|
|
||||||
content=content,
|
|
||||||
metadata=metadata,
|
|
||||||
):
|
|
||||||
added_count += 1
|
|
||||||
|
|
||||||
self._last_processed_index = len(message_cache)
|
|
||||||
|
|
||||||
if added_count > 0:
|
|
||||||
logger.info(
|
|
||||||
f"Maisaka ???????: ????={self._session_id} ????={added_count}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
f"Maisaka ???????????????: ????={self._session_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return added_count
|
|
||||||
|
|
||||||
def _build_chat_excerpt(self, messages: List[SessionMessage]) -> str:
|
|
||||||
"""
|
|
||||||
构建适合画像提取的对话片段,只保留用户可见文本。
|
|
||||||
"""
|
|
||||||
lines: List[str] = []
|
|
||||||
for message in messages[-30:]:
|
|
||||||
if isinstance(message, (AssistantMessage, ToolResultMessage)):
|
|
||||||
continue
|
|
||||||
if isinstance(message, SessionBackedMessage):
|
|
||||||
if message.original_message and is_bot_self(
|
|
||||||
message.original_message.platform,
|
|
||||||
message.original_message.message_info.user_info.user_id,
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
raw_text = message.processed_plain_text.strip()
|
|
||||||
fallback_speaker = (
|
|
||||||
message.original_message.message_info.user_info.user_nickname
|
|
||||||
if message.original_message is not None
|
|
||||||
else "用户"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if is_bot_self(message.platform, message.message_info.user_info.user_id):
|
|
||||||
continue
|
|
||||||
raw_text = message.processed_plain_text.strip()
|
|
||||||
fallback_speaker = message.message_info.user_info.user_nickname or "用户"
|
|
||||||
|
|
||||||
if not raw_text:
|
|
||||||
continue
|
|
||||||
|
|
||||||
speaker_name, body = parse_speaker_content(raw_text)
|
|
||||||
visible_text = (body or raw_text).strip()
|
|
||||||
if not visible_text:
|
|
||||||
continue
|
|
||||||
|
|
||||||
speaker = speaker_name or fallback_speaker
|
|
||||||
user_metadata = self._extract_message_user_metadata(message)
|
|
||||||
metadata_parts = [
|
|
||||||
f"platform={user_metadata['platform'] or 'unknown'}",
|
|
||||||
f"user_id={user_metadata['user_id'] or 'unknown'}",
|
|
||||||
f"user_nickname={user_metadata['user_nickname'] or speaker}",
|
|
||||||
f"person_name={user_metadata['person_name'] or ''}",
|
|
||||||
]
|
|
||||||
lines.append(
|
|
||||||
f"[用户信息] {'; '.join(metadata_parts)}\n"
|
|
||||||
f"[发言] {speaker}: {visible_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_message_user_metadata(message: SessionMessage) -> Dict[str, str]:
|
|
||||||
"""提取消息对应的用户元信息。"""
|
|
||||||
source_message = message.original_message if isinstance(message, SessionBackedMessage) else message
|
|
||||||
platform = str(getattr(source_message, "platform", "") or "").strip()
|
|
||||||
user_info = getattr(getattr(source_message, "message_info", None), "user_info", None)
|
|
||||||
user_id = str(getattr(user_info, "user_id", "") or "").strip()
|
|
||||||
user_nickname = str(getattr(user_info, "user_nickname", "") or "").strip()
|
|
||||||
|
|
||||||
person_name = ""
|
|
||||||
if platform and user_id:
|
|
||||||
try:
|
|
||||||
person = Person(platform=platform, user_id=user_id)
|
|
||||||
if person.is_known and person.person_name:
|
|
||||||
person_name = str(person.person_name).strip()
|
|
||||||
except Exception:
|
|
||||||
person_name = ""
|
|
||||||
|
|
||||||
return {
|
|
||||||
"platform": platform,
|
|
||||||
"user_id": user_id,
|
|
||||||
"user_nickname": user_nickname,
|
|
||||||
"person_name": person_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _build_learning_prompt(self, chat_excerpt: str) -> str:
|
|
||||||
"""构建知识提取提示词。"""
|
|
||||||
categories_text = "\n".join(
|
|
||||||
f"{category_id}. {category_name}" for category_id, category_name in KNOWLEDGE_CATEGORIES.items()
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
"你是一个用户画像知识提取器,需要从聊天记录里提取稳定、可复用的用户事实。\n"
|
|
||||||
"只提取用户明确表达或高置信度可归纳的信息,不要猜测,不要提取一次性情绪,不要重复表述。\n"
|
|
||||||
"如果没有可提取内容,返回空数组 []。\n"
|
|
||||||
"输出必须是 JSON 数组,每项格式为 "
|
|
||||||
'{"category_id":"分类编号","content":"简洁中文陈述"}。\n'
|
|
||||||
"分类如下:\n"
|
|
||||||
f"{categories_text}\n\n"
|
|
||||||
"聊天记录:\n"
|
|
||||||
f"{chat_excerpt}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_learning_result(self, result: str) -> List[Dict[str, str]]:
|
|
||||||
"""解析模型返回的知识条目。"""
|
|
||||||
normalized = result.strip()
|
|
||||||
if not normalized:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if "```" in normalized:
|
|
||||||
normalized = normalized.replace("```json", "").replace("```JSON", "").replace("```", "").strip()
|
|
||||||
|
|
||||||
try:
|
|
||||||
parsed = json.loads(normalized)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning("知识学习结果不是有效的 JSON")
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not isinstance(parsed, list):
|
|
||||||
return []
|
|
||||||
|
|
||||||
normalized_items: List[Dict[str, str]] = []
|
|
||||||
seen_pairs: set[tuple[str, str]] = set()
|
|
||||||
for item in parsed:
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
category_id = str(item.get("category_id", "")).strip()
|
|
||||||
content = " ".join(str(item.get("content", "")).strip().split())
|
|
||||||
if category_id not in KNOWLEDGE_CATEGORIES:
|
|
||||||
continue
|
|
||||||
if not content:
|
|
||||||
continue
|
|
||||||
|
|
||||||
pair = (category_id, content)
|
|
||||||
if pair in seen_pairs:
|
|
||||||
continue
|
|
||||||
seen_pairs.add(pair)
|
|
||||||
normalized_items.append(
|
|
||||||
{
|
|
||||||
"category_id": category_id,
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return normalized_items
|
|
||||||
|
|
||||||
def _build_learning_prompt(self, chat_excerpt: str) -> str:
|
|
||||||
"""构建知识提取提示词。"""
|
|
||||||
categories_text = "\n".join(
|
|
||||||
f"{category_id}. {category_name}" for category_id, category_name in KNOWLEDGE_CATEGORIES.items()
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
"你是一个用户画像知识提取器,需要从聊天记录里提取稳定、可复用的用户事实。\n"
|
|
||||||
"聊天记录每条发言前都带有用户元信息,你必须明确判断这些特征属于哪个用户。\n"
|
|
||||||
"只提取用户明确表达或高置信度可归纳的信息,不要猜测,不要提取一次性情绪,不要重复表达。\n"
|
|
||||||
"如果没有可提取内容,返回空数组[]。\n"
|
|
||||||
"输出必须是 JSON 数组,每项格式为 "
|
|
||||||
'{"category_id":"分类编号","content":"简洁中文陈述","platform":"平台","user_id":"用户ID","user_nickname":"用户昵称","person_name":"人物名或空字符串"}。\n'
|
|
||||||
"其中 platform 和 user_id 必填;user_nickname 尽量填写;person_name 仅在用户信息中明确给出时填写,否则填空字符串。\n"
|
|
||||||
"同一条知识只能归属到一个用户,不要混合不同人的信息。\n"
|
|
||||||
"分类如下:\n"
|
|
||||||
f"{categories_text}\n\n"
|
|
||||||
"聊天记录:\n"
|
|
||||||
f"{chat_excerpt}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_learning_result(self, result: str) -> List[Dict[str, str]]:
|
|
||||||
"""解析模型返回的知识条目。"""
|
|
||||||
normalized = result.strip()
|
|
||||||
if not normalized:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if "```" in normalized:
|
|
||||||
normalized = normalized.replace("```json", "").replace("```JSON", "").replace("```", "").strip()
|
|
||||||
|
|
||||||
try:
|
|
||||||
parsed = json.loads(normalized)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning("知识学习结果不是有效的 JSON")
|
|
||||||
return []
|
|
||||||
|
|
||||||
if not isinstance(parsed, list):
|
|
||||||
return []
|
|
||||||
|
|
||||||
normalized_items: List[Dict[str, str]] = []
|
|
||||||
seen_pairs: set[tuple[str, str, str, str]] = set()
|
|
||||||
for item in parsed:
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
category_id = str(item.get("category_id", "")).strip()
|
|
||||||
content = " ".join(str(item.get("content", "")).strip().split())
|
|
||||||
platform = str(item.get("platform", "")).strip()
|
|
||||||
user_id = str(item.get("user_id", "")).strip()
|
|
||||||
user_nickname = str(item.get("user_nickname", "")).strip()
|
|
||||||
person_name = str(item.get("person_name", "")).strip()
|
|
||||||
if category_id not in KNOWLEDGE_CATEGORIES:
|
|
||||||
continue
|
|
||||||
if not content or not platform or not user_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
pair = (category_id, content, platform, user_id)
|
|
||||||
if pair in seen_pairs:
|
|
||||||
continue
|
|
||||||
seen_pairs.add(pair)
|
|
||||||
normalized_items.append(
|
|
||||||
{
|
|
||||||
"category_id": category_id,
|
|
||||||
"content": content,
|
|
||||||
"platform": platform,
|
|
||||||
"user_id": user_id,
|
|
||||||
"user_nickname": user_nickname,
|
|
||||||
"person_name": person_name,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return normalized_items
|
|
||||||
@@ -1,370 +0,0 @@
|
|||||||
"""
|
|
||||||
MaiSaka knowledge store.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from sqlmodel import col, select
|
|
||||||
|
|
||||||
from src.common.database.database import DATABASE_URL, get_db_session
|
|
||||||
from src.common.database.database_model import MaiKnowledge
|
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
|
||||||
KNOWLEDGE_DATA_DIR = PROJECT_ROOT / "mai_knowledge"
|
|
||||||
KNOWLEDGE_FILE = KNOWLEDGE_DATA_DIR / "knowledge.json"
|
|
||||||
|
|
||||||
|
|
||||||
KNOWLEDGE_CATEGORIES = {
|
|
||||||
"1": "性别",
|
|
||||||
"2": "性格",
|
|
||||||
"3": "饮食口味",
|
|
||||||
"4": "交友偏好",
|
|
||||||
"5": "情绪/理性倾向",
|
|
||||||
"6": "兴趣爱好",
|
|
||||||
"7": "职业/专业",
|
|
||||||
"8": "生活习惯",
|
|
||||||
"9": "价值观",
|
|
||||||
"10": "沟通风格",
|
|
||||||
"11": "学习方式",
|
|
||||||
"12": "压力应对方式",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeStore:
|
|
||||||
"""存储 Maisaka 的用户画像知识。"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""初始化知识存储,并在需要时迁移旧版 JSON 数据。"""
|
|
||||||
self._ensure_legacy_data_dir()
|
|
||||||
self._migrate_legacy_file_if_needed()
|
|
||||||
|
|
||||||
def _ensure_legacy_data_dir(self) -> None:
|
|
||||||
"""确保旧版知识目录存在,便于兼容历史数据。"""
|
|
||||||
KNOWLEDGE_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _normalize_content(content: str) -> str:
|
|
||||||
"""标准化知识内容,便于去重。"""
|
|
||||||
return " ".join(str(content).strip().split())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _serialize_metadata(metadata: Optional[Dict[str, Any]]) -> Optional[str]:
|
|
||||||
"""将元数据序列化为 JSON 文本。"""
|
|
||||||
if not metadata:
|
|
||||||
return None
|
|
||||||
return json.dumps(metadata, ensure_ascii=False, sort_keys=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _deserialize_metadata(raw_text: Optional[str]) -> Dict[str, Any]:
|
|
||||||
"""将 JSON 文本反序列化为元数据字典。"""
|
|
||||||
if not raw_text:
|
|
||||||
return {}
|
|
||||||
try:
|
|
||||||
parsed = json.loads(raw_text)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return {}
|
|
||||||
return parsed if isinstance(parsed, dict) else {}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_created_at(raw_value: Any) -> datetime:
|
|
||||||
"""解析旧版数据中的创建时间。"""
|
|
||||||
if isinstance(raw_value, datetime):
|
|
||||||
return raw_value
|
|
||||||
if isinstance(raw_value, str):
|
|
||||||
raw_text = raw_value.strip()
|
|
||||||
if raw_text:
|
|
||||||
try:
|
|
||||||
return datetime.fromisoformat(raw_text)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
return datetime.now()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _build_item_dict(cls, record: MaiKnowledge) -> Dict[str, Any]:
|
|
||||||
"""将数据库记录转换为兼容旧接口的字典。"""
|
|
||||||
return {
|
|
||||||
"id": record.knowledge_id,
|
|
||||||
"content": record.content,
|
|
||||||
"metadata": cls._deserialize_metadata(record.metadata_json),
|
|
||||||
"created_at": record.created_at.isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _load_legacy_knowledge_file(self) -> Dict[str, List[Dict[str, Any]]]:
|
|
||||||
"""读取旧版 JSON 知识文件。"""
|
|
||||||
if not KNOWLEDGE_FILE.exists():
|
|
||||||
return {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(KNOWLEDGE_FILE, "r", encoding="utf-8") as file:
|
|
||||||
loaded = json.load(file)
|
|
||||||
except Exception:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
if not isinstance(loaded, dict):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
normalized_knowledge: Dict[str, List[Dict[str, Any]]] = {}
|
|
||||||
for category_id in KNOWLEDGE_CATEGORIES:
|
|
||||||
category_items = loaded.get(category_id, [])
|
|
||||||
if isinstance(category_items, list):
|
|
||||||
normalized_knowledge[category_id] = [
|
|
||||||
item for item in category_items if isinstance(item, dict)
|
|
||||||
]
|
|
||||||
return normalized_knowledge
|
|
||||||
|
|
||||||
def _migrate_legacy_file_if_needed(self) -> None:
|
|
||||||
"""在数据库为空时,将旧版 JSON 中的知识导入数据库。"""
|
|
||||||
legacy_knowledge = self._load_legacy_knowledge_file()
|
|
||||||
if not legacy_knowledge:
|
|
||||||
return
|
|
||||||
|
|
||||||
with get_db_session(auto_commit=False) as session:
|
|
||||||
existing_record = session.exec(select(MaiKnowledge.id).limit(1)).first()
|
|
||||||
if existing_record is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
for category_id, items in legacy_knowledge.items():
|
|
||||||
if category_id not in KNOWLEDGE_CATEGORIES:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for item in items:
|
|
||||||
content = self._normalize_content(str(item.get("content", "")))
|
|
||||||
if not content:
|
|
||||||
continue
|
|
||||||
|
|
||||||
metadata = item.get("metadata")
|
|
||||||
session.add(
|
|
||||||
MaiKnowledge(
|
|
||||||
knowledge_id=str(item.get("id") or f"know_{category_id}_{datetime.now().timestamp()}"),
|
|
||||||
category_id=category_id,
|
|
||||||
content=content,
|
|
||||||
normalized_content=content,
|
|
||||||
metadata_json=self._serialize_metadata(metadata if isinstance(metadata, dict) else None),
|
|
||||||
created_at=self._parse_created_at(item.get("created_at")),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
def add_knowledge(
|
|
||||||
self,
|
|
||||||
category_id: str,
|
|
||||||
content: str,
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> bool:
|
|
||||||
"""添加一条知识信息。"""
|
|
||||||
if category_id not in KNOWLEDGE_CATEGORIES:
|
|
||||||
return False
|
|
||||||
|
|
||||||
normalized_content = self._normalize_content(content)
|
|
||||||
if not normalized_content:
|
|
||||||
return False
|
|
||||||
|
|
||||||
user_platform = str((metadata or {}).get("platform", "")).strip()
|
|
||||||
user_id = str((metadata or {}).get("user_id", "")).strip()
|
|
||||||
with get_db_session(auto_commit=False) as session:
|
|
||||||
existing_records = session.exec(
|
|
||||||
select(MaiKnowledge).where(
|
|
||||||
MaiKnowledge.category_id == category_id,
|
|
||||||
MaiKnowledge.normalized_content == normalized_content,
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
for existing_record in existing_records:
|
|
||||||
existing_metadata = self._deserialize_metadata(existing_record.metadata_json)
|
|
||||||
existing_platform = str(existing_metadata.get("platform", "")).strip()
|
|
||||||
existing_user_id = str(existing_metadata.get("user_id", "")).strip()
|
|
||||||
if user_platform and user_id:
|
|
||||||
if existing_platform == user_platform and existing_user_id == user_id:
|
|
||||||
return False
|
|
||||||
continue
|
|
||||||
if not existing_platform and not existing_user_id:
|
|
||||||
return False
|
|
||||||
|
|
||||||
session.add(
|
|
||||||
MaiKnowledge(
|
|
||||||
knowledge_id=f"know_{category_id}_{datetime.now().timestamp()}",
|
|
||||||
category_id=category_id,
|
|
||||||
content=normalized_content,
|
|
||||||
normalized_content=normalized_content,
|
|
||||||
metadata_json=self._serialize_metadata(metadata),
|
|
||||||
created_at=datetime.now(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
def search_knowledge(
|
|
||||||
self,
|
|
||||||
keyword: str,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""按关键词搜索知识内容。"""
|
|
||||||
normalized_keyword = self._normalize_content(keyword)
|
|
||||||
if not normalized_keyword:
|
|
||||||
return []
|
|
||||||
|
|
||||||
limit_value = max(1, int(limit))
|
|
||||||
with get_db_session() as session:
|
|
||||||
records = session.exec(
|
|
||||||
select(MaiKnowledge)
|
|
||||||
.where(
|
|
||||||
col(MaiKnowledge.content).contains(normalized_keyword)
|
|
||||||
| col(MaiKnowledge.normalized_content).contains(normalized_keyword)
|
|
||||||
)
|
|
||||||
.order_by(MaiKnowledge.created_at.desc(), MaiKnowledge.id.desc())
|
|
||||||
.limit(limit_value)
|
|
||||||
).all()
|
|
||||||
|
|
||||||
results: List[Dict[str, Any]] = []
|
|
||||||
for record in records:
|
|
||||||
item = self._build_item_dict(record)
|
|
||||||
item["category_id"] = record.category_id
|
|
||||||
item["category_name"] = self.get_category_name(record.category_id)
|
|
||||||
results.append(item)
|
|
||||||
return results
|
|
||||||
|
|
||||||
def get_knowledge_by_user(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
platform: str = "",
|
|
||||||
user_id: str = "",
|
|
||||||
user_nickname: str = "",
|
|
||||||
person_name: str = "",
|
|
||||||
limit: int = 10,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""按用户元信息筛选知识条目。"""
|
|
||||||
platform = str(platform).strip()
|
|
||||||
user_id = str(user_id).strip()
|
|
||||||
user_nickname = str(user_nickname).strip()
|
|
||||||
person_name = str(person_name).strip()
|
|
||||||
if not any((platform, user_id, user_nickname, person_name)):
|
|
||||||
return []
|
|
||||||
|
|
||||||
limit_value = max(1, int(limit))
|
|
||||||
with get_db_session() as session:
|
|
||||||
records = session.exec(
|
|
||||||
select(MaiKnowledge).order_by(MaiKnowledge.created_at.desc(), MaiKnowledge.id.desc())
|
|
||||||
).all()
|
|
||||||
|
|
||||||
results: List[Dict[str, Any]] = []
|
|
||||||
for record in records:
|
|
||||||
metadata = self._deserialize_metadata(record.metadata_json)
|
|
||||||
if user_id and str(metadata.get("user_id", "")).strip() != user_id:
|
|
||||||
continue
|
|
||||||
if platform and str(metadata.get("platform", "")).strip() != platform:
|
|
||||||
continue
|
|
||||||
if user_nickname and str(metadata.get("user_nickname", "")).strip() != user_nickname:
|
|
||||||
continue
|
|
||||||
if person_name and str(metadata.get("person_name", "")).strip() != person_name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
item = self._build_item_dict(record)
|
|
||||||
item["category_id"] = record.category_id
|
|
||||||
item["category_name"] = self.get_category_name(record.category_id)
|
|
||||||
results.append(item)
|
|
||||||
if len(results) >= limit_value:
|
|
||||||
break
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def get_category_knowledge(self, category_id: str) -> List[Dict[str, Any]]:
|
|
||||||
"""获取某个分类下的所有知识。"""
|
|
||||||
if category_id not in KNOWLEDGE_CATEGORIES:
|
|
||||||
return []
|
|
||||||
|
|
||||||
with get_db_session() as session:
|
|
||||||
records = session.exec(
|
|
||||||
select(MaiKnowledge)
|
|
||||||
.where(MaiKnowledge.category_id == category_id)
|
|
||||||
.order_by(MaiKnowledge.created_at.asc(), MaiKnowledge.id.asc())
|
|
||||||
).all()
|
|
||||||
return [self._build_item_dict(record) for record in records]
|
|
||||||
|
|
||||||
def get_all_knowledge(self) -> Dict[str, List[Dict[str, Any]]]:
|
|
||||||
"""获取全部知识。"""
|
|
||||||
all_knowledge: Dict[str, List[Dict[str, Any]]] = {
|
|
||||||
category_id: [] for category_id in KNOWLEDGE_CATEGORIES
|
|
||||||
}
|
|
||||||
with get_db_session() as session:
|
|
||||||
records = session.exec(
|
|
||||||
select(MaiKnowledge).order_by(
|
|
||||||
MaiKnowledge.category_id.asc(),
|
|
||||||
MaiKnowledge.created_at.asc(),
|
|
||||||
MaiKnowledge.id.asc(),
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
|
|
||||||
for record in records:
|
|
||||||
all_knowledge.setdefault(record.category_id, []).append(self._build_item_dict(record))
|
|
||||||
return all_knowledge
|
|
||||||
|
|
||||||
def get_category_name(self, category_id: str) -> str:
|
|
||||||
"""获取分类名称。"""
|
|
||||||
return KNOWLEDGE_CATEGORIES.get(category_id, "未知分类")
|
|
||||||
|
|
||||||
def get_categories_summary(self) -> str:
|
|
||||||
"""获取分类摘要,供模型判断是否需要检索。"""
|
|
||||||
counts: Dict[str, int] = {category_id: 0 for category_id in KNOWLEDGE_CATEGORIES}
|
|
||||||
with get_db_session() as session:
|
|
||||||
records = session.exec(select(MaiKnowledge.category_id)).all()
|
|
||||||
|
|
||||||
for category_id in records:
|
|
||||||
if category_id in counts:
|
|
||||||
counts[category_id] += 1
|
|
||||||
|
|
||||||
lines: List[str] = []
|
|
||||||
for category_id, category_name in KNOWLEDGE_CATEGORIES.items():
|
|
||||||
count = counts.get(category_id, 0)
|
|
||||||
count_text = f"{count}条" if count > 0 else "无数据"
|
|
||||||
lines.append(f"{category_id}. {category_name} ({count_text})")
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
def get_formatted_knowledge(self, category_ids: List[str], limit_per_category: int = 5) -> str:
|
|
||||||
"""获取指定分类的格式化知识内容。"""
|
|
||||||
parts: List[str] = []
|
|
||||||
for category_id in category_ids:
|
|
||||||
items = self.get_category_knowledge(category_id)
|
|
||||||
if not items:
|
|
||||||
continue
|
|
||||||
|
|
||||||
category_name = self.get_category_name(category_id)
|
|
||||||
parts.append(f"【{category_name}】")
|
|
||||||
|
|
||||||
recent_items = items[-limit_per_category:]
|
|
||||||
for item in recent_items:
|
|
||||||
content = str(item.get("content", "")).strip()
|
|
||||||
if content:
|
|
||||||
parts.append(f"- {content}")
|
|
||||||
|
|
||||||
return "\n".join(parts)
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
|
||||||
"""获取知识数据统计。"""
|
|
||||||
with get_db_session() as session:
|
|
||||||
total_items = len(session.exec(select(MaiKnowledge.id)).all())
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_categories": len(KNOWLEDGE_CATEGORIES),
|
|
||||||
"total_items": total_items,
|
|
||||||
"data_file": DATABASE_URL,
|
|
||||||
"data_exists": True,
|
|
||||||
"data_size_kb": 0,
|
|
||||||
"legacy_data_file": str(KNOWLEDGE_FILE),
|
|
||||||
"legacy_data_exists": KNOWLEDGE_FILE.exists(),
|
|
||||||
"storage_type": "database",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
_knowledge_store_instance: Optional[KnowledgeStore] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_knowledge_store() -> KnowledgeStore:
|
|
||||||
"""获取知识存储单例。"""
|
|
||||||
global _knowledge_store_instance
|
|
||||||
if _knowledge_store_instance is None:
|
|
||||||
_knowledge_store_instance = KnowledgeStore()
|
|
||||||
return _knowledge_store_instance
|
|
||||||
@@ -524,6 +524,7 @@ class ToolCall:
|
|||||||
call_id: str
|
call_id: str
|
||||||
func_name: str
|
func_name: str
|
||||||
args: Dict[str, Any] | None = None
|
args: Dict[str, Any] | None = None
|
||||||
|
extra_content: Dict[str, Any] | None = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""执行工具调用的基础校验。
|
"""执行工具调用的基础校验。
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from sqlmodel import col, select
|
|||||||
from src.common.database.database import get_db_session
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import PersonInfo
|
from src.common.database.database_model import PersonInfo
|
||||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||||
from src.know_u.knowledge_store import get_knowledge_store
|
|
||||||
|
|
||||||
from .context import BuiltinToolRuntimeContext
|
from .context import BuiltinToolRuntimeContext
|
||||||
|
|
||||||
@@ -79,7 +78,6 @@ async def handle_tool(
|
|||||||
result: Dict[str, Any] = {
|
result: Dict[str, Any] = {
|
||||||
"query": person_name,
|
"query": person_name,
|
||||||
"persons": persons,
|
"persons": persons,
|
||||||
"related_knowledge": _query_related_knowledge(person_name, persons, limit),
|
|
||||||
}
|
}
|
||||||
return tool_ctx.build_success_result(
|
return tool_ctx.build_success_result(
|
||||||
invocation.tool_name,
|
invocation.tool_name,
|
||||||
@@ -129,55 +127,3 @@ def _query_person_records(person_name: str, limit: int) -> List[Dict[str, Any]]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return persons
|
return persons
|
||||||
|
|
||||||
|
|
||||||
def _query_related_knowledge(
|
|
||||||
person_name: str,
|
|
||||||
persons: List[Dict[str, Any]],
|
|
||||||
limit: int,
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""从 Maisaka knowledge 中补充检索与该人物相关的条目。"""
|
|
||||||
|
|
||||||
store = get_knowledge_store()
|
|
||||||
knowledge_items: List[Dict[str, Any]] = []
|
|
||||||
seen_ids: set[str] = set()
|
|
||||||
|
|
||||||
for person in persons:
|
|
||||||
matched_items = store.get_knowledge_by_user(
|
|
||||||
platform=str(person.get("platform", "")).strip(),
|
|
||||||
user_id=str(person.get("user_id", "")).strip(),
|
|
||||||
user_nickname=str(person.get("user_nickname", "")).strip(),
|
|
||||||
person_name=str(person.get("person_name", "")).strip(),
|
|
||||||
limit=max(limit, 5),
|
|
||||||
)
|
|
||||||
for item in matched_items:
|
|
||||||
item_id = str(item.get("id", "")).strip()
|
|
||||||
if item_id and item_id in seen_ids:
|
|
||||||
continue
|
|
||||||
if item_id:
|
|
||||||
seen_ids.add(item_id)
|
|
||||||
knowledge_items.append(item)
|
|
||||||
|
|
||||||
if not knowledge_items:
|
|
||||||
fallback_items = store.search_knowledge(person_name, limit=max(limit, 5))
|
|
||||||
for item in fallback_items:
|
|
||||||
item_id = str(item.get("id", "")).strip()
|
|
||||||
if item_id and item_id in seen_ids:
|
|
||||||
continue
|
|
||||||
if item_id:
|
|
||||||
seen_ids.add(item_id)
|
|
||||||
knowledge_items.append(item)
|
|
||||||
|
|
||||||
results: List[Dict[str, Any]] = []
|
|
||||||
for item in knowledge_items:
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"id": str(item.get("id", "")).strip(),
|
|
||||||
"category_id": str(item.get("category_id", "")).strip(),
|
|
||||||
"category_name": str(item.get("category_name", "")).strip(),
|
|
||||||
"content": str(item.get("content", "")).strip(),
|
|
||||||
"metadata": item.get("metadata", {}),
|
|
||||||
"created_at": item.get("created_at"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from src.common.prompt_i18n import load_prompt
|
|||||||
from src.common.utils.utils_session import SessionUtils
|
from src.common.utils.utils_session import SessionUtils
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.core.tooling import ToolRegistry, ToolSpec
|
from src.core.tooling import ToolRegistry, ToolSpec
|
||||||
from src.know_u.knowledge import extract_category_ids_from_result
|
|
||||||
from src.llm_models.model_client.base_client import BaseClient
|
from src.llm_models.model_client.base_client import BaseClient
|
||||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||||
@@ -665,41 +664,6 @@ class MaisakaChatLoopService:
|
|||||||
)
|
)
|
||||||
return filtered_tool_specs
|
return filtered_tool_specs
|
||||||
|
|
||||||
async def analyze_knowledge_need(
|
|
||||||
self,
|
|
||||||
chat_history: List[LLMContextMessage],
|
|
||||||
categories_summary: str,
|
|
||||||
) -> List[str]:
|
|
||||||
"""分析当前对话是否需要检索知识库分类。"""
|
|
||||||
visible_history: List[str] = []
|
|
||||||
for message in chat_history[-8:]:
|
|
||||||
if not message.processed_plain_text:
|
|
||||||
continue
|
|
||||||
visible_history.append(f"{message.role}: {message.processed_plain_text}")
|
|
||||||
|
|
||||||
if not visible_history or not categories_summary.strip():
|
|
||||||
return []
|
|
||||||
|
|
||||||
prompt = (
|
|
||||||
"你需要判断当前对话是否需要查询知识库。\n"
|
|
||||||
"请只返回最相关的分类编号,多个编号用空格分隔;如果完全不需要,返回 none。\n\n"
|
|
||||||
f"【可用分类】\n{categories_summary}\n\n"
|
|
||||||
f"【最近对话】\n{chr(10).join(visible_history)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
generation_result = await self._llm_chat.generate_response(
|
|
||||||
prompt=prompt,
|
|
||||||
options=LLMGenerationOptions(
|
|
||||||
temperature=0.1,
|
|
||||||
max_tokens=64,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return extract_category_ids_from_result(generation_result.response or "")
|
|
||||||
|
|
||||||
async def chat_loop_step(
|
async def chat_loop_step(
|
||||||
self,
|
self,
|
||||||
chat_history: List[LLMContextMessage],
|
chat_history: List[LLMContextMessage],
|
||||||
|
|||||||
@@ -263,7 +263,6 @@ class ReferenceMessageType(str, Enum):
|
|||||||
|
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
JARGON = "jargon"
|
JARGON = "jargon"
|
||||||
KNOWLEDGE = "knowledge"
|
|
||||||
MEMORY = "memory"
|
MEMORY = "memory"
|
||||||
TOOL_HINT = "tool_hint"
|
TOOL_HINT = "tool_hint"
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from src.common.logger import get_logger
|
|||||||
from src.common.utils.utils_config import ExpressionConfigUtils
|
from src.common.utils.utils_config import ExpressionConfigUtils
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.core.tooling import ToolRegistry
|
from src.core.tooling import ToolRegistry
|
||||||
from src.know_u.knowledge import KnowledgeLearner
|
|
||||||
from src.learners.expression_learner import ExpressionLearner
|
from src.learners.expression_learner import ExpressionLearner
|
||||||
from src.learners.jargon_miner import JargonMiner
|
from src.learners.jargon_miner import JargonMiner
|
||||||
from src.llm_models.payload_content.resp_format import RespFormat
|
from src.llm_models.payload_content.resp_format import RespFormat
|
||||||
@@ -102,10 +101,8 @@ class MaisakaHeartFlowChatting:
|
|||||||
self._enable_jargon_learning = jargon_learn
|
self._enable_jargon_learning = jargon_learn
|
||||||
self._min_extraction_interval = 30
|
self._min_extraction_interval = 30
|
||||||
self._last_expression_extraction_time = 0.0
|
self._last_expression_extraction_time = 0.0
|
||||||
self._last_knowledge_extraction_time = 0.0
|
|
||||||
self._expression_learner = ExpressionLearner(session_id)
|
self._expression_learner = ExpressionLearner(session_id)
|
||||||
self._jargon_miner = JargonMiner(session_id, session_name=session_name)
|
self._jargon_miner = JargonMiner(session_id, session_name=session_name)
|
||||||
self._knowledge_learner = KnowledgeLearner(session_id)
|
|
||||||
|
|
||||||
self._reasoning_engine = MaisakaReasoningEngine(self)
|
self._reasoning_engine = MaisakaReasoningEngine(self)
|
||||||
self._tool_registry = ToolRegistry()
|
self._tool_registry = ToolRegistry()
|
||||||
@@ -449,16 +446,11 @@ class MaisakaHeartFlowChatting:
|
|||||||
self._wait_timeout_task = None
|
self._wait_timeout_task = None
|
||||||
|
|
||||||
async def _trigger_batch_learning(self, messages: list[SessionMessage]) -> None:
|
async def _trigger_batch_learning(self, messages: list[SessionMessage]) -> None:
|
||||||
"""按同一批消息触发表达方式、黑话和 knowledge 学习。"""
|
"""按同一批消息触发表达方式和黑话学习。"""
|
||||||
expression_result, knowledge_result = await asyncio.gather(
|
try:
|
||||||
self._trigger_expression_learning(messages),
|
await self._trigger_expression_learning(messages)
|
||||||
self._trigger_knowledge_learning(messages),
|
except Exception as exc:
|
||||||
return_exceptions=True,
|
logger.error(f"{self.log_prefix} 表达学习任务异常退出: {exc}")
|
||||||
)
|
|
||||||
if isinstance(expression_result, Exception):
|
|
||||||
logger.error(f"{self.log_prefix} 表达学习任务异常退出: {expression_result}")
|
|
||||||
if isinstance(knowledge_result, Exception):
|
|
||||||
logger.error(f"{self.log_prefix} 知识学习任务异常退出: {knowledge_result}")
|
|
||||||
|
|
||||||
def _should_trigger_learning(
|
def _should_trigger_learning(
|
||||||
self,
|
self,
|
||||||
@@ -523,34 +515,6 @@ class MaisakaHeartFlowChatting:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"{self.log_prefix} ??????")
|
logger.exception(f"{self.log_prefix} ??????")
|
||||||
|
|
||||||
async def _trigger_knowledge_learning(self, messages: list[SessionMessage]) -> None:
|
|
||||||
"""?????????????????"""
|
|
||||||
pending_count = self._knowledge_learner.get_pending_count(self.message_cache)
|
|
||||||
if not self._should_trigger_learning(
|
|
||||||
enabled=global_config.maisaka.enable_knowledge_module,
|
|
||||||
feature_name="知识学习",
|
|
||||||
last_extraction_time=self._last_knowledge_extraction_time,
|
|
||||||
pending_count=pending_count,
|
|
||||||
min_messages_for_extraction=self._knowledge_learner.min_messages_for_extraction,
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
self._last_knowledge_extraction_time = time.time()
|
|
||||||
logger.info(
|
|
||||||
f"{self.log_prefix} ??????: "
|
|
||||||
f"??????={len(messages)} ??????={pending_count} "
|
|
||||||
f"?????={len(self.message_cache)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
added_count = await self._knowledge_learner.learn(self.message_cache)
|
|
||||||
if added_count > 0:
|
|
||||||
logger.info(f"{self.log_prefix} ???????: ?????={added_count}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"{self.log_prefix} ???????????????")
|
|
||||||
except Exception:
|
|
||||||
logger.exception(f"{self.log_prefix} ??????")
|
|
||||||
|
|
||||||
async def _init_mcp(self) -> None:
|
async def _init_mcp(self) -> None:
|
||||||
"""初始化 MCP 工具并注册到统一工具层。"""
|
"""初始化 MCP 工具并注册到统一工具层。"""
|
||||||
self._mcp_host_bridge = MCPHostLLMBridge(
|
self._mcp_host_bridge = MCPHostLLMBridge(
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -62,6 +62,7 @@ def serialize_tool_calls(tool_calls: Sequence[ToolCall] | None) -> List[Dict[str
|
|||||||
"name": tool_call.func_name,
|
"name": tool_call.func_name,
|
||||||
"arguments": dict(tool_call.args or {}),
|
"arguments": dict(tool_call.args or {}),
|
||||||
},
|
},
|
||||||
|
**({"extra_content": tool_call.extra_content} if tool_call.extra_content else {}),
|
||||||
}
|
}
|
||||||
for tool_call in tool_calls
|
for tool_call in tool_calls
|
||||||
]
|
]
|
||||||
@@ -102,11 +103,13 @@ def deserialize_tool_calls(raw_tool_calls: Any) -> List[ToolCall]:
|
|||||||
if not isinstance(call_id, str) or not isinstance(function_name, str):
|
if not isinstance(call_id, str) or not isinstance(function_name, str):
|
||||||
raise ValueError("Hook 返回的工具调用缺少 `id` 或函数名称")
|
raise ValueError("Hook 返回的工具调用缺少 `id` 或函数名称")
|
||||||
|
|
||||||
|
extra_content = raw_tool_call.get("extra_content")
|
||||||
normalized_tool_calls.append(
|
normalized_tool_calls.append(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
call_id=call_id,
|
call_id=call_id,
|
||||||
func_name=function_name,
|
func_name=function_name,
|
||||||
args=function_arguments if isinstance(function_arguments, dict) else {},
|
args=function_arguments if isinstance(function_arguments, dict) else {},
|
||||||
|
extra_content=extra_content if isinstance(extra_content, dict) else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return normalized_tool_calls
|
return normalized_tool_calls
|
||||||
|
|||||||
@@ -384,11 +384,13 @@ def _build_tool_calls(raw_tool_calls: Any) -> List[ToolCall] | None:
|
|||||||
if not isinstance(call_id, str) or not isinstance(func_name, str):
|
if not isinstance(call_id, str) or not isinstance(func_name, str):
|
||||||
raise ValueError("工具调用缺少 `id` 或函数名称")
|
raise ValueError("工具调用缺少 `id` 或函数名称")
|
||||||
|
|
||||||
|
extra_content = raw_tool_call.get("extra_content")
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
call_id=call_id,
|
call_id=call_id,
|
||||||
func_name=func_name,
|
func_name=func_name,
|
||||||
args=_normalize_tool_arguments(arguments),
|
args=_normalize_tool_arguments(arguments),
|
||||||
|
extra_content=extra_content if isinstance(extra_content, dict) else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user