修复typing问题,保证类型正确
This commit is contained in:
@@ -4,44 +4,43 @@ from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
|
||||
|
||||
|
||||
class FocusValueControl:
|
||||
def __init__(self,chat_id:str):
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.focus_value_adjust = 1
|
||||
|
||||
|
||||
self.focus_value_adjust: float = 1
|
||||
|
||||
def get_current_focus_value(self) -> float:
|
||||
return get_current_focus_value(self.chat_id) * self.focus_value_adjust
|
||||
|
||||
|
||||
|
||||
class FocusValueControlManager:
|
||||
def __init__(self):
|
||||
self.focus_value_controls = {}
|
||||
|
||||
def get_focus_value_control(self,chat_id:str) -> FocusValueControl:
|
||||
self.focus_value_controls: dict[str, FocusValueControl] = {}
|
||||
|
||||
def get_focus_value_control(self, chat_id: str) -> FocusValueControl:
|
||||
if chat_id not in self.focus_value_controls:
|
||||
self.focus_value_controls[chat_id] = FocusValueControl(chat_id)
|
||||
return self.focus_value_controls[chat_id]
|
||||
|
||||
|
||||
|
||||
def get_current_focus_value(chat_id: Optional[str] = None) -> float:
|
||||
"""
|
||||
根据当前时间和聊天流获取对应的 focus_value
|
||||
"""
|
||||
if not global_config.chat.focus_value_adjust:
|
||||
return global_config.chat.focus_value
|
||||
|
||||
|
||||
if chat_id:
|
||||
stream_focus_value = get_stream_specific_focus_value(chat_id)
|
||||
if stream_focus_value is not None:
|
||||
return stream_focus_value
|
||||
|
||||
|
||||
global_focus_value = get_global_focus_value()
|
||||
if global_focus_value is not None:
|
||||
return global_focus_value
|
||||
|
||||
|
||||
return global_config.chat.focus_value
|
||||
|
||||
|
||||
def get_stream_specific_focus_value(chat_id: str) -> Optional[float]:
|
||||
"""
|
||||
获取特定聊天流在当前时间的专注度
|
||||
@@ -140,4 +139,5 @@ def get_global_focus_value() -> Optional[float]:
|
||||
|
||||
return None
|
||||
|
||||
focus_value_control = FocusValueControlManager()
|
||||
|
||||
focus_value_control = FocusValueControlManager()
|
||||
|
||||
@@ -2,20 +2,21 @@ from typing import Optional
|
||||
from src.config.config import global_config
|
||||
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
|
||||
|
||||
|
||||
class TalkFrequencyControl:
|
||||
def __init__(self,chat_id:str):
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.talk_frequency_adjust = 1
|
||||
|
||||
self.talk_frequency_adjust: float = 1
|
||||
|
||||
def get_current_talk_frequency(self) -> float:
|
||||
return get_current_talk_frequency(self.chat_id) * self.talk_frequency_adjust
|
||||
|
||||
|
||||
|
||||
class TalkFrequencyControlManager:
|
||||
def __init__(self):
|
||||
self.talk_frequency_controls = {}
|
||||
|
||||
def get_talk_frequency_control(self,chat_id:str) -> TalkFrequencyControl:
|
||||
|
||||
def get_talk_frequency_control(self, chat_id: str) -> TalkFrequencyControl:
|
||||
if chat_id not in self.talk_frequency_controls:
|
||||
self.talk_frequency_controls[chat_id] = TalkFrequencyControl(chat_id)
|
||||
return self.talk_frequency_controls[chat_id]
|
||||
@@ -44,6 +45,7 @@ def get_current_talk_frequency(chat_id: Optional[str] = None) -> float:
|
||||
global_frequency = get_global_frequency()
|
||||
return global_config.chat.talk_frequency if global_frequency is None else global_frequency
|
||||
|
||||
|
||||
def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]:
|
||||
"""
|
||||
根据时间配置列表获取当前时段的频率
|
||||
@@ -124,6 +126,7 @@ def get_stream_specific_frequency(chat_stream_id: str):
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_global_frequency() -> Optional[float]:
|
||||
"""
|
||||
获取全局默认频率配置
|
||||
@@ -141,4 +144,5 @@ def get_global_frequency() -> Optional[float]:
|
||||
|
||||
return None
|
||||
|
||||
talk_frequency_control = TalkFrequencyControlManager()
|
||||
|
||||
talk_frequency_control = TalkFrequencyControlManager()
|
||||
|
||||
@@ -30,9 +30,7 @@ def cosine_similarity(v1, v2):
|
||||
dot_product = np.dot(v1, v2)
|
||||
norm1 = np.linalg.norm(v1)
|
||||
norm2 = np.linalg.norm(v2)
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0
|
||||
return dot_product / (norm1 * norm2)
|
||||
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -142,11 +140,10 @@ class MemoryGraph:
|
||||
# 获取当前节点的记忆项
|
||||
node_data = self.get_dot(topic)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
_, data = node_data
|
||||
if "memory_items" in data:
|
||||
memory_items = data["memory_items"]
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items:
|
||||
if memory_items := data["memory_items"]:
|
||||
first_layer_items.append(memory_items)
|
||||
|
||||
# 只在depth=2时获取第二层记忆
|
||||
@@ -154,11 +151,10 @@ class MemoryGraph:
|
||||
# 获取相邻节点的记忆项
|
||||
for neighbor in neighbors:
|
||||
if node_data := self.get_dot(neighbor):
|
||||
concept, data = node_data
|
||||
_, data = node_data
|
||||
if "memory_items" in data:
|
||||
memory_items = data["memory_items"]
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items:
|
||||
if memory_items := data["memory_items"]:
|
||||
second_layer_items.append(memory_items)
|
||||
|
||||
return first_layer_items, second_layer_items
|
||||
@@ -224,27 +220,17 @@ class MemoryGraph:
|
||||
# 获取话题节点数据
|
||||
node_data = self.G.nodes[topic]
|
||||
|
||||
# 删除整个节点
|
||||
self.G.remove_node(topic)
|
||||
# 如果节点存在memory_items
|
||||
if "memory_items" in node_data:
|
||||
memory_items = node_data["memory_items"]
|
||||
|
||||
# 既然每个节点现在是一个完整的记忆内容,直接删除整个节点
|
||||
if memory_items:
|
||||
# 删除整个节点
|
||||
self.G.remove_node(topic)
|
||||
if memory_items := node_data["memory_items"]:
|
||||
return (
|
||||
f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..."
|
||||
if len(memory_items) > 50
|
||||
else f"删除了节点 {topic} 的完整记忆: {memory_items}"
|
||||
)
|
||||
else:
|
||||
# 如果没有记忆项,删除该节点
|
||||
self.G.remove_node(topic)
|
||||
return None
|
||||
else:
|
||||
# 如果没有memory_items字段,删除该节点
|
||||
self.G.remove_node(topic)
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
# 海马体
|
||||
@@ -392,9 +378,8 @@ class Hippocampus:
|
||||
# 如果相似度超过阈值,获取该节点的记忆
|
||||
if similarity >= 0.3: # 可以调整这个阈值
|
||||
node_data = self.memory_graph.G.nodes[node]
|
||||
memory_items = node_data.get("memory_items", "")
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items:
|
||||
if memory_items := node_data.get("memory_items", ""):
|
||||
memories.append((node, memory_items, similarity))
|
||||
|
||||
# 按相似度降序排序
|
||||
@@ -587,7 +572,7 @@ class Hippocampus:
|
||||
unique_memories = []
|
||||
for topic, memory_items, activation_value in all_memories:
|
||||
# memory_items现在是完整的字符串格式
|
||||
memory = memory_items if memory_items else ""
|
||||
memory = memory_items or ""
|
||||
if memory not in seen_memories:
|
||||
seen_memories.add(memory)
|
||||
unique_memories.append((topic, memory_items, activation_value))
|
||||
@@ -599,7 +584,7 @@ class Hippocampus:
|
||||
result = []
|
||||
for topic, memory_items, _ in unique_memories:
|
||||
# memory_items现在是完整的字符串格式
|
||||
memory = memory_items if memory_items else ""
|
||||
memory = memory_items or ""
|
||||
result.append((topic, memory))
|
||||
logger.debug(f"选中记忆: {memory} (来自节点: {topic})")
|
||||
|
||||
@@ -1471,6 +1456,7 @@ class MemoryBuilder:
|
||||
self.last_processed_time: float = 0.0
|
||||
|
||||
def should_trigger_memory_build(self) -> bool:
|
||||
# sourcery skip: assign-if-exp, boolean-if-exp-identity, reintroduce-else
|
||||
"""检查是否应该触发记忆构建"""
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user