feat: 添加 NapCat 适配器的入站消息编解码功能,增强插件配置更新逻辑和数据库交互测试

This commit is contained in:
DrSmoothl
2026-03-22 00:43:34 +08:00
parent 56a6d2fd8c
commit 89df7ccf6b
10 changed files with 511 additions and 35 deletions

View File

@@ -0,0 +1,81 @@
"""测试表达方式学习器的数据库读取行为。"""
from contextlib import contextmanager
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.bw_learner.expression_learner import ExpressionLearner
from src.common.database.database_model import Expression
@pytest.fixture(name="expression_learner_engine")
def expression_learner_engine_fixture() -> Generator:
"""创建用于表达方式学习器测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
def test_find_similar_expression_uses_read_only_session_and_history_content(
monkeypatch: pytest.MonkeyPatch,
expression_learner_engine,
) -> None:
"""查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。"""
import src.bw_learner.expression_learner as expression_learner_module
with Session(expression_learner_engine) as session:
session.add(
Expression(
situation="发送汗滴表情",
style="发送💦表情符号",
content_list='["表达情绪高涨或生理反应"]',
count=1,
session_id="session-a",
checked=False,
rejected=False,
)
)
session.commit()
@contextmanager
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
"""构造带自动提交语义的测试会话工厂。
Args:
auto_commit: 退出上下文时是否自动提交。
Yields:
Generator[Session, None, None]: SQLModel 会话对象。
"""
session = Session(expression_learner_engine)
try:
yield session
if auto_commit:
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session)
learner = ExpressionLearner(session_id="session-a")
result = learner._find_similar_expression("表达情绪高涨或生理反应")
assert result is not None
expression, similarity = result
assert expression.item_id is not None
assert expression.style == "发送💦表情符号"
assert similarity == pytest.approx(1.0)

View File

@@ -0,0 +1,90 @@
"""测试黑话学习器的数据库读取行为。"""
from contextlib import contextmanager
from typing import Generator
import pytest
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine, select
from src.bw_learner.jargon_miner import JargonMiner
from src.common.database.database_model import Jargon
@pytest.fixture(name="jargon_miner_engine")
def jargon_miner_engine_fixture() -> Generator:
"""创建用于黑话学习器测试的内存数据库引擎。
Yields:
Generator: 供测试使用的 SQLite 内存引擎。
"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
@pytest.mark.asyncio
async def test_process_extracted_entries_updates_existing_jargon_without_detached_session(
monkeypatch: pytest.MonkeyPatch,
jargon_miner_engine,
) -> None:
"""更新已有黑话时,不应因会话关闭导致 ORM 实例失效。"""
import src.bw_learner.jargon_miner as jargon_miner_module
with Session(jargon_miner_engine) as session:
session.add(
Jargon(
content="VF8V4L",
raw_content='["[1] first"]',
meaning="",
session_id_dict='{"session-a": 1}',
count=0,
is_jargon=True,
is_complete=False,
is_global=False,
last_inference_count=0,
)
)
session.commit()
@contextmanager
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
"""构造带自动提交语义的测试会话工厂。
Args:
auto_commit: 退出上下文时是否自动提交。
Yields:
Generator[Session, None, None]: SQLModel 会话对象。
"""
session = Session(jargon_miner_engine)
try:
yield session
if auto_commit:
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session)
jargon_miner = JargonMiner(session_id="session-a", session_name="测试群")
await jargon_miner.process_extracted_entries(
[{"content": "VF8V4L", "raw_content": {"[2] second"}}],
)
with Session(jargon_miner_engine) as session:
db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one()
assert db_jargon.count == 1
assert db_jargon.session_id_dict == '{"session-a": 2}'
assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [
"[1] first",
"[2] second",
]

View File

@@ -3,12 +3,16 @@ from typing import Any, Dict
import importlib import importlib
import sys import sys
from types import SimpleNamespace
import pytest
BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in" BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in"
if str(BUILT_IN_PLUGIN_ROOT) not in sys.path: if str(BUILT_IN_PLUGIN_ROOT) not in sys.path:
sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT)) sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT))
NapCatInboundCodec = importlib.import_module("napcat_adapter.codec_inbound").NapCatInboundCodec
NapCatOutboundCodec = importlib.import_module("napcat_adapter.codec_outbound").NapCatOutboundCodec NapCatOutboundCodec = importlib.import_module("napcat_adapter.codec_outbound").NapCatOutboundCodec
@@ -68,3 +72,80 @@ def test_napcat_outbound_codec_builds_private_action_from_route_metadata() -> No
assert action_name == "send_private_msg" assert action_name == "send_private_msg"
assert params == {"message": [{"type": "text", "data": {"text": "hello"}}], "user_id": "30001"} assert params == {"message": [{"type": "text", "data": {"text": "hello"}}], "user_id": "30001"}
class DummyQueryService:
"""用于测试的轻量查询服务。"""
async def download_binary(self, url: str) -> bytes:
"""返回固定图片二进制。
Args:
url: 图片地址。
Returns:
bytes: 固定测试图片二进制。
"""
if url:
return b"image-bytes"
return b""
async def get_message_detail(self, message_id: str) -> Dict[str, Any] | None:
"""返回空消息详情。
Args:
message_id: 目标消息 ID。
Returns:
Dict[str, Any] | None: 固定空结果。
"""
del message_id
return None
async def get_record_detail(self, file_name: str, file_id: str | None = None) -> Dict[str, Any] | None:
"""返回空语音详情。
Args:
file_name: 语音文件名。
file_id: 可选文件 ID。
Returns:
Dict[str, Any] | None: 固定空结果。
"""
del file_name
del file_id
return None
async def get_forward_message(self, message_id: str) -> Dict[str, Any] | None:
"""返回空转发详情。
Args:
message_id: 转发消息 ID。
Returns:
Dict[str, Any] | None: 固定空结果。
"""
del message_id
return None
@pytest.mark.asyncio
async def test_napcat_inbound_codec_parses_cq_string_image_segments() -> None:
codec = NapCatInboundCodec(SimpleNamespace(debug=lambda message: None), DummyQueryService())
payload = {
"message": "[CQ:image,file=test.png,sub_type=0,url=https://example.com/test.png][CQ:at,qq=10001] 看到是国人直接给你封了",
}
raw_message, is_at = await codec.convert_segments(payload, "10001")
assert raw_message[0]["type"] == "image"
assert raw_message[1] == {
"type": "at",
"data": {
"target_user_id": "10001",
"target_user_nickname": None,
"target_user_cardname": None,
},
}
assert raw_message[2] == {"type": "text", "data": " 看到是国人直接给你封了"}
assert is_at is True

View File

@@ -0,0 +1,60 @@
"""NapCat 插件入口行为测试。"""
from pathlib import Path
from typing import List
from types import SimpleNamespace
import importlib
import sys
import pytest
BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in"
if str(BUILT_IN_PLUGIN_ROOT) not in sys.path:
sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT))
NapCatAdapterPlugin = importlib.import_module("napcat_adapter.plugin").NapCatAdapterPlugin
class DummyLogger:
"""用于测试的轻量日志对象。"""
def __init__(self) -> None:
"""初始化测试日志对象。"""
self.debug_messages: List[str] = []
def debug(self, message: str) -> None:
"""记录调试日志。
Args:
message: 待记录的日志内容。
"""
self.debug_messages.append(message)
@pytest.mark.asyncio
async def test_on_config_update_refreshes_settings_and_restarts(monkeypatch: pytest.MonkeyPatch) -> None:
"""配置更新时应刷新插件配置、清空旧 settings并触发连接重启。"""
plugin = NapCatAdapterPlugin()
plugin._ctx = SimpleNamespace(logger=DummyLogger())
plugin._settings = object()
restart_calls: List[dict] = []
async def fake_restart() -> None:
"""记录一次重启调用。"""
restart_calls.append(dict(plugin._plugin_config))
monkeypatch.setattr(plugin, "_restart_connection_if_needed", fake_restart)
new_config = {
"plugin": {"enabled": True, "config_version": "0.1.0"},
"napcat_server": {"host": "127.0.0.1", "port": 3001},
}
await plugin.on_config_update(new_config, "v2")
assert plugin._plugin_config == new_config
assert plugin._settings is None
assert restart_calls == [new_config]
assert plugin.ctx.logger.debug_messages == ["NapCat 适配器收到配置更新通知: v2"]

View File

@@ -2238,6 +2238,7 @@ class TestIntegration:
async def test_handle_plugin_source_changes_only_reload_matching_supervisor(self, monkeypatch, tmp_path): async def test_handle_plugin_source_changes_only_reload_matching_supervisor(self, monkeypatch, tmp_path):
from src.config.file_watcher import FileChange from src.config.file_watcher import FileChange
from src.plugin_runtime import integration as integration_module from src.plugin_runtime import integration as integration_module
import json
builtin_root = tmp_path / "src" / "plugins" / "built_in" builtin_root = tmp_path / "src" / "plugins" / "built_in"
thirdparty_root = tmp_path / "plugins" thirdparty_root = tmp_path / "plugins"
@@ -2247,6 +2248,10 @@ class TestIntegration:
beta_dir.mkdir(parents=True) beta_dir.mkdir(parents=True)
(alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8") (alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8")
(beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8")
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
(beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
monkeypatch.chdir(tmp_path) monkeypatch.chdir(tmp_path)
@@ -2257,8 +2262,8 @@ class TestIntegration:
self.reload_reasons = [] self.reload_reasons = []
self.config_updates = [] self.config_updates = []
async def reload_plugins(self, reason="manual"): async def reload_plugins(self, plugin_ids=None, reason="manual"):
self.reload_reasons.append(reason) self.reload_reasons.append((plugin_ids, reason))
async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""):
self.config_updates.append((plugin_id, config_data, config_version)) self.config_updates.append((plugin_id, config_data, config_version))
@@ -2283,13 +2288,13 @@ class TestIntegration:
await manager._handle_plugin_source_changes(changes) await manager._handle_plugin_source_changes(changes)
assert manager._builtin_supervisor.reload_reasons == [] assert manager._builtin_supervisor.reload_reasons == []
assert manager._third_party_supervisor.reload_reasons == ["file_watcher"] assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher")]
assert manager._builtin_supervisor.config_updates == [] assert manager._builtin_supervisor.config_updates == []
assert manager._third_party_supervisor.config_updates == [] assert manager._third_party_supervisor.config_updates == []
assert refresh_calls == [True] assert refresh_calls == [True]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path): async def test_handle_plugin_config_changes_only_reload_target_plugin(self, monkeypatch, tmp_path):
from src.plugin_runtime import integration as integration_module from src.plugin_runtime import integration as integration_module
from src.config.file_watcher import FileChange from src.config.file_watcher import FileChange
@@ -2308,27 +2313,35 @@ class TestIntegration:
def __init__(self, plugin_dirs, plugins): def __init__(self, plugin_dirs, plugins):
self._plugin_dirs = plugin_dirs self._plugin_dirs = plugin_dirs
self._registered_plugins = {plugin_id: object() for plugin_id in plugins} self._registered_plugins = {plugin_id: object() for plugin_id in plugins}
self.config_updates = [] self.reload_calls = []
async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): async def reload_plugin(self, plugin_id, reason="manual"):
self.config_updates.append((plugin_id, config_data, config_version)) self.reload_calls.append((plugin_id, reason))
return True return True
manager = integration_module.PluginRuntimeManager() manager = integration_module.PluginRuntimeManager()
manager._started = True manager._started = True
manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"]) manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"])
manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"]) manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"])
refresh_calls = []
def fake_refresh() -> None:
refresh_calls.append(True)
manager._refresh_plugin_config_watch_subscriptions = fake_refresh
await manager._handle_plugin_config_changes( await manager._handle_plugin_config_changes(
"alpha", "alpha",
[FileChange(change_type=1, path=alpha_dir / "config.toml")], [FileChange(change_type=1, path=alpha_dir / "config.toml")],
) )
assert manager._builtin_supervisor.config_updates == [("alpha", {"enabled": True}, "")] assert manager._builtin_supervisor.reload_calls == [("alpha", "config_file_changed")]
assert manager._third_party_supervisor.config_updates == [] assert manager._third_party_supervisor.reload_calls == []
assert refresh_calls == [True]
def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path): def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path):
from src.plugin_runtime import integration as integration_module from src.plugin_runtime import integration as integration_module
import json
builtin_root = tmp_path / "src" / "plugins" / "built_in" builtin_root = tmp_path / "src" / "plugins" / "built_in"
thirdparty_root = tmp_path / "plugins" thirdparty_root = tmp_path / "plugins"
@@ -2336,6 +2349,10 @@ class TestIntegration:
beta_dir = thirdparty_root / "beta" beta_dir = thirdparty_root / "beta"
alpha_dir.mkdir(parents=True) alpha_dir.mkdir(parents=True)
beta_dir.mkdir(parents=True) beta_dir.mkdir(parents=True)
(alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8")
(alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8")
(beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8")
class FakeWatcher: class FakeWatcher:
def __init__(self): def __init__(self):

View File

@@ -461,25 +461,43 @@ class ExpressionLearner:
def _find_similar_expression( def _find_similar_expression(
self, situation: str, similarity_threshold: float = 0.75 self, situation: str, similarity_threshold: float = 0.75
) -> Optional[Tuple[MaiExpression, float]]: ) -> Optional[Tuple[MaiExpression, float]]:
"""在数据库中查找相似的表达方式""" """在数据库中查找相似的表达方式
Args:
situation: 当前待匹配的情景描述。
similarity_threshold: 认定为相似表达方式的最低相似度阈值。
Returns:
Optional[Tuple[MaiExpression, float]]: 若找到最相似的表达方式,则返回
``(表达方式对象, 相似度)``;否则返回 ``None``。
"""
try: try:
with get_db_session() as session: with get_db_session(auto_commit=False) as session:
statement = select(Expression).filter_by(session_id=self.session_id) statement = select(Expression).filter_by(session_id=self.session_id)
expressions = session.exec(statement).all() expressions = session.exec(statement).all()
best_match: Optional[Expression] = None best_match: Optional[MaiExpression] = None
best_similarity = 0.0 best_similarity = 0.0
for db_expression in expressions:
expression = MaiExpression.from_db_instance(db_expression)
candidate_situations = [expression.situation, *expression.content]
for candidate_situation in candidate_situations:
normalized_candidate_situation = candidate_situation.strip()
if not normalized_candidate_situation:
continue
similarity = difflib.SequenceMatcher(
None,
situation,
normalized_candidate_situation,
).ratio()
if similarity > similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expression
for expr in expressions:
content_list = json.loads(expr.content_list)
for situation in content_list:
similarity = difflib.SequenceMatcher(None, situation, expr.situation).ratio()
if similarity > similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
if best_match: if best_match:
logger.debug(f"找到相似表达方式情景 [ID: {best_match.id}],相似度: {best_similarity:.2f}") logger.debug(f"找到相似表达方式情景 [ID: {best_match.item_id}],相似度: {best_similarity:.2f}")
return MaiExpression.from_db_instance(best_match), best_similarity return best_match, best_similarity
except Exception as e: except Exception as e:
logger.error(f"查找相似表达方式失败: {e}") logger.error(f"查找相似表达方式失败: {e}")

View File

@@ -199,7 +199,7 @@ class JargonMiner:
async def process_extracted_entries( async def process_extracted_entries(
self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None
): ) -> None:
""" """
处理已提取的黑话条目(从 expression_learner 路由过来的) 处理已提取的黑话条目(从 expression_learner 路由过来的)
@@ -230,7 +230,7 @@ class JargonMiner:
content = entry["content"] content = entry["content"]
raw_content_set = entry["raw_content"] raw_content_set = entry["raw_content"]
try: try:
with get_db_session() as session: with get_db_session(auto_commit=False) as session:
jargon_items = session.exec(select(Jargon).filter_by(content=content)).all() jargon_items = session.exec(select(Jargon).filter_by(content=content)).all()
except Exception as e: except Exception as e:
logger.error(f"查询黑话 '{content}' 失败: {e}") logger.error(f"查询黑话 '{content}' 失败: {e}")
@@ -306,7 +306,13 @@ class JargonMiner:
removed_content, _ = self.cache.popitem(last=False) removed_content, _ = self.cache.popitem(last=False)
logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}") logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}")
def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]): def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]) -> None:
"""更新已有黑话记录并写回数据库。
Args:
db_jargon: 已命中的黑话 ORM 对象。
raw_content_set: 本次新增的原始上下文集合。
"""
db_jargon.count += 1 db_jargon.count += 1
existing_raw_content: List[str] = [] existing_raw_content: List[str] = []
if db_jargon.raw_content: if db_jargon.raw_content:
@@ -328,7 +334,17 @@ class JargonMiner:
try: try:
with get_db_session() as session: with get_db_session() as session:
session.add(db_jargon) if db_jargon.id is None:
raise ValueError("黑话记录缺少 id无法更新数据库")
statement = select(Jargon).filter_by(id=db_jargon.id).limit(1)
if persisted_jargon := session.exec(statement).first():
persisted_jargon.count = db_jargon.count
persisted_jargon.raw_content = db_jargon.raw_content
persisted_jargon.session_id_dict = db_jargon.session_id_dict
persisted_jargon.is_global = db_jargon.is_global
session.add(persisted_jargon)
else:
logger.warning(f"黑话 ID {db_jargon.id} 在数据库中未找到,无法更新")
except Exception as e: except Exception as e:
logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}") logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}")

View File

@@ -612,7 +612,17 @@ class PluginRuntimeManager(
return None if plugin_path is None else plugin_path / "config.toml" return None if plugin_path is None else plugin_path / "config.toml"
async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None: async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None:
"""处理单个插件配置文件变化,并仅向目标插件推送配置更新。""" """处理单个插件配置文件变化,并精确重载目标插件。
Args:
plugin_id: 发生配置变更的插件 ID。
changes: 当前批次收集到的配置文件变更列表。
Notes:
这里选择“精确重载该插件”,而不是仅推送软性的配置更新通知。
这样可以保证没有实现 ``on_config_update()`` 的插件也能重新执行
``on_load()``,让磁盘上的 ``config.toml`` 修改对插件运行态真正生效。
"""
if not self._started or not changes: if not self._started or not changes:
return return
@@ -626,18 +636,24 @@ class PluginRuntimeManager(
return return
try: try:
await supervisor.notify_plugin_config_updated( self._load_plugin_config_for_supervisor(supervisor, plugin_id)
reload_success = await supervisor.reload_plugin(
plugin_id=plugin_id, plugin_id=plugin_id,
config_data=self._load_plugin_config_for_supervisor(supervisor, plugin_id), reason="config_file_changed",
) )
if reload_success:
self._refresh_plugin_config_watch_subscriptions()
else:
logger.warning(f"插件 {plugin_id} 配置文件变更后重载失败")
except Exception as exc: except Exception as exc:
logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}") logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")
async def _handle_plugin_source_changes(self, changes: Sequence[FileChange]) -> None: async def _handle_plugin_source_changes(self, changes: Sequence[FileChange]) -> None:
"""处理插件源码相关变化。 """处理插件源码相关变化。
这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由 这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由
单独的 per-plugin watcher 处理,避免把单插件配置更新放大成全量 reload。 单独的 per-plugin watcher 处理,并精确重载对应插件,避免放大成
不必要的跨插件 reload。
""" """
if not self._started or not changes: if not self._started or not changes:
return return

View File

@@ -5,11 +5,15 @@ from uuid import uuid4
import hashlib import hashlib
import json import json
import re
import time import time
from napcat_adapter.qq_queries import NapCatQueryService from napcat_adapter.qq_queries import NapCatQueryService
_CQ_SEGMENT_PATTERN = re.compile(r"\[CQ:(?P<type>[a-zA-Z0-9_]+)(?P<params>(?:,[^\]]*)?)\]")
class NapCatInboundCodec: class NapCatInboundCodec:
"""NapCat 入站消息编码器。""" """NapCat 入站消息编码器。"""
@@ -104,8 +108,12 @@ class NapCatInboundCodec:
""" """
message_payload = payload.get("message") message_payload = payload.get("message")
if isinstance(message_payload, str): if isinstance(message_payload, str):
normalized_text = message_payload.strip() parsed_message_payload = self._parse_cq_message_text(message_payload)
return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False if parsed_message_payload:
message_payload = parsed_message_payload
else:
normalized_text = self._decode_cq_entities(message_payload).strip()
return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False
if not isinstance(message_payload, list): if not isinstance(message_payload, list):
return [], False return [], False
@@ -223,8 +231,8 @@ class NapCatInboundCodec:
Returns: Returns:
Dict[str, Any]: 转换后的图片或表情消息段。 Dict[str, Any]: 转换后的图片或表情消息段。
""" """
subtype = segment_data.get("sub_type") subtype = self._normalize_numeric_segment_value(segment_data.get("sub_type"))
actual_is_emoji = is_emoji or (isinstance(subtype, int) and subtype not in {0, 4, 9}) actual_is_emoji = is_emoji or (subtype is not None and subtype not in {0, 4, 9})
image_url = str(segment_data.get("url") or "").strip() image_url = str(segment_data.get("url") or "").strip()
binary_data = await self._query_service.download_binary(image_url) binary_data = await self._query_service.download_binary(image_url)
@@ -412,3 +420,91 @@ class NapCatInboundCodec:
plain_text = "".join(part for part in plain_text_parts if part).strip() plain_text = "".join(part for part in plain_text_parts if part).strip()
return plain_text or fallback_text or "[unsupported]" return plain_text or fallback_text or "[unsupported]"
def _parse_cq_message_text(self, message_text: str) -> List[Dict[str, Any]]:
"""将 CQ 码字符串解析为 OneBot 风格消息段列表。
Args:
message_text: NapCat 在字符串模式下返回的消息内容。
Returns:
List[Dict[str, Any]]: 解析后的 OneBot 风格消息段列表。
"""
parsed_segments: List[Dict[str, Any]] = []
current_index = 0
for match in _CQ_SEGMENT_PATTERN.finditer(message_text):
prefix_text = self._decode_cq_entities(message_text[current_index : match.start()])
if prefix_text:
parsed_segments.append({"type": "text", "data": {"text": prefix_text}})
segment_type = str(match.group("type") or "").strip()
segment_data = self._parse_cq_segment_data(match.group("params") or "")
if segment_type:
parsed_segments.append({"type": segment_type, "data": segment_data})
current_index = match.end()
suffix_text = self._decode_cq_entities(message_text[current_index:])
if suffix_text:
parsed_segments.append({"type": "text", "data": {"text": suffix_text}})
return parsed_segments
def _parse_cq_segment_data(self, raw_params: str) -> Dict[str, Any]:
"""解析单个 CQ 段中的参数串。
Args:
raw_params: 形如 ``,key=value,key2=value2`` 的原始参数字符串。
Returns:
Dict[str, Any]: 解析后的参数字典。
"""
parsed_data: Dict[str, Any] = {}
if not raw_params:
return parsed_data
for item in raw_params.lstrip(",").split(","):
if not item or "=" not in item:
continue
key, value = item.split("=", 1)
normalized_key = key.strip()
if not normalized_key:
continue
decoded_value = self._decode_cq_entities(value)
parsed_data[normalized_key] = self._normalize_numeric_segment_value(decoded_value)
return parsed_data
@staticmethod
def _decode_cq_entities(text: str) -> str:
"""解码 CQ 码中的 HTML 风格转义实体。
Args:
text: 待解码的 CQ 文本。
Returns:
str: 解码后的普通文本。
"""
return (
text.replace("&amp;", "&")
.replace("&#91;", "[")
.replace("&#93;", "]")
.replace("&#44;", ",")
)
@staticmethod
def _normalize_numeric_segment_value(value: Any) -> Any:
"""将可安全识别的数字字符串转为整数。
Args:
value: 原始字段值。
Returns:
Any: 规范化后的字段值。
"""
if isinstance(value, str):
stripped_value = value.strip()
if stripped_value.isdigit():
return int(stripped_value)
return stripped_value
return value

View File

@@ -71,6 +71,7 @@ class NapCatAdapterPlugin(MaiBotPlugin):
version: 配置版本号。 version: 配置版本号。
""" """
self.set_plugin_config(new_config) self.set_plugin_config(new_config)
self._settings = None
if version: if version:
self.ctx.logger.debug(f"NapCat 适配器收到配置更新通知: {version}") self.ctx.logger.debug(f"NapCat 适配器收到配置更新通知: {version}")
await self._restart_connection_if_needed() await self._restart_connection_if_needed()