fix:允许表达方式全局共享
This commit is contained in:
91
pytests/common_test/test_maisaka_expression_selector.py
Normal file
91
pytests/common_test/test_maisaka_expression_selector.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import src.chat.replyer.maisaka_expression_selector as selector_module
|
||||||
|
from src.chat.replyer.maisaka_expression_selector import MaisakaExpressionSelector
|
||||||
|
from src.common.utils.utils_session import SessionUtils
|
||||||
|
|
||||||
|
|
||||||
|
def _build_target(platform: str, item_id: str, rule_type: str = "group") -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(platform=platform, item_id=item_id, rule_type=rule_type)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_expression_group_scope_returns_related_sessions(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
|
||||||
|
related_session_id = SessionUtils.calculate_session_id("qq", group_id="10002")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
selector_module,
|
||||||
|
"global_config",
|
||||||
|
SimpleNamespace(
|
||||||
|
expression=SimpleNamespace(
|
||||||
|
expression_groups=[
|
||||||
|
SimpleNamespace(
|
||||||
|
expression_groups=[
|
||||||
|
_build_target("qq", "10001"),
|
||||||
|
_build_target("qq", "10002"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
selector = MaisakaExpressionSelector()
|
||||||
|
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
||||||
|
|
||||||
|
assert related_session_ids == {current_session_id, related_session_id}
|
||||||
|
assert has_global_share is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_expression_group_scope_uses_star_as_global_share(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
selector_module,
|
||||||
|
"global_config",
|
||||||
|
SimpleNamespace(
|
||||||
|
expression=SimpleNamespace(
|
||||||
|
expression_groups=[
|
||||||
|
SimpleNamespace(
|
||||||
|
expression_groups=[
|
||||||
|
_build_target("*", "*"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
selector = MaisakaExpressionSelector()
|
||||||
|
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
||||||
|
|
||||||
|
assert related_session_ids == {current_session_id}
|
||||||
|
assert has_global_share is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_expression_group_scope_does_not_treat_empty_target_as_global(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
selector_module,
|
||||||
|
"global_config",
|
||||||
|
SimpleNamespace(
|
||||||
|
expression=SimpleNamespace(
|
||||||
|
expression_groups=[
|
||||||
|
SimpleNamespace(
|
||||||
|
expression_groups=[
|
||||||
|
_build_target("", ""),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
selector = MaisakaExpressionSelector()
|
||||||
|
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
||||||
|
|
||||||
|
assert related_session_ids == {current_session_id}
|
||||||
|
assert has_global_share is False
|
||||||
@@ -40,18 +40,27 @@ class MaisakaExpressionSelector:
|
|||||||
logger.error(f"检查表达方式使用开关失败: {exc}")
|
logger.error(f"检查表达方式使用开关失败: {exc}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _get_related_session_ids(self, session_id: str) -> List[str]:
|
@staticmethod
|
||||||
|
def _is_global_expression_group_marker(platform: str, item_id: str) -> bool:
|
||||||
|
return platform == "*" and item_id == "*"
|
||||||
|
|
||||||
|
def _resolve_expression_group_scope(self, session_id: str) -> tuple[set[str], bool]:
|
||||||
related_session_ids = {session_id}
|
related_session_ids = {session_id}
|
||||||
|
has_global_share = False
|
||||||
expression_groups = global_config.expression.expression_groups
|
expression_groups = global_config.expression.expression_groups
|
||||||
|
|
||||||
for expression_group in expression_groups:
|
for expression_group in expression_groups:
|
||||||
target_items = expression_group.expression_groups
|
target_items = expression_group.expression_groups
|
||||||
group_session_ids: set[str] = set()
|
group_session_ids: set[str] = set()
|
||||||
contains_current_session = False
|
contains_current_session = False
|
||||||
|
contains_global_share_marker = False
|
||||||
|
|
||||||
for target_item in target_items:
|
for target_item in target_items:
|
||||||
platform = target_item.platform.strip()
|
platform = target_item.platform.strip()
|
||||||
item_id = target_item.item_id.strip()
|
item_id = target_item.item_id.strip()
|
||||||
|
if self._is_global_expression_group_marker(platform, item_id):
|
||||||
|
contains_global_share_marker = True
|
||||||
|
continue
|
||||||
if not platform or not item_id:
|
if not platform or not item_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -65,19 +74,24 @@ class MaisakaExpressionSelector:
|
|||||||
if target_session_id == session_id:
|
if target_session_id == session_id:
|
||||||
contains_current_session = True
|
contains_current_session = True
|
||||||
|
|
||||||
|
if contains_global_share_marker:
|
||||||
|
has_global_share = True
|
||||||
if contains_current_session:
|
if contains_current_session:
|
||||||
related_session_ids.update(group_session_ids)
|
related_session_ids.update(group_session_ids)
|
||||||
|
|
||||||
return list(related_session_ids)
|
return related_session_ids, has_global_share
|
||||||
|
|
||||||
def _load_expression_candidates(self, session_id: str) -> List[dict[str, Any]]:
|
def _load_expression_candidates(self, session_id: str) -> List[dict[str, Any]]:
|
||||||
related_session_ids = self._get_related_session_ids(session_id)
|
related_session_ids, has_global_share = self._resolve_expression_group_scope(session_id)
|
||||||
|
|
||||||
with get_db_session(auto_commit=False) as session:
|
with get_db_session(auto_commit=False) as session:
|
||||||
base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
base_query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||||
scoped_query = base_query.where(
|
if has_global_share:
|
||||||
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
scoped_query = base_query
|
||||||
)
|
else:
|
||||||
|
scoped_query = base_query.where(
|
||||||
|
(Expression.session_id.in_(related_session_ids)) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
if global_config.expression.expression_checked_only:
|
if global_config.expression.expression_checked_only:
|
||||||
scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
scoped_query = scoped_query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||||
expressions = session.exec(scoped_query).all()
|
expressions = session.exec(scoped_query).all()
|
||||||
|
|||||||
Reference in New Issue
Block a user