refactor(core): 精细化异步类型并改进配置处理的相关逻辑
精简协程工厂与执行器的类型注解(将 Awaitable 替换为 Coroutine,添加 Callable/Coroutine 导入,并在 create_task 前对 episode 循环进行类型转换) 优化 SDKMemoryKernel 的配置文件处理逻辑,规范应用配置时的配置文件载荷格式 调整 GraphStore.matrix_format ,现支持 SparseMatrixFormat 类型,初始化 VectorStore.min_train_threshold,并将 RelationWriteService.source_paragraph 重构为 Optional[str] 加固 Web UI 处理程序:安全解析 path_aliases,改为直接调用 get_raw_config_with_meta(),并在使用前规范化 tuning report 的载荷与报告结构
This commit is contained in:
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Callable, Coroutine, cast
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -106,7 +106,8 @@ def start_background_tasks(plugin: Any) -> None:
|
|||||||
and bool(plugin.get_config("episode.generation_enabled", True))
|
and bool(plugin.get_config("episode.generation_enabled", True))
|
||||||
and (episode_task is None or episode_task.done())
|
and (episode_task is None or episode_task.done())
|
||||||
):
|
):
|
||||||
plugin._episode_generation_task = asyncio.create_task(episode_loop())
|
episode_loop_fn = cast(Callable[[], Coroutine[Any, Any, Any]], episode_loop)
|
||||||
|
plugin._episode_generation_task = asyncio.create_task(episode_loop_fn())
|
||||||
|
|
||||||
|
|
||||||
async def cancel_background_tasks(plugin: Any) -> None:
|
async def cancel_background_tasks(plugin: Any) -> None:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, Sequence
|
from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ class _KernelRuntimeFacade:
|
|||||||
async def execute_request_with_dedup(
|
async def execute_request_with_dedup(
|
||||||
self,
|
self,
|
||||||
request_key: str,
|
request_key: str,
|
||||||
executor: Callable[[], Awaitable[Dict[str, Any]]],
|
executor: Callable[[], Coroutine[Any, Any, Dict[str, Any]]],
|
||||||
) -> tuple[bool, Dict[str, Any]]:
|
) -> tuple[bool, Dict[str, Any]]:
|
||||||
return await self._kernel.execute_request_with_dedup(request_key, executor)
|
return await self._kernel.execute_request_with_dedup(request_key, executor)
|
||||||
|
|
||||||
@@ -769,7 +769,7 @@ class SDKMemoryKernel:
|
|||||||
async def execute_request_with_dedup(
|
async def execute_request_with_dedup(
|
||||||
self,
|
self,
|
||||||
request_key: str,
|
request_key: str,
|
||||||
executor: Callable[[], Awaitable[Dict[str, Any]]],
|
executor: Callable[[], Coroutine[Any, Any, Dict[str, Any]]],
|
||||||
) -> tuple[bool, Dict[str, Any]]:
|
) -> tuple[bool, Dict[str, Any]]:
|
||||||
token = str(request_key or "").strip()
|
token = str(request_key or "").strip()
|
||||||
if not token:
|
if not token:
|
||||||
@@ -1761,8 +1761,22 @@ class SDKMemoryKernel:
|
|||||||
profile = manager.get_profile_snapshot()
|
profile = manager.get_profile_snapshot()
|
||||||
return {"success": True, "profile": profile, "toml": manager.export_toml_snippet(profile)}
|
return {"success": True, "profile": profile, "toml": manager.export_toml_snippet(profile)}
|
||||||
if act == "apply_profile":
|
if act == "apply_profile":
|
||||||
profile = kwargs.get("profile") if isinstance(kwargs.get("profile"), dict) else kwargs
|
profile_raw = kwargs.get("profile")
|
||||||
return {"success": True, **await manager.apply_profile(profile, reason=str(kwargs.get("reason", "manual") or "manual"))}
|
if isinstance(profile_raw, dict):
|
||||||
|
profile_payload: Dict[str, Any] = dict(profile_raw)
|
||||||
|
else:
|
||||||
|
profile_payload = {
|
||||||
|
key: value
|
||||||
|
for key, value in kwargs.items()
|
||||||
|
if key not in {"reason", "profile"}
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
**await manager.apply_profile(
|
||||||
|
profile_payload,
|
||||||
|
reason=str(kwargs.get("reason", "manual") or "manual"),
|
||||||
|
),
|
||||||
|
}
|
||||||
if act == "rollback_profile":
|
if act == "rollback_profile":
|
||||||
return {"success": True, **await manager.rollback_profile()}
|
return {"success": True, **await manager.rollback_profile()}
|
||||||
if act == "export_profile":
|
if act == "export_profile":
|
||||||
@@ -1999,7 +2013,11 @@ class SDKMemoryKernel:
|
|||||||
self._ensure_background_task("memory_maintenance", self._memory_maintenance_loop)
|
self._ensure_background_task("memory_maintenance", self._memory_maintenance_loop)
|
||||||
self._ensure_background_task("person_profile_refresh", self._person_profile_refresh_loop)
|
self._ensure_background_task("person_profile_refresh", self._person_profile_refresh_loop)
|
||||||
|
|
||||||
def _ensure_background_task(self, name: str, factory: Callable[[], Awaitable[None]]) -> None:
|
def _ensure_background_task(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
factory: Callable[[], Coroutine[Any, Any, None]],
|
||||||
|
) -> None:
|
||||||
task = self._background_tasks.get(name)
|
task = self._background_tasks.get(name)
|
||||||
if task is not None and not task.done():
|
if task is not None and not task.done():
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class GraphStore:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
matrix_format: str = "csr",
|
matrix_format: Union[str, SparseMatrixFormat] = "csr",
|
||||||
data_dir: Optional[Union[str, Path]] = None,
|
data_dir: Optional[Union[str, Path]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ class VectorStore:
|
|||||||
self.quantization_type = QuantizationType.INT8
|
self.quantization_type = QuantizationType.INT8
|
||||||
self.index_type = "sq8"
|
self.index_type = "sq8"
|
||||||
self.buffer_size = buffer_size
|
self.buffer_size = buffer_size
|
||||||
|
self.min_train_threshold = self.DEFAULT_MIN_TRAIN
|
||||||
|
|
||||||
self._index: Optional[faiss.IndexIDMap2] = None
|
self._index: Optional[faiss.IndexIDMap2] = None
|
||||||
self._init_index()
|
self._init_index()
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ class RelationWriteService:
|
|||||||
predicate: str,
|
predicate: str,
|
||||||
obj: str,
|
obj: str,
|
||||||
confidence: float = 1.0,
|
confidence: float = 1.0,
|
||||||
source_paragraph: str = "",
|
source_paragraph: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
*,
|
*,
|
||||||
write_vector: bool = True,
|
write_vector: bool = True,
|
||||||
|
|||||||
@@ -125,7 +125,8 @@ class DeletePurgeRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
def _build_import_guide_markdown(settings: dict[str, Any]) -> str:
|
def _build_import_guide_markdown(settings: dict[str, Any]) -> str:
|
||||||
path_aliases = settings.get("path_aliases") if isinstance(settings.get("path_aliases"), dict) else {}
|
path_aliases_raw = settings.get("path_aliases")
|
||||||
|
path_aliases = path_aliases_raw if isinstance(path_aliases_raw, dict) else {}
|
||||||
alias_lines = [
|
alias_lines = [
|
||||||
f"- `{name}` -> `{path}`"
|
f"- `{name}` -> `{path}`"
|
||||||
for name, path in sorted(path_aliases.items())
|
for name, path in sorted(path_aliases.items())
|
||||||
@@ -394,15 +395,7 @@ async def _memory_config_get() -> dict:
|
|||||||
|
|
||||||
|
|
||||||
async def _memory_config_get_raw() -> dict:
|
async def _memory_config_get_raw() -> dict:
|
||||||
raw_payload_getter = getattr(a_memorix_host_service, "get_raw_config_with_meta", None)
|
raw_payload = a_memorix_host_service.get_raw_config_with_meta()
|
||||||
if callable(raw_payload_getter):
|
|
||||||
raw_payload = raw_payload_getter()
|
|
||||||
else:
|
|
||||||
raw_payload = {
|
|
||||||
"config": a_memorix_host_service.get_raw_config(),
|
|
||||||
"exists": bool(a_memorix_host_service.get_config_path().exists()),
|
|
||||||
"using_default": False,
|
|
||||||
}
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"config": str(raw_payload.get("config", "") or ""),
|
"config": str(raw_payload.get("config", "") or ""),
|
||||||
@@ -628,8 +621,10 @@ async def _tuning_apply_best(task_id: str) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
async def _tuning_report(task_id: str, fmt: str) -> dict:
|
async def _tuning_report(task_id: str, fmt: str) -> dict:
|
||||||
payload = await memory_service.tuning_admin(action="get_report", task_id=task_id, format=fmt)
|
payload_raw = await memory_service.tuning_admin(action="get_report", task_id=task_id, format=fmt)
|
||||||
report = payload.get("report") if isinstance(payload.get("report"), dict) else {}
|
payload = payload_raw if isinstance(payload_raw, dict) else {}
|
||||||
|
report_raw = payload.get("report")
|
||||||
|
report = report_raw if isinstance(report_raw, dict) else {}
|
||||||
return {
|
return {
|
||||||
"success": bool(payload.get("success", False)),
|
"success": bool(payload.get("success", False)),
|
||||||
"format": report.get("format", fmt),
|
"format": report.get("format", fmt),
|
||||||
|
|||||||
Reference in New Issue
Block a user