From 89df7ccf6b213916aa787290e8c3f3ac97a4fbb0 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sun, 22 Mar 2026 00:43:34 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20NapCat=20=E9=80=82?= =?UTF-8?q?=E9=85=8D=E5=99=A8=E7=9A=84=E5=85=A5=E7=AB=99=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E7=BC=96=E8=A7=A3=E7=A0=81=E5=8A=9F=E8=83=BD=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=BC=BA=E6=8F=92=E4=BB=B6=E9=85=8D=E7=BD=AE=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E9=80=BB=E8=BE=91=E5=92=8C=E6=95=B0=E6=8D=AE=E5=BA=93=E4=BA=A4?= =?UTF-8?q?=E4=BA=92=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../common_test/test_expression_learner.py | 81 ++++++++++++++ pytests/common_test/test_jargon_miner.py | 90 +++++++++++++++ pytests/test_napcat_adapter_codec.py | 81 ++++++++++++++ pytests/test_napcat_adapter_plugin.py | 60 ++++++++++ pytests/test_plugin_runtime.py | 35 ++++-- src/learners/expression_learner.py | 44 +++++--- src/learners/jargon_miner.py | 24 +++- src/plugin_runtime/integration.py | 26 ++++- .../built_in/napcat_adapter/codec_inbound.py | 104 +++++++++++++++++- src/plugins/built_in/napcat_adapter/plugin.py | 1 + 10 files changed, 511 insertions(+), 35 deletions(-) create mode 100644 pytests/common_test/test_expression_learner.py create mode 100644 pytests/common_test/test_jargon_miner.py create mode 100644 pytests/test_napcat_adapter_plugin.py diff --git a/pytests/common_test/test_expression_learner.py b/pytests/common_test/test_expression_learner.py new file mode 100644 index 00000000..951aa424 --- /dev/null +++ b/pytests/common_test/test_expression_learner.py @@ -0,0 +1,81 @@ +"""测试表达方式学习器的数据库读取行为。""" + +from contextlib import contextmanager +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine + +from src.bw_learner.expression_learner import ExpressionLearner +from src.common.database.database_model import Expression + + +@pytest.fixture(name="expression_learner_engine") +def expression_learner_engine_fixture() -> Generator: + """创建用于表达方式学习器测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +def test_find_similar_expression_uses_read_only_session_and_history_content( + monkeypatch: pytest.MonkeyPatch, + expression_learner_engine, +) -> None: + """查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。""" + import src.bw_learner.expression_learner as expression_learner_module + + with Session(expression_learner_engine) as session: + session.add( + Expression( + situation="发送汗滴表情", + style="发送💦表情符号", + content_list='["表达情绪高涨或生理反应"]', + count=1, + session_id="session-a", + checked=False, + rejected=False, + ) + ) + session.commit() + + @contextmanager + def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: + """构造带自动提交语义的测试会话工厂。 + + Args: + auto_commit: 退出上下文时是否自动提交。 + + Yields: + Generator[Session, None, None]: SQLModel 会话对象。 + """ + session = Session(expression_learner_engine) + try: + yield session + if auto_commit: + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session) + + learner = ExpressionLearner(session_id="session-a") + result = learner._find_similar_expression("表达情绪高涨或生理反应") + + assert result is not None + expression, similarity = result + assert expression.item_id is not None + assert expression.style == "发送💦表情符号" + assert similarity == pytest.approx(1.0) diff --git a/pytests/common_test/test_jargon_miner.py b/pytests/common_test/test_jargon_miner.py new file mode 100644 index 00000000..bf81e4d2 --- /dev/null +++ b/pytests/common_test/test_jargon_miner.py @@ -0,0 +1,90 @@ +"""测试黑话学习器的数据库读取行为。""" + +from contextlib import contextmanager +from typing import Generator + +import pytest +from sqlalchemy.pool import StaticPool +from sqlmodel import Session, SQLModel, create_engine, select + +from src.bw_learner.jargon_miner import JargonMiner +from src.common.database.database_model import Jargon + + +@pytest.fixture(name="jargon_miner_engine") +def jargon_miner_engine_fixture() -> Generator: + """创建用于黑话学习器测试的内存数据库引擎。 + + Yields: + Generator: 供测试使用的 SQLite 内存引擎。 + """ + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + yield engine + + +@pytest.mark.asyncio +async def test_process_extracted_entries_updates_existing_jargon_without_detached_session( + monkeypatch: pytest.MonkeyPatch, + jargon_miner_engine, +) -> None: + """更新已有黑话时,不应因会话关闭导致 ORM 实例失效。""" + import src.bw_learner.jargon_miner as jargon_miner_module + + with Session(jargon_miner_engine) as session: + session.add( + Jargon( + content="VF8V4L", + raw_content='["[1] first"]', + meaning="", + session_id_dict='{"session-a": 1}', + count=0, + is_jargon=True, + is_complete=False, + is_global=False, + last_inference_count=0, + ) + ) + session.commit() + + @contextmanager + def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: + """构造带自动提交语义的测试会话工厂。 + + Args: + auto_commit: 退出上下文时是否自动提交。 + + Yields: + Generator[Session, None, None]: SQLModel 会话对象。 + """ + session = Session(jargon_miner_engine) + try: + yield session + if auto_commit: + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session) + + jargon_miner = JargonMiner(session_id="session-a", session_name="测试群") + await jargon_miner.process_extracted_entries( + [{"content": "VF8V4L", "raw_content": {"[2] second"}}], + ) + + with Session(jargon_miner_engine) as session: + db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one() + + assert db_jargon.count == 1 + assert db_jargon.session_id_dict == '{"session-a": 2}' + assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [ + "[1] first", + "[2] second", + ] diff --git a/pytests/test_napcat_adapter_codec.py b/pytests/test_napcat_adapter_codec.py index 6f557e08..97ed1d9e 100644 --- a/pytests/test_napcat_adapter_codec.py +++ b/pytests/test_napcat_adapter_codec.py @@ -3,12 +3,16 @@ from typing import Any, Dict import importlib import sys +from types import SimpleNamespace + +import pytest BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in" if str(BUILT_IN_PLUGIN_ROOT) not in sys.path: sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT)) +NapCatInboundCodec = importlib.import_module("napcat_adapter.codec_inbound").NapCatInboundCodec NapCatOutboundCodec = importlib.import_module("napcat_adapter.codec_outbound").NapCatOutboundCodec @@ -68,3 +72,80 @@ def test_napcat_outbound_codec_builds_private_action_from_route_metadata() -> No assert action_name == "send_private_msg" assert params == {"message": [{"type": "text", "data": {"text": "hello"}}], "user_id": "30001"} + + +class DummyQueryService: + """用于测试的轻量查询服务。""" + + async def download_binary(self, url: str) -> bytes: + """返回固定图片二进制。 + + Args: + url: 图片地址。 + + Returns: + bytes: 固定测试图片二进制。 + """ + if url: + return b"image-bytes" + return b"" + + async def get_message_detail(self, message_id: str) -> Dict[str, Any] | None: + """返回空消息详情。 + + Args: + message_id: 目标消息 ID。 + + Returns: + Dict[str, Any] | None: 固定空结果。 + """ + del message_id + return None + + async def get_record_detail(self, file_name: str, file_id: str | None = None) -> Dict[str, Any] | None: + """返回空语音详情。 + + Args: + file_name: 语音文件名。 + file_id: 可选文件 ID。 + + Returns: + Dict[str, Any] | None: 固定空结果。 + """ + del file_name + del file_id + return None + + async def get_forward_message(self, message_id: str) -> Dict[str, Any] | None: + """返回空转发详情。 + + Args: + message_id: 转发消息 ID。 + + Returns: + Dict[str, Any] | None: 固定空结果。 + """ + del message_id + return None + + +@pytest.mark.asyncio +async def test_napcat_inbound_codec_parses_cq_string_image_segments() -> None: + codec = NapCatInboundCodec(SimpleNamespace(debug=lambda message: None), DummyQueryService()) + payload = { + "message": "[CQ:image,file=test.png,sub_type=0,url=https://example.com/test.png][CQ:at,qq=10001] 看到是国人直接给你封了", + } + + raw_message, is_at = await codec.convert_segments(payload, "10001") + + assert raw_message[0]["type"] == "image" + assert raw_message[1] == { + "type": "at", + "data": { + "target_user_id": "10001", + "target_user_nickname": None, + "target_user_cardname": None, + }, + } + assert raw_message[2] == {"type": "text", "data": " 看到是国人直接给你封了"} + assert is_at is True diff --git a/pytests/test_napcat_adapter_plugin.py b/pytests/test_napcat_adapter_plugin.py new file mode 100644 index 00000000..ca550a39 --- /dev/null +++ b/pytests/test_napcat_adapter_plugin.py @@ -0,0 +1,60 @@ +"""NapCat 插件入口行为测试。""" + +from pathlib import Path +from typing import List +from types import SimpleNamespace + +import importlib +import sys + +import pytest + + +BUILT_IN_PLUGIN_ROOT = Path(__file__).resolve().parents[1] / "src" / "plugins" / "built_in" +if str(BUILT_IN_PLUGIN_ROOT) not in sys.path: + sys.path.insert(0, str(BUILT_IN_PLUGIN_ROOT)) + +NapCatAdapterPlugin = importlib.import_module("napcat_adapter.plugin").NapCatAdapterPlugin + + +class DummyLogger: + """用于测试的轻量日志对象。""" + + def __init__(self) -> None: + """初始化测试日志对象。""" + self.debug_messages: List[str] = [] + + def debug(self, message: str) -> None: + """记录调试日志。 + + Args: + message: 待记录的日志内容。 + """ + self.debug_messages.append(message) + + +@pytest.mark.asyncio +async def test_on_config_update_refreshes_settings_and_restarts(monkeypatch: pytest.MonkeyPatch) -> None: + """配置更新时应刷新插件配置、清空旧 settings,并触发连接重启。""" + plugin = NapCatAdapterPlugin() + plugin._ctx = SimpleNamespace(logger=DummyLogger()) + plugin._settings = object() + + restart_calls: List[dict] = [] + + async def fake_restart() -> None: + """记录一次重启调用。""" + restart_calls.append(dict(plugin._plugin_config)) + + monkeypatch.setattr(plugin, "_restart_connection_if_needed", fake_restart) + + new_config = { + "plugin": {"enabled": True, "config_version": "0.1.0"}, + "napcat_server": {"host": "127.0.0.1", "port": 3001}, + } + await plugin.on_config_update(new_config, "v2") + + assert plugin._plugin_config == new_config + assert plugin._settings is None + assert restart_calls == [new_config] + assert plugin.ctx.logger.debug_messages == ["NapCat 适配器收到配置更新通知: v2"] diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 5ab16c85..20cceb82 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -2238,6 +2238,7 @@ class TestIntegration: async def test_handle_plugin_source_changes_only_reload_matching_supervisor(self, monkeypatch, tmp_path): from src.config.file_watcher import FileChange from src.plugin_runtime import integration as integration_module + import json builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" @@ -2247,6 +2248,10 @@ class TestIntegration: beta_dir.mkdir(parents=True) (alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8") (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") + (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8") + (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8") monkeypatch.chdir(tmp_path) @@ -2257,8 +2262,8 @@ class TestIntegration: self.reload_reasons = [] self.config_updates = [] - async def reload_plugins(self, reason="manual"): - self.reload_reasons.append(reason) + async def reload_plugins(self, plugin_ids=None, reason="manual"): + self.reload_reasons.append((plugin_ids, reason)) async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): self.config_updates.append((plugin_id, config_data, config_version)) @@ -2283,13 +2288,13 @@ class TestIntegration: await manager._handle_plugin_source_changes(changes) assert manager._builtin_supervisor.reload_reasons == [] - assert manager._third_party_supervisor.reload_reasons == ["file_watcher"] + assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher")] assert manager._builtin_supervisor.config_updates == [] assert manager._third_party_supervisor.config_updates == [] assert refresh_calls == [True] @pytest.mark.asyncio - async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path): + async def test_handle_plugin_config_changes_only_reload_target_plugin(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module from src.config.file_watcher import FileChange @@ -2308,27 +2313,35 @@ class TestIntegration: def __init__(self, plugin_dirs, plugins): self._plugin_dirs = plugin_dirs self._registered_plugins = {plugin_id: object() for plugin_id in plugins} - self.config_updates = [] + self.reload_calls = [] - async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): - self.config_updates.append((plugin_id, config_data, config_version)) + async def reload_plugin(self, plugin_id, reason="manual"): + self.reload_calls.append((plugin_id, reason)) return True manager = integration_module.PluginRuntimeManager() manager._started = True manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"]) manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"]) + refresh_calls = [] + + def fake_refresh() -> None: + refresh_calls.append(True) + + manager._refresh_plugin_config_watch_subscriptions = fake_refresh await manager._handle_plugin_config_changes( "alpha", [FileChange(change_type=1, path=alpha_dir / "config.toml")], ) - assert manager._builtin_supervisor.config_updates == [("alpha", {"enabled": True}, "")] - assert manager._third_party_supervisor.config_updates == [] + assert manager._builtin_supervisor.reload_calls == [("alpha", "config_file_changed")] + assert manager._third_party_supervisor.reload_calls == [] + assert refresh_calls == [True] def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path): from src.plugin_runtime import integration as integration_module + import json builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" @@ -2336,6 +2349,10 @@ class TestIntegration: beta_dir = thirdparty_root / "beta" alpha_dir.mkdir(parents=True) beta_dir.mkdir(parents=True) + (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") + (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8") + (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8") class FakeWatcher: def __init__(self): diff --git a/src/learners/expression_learner.py b/src/learners/expression_learner.py index 156fedc5..b82ae1fa 100644 --- a/src/learners/expression_learner.py +++ b/src/learners/expression_learner.py @@ -461,25 +461,43 @@ class ExpressionLearner: def _find_similar_expression( self, situation: str, similarity_threshold: float = 0.75 ) -> Optional[Tuple[MaiExpression, float]]: - """在数据库中查找相似的表达方式""" + """在数据库中查找相似的表达方式。 + + Args: + situation: 当前待匹配的情景描述。 + similarity_threshold: 认定为相似表达方式的最低相似度阈值。 + + Returns: + Optional[Tuple[MaiExpression, float]]: 若找到最相似的表达方式,则返回 + ``(表达方式对象, 相似度)``;否则返回 ``None``。 + """ try: - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: statement = select(Expression).filter_by(session_id=self.session_id) expressions = session.exec(statement).all() - best_match: Optional[Expression] = None - best_similarity = 0.0 + best_match: Optional[MaiExpression] = None + best_similarity = 0.0 + + for db_expression in expressions: + expression = MaiExpression.from_db_instance(db_expression) + candidate_situations = [expression.situation, *expression.content] + for candidate_situation in candidate_situations: + normalized_candidate_situation = candidate_situation.strip() + if not normalized_candidate_situation: + continue + similarity = difflib.SequenceMatcher( + None, + situation, + normalized_candidate_situation, + ).ratio() + if similarity > similarity_threshold and similarity > best_similarity: + best_similarity = similarity + best_match = expression - for expr in expressions: - content_list = json.loads(expr.content_list) - for situation in content_list: - similarity = difflib.SequenceMatcher(None, situation, expr.situation).ratio() - if similarity > similarity_threshold and similarity > best_similarity: - best_similarity = similarity - best_match = expr if best_match: - logger.debug(f"找到相似表达方式情景 [ID: {best_match.id}],相似度: {best_similarity:.2f}") - return MaiExpression.from_db_instance(best_match), best_similarity + logger.debug(f"找到相似表达方式情景 [ID: {best_match.item_id}],相似度: {best_similarity:.2f}") + return best_match, best_similarity except Exception as e: logger.error(f"查找相似表达方式失败: {e}") diff --git a/src/learners/jargon_miner.py b/src/learners/jargon_miner.py index 674e5cc0..32926894 100644 --- a/src/learners/jargon_miner.py +++ b/src/learners/jargon_miner.py @@ -199,7 +199,7 @@ class JargonMiner: async def process_extracted_entries( self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None - ): + ) -> None: """ 处理已提取的黑话条目(从 expression_learner 路由过来的) @@ -230,7 +230,7 @@ class JargonMiner: content = entry["content"] raw_content_set = entry["raw_content"] try: - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: jargon_items = session.exec(select(Jargon).filter_by(content=content)).all() except Exception as e: logger.error(f"查询黑话 '{content}' 失败: {e}") @@ -306,7 +306,13 @@ class JargonMiner: removed_content, _ = self.cache.popitem(last=False) logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}") - def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]): + def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]) -> None: + """更新已有黑话记录并写回数据库。 + + Args: + db_jargon: 已命中的黑话 ORM 对象。 + raw_content_set: 本次新增的原始上下文集合。 + """ db_jargon.count += 1 existing_raw_content: List[str] = [] if db_jargon.raw_content: @@ -328,7 +334,17 @@ class JargonMiner: try: with get_db_session() as session: - session.add(db_jargon) + if db_jargon.id is None: + raise ValueError("黑话记录缺少 id,无法更新数据库") + statement = select(Jargon).filter_by(id=db_jargon.id).limit(1) + if persisted_jargon := session.exec(statement).first(): + persisted_jargon.count = db_jargon.count + persisted_jargon.raw_content = db_jargon.raw_content + persisted_jargon.session_id_dict = db_jargon.session_id_dict + persisted_jargon.is_global = db_jargon.is_global + session.add(persisted_jargon) + else: + logger.warning(f"黑话 ID {db_jargon.id} 在数据库中未找到,无法更新") except Exception as e: logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}") diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 24cf09fc..bf85669b 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -612,7 +612,17 @@ class PluginRuntimeManager( return None if plugin_path is None else plugin_path / "config.toml" async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None: - """处理单个插件配置文件变化,并仅向目标插件推送配置更新。""" + """处理单个插件配置文件变化,并精确重载目标插件。 + + Args: + plugin_id: 发生配置变更的插件 ID。 + changes: 当前批次收集到的配置文件变更列表。 + + Notes: + 这里选择“精确重载该插件”,而不是仅推送软性的配置更新通知。 + 这样可以保证没有实现 ``on_config_update()`` 的插件也能重新执行 + ``on_load()``,让磁盘上的 ``config.toml`` 修改对插件运行态真正生效。 + """ if not self._started or not changes: return @@ -626,18 +636,24 @@ class PluginRuntimeManager( return try: - await supervisor.notify_plugin_config_updated( + self._load_plugin_config_for_supervisor(supervisor, plugin_id) + reload_success = await supervisor.reload_plugin( plugin_id=plugin_id, - config_data=self._load_plugin_config_for_supervisor(supervisor, plugin_id), + reason="config_file_changed", ) + if reload_success: + self._refresh_plugin_config_watch_subscriptions() + else: + logger.warning(f"插件 {plugin_id} 配置文件变更后重载失败") except Exception as exc: - logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}") + logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}") async def _handle_plugin_source_changes(self, changes: Sequence[FileChange]) -> None: """处理插件源码相关变化。 这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由 - 单独的 per-plugin watcher 处理,避免把单插件配置更新放大成全量 reload。 + 单独的 per-plugin watcher 处理,并精确重载对应插件,避免放大成 + 不必要的跨插件 reload。 """ if not self._started or not changes: return diff --git a/src/plugins/built_in/napcat_adapter/codec_inbound.py b/src/plugins/built_in/napcat_adapter/codec_inbound.py index b8065585..8fb020dc 100644 --- a/src/plugins/built_in/napcat_adapter/codec_inbound.py +++ b/src/plugins/built_in/napcat_adapter/codec_inbound.py @@ -5,11 +5,15 @@ from uuid import uuid4 import hashlib import json +import re import time from napcat_adapter.qq_queries import NapCatQueryService +_CQ_SEGMENT_PATTERN = re.compile(r"\[CQ:(?P[a-zA-Z0-9_]+)(?P(?:,[^\]]*)?)\]") + + class NapCatInboundCodec: """NapCat 入站消息编码器。""" @@ -104,8 +108,12 @@ class NapCatInboundCodec: """ message_payload = payload.get("message") if isinstance(message_payload, str): - normalized_text = message_payload.strip() - return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False + parsed_message_payload = self._parse_cq_message_text(message_payload) + if parsed_message_payload: + message_payload = parsed_message_payload + else: + normalized_text = self._decode_cq_entities(message_payload).strip() + return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False if not isinstance(message_payload, list): return [], False @@ -223,8 +231,8 @@ class NapCatInboundCodec: Returns: Dict[str, Any]: 转换后的图片或表情消息段。 """ - subtype = segment_data.get("sub_type") - actual_is_emoji = is_emoji or (isinstance(subtype, int) and subtype not in {0, 4, 9}) + subtype = self._normalize_numeric_segment_value(segment_data.get("sub_type")) + actual_is_emoji = is_emoji or (subtype is not None and subtype not in {0, 4, 9}) image_url = str(segment_data.get("url") or "").strip() binary_data = await self._query_service.download_binary(image_url) @@ -412,3 +420,91 @@ class NapCatInboundCodec: plain_text = "".join(part for part in plain_text_parts if part).strip() return plain_text or fallback_text or "[unsupported]" + + def _parse_cq_message_text(self, message_text: str) -> List[Dict[str, Any]]: + """将 CQ 码字符串解析为 OneBot 风格消息段列表。 + + Args: + message_text: NapCat 在字符串模式下返回的消息内容。 + + Returns: + List[Dict[str, Any]]: 解析后的 OneBot 风格消息段列表。 + """ + parsed_segments: List[Dict[str, Any]] = [] + current_index = 0 + + for match in _CQ_SEGMENT_PATTERN.finditer(message_text): + prefix_text = self._decode_cq_entities(message_text[current_index : match.start()]) + if prefix_text: + parsed_segments.append({"type": "text", "data": {"text": prefix_text}}) + + segment_type = str(match.group("type") or "").strip() + segment_data = self._parse_cq_segment_data(match.group("params") or "") + if segment_type: + parsed_segments.append({"type": segment_type, "data": segment_data}) + current_index = match.end() + + suffix_text = self._decode_cq_entities(message_text[current_index:]) + if suffix_text: + parsed_segments.append({"type": "text", "data": {"text": suffix_text}}) + + return parsed_segments + + def _parse_cq_segment_data(self, raw_params: str) -> Dict[str, Any]: + """解析单个 CQ 段中的参数串。 + + Args: + raw_params: 形如 ``,key=value,key2=value2`` 的原始参数字符串。 + + Returns: + Dict[str, Any]: 解析后的参数字典。 + """ + parsed_data: Dict[str, Any] = {} + if not raw_params: + return parsed_data + + for item in raw_params.lstrip(",").split(","): + if not item or "=" not in item: + continue + key, value = item.split("=", 1) + normalized_key = key.strip() + if not normalized_key: + continue + decoded_value = self._decode_cq_entities(value) + parsed_data[normalized_key] = self._normalize_numeric_segment_value(decoded_value) + + return parsed_data + + @staticmethod + def _decode_cq_entities(text: str) -> str: + """解码 CQ 码中的 HTML 风格转义实体。 + + Args: + text: 待解码的 CQ 文本。 + + Returns: + str: 解码后的普通文本。 + """ + return ( + text.replace("&", "&") + .replace("[", "[") + .replace("]", "]") + .replace(",", ",") + ) + + @staticmethod + def _normalize_numeric_segment_value(value: Any) -> Any: + """将可安全识别的数字字符串转为整数。 + + Args: + value: 原始字段值。 + + Returns: + Any: 规范化后的字段值。 + """ + if isinstance(value, str): + stripped_value = value.strip() + if stripped_value.isdigit(): + return int(stripped_value) + return stripped_value + return value diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py index b1e9bc8c..50900c5d 100644 --- a/src/plugins/built_in/napcat_adapter/plugin.py +++ b/src/plugins/built_in/napcat_adapter/plugin.py @@ -71,6 +71,7 @@ class NapCatAdapterPlugin(MaiBotPlugin): version: 配置版本号。 """ self.set_plugin_config(new_config) + self._settings = None if version: self.ctx.logger.debug(f"NapCat 适配器收到配置更新通知: {version}") await self._restart_connection_if_needed()