remove:移除tool_use模型,修复Jargon提取问题,修改统计为tool统计
This commit is contained in:
@@ -348,17 +348,11 @@ class MessageSequence:
|
||||
if isinstance(item, TextComponent):
|
||||
return {"type": "text", "data": item.text}
|
||||
elif isinstance(item, ImageComponent):
|
||||
if not item.content:
|
||||
raise RuntimeError("ImageComponent content 未初始化")
|
||||
return {"type": "image", "data": item.content, "hash": item.binary_hash}
|
||||
return {"type": "image", "data": self._ensure_binary_component_content(item, "[图片]"), "hash": item.binary_hash}
|
||||
elif isinstance(item, EmojiComponent):
|
||||
if not item.content:
|
||||
raise RuntimeError("EmojiComponent content 未初始化")
|
||||
return {"type": "emoji", "data": item.content, "hash": item.binary_hash}
|
||||
return {"type": "emoji", "data": self._ensure_binary_component_content(item, "[表情包]"), "hash": item.binary_hash}
|
||||
elif isinstance(item, VoiceComponent):
|
||||
if not item.content:
|
||||
raise RuntimeError("VoiceComponent content 未初始化")
|
||||
return {"type": "voice", "data": item.content, "hash": item.binary_hash}
|
||||
return {"type": "voice", "data": self._ensure_binary_component_content(item, "[语音消息]"), "hash": item.binary_hash}
|
||||
elif isinstance(item, AtComponent):
|
||||
return {
|
||||
"type": "at",
|
||||
@@ -388,6 +382,14 @@ class MessageSequence:
|
||||
logger.warning(f"Unofficial component type: {type(item)}, defaulting to DictComponent")
|
||||
return {"type": "dict", "data": item.data}
|
||||
|
||||
@staticmethod
|
||||
def _ensure_binary_component_content(item: ByteComponent, fallback_text: str) -> str:
|
||||
"""确保二进制组件在序列化时带有稳定的文本占位。"""
|
||||
if item.content:
|
||||
return item.content
|
||||
item.content = fallback_text
|
||||
return item.content
|
||||
|
||||
@classmethod
|
||||
def _dict_2_item(cls, item: Dict[str, Any]) -> StandardMessageComponents:
|
||||
"""内部方法:将单个消息组件的字典格式转换回组件对象"""
|
||||
|
||||
59
src/common/data_models/tool_record_data_model.py
Normal file
59
src/common/data_models/tool_record_data_model.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
import json
|
||||
|
||||
from src.common.database.database_model import ToolRecord
|
||||
|
||||
from . import BaseDatabaseDataModel
|
||||
|
||||
|
||||
class MaiToolRecord(BaseDatabaseDataModel[ToolRecord]):
|
||||
"""工具调用记录数据模型。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_id: str,
|
||||
timestamp: datetime,
|
||||
session_id: str,
|
||||
tool_name: str,
|
||||
tool_reasoning: Optional[str] = None,
|
||||
tool_data: Optional[Dict] = None,
|
||||
tool_builtin_prompt: Optional[str] = None,
|
||||
tool_display_prompt: Optional[str] = None,
|
||||
):
|
||||
self.tool_id = tool_id
|
||||
self.timestamp = timestamp
|
||||
self.session_id = session_id
|
||||
self.tool_name = tool_name
|
||||
self.tool_reasoning = tool_reasoning
|
||||
self.tool_data = tool_data or {}
|
||||
self.tool_builtin_prompt = tool_builtin_prompt
|
||||
self.tool_display_prompt = tool_display_prompt
|
||||
|
||||
@classmethod
|
||||
def from_db_instance(cls, db_record: ToolRecord):
|
||||
"""从数据库实例创建数据模型对象。"""
|
||||
return cls(
|
||||
tool_id=db_record.tool_id,
|
||||
timestamp=db_record.timestamp,
|
||||
session_id=db_record.session_id,
|
||||
tool_name=db_record.tool_name,
|
||||
tool_reasoning=db_record.tool_reasoning,
|
||||
tool_data=json.loads(db_record.tool_data) if db_record.tool_data else None,
|
||||
tool_builtin_prompt=db_record.tool_builtin_prompt,
|
||||
tool_display_prompt=db_record.tool_display_prompt,
|
||||
)
|
||||
|
||||
def to_db_instance(self):
|
||||
"""将数据模型对象转换为数据库实例。"""
|
||||
return ToolRecord(
|
||||
tool_id=self.tool_id,
|
||||
timestamp=self.timestamp,
|
||||
session_id=self.session_id,
|
||||
tool_name=self.tool_name,
|
||||
tool_reasoning=self.tool_reasoning,
|
||||
tool_data=json.dumps(self.tool_data) if self.tool_data else None,
|
||||
tool_builtin_prompt=self.tool_builtin_prompt,
|
||||
tool_display_prompt=self.tool_display_prompt,
|
||||
)
|
||||
@@ -3,7 +3,7 @@ from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import event, text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel, Session, create_engine
|
||||
@@ -57,6 +57,41 @@ SessionLocal = sessionmaker(
|
||||
_db_initialized = False
|
||||
|
||||
|
||||
def _migrate_action_records_to_tool_records() -> None:
|
||||
"""将旧的 ``action_records`` 历史数据迁移到 ``tool_records``。"""
|
||||
migration_sql = text(
|
||||
"""
|
||||
INSERT INTO tool_records (
|
||||
tool_id,
|
||||
timestamp,
|
||||
session_id,
|
||||
tool_name,
|
||||
tool_reasoning,
|
||||
tool_data,
|
||||
tool_builtin_prompt,
|
||||
tool_display_prompt
|
||||
)
|
||||
SELECT
|
||||
action_id,
|
||||
timestamp,
|
||||
session_id,
|
||||
action_name,
|
||||
action_reasoning,
|
||||
action_data,
|
||||
action_builtin_prompt,
|
||||
action_display_prompt
|
||||
FROM action_records
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM tool_records
|
||||
WHERE tool_records.tool_id = action_records.action_id
|
||||
)
|
||||
"""
|
||||
)
|
||||
with engine.begin() as connection:
|
||||
connection.execute(migration_sql)
|
||||
|
||||
|
||||
def initialize_database() -> None:
|
||||
global _db_initialized
|
||||
if _db_initialized:
|
||||
@@ -65,6 +100,7 @@ def initialize_database() -> None:
|
||||
import src.common.database.database_model # noqa: F401
|
||||
|
||||
SQLModel.metadata.create_all(engine)
|
||||
_migrate_action_records_to_tool_records()
|
||||
_db_initialized = True
|
||||
|
||||
|
||||
|
||||
@@ -134,6 +134,27 @@ class ActionRecord(SQLModel, table=True):
|
||||
action_display_prompt: Optional[str] = Field(default=None) # 最终输入到Prompt的内容
|
||||
|
||||
|
||||
class ToolRecord(SQLModel, table=True):
|
||||
"""存储工具调用记录"""
|
||||
|
||||
__tablename__ = "tool_records" # type: ignore
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True) # 自增主键
|
||||
|
||||
# 元信息
|
||||
tool_id: str = Field(index=True, max_length=255) # 工具调用ID
|
||||
timestamp: datetime = Field(default_factory=datetime.now, sa_column=Column(DateTime, index=True)) # 记录时间戳
|
||||
session_id: str = Field(index=True, max_length=255) # 对应的 ChatSession session_id
|
||||
|
||||
# 调用信息
|
||||
tool_name: str = Field(index=True, max_length=255) # 工具名称
|
||||
tool_reasoning: Optional[str] = Field(default=None) # 工具调用推理过程
|
||||
tool_data: Optional[str] = Field(default=None) # 工具数据,JSON格式存储
|
||||
|
||||
tool_builtin_prompt: Optional[str] = Field(default=None) # 内置工具提示
|
||||
tool_display_prompt: Optional[str] = Field(default=None) # 最终输入到 Prompt 的内容
|
||||
|
||||
|
||||
class CommandRecord(SQLModel, table=True):
|
||||
"""记录命令执行情况"""
|
||||
|
||||
|
||||
@@ -3,12 +3,12 @@ from typing import TYPE_CHECKING, List
|
||||
from src.common.utils.math_utils import translate_timestamp_to_human_readable, TimestampMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.action_record_data_model import MaiActionRecord
|
||||
from src.common.data_models.tool_record_data_model import MaiToolRecord
|
||||
|
||||
|
||||
class ActionUtils:
|
||||
@staticmethod
|
||||
def build_readable_action_records(action_records: List["MaiActionRecord"], timestamp_mode: str | TimestampMode):
|
||||
def build_readable_action_records(action_records: List["MaiToolRecord"], timestamp_mode: str | TimestampMode):
|
||||
"""
|
||||
将动作列表转换为可读的文本格式。
|
||||
|
||||
@@ -27,6 +27,6 @@ class ActionUtils:
|
||||
output_lines = []
|
||||
for record in action_records:
|
||||
timestamp_str = translate_timestamp_to_human_readable(record.timestamp.timestamp(), mode=timestamp_mode)
|
||||
line = f"在{timestamp_str},你使用了{record.action_name},具体内容是:{record.action_display_prompt}"
|
||||
line = f"在{timestamp_str},你使用了{record.tool_name},具体内容是:{record.tool_display_prompt}"
|
||||
output_lines.append(line)
|
||||
return "\n".join(output_lines)
|
||||
|
||||
@@ -579,26 +579,26 @@ class MessageUtils:
|
||||
List[Tuple[float, str]]: 按时间排序的动作文本列表,每个元素为 (timestamp, action_text)
|
||||
"""
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ActionRecord
|
||||
from src.common.database.database_model import ToolRecord
|
||||
|
||||
# 获取这个时间范围内的动作记录,并匹配session_id
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
actions_in_range = session.exec(
|
||||
select(ActionRecord)
|
||||
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(min_time))
|
||||
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(max_time))
|
||||
.where(col(ActionRecord.session_id) == session_id)
|
||||
.order_by(col(ActionRecord.timestamp))
|
||||
select(ToolRecord)
|
||||
.where(col(ToolRecord.timestamp) >= datetime.fromtimestamp(min_time))
|
||||
.where(col(ToolRecord.timestamp) <= datetime.fromtimestamp(max_time))
|
||||
.where(col(ToolRecord.session_id) == session_id)
|
||||
.order_by(col(ToolRecord.timestamp))
|
||||
).all()
|
||||
|
||||
# 获取最新消息之后的第一个动作记录
|
||||
with get_db_session() as session:
|
||||
action_after_latest = session.exec(
|
||||
select(ActionRecord)
|
||||
.where(col(ActionRecord.timestamp) > datetime.fromtimestamp(max_time))
|
||||
.where(col(ActionRecord.session_id) == session_id)
|
||||
.order_by(col(ActionRecord.timestamp))
|
||||
select(ToolRecord)
|
||||
.where(col(ToolRecord.timestamp) > datetime.fromtimestamp(max_time))
|
||||
.where(col(ToolRecord.session_id) == session_id)
|
||||
.order_by(col(ToolRecord.timestamp))
|
||||
.limit(1)
|
||||
).all()
|
||||
except Exception as e:
|
||||
@@ -611,7 +611,7 @@ class MessageUtils:
|
||||
# 构建动作文本列表
|
||||
action_messages: List[Tuple[float, str]] = []
|
||||
for action in actions:
|
||||
if action_display_prompt := action.action_display_prompt or "":
|
||||
if action_display_prompt := action.tool_display_prompt or "":
|
||||
action_time = action.timestamp.timestamp()
|
||||
action_messages.append((action_time, action_display_prompt))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user