feat:同步本地非算法改动到上游基线

保留反馈纠错、WebUI 与运行时增强。\n移除不应提交的 algorithm_redesign 设计目录及其专项测试。
This commit is contained in:
A-Dawn
2026-04-16 13:57:07 +08:00
parent 6c22fdfdf9
commit 21b642d07d
10 changed files with 2244 additions and 34 deletions

View File

@@ -11,7 +11,7 @@ from __future__ import annotations
import asyncio
import time
from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import aiohttp
import numpy as np
@@ -29,6 +29,9 @@ logger = get_logger("A_Memorix.EmbeddingAPIAdapter")
class EmbeddingAPIAdapter:
"""适配宿主 embedding 请求接口。"""
_GLOBAL_DIMENSION_CACHE: Dict[str, int] = {}
_GLOBAL_TEXT_EMBEDDING_CACHE: Dict[Tuple[str, int, str], np.ndarray] = {}
def __init__(
self,
batch_size: int = 32,
@@ -232,10 +235,32 @@ class EmbeddingAPIAdapter:
logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}")
return None
def _dimension_cache_key(self) -> str:
candidate_names = self._resolve_candidate_model_names()
return "|".join(
[
str(self.model_name or "auto"),
str(self.default_dimension),
",".join(candidate_names),
]
)
def _embedding_cache_key(self, text: str, dimensions: Optional[int]) -> Tuple[str, int, str]:
requested_dimension = self._resolve_canonical_dimension(dimensions)
return (self._dimension_cache_key(), int(requested_dimension), str(text or ""))
async def _detect_dimension(self) -> int:
if self._dimension_detected and self._dimension is not None:
return self._dimension
cache_key = self._dimension_cache_key()
cached_dimension = self._GLOBAL_DIMENSION_CACHE.get(cache_key)
if cached_dimension is not None:
self._dimension = int(cached_dimension)
self._dimension_detected = True
logger.info(f"嵌入维度命中进程缓存: {self._dimension}")
return self._dimension
logger.info("正在检测嵌入模型维度...")
try:
target_dim = self.default_dimension
@@ -251,6 +276,7 @@ class EmbeddingAPIAdapter:
)
self._dimension = detected_dim
self._dimension_detected = True
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(detected_dim)
return detected_dim
except Exception as exc:
logger.debug(f"带维度参数探测失败: {exc},尝试不带维度参数探测")
@@ -261,6 +287,7 @@ class EmbeddingAPIAdapter:
detected_dim = len(test_embedding)
self._dimension = detected_dim
self._dimension_detected = True
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(detected_dim)
logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}")
return detected_dim
logger.warning(f"嵌入维度检测失败,使用 configured_dimension: {self.default_dimension}")
@@ -269,6 +296,7 @@ class EmbeddingAPIAdapter:
self._dimension = self.default_dimension
self._dimension_detected = True
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(self.default_dimension)
return self.default_dimension
async def encode(
@@ -336,6 +364,25 @@ class EmbeddingAPIAdapter:
all_embeddings: List[np.ndarray] = []
for offset in range(0, len(texts), batch_size):
batch = texts[offset : offset + batch_size]
batch_results: List[Tuple[int, np.ndarray]] = []
uncached_items: List[Tuple[int, str]] = []
if self.enable_cache:
for index, text in enumerate(batch):
cache_key = self._embedding_cache_key(text, dimensions)
cached_vector = self._GLOBAL_TEXT_EMBEDDING_CACHE.get(cache_key)
if cached_vector is None:
uncached_items.append((index, text))
else:
batch_results.append((index, cached_vector.copy()))
else:
uncached_items = list(enumerate(batch))
if not uncached_items:
batch_results.sort(key=lambda item: item[0])
all_embeddings.extend(emb for _, emb in batch_results)
continue
semaphore = asyncio.Semaphore(self.max_concurrent)
async def encode_with_semaphore(text: str, index: int):
@@ -351,11 +398,20 @@ class EmbeddingAPIAdapter:
tasks = [
encode_with_semaphore(text, offset + index)
for index, text in enumerate(batch)
for index, text in uncached_items
]
results = await asyncio.gather(*tasks)
results.sort(key=lambda item: item[0])
all_embeddings.extend(emb for _, emb in results)
normalized_results: List[Tuple[int, np.ndarray]] = []
for batch_index, vector in results:
normalized_results.append((batch_index, vector))
if self.enable_cache:
text = batch[batch_index]
cache_key = self._embedding_cache_key(text, dimensions)
self._GLOBAL_TEXT_EMBEDDING_CACHE[cache_key] = vector.copy()
batch_results.extend(normalized_results)
batch_results.sort(key=lambda item: item[0])
all_embeddings.extend(emb for _, emb in batch_results)
return np.array(all_embeddings, dtype=np.float32)

View File

@@ -34,7 +34,7 @@ except Exception:
logger = get_logger("A_Memorix.MetadataStore")
SCHEMA_VERSION = 10
SCHEMA_VERSION = 12
RUNTIME_AUTO_MIGRATION_MIN_SCHEMA_VERSION = 9

View File

@@ -375,6 +375,30 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
"memory_feedback_tasks rollback columns missing under current schema version",
)
)
elif not has_stale_marks:
checks.append(
CheckItem(
"CP-15",
"error",
"paragraph_stale_relation_marks table missing under current schema version",
)
)
elif not has_profile_refresh_queue:
checks.append(
CheckItem(
"CP-16",
"error",
"person_profile_refresh_queue table missing under current schema version",
)
)
elif not has_feedback_rollback_status or not has_feedback_rollback_plan:
checks.append(
CheckItem(
"CP-17",
"error",
"memory_feedback_tasks rollback columns missing under current schema version",
)
)
if _sqlite_table_exists(conn, "relations"):
row = conn.execute("SELECT COUNT(*) FROM relations").fetchone()