feat:同步本地非算法改动到上游基线
保留反馈纠错、WebUI 与运行时增强。\n移除不应提交的 algorithm_redesign 设计目录及其专项测试。
This commit is contained in:
@@ -11,7 +11,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
@@ -29,6 +29,9 @@ logger = get_logger("A_Memorix.EmbeddingAPIAdapter")
|
||||
class EmbeddingAPIAdapter:
|
||||
"""适配宿主 embedding 请求接口。"""
|
||||
|
||||
_GLOBAL_DIMENSION_CACHE: Dict[str, int] = {}
|
||||
_GLOBAL_TEXT_EMBEDDING_CACHE: Dict[Tuple[str, int, str], np.ndarray] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = 32,
|
||||
@@ -232,10 +235,32 @@ class EmbeddingAPIAdapter:
|
||||
logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}")
|
||||
return None
|
||||
|
||||
def _dimension_cache_key(self) -> str:
|
||||
candidate_names = self._resolve_candidate_model_names()
|
||||
return "|".join(
|
||||
[
|
||||
str(self.model_name or "auto"),
|
||||
str(self.default_dimension),
|
||||
",".join(candidate_names),
|
||||
]
|
||||
)
|
||||
|
||||
def _embedding_cache_key(self, text: str, dimensions: Optional[int]) -> Tuple[str, int, str]:
|
||||
requested_dimension = self._resolve_canonical_dimension(dimensions)
|
||||
return (self._dimension_cache_key(), int(requested_dimension), str(text or ""))
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
if self._dimension_detected and self._dimension is not None:
|
||||
return self._dimension
|
||||
|
||||
cache_key = self._dimension_cache_key()
|
||||
cached_dimension = self._GLOBAL_DIMENSION_CACHE.get(cache_key)
|
||||
if cached_dimension is not None:
|
||||
self._dimension = int(cached_dimension)
|
||||
self._dimension_detected = True
|
||||
logger.info(f"嵌入维度命中进程缓存: {self._dimension}")
|
||||
return self._dimension
|
||||
|
||||
logger.info("正在检测嵌入模型维度...")
|
||||
try:
|
||||
target_dim = self.default_dimension
|
||||
@@ -251,6 +276,7 @@ class EmbeddingAPIAdapter:
|
||||
)
|
||||
self._dimension = detected_dim
|
||||
self._dimension_detected = True
|
||||
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(detected_dim)
|
||||
return detected_dim
|
||||
except Exception as exc:
|
||||
logger.debug(f"带维度参数探测失败: {exc},尝试不带维度参数探测")
|
||||
@@ -261,6 +287,7 @@ class EmbeddingAPIAdapter:
|
||||
detected_dim = len(test_embedding)
|
||||
self._dimension = detected_dim
|
||||
self._dimension_detected = True
|
||||
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(detected_dim)
|
||||
logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}")
|
||||
return detected_dim
|
||||
logger.warning(f"嵌入维度检测失败,使用 configured_dimension: {self.default_dimension}")
|
||||
@@ -269,6 +296,7 @@ class EmbeddingAPIAdapter:
|
||||
|
||||
self._dimension = self.default_dimension
|
||||
self._dimension_detected = True
|
||||
self._GLOBAL_DIMENSION_CACHE[cache_key] = int(self.default_dimension)
|
||||
return self.default_dimension
|
||||
|
||||
async def encode(
|
||||
@@ -336,6 +364,25 @@ class EmbeddingAPIAdapter:
|
||||
all_embeddings: List[np.ndarray] = []
|
||||
for offset in range(0, len(texts), batch_size):
|
||||
batch = texts[offset : offset + batch_size]
|
||||
batch_results: List[Tuple[int, np.ndarray]] = []
|
||||
uncached_items: List[Tuple[int, str]] = []
|
||||
|
||||
if self.enable_cache:
|
||||
for index, text in enumerate(batch):
|
||||
cache_key = self._embedding_cache_key(text, dimensions)
|
||||
cached_vector = self._GLOBAL_TEXT_EMBEDDING_CACHE.get(cache_key)
|
||||
if cached_vector is None:
|
||||
uncached_items.append((index, text))
|
||||
else:
|
||||
batch_results.append((index, cached_vector.copy()))
|
||||
else:
|
||||
uncached_items = list(enumerate(batch))
|
||||
|
||||
if not uncached_items:
|
||||
batch_results.sort(key=lambda item: item[0])
|
||||
all_embeddings.extend(emb for _, emb in batch_results)
|
||||
continue
|
||||
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
async def encode_with_semaphore(text: str, index: int):
|
||||
@@ -351,11 +398,20 @@ class EmbeddingAPIAdapter:
|
||||
|
||||
tasks = [
|
||||
encode_with_semaphore(text, offset + index)
|
||||
for index, text in enumerate(batch)
|
||||
for index, text in uncached_items
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
results.sort(key=lambda item: item[0])
|
||||
all_embeddings.extend(emb for _, emb in results)
|
||||
normalized_results: List[Tuple[int, np.ndarray]] = []
|
||||
for batch_index, vector in results:
|
||||
normalized_results.append((batch_index, vector))
|
||||
if self.enable_cache:
|
||||
text = batch[batch_index]
|
||||
cache_key = self._embedding_cache_key(text, dimensions)
|
||||
self._GLOBAL_TEXT_EMBEDDING_CACHE[cache_key] = vector.copy()
|
||||
|
||||
batch_results.extend(normalized_results)
|
||||
batch_results.sort(key=lambda item: item[0])
|
||||
all_embeddings.extend(emb for _, emb in batch_results)
|
||||
|
||||
return np.array(all_embeddings, dtype=np.float32)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ except Exception:
|
||||
logger = get_logger("A_Memorix.MetadataStore")
|
||||
|
||||
|
||||
SCHEMA_VERSION = 10
|
||||
SCHEMA_VERSION = 12
|
||||
RUNTIME_AUTO_MIGRATION_MIN_SCHEMA_VERSION = 9
|
||||
|
||||
|
||||
|
||||
@@ -375,6 +375,30 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
|
||||
"memory_feedback_tasks rollback columns missing under current schema version",
|
||||
)
|
||||
)
|
||||
elif not has_stale_marks:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-15",
|
||||
"error",
|
||||
"paragraph_stale_relation_marks table missing under current schema version",
|
||||
)
|
||||
)
|
||||
elif not has_profile_refresh_queue:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-16",
|
||||
"error",
|
||||
"person_profile_refresh_queue table missing under current schema version",
|
||||
)
|
||||
)
|
||||
elif not has_feedback_rollback_status or not has_feedback_rollback_plan:
|
||||
checks.append(
|
||||
CheckItem(
|
||||
"CP-17",
|
||||
"error",
|
||||
"memory_feedback_tasks rollback columns missing under current schema version",
|
||||
)
|
||||
)
|
||||
|
||||
if _sqlite_table_exists(conn, "relations"):
|
||||
row = conn.execute("SELECT COUNT(*) FROM relations").fetchone()
|
||||
|
||||
Reference in New Issue
Block a user