-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
这需要维护已有记忆的 embedding 索引
from sage.benchmark.benchmark_memory.experiment.utils.triple_parser import TripleParser
from sage.common.core import MapFunction
logger = logging.getLogger(__name__)
class PreInsert(MapFunction):
"""记忆插入前的预处理算子
职责:
- 数据验证 (validate)
- 格式转换 (transform)
- 信息抽取 (extract)
- 重要性评分 (score)
- 多维编码 (multi_embed)
支持的 action:
- none: 直接透传
- tri_embed: 三元组提取 + embedding (HippoRAG)
- transform: 内容转换 (MemGPT, SeCom, LoCoMo)
- extract: 信息抽取 (A-mem, LD-Agent, LAPS)
- score: 重要性评分 (Generative Agents, EmotionalRAG)
- multi_embed: 多维向量编码 (EmotionalRAG)
- validate: 输入验证
注:短期记忆通常使用 none,长期记忆需要更多预处理
"""
def __init__(self, config):
"""初始化 PreInsert
Args:
config: RuntimeConfig 对象,从中获取 operators.pre_insert.* 配置
"""
super().__init__()
self.config = config
self.action = self._get_required_config("operators.pre_insert.action")
# 初始化解析器
self.dialogue_parser = DialogueParser()
self.triple_parser = TripleParser()
# 延迟初始化的组件
self._generator: LLMGenerator | None = None
self._embedding_generator: EmbeddingGenerator | None = None
self._spacy_nlp = None # spaCy 模型,用于 NER 和名词提取
# 用于重复检测的缓存
self._content_hashes: set[str] = set()
# 根据 action 初始化必要组件
self._init_for_action()
def _get_required_config(self, key: str, context: str = "") -> any:
"""获取必需配置,缺失则报错
Args:
key: 配置键路径,如 "operators.pre_insert.action"
context: 上下文说明,用于错误消息
Returns:
配置值
Raises:
ValueError: 配置缺失时抛出
"""
value = self.config.get(key)
if value is None:
ctx = f" ({context})" if context else ""
raise ValueError(f"缺少必需配置: {key}{ctx}")
return value
def _init_for_action(self):
"""根据 action 类型初始化必要的组件"""
if self.action == "tri_embed":
self.triple_extraction_prompt = self._get_required_config(
"operators.pre_insert.triple_extraction_prompt", "action=tri_embed"
)
self._init_llm_generator()
self._init_embedding_generator()
elif self.action == "transform":
transform_type = self._get_required_config(
"operators.pre_insert.transform_type", "action=transform"
)
if transform_type in ["topic_segment", "fact_extract", "summarize"]:
self._init_llm_generator()
# chunking 和 compress 不需要 LLM
elif self.action == "extract":
extract_type = self._get_required_config(
"operators.pre_insert.extract_type", "action=extract"
)
if extract_type in ["keyword", "persona", "all"]:
self._init_llm_generator()
if extract_type in ["entity", "noun", "all"]:
self._init_spacy()
elif self.action == "score":
self._init_llm_generator()
elif self.action == "multi_embed":
self._init_embedding_generator()
elif self.action == "validate":
# validate 可能需要重复检测的 embedding
similarity_check = any(
r.get("type") == "duplicate"
for r in self.config.get("operators.pre_insert.rules", [])
)
if similarity_check:
self._init_embedding_generator()
def _init_llm_generator(self):
"""初始化 LLM 生成器"""
if self._generator is None:
self._generator = LLMGenerator.from_config(self.config)
def _init_embedding_generator(self):
"""初始化 Embedding 生成器"""
if self._embedding_generator is None:
self._embedding_generator = EmbeddingGenerator.from_config(self.config)
def _init_spacy(self):
"""初始化 spaCy 模型"""
if self._spacy_nlp is None:
try:
import spacy
model_name = self._get_required_config(
"operators.pre_insert.spacy_model", "entity/noun extraction"
)
try:
self._spacy_nlp = spacy.load(model_name)
except OSError:
logger.warning(f"spaCy model {model_name} not found, downloading...")
from spacy.cli import download
download(model_name)
self._spacy_nlp = spacy.load(model_name)
except ImportError:
logger.warning("spaCy not installed. Entity/noun extraction will be limited.")
self._spacy_nlp = None
@property
def generator(self) -> LLMGenerator:
"""获取 LLM 生成器(延迟初始化)"""
if self._generator is None:
self._init_llm_generator()
return self._generator
@property
def embedding_generator(self) -> EmbeddingGenerator:
"""获取 Embedding 生成器(延迟初始化)"""
if self._embedding_generator is None:
self._init_embedding_generator()
return self._embedding_generator
def execute(self, data: dict[str, Any]) -> dict[str, Any]:
"""执行预处理
Args:
data: 原始对话数据(字典格式)
Returns:
处理后的数据(字典格式),包含 memory_entries 队列
"""
# 根据 action 模式生成记忆条目队列
if self.action == "none":
entries = [data]
elif self.action == "tri_embed":
entries = self._extract_and_embed_triples(data)
elif self.action == "transform":
entries = self._execute_transform(data)
elif self.action == "extract":
entries = self._execute_extract(data)
elif self.action == "score":
entries = self._execute_score(data)
elif self.action == "multi_embed":
entries = self._execute_multi_embed(data)
elif self.action == "validate":
entries = self._execute_validate(data)
else:
# 未知操作模式,原样返回
logger.warning(f"Unknown action: {self.action}, passing through")
entries = [data]
# 在原字典基础上添加 memory_entries 队列
data["memory_entries"] = entries
return data
# ========================================================================
# Transform Action (D2-1)
# 支持: chunking, topic_segment, fact_extract, summarize, compress
# 参考: MemGPT, SeCom, LoCoMo
# ========================================================================
def _execute_transform(self, data: dict[str, Any]) -> list[dict[str, Any]]:
"""执行内容转换
Args:
data: 原始数据
Returns:
转换后的记忆条目列表
"""
transform_type = self._get_required_config(
"operators.pre_insert.transform_type", "action=transform"
)
if transform_type == "chunking":
return self._transform_chunking(data)
elif transform_type == "topic_segment":
return self._transform_topic_segment(data)
elif transform_type == "fact_extract":
return self._transform_fact_extract(data)
elif transform_type == "summarize":
return self._transform_summarize(data)
elif transform_type == "compress":
return self._transform_compress(data)
else:
logger.warning(f"Unknown transform_type: {transform_type}, using chunking")
return self._transform_chunking(data)
def _transform_chunking(self, data: dict[str, Any]) -> list[dict[str, Any]]:
"""分块处理 - 参考 MemGPT
将长文本分割成固定大小的块,支持重叠窗口
必需配置参数:
- chunk_size: 每个块的最大字符数
- chunk_overlap: 块之间的重叠字符数
- chunk_strategy: 分块策略 (fixed/sentence/paragraph)
"""
chunk_size = self._get_required_config(
"operators.pre_insert.chunk_size", "transform_type=chunking"
)
chunk_overlap = self._get_required_config(
"operators.pre_insert.chunk_overlap", "transform_type=chunking"
)
chunk_strategy = self.config.get("operators.pre_insert.chunk_strategy", "fixed")
# 获取文本内容
dialogs = data.get("dialogs", [])
text = self.dialogue_parser.format(dialogs)
if not text:
return [data]
# 根据策略分块
if chunk_strategy == "sentence":
chunks = self._chunk_by_sentence(text, chunk_size, chunk_overlap)
elif chunk_strategy == "paragraph":
chunks = self._chunk_by_paragraph(text, chunk_size, chunk_overlap)
else: # fixed
chunks = self._chunk_fixed(text, chunk_size, chunk_overlap)
# 构建条目
entries = []
for i, chunk in enumerate(chunks):
entry = data.copy()
entry["chunk_text"] = chunk
entry["chunk_index"] = i
entry["total_chunks"] = len(chunks)
entries.append(entry)
return entries if entries else [data]
def _chunk_fixed(self, text: str, size: int, overlap: int) -> list[str]:
"""固定大小分块"""
chunks = []
start = 0
while start < len(text):
end = start + size
chunk = text[start:end]
if chunk.strip():
chunks.append(chunk)
start = end - overlap
return chunks
def _chunk_by_sentence(self, text: str, size: int, overlap: int) -> list[str]:
"""按句子边界分块"""
# 简单的句子分割 (支持中英文)
sentence_endings = re.compile(r"[.!?。!?]+[\s]*")
sentences = sentence_endings.split(text)
sentences = [s.strip() for s in sentences if s.strip()]
chunks = []
current_chunk = []
current_size = 0
for sentence in sentences:
sentence_len = len(sentence)
if current_size + sentence_len > size and current_chunk:
chunks.append(" ".join(current_chunk))
# 保留部分句子作为重叠
overlap_sentences = []
overlap_size = 0
for s in reversed(current_chunk):
if overlap_size + len(s) <= overlap:
overlap_sentences.insert(0, s)
overlap_size += len(s)
else:
break
current_chunk = overlap_sentences
current_size = overlap_size
current_chunk.append(sentence)
current_size += sentence_len
if current_chunk:
chunks.append(" ".join(current_chunk))
return chunks
def _chunk_by_paragraph(self, text: str, size: int, overlap: int) -> list[str]:
"""按段落边界分块"""
paragraphs = text.split("\n\n")
paragraphs = [p.strip() for p in paragraphs if p.strip()]
chunks = []
current_chunk = []
current_size = 0
for para in paragraphs:
para_len = len(para)
if current_size + para_len > size and current_chunk:
chunks.append("\n\n".join(current_chunk))
current_chunk = []
current_size = 0
current_chunk.append(para)
current_size += para_len
if current_chunk:
chunks.append("\n\n".join(current_chunk))
return chunks
def _transform_topic_segment(self, data: dict[str, Any]) -> list[dict[str, Any]]:
"""话题分段 - 参考 SeCom
使用 LLM 识别对话中的话题边界,将对话分成不同的话题段落
配置参数:
- segment_prompt: 分段 prompt 模板
- min_segment_size: 最小段落大小
- max_segment_size: 最大段落大小
"""
dialogs = data.get("dialogs", [])
if not dialogs:
return [data]
# 格式化对话(带索引)
formatted_dialogs = []
for i, dialog in enumerate(dialogs):
speaker = dialog.get("speaker", "Unknown")
text = dialog.get("text", dialog.get("clean_text", ""))
formatted_dialogs.append(f"[Exchange {i}]: {speaker}: {text}")
dialogue_text = "\n".join(formatted_dialogs)
# 使用 LLM 识别话题边界
prompt_template = self._get_required_config(
"operators.pre_insert.segment_prompt", "transform_type=topic_segment"
)
prompt = prompt_template.replace("{dialogue}", dialogue_text)
try:
response = self.generator.generate(prompt)
segments = self._parse_json_response(response, default=[])
except Exception as e:
logger.warning(f"Topic segmentation failed: {e}, falling back to single segment")
return [data]
if not segments:
return [data]
# 构建分段条目
min_size = self.config.get("operators.pre_insert.min_segment_size", 100)
max_size = self.config.get("operators.pre_insert.max_segment_size", 500)
entries = []
for i, segment in enumerate(segments):
exchange_indices = segment.get("exchanges", [])
topic = segment.get("topic", f"segment_{i}")
# 提取该段的对话
segment_dialogs = [dialogs[idx] for idx in exchange_indices if idx < len(dialogs)]
segment_text = self.dialogue_parser.format(segment_dialogs)
# 检查大小约束
if len(segment_text) < min_size:
continue
if len(segment_text) > max_size:
# 进一步分块
sub_entries = self._transform_chunking(
{**data, "dialogs": segment_dialogs}
)
for sub in sub_entries:
sub["topic"] = topic
sub["segment_index"] = i
entries.extend(sub_entries)
else:
entry = data.copy()
entry["segment_dialogs"] = segment_dialogs
entry["segment_text"] = segment_text
entry["topic"] = topic
entry["segment_index"] = i
entry["total_segments"] = len(segments)
entries.append(entry)
return entries if entries else [data]
def _transform_fact_extract(self, data: dict[str, Any]) -> list[dict[str, Any]]:
"""事实提取 - 参考 LoCoMo
从对话中提取事实性陈述
配置参数:
- fact_prompt: 事实提取 prompt 模板
- fact_format: 输出格式 (statement/triple/json)
"""
dialogs = data.get("dialogs", [])
if not dialogs:
return [data]
# 格式化对话(带 dialog_id)
formatted_dialogs = []
for i, dialog in enumerate(dialogs):
speaker = dialog.get("speaker", "Unknown")
text = dialog.get("text", dialog.get("clean_text", ""))
dia_id = dialog.get("dia_id", i)
formatted_dialogs.append(f"[{dia_id}] {speaker}: {text}")
dialogue_text = "\n".join(formatted_dialogs)
# 使用 LLM 提取事实
prompt_template = self._get_required_config(
"operators.pre_insert.fact_prompt", "transform_type=fact_extract"
)
prompt = prompt_template.replace("{dialogue}", dialogue_text)
try:
response = self.generator.generate(prompt)
facts = self._parse_json_response(response, default=[])
except Exception as e:
logger.warning(f"Fact extraction failed: {e}")
return [data]
if not facts:
return [data]
# 根据格式构建条目
fact_format = self.config.get("operators.pre_insert.fact_format", "statement")
entries = []
for fact_item in facts:
if isinstance(fact_item, str):
fact_text = fact_item
speaker = "general"
dialog_id = None
else:
fact_text = fact_item.get("fact", str(fact_item))
speaker = fact_item.get("speaker", "general")
dialog_id = fact_item.get("dialog_id")
entry = data.copy()
entry["fact"] = fact_text
entry["fact_speaker"] = speaker
entry["fact_dialog_id"] = dialog_id
entry["fact_format"] = fact_format
entries.append(entry)
return entries if entries else [data]
def _transform_summarize(self, data: dict[str, Any]) -> list[dict[str, Any]]:
"""内容摘要
配置参数:
- summary_prompt: 摘要 prompt 模板
- summary_max_tokens: 摘要最大 token 数
"""
dialogs = data.get("dialogs", [])
dialogue_text = self.dialogue_parser.format(dialogs)
if not dialogue_text:
return [data]
# 使用 LLM 生成摘要
prompt_template = self._get_required_config(
"operators.pre_insert.summary_prompt", "transform_type=summarize"
)
prompt = prompt_template.replace("{dialogue}", dialogue_text)
try:
max_tokens = self.config.get("operators.pre_insert.summary_max_tokens", 200)
summary = self.generator.generate(prompt, max_tokens=max_tokens)
except Exception as e:
logger.warning(f"Summarization failed: {e}")
return [data]
entry = data.copy()
entry["summary"] = summary.strip()
entry["original_text"] = dialogue_text
return [entry]
def _transform_compress(self, data: dict[str, Any]) -> list[dict[str, Any]]:
"""内容压缩 - 参考 SeCom (LLMLingua)
使用压缩模型减少文本长度同时保留关键信息
配置参数:
- compression_ratio: 目标压缩率 (0-1)
- compression_model: 压缩模型名称
"""
dialogs = data.get("dialogs", [])
text = self.dialogue_parser.format(dialogs)
if not text:
return [data]
compression_ratio = self.config.get("operators.pre_insert.compression_ratio", 0.5)
try:
# 尝试使用 LLMLingua
from llmlingua import PromptCompressor
model_name = self.config.get(
"operators.pre_insert.compression_model", "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank"
)
compressor = PromptCompressor(model_name, use_llmlingua2=True)
result = compressor.compress_prompt(
text,
rate=compression_ratio,
use_context_level_filter=False,
force_tokens=["\n", ".", "[human]", "[bot]"],
)
compressed_text = result.get("compressed_prompt", text)
except ImportError:
logger.warning("LLMLingua not installed, using simple truncation")
# 简单截断作为后备
target_len = int(len(text) * compression_ratio)
compressed_text = text[:target_len]
except Exception as e:
logger.warning(f"Compression failed: {e}, using original text")
compressed_text = text
entry = data.copy()
entry["compressed_text"] = compressed_text
entry["original_text"] = text
entry["compression_ratio"] = len(compressed_text) / len(text) if text else 1.0
return [entry]
# ========================================================================
# Extract Action (D2-2)
# 支持: keyword, entity, noun, persona, all
# 参考: A-mem, LD-Agent, LAPS, HippoRAG
# ========================================================================
def _execute_extract(self, data: dict[str, Any]) -> list[dict[str, Any]]:
"""执行信息抽取
Args:
data: 原始数据
Returns:
添加了抽取信息的记忆条目列表
"""
extract_type = self._get_required_config(
"operators.pre_insert.extract_type", "action=extract"
)
entry = data.copy()
extracted_info: dict[str, Any] = {}
if extract_type == "keyword" or extract_type == "all":
extracted_info["keywords"] = self._extract_keywords(data)
if extract_type == "entity" or extract_type == "all":
extracted_info["entities"] = self._extract_entities(data)
if extract_type == "noun" or extract_type == "all":
extracted_info["nouns"] = self._extract_nouns(data)
if extract_type == "persona" or extract_type == "all":
extracted_info["personas"] = self._extract_personas(data)
# 添加到 metadata 或直接到 entry
add_to_metadata = self.config.get("operators.pre_insert.add_to_metadata", True)
if add_to_metadata:
entry.setdefault("metadata", {}).update(extracted_info)
else:
entry.update(extracted_info)
return [entry]
def _extract_keywords(self, data: dict[str, Any]) -> list[str]:
"""关键词提取 - 参考 A-mem
使用 LLM 提取关键概念和术语
"""
dialogs = data.get("dialogs", [])
text = self.dialogue_parser.format(dialogs)
if not text:
return []
prompt_template = self._get_required_config(
"operators.pre_insert.keyword_prompt", "extract_type=keyword"
)
prompt = prompt_template.replace("{text}", text)
try:
response = self.generator.generate(prompt)
result = self._parse_json_response(response, default={"keywords": []})
keywords = result.get("keywords", [])
# 限制数量
max_keywords = self.config.get("operators.pre_insert.max_keywords", 10)
return keywords[:max_keywords]
except Exception as e:
logger.warning(f"Keyword extraction failed: {e}")
return []
def _extract_entities(self, data: dict[str, Any]) -> list[dict[str, str]]:
"""实体抽取 - 参考 HippoRAG, LAPS
使用 spaCy 或 LLM 识别命名实体
"""
dialogs = data.get("dialogs", [])
text = self.dialogue_parser.format(dialogs)
if not text:
return []
ner_model = self.config.get("operators.pre_insert.ner_model", "spacy")
entity_types = self.config.get(
"operators.pre_insert.entity_types",
["PERSON", "ORG", "LOC", "EVENT", "GPE", "DATE", "TIME"],
)
entities = []
if ner_model == "spacy" and self._spacy_nlp:
doc = self._spacy_nlp(text)
for ent in doc.ents:
if ent.label_ in entity_types:
entities.append({"text": ent.text, "type": ent.label_})
elif ner_model == "llm":
# 使用 LLM 进行 NER
prompt = f"""Extract named entities from the following text.
Entity types to extract: {', '.join(entity_types)}
Text: {text}
Return a JSON list of entities: [{{"text": "entity text", "type": "ENTITY_TYPE"}}]
Entities:"""
try:
response = self.generator.generate(prompt)
entities = self._parse_json_response(response, default=[])
except Exception as e:
logger.warning(f"LLM NER failed: {e}")
# 去重
seen = set()
unique_entities = []
for ent in entities:
key = (ent.get("text", "").lower(), ent.get("type", ""))
if key not in seen:
seen.add(key)
unique_entities.append(ent)
return unique_entities
def _extract_nouns(self, data: dict[str, Any]) -> list[str]:
"""名词提取 - 参考 LD-Agent
使用 spaCy 提取名词短语
"""
dialogs = data.get("dialogs", [])
text = self.dialogue_parser.format(dialogs)
if not text or not self._spacy_nlp:
return []
include_proper_nouns = self.config.get(
"operators.pre_insert.include_proper_nouns", True
)
doc = self._spacy_nlp(text)
nouns = []
for token in doc:
if token.pos_ == "NOUN":
nouns.append(token.lemma_)
elif token.pos_ == "PROPN" and include_proper_nouns:
nouns.append(token.text)
# 去重并保持顺序
seen = set()
unique_nouns = []
for noun in nouns:
if noun.lower() not in seen:
seen.add(noun.lower())
unique_nouns.append(noun)
return unique_nouns
def _extract_personas(self, data: dict[str, Any]) -> dict[str, dict[str, list[str]]]:
"""人格特征提取 - 参考 LD-Agent
从对话中提取说话者的性格特征、偏好和事实信息
"""
dialogs = data.get("dialogs", [])
dialogue_text = self.dialogue_parser.format(dialogs)
if not dialogue_text:
return {}
prompt_template = self._get_required_config(
"operators.pre_insert.persona_prompt", "extract_type=persona"
)
prompt = prompt_template.replace("{dialogue}", dialogue_text)
persona_fields = self.config.get(
"operators.pre_insert.persona_fields", ["traits", "preferences", "facts"]
)
try:
response = self.generator.generate(prompt)
personas = self._parse_json_response(response, default={})
# 只保留配置的字段
filtered_personas = {}
for speaker, info in personas.items():
if isinstance(info, dict):
filtered_personas[speaker] = {
k: v for k, v in info.items() if k in persona_fields
}
return filtered_personas
except Exception as e:
logger.warning(f"Persona extraction failed: {e}")
return {}
# ========================================================================
# Score Action (D2-3)
# 支持: importance, emotion
# 参考: Generative Agents, EmotionalRAG
# ========================================================================
def _execute_score(self, data: dict[str, Any]) -> list[dict[str, Any]]:
"""执行重要性评分
Args:
data: 原始数据
Returns:
添加了评分的记忆条目列表
"""
score_type = self._get_required_config(
"operators.pre_insert.score_type", "action=score"
)
entry = data.copy()
if score_type == "importance":
score_result = self._score_importance(data)
elif score_type == "emotion":
score_result = self._score_emotion(data)
else:
logger.warning(f"Unknown score_type: {score_type}")
score_result = {}
# 添加评分到 metadata
add_to_metadata = self.config.get("operators.pre_insert.add_to_metadata", True)
score_field = self.config.get("operators.pre_insert.score_field", "importance_score")
if add_to_metadata:
entry.setdefault("metadata", {})[score_field] = score_result
else:
entry[score_field] = score_result
return [entry]
def _score_importance(self, data: dict[str, Any]) -> dict[str, Any]:
"""重要性评分 - 参考 Generative Agents
使用 LLM 评估记忆的重要性 (1-10)
"""
dialogs = data.get("dialogs", [])
text = self.dialogue_parser.format(dialogs)
if not text:
return {"score": 5, "reason": "Empty content"}
prompt_template = self._get_required_config(
"operators.pre_insert.importance_prompt", "score_type=importance"
)
prompt = prompt_template.replace("{text}", text)
importance_scale = self.config.get("operators.pre_insert.importance_scale", [1, 10])
min_score, max_score = importance_scale
try:
response = self.generator.generate(prompt)
result = self._parse_json_response(
response, default={"score": 5, "reason": "Default score"}
)
score = result.get("score", 5)
# 确保分数在范围内
score = max(min_score, min(max_score, int(score)))
result["score"] = score
return result
except Exception as e:
logger.warning(f"Importance scoring failed: {e}")
return {"score": 5, "reason": f"Scoring failed: {e}"}
def _score_emotion(self, data: dict[str, Any]) -> dict[str, Any]:
"""情感评分 - 参考 EmotionalRAG
识别文本的情感类别和强度
"""
dialogs = data.get("dialogs", [])
text = self.dialogue_parser.format(dialogs)
if not text:
return {"category": "neutral", "intensity": 0.5, "vector": None}
emotion_categories = self.config.get(
"operators.pre_insert.emotion_categories",
["joy", "sadness", "anger", "fear", "surprise", "neutral"],
)
# 尝试使用情感分类模型
emotion_model = self.config.get("operators.pre_insert.emotion_model", "llm")
if emotion_model == "llm":
# 使用 LLM 进行情感分类
prompt = f"""Analyze the emotion in the following text.
Categories: {', '.join(emotion_categories)}
Text: {text}
Return a JSON object with:
- "category": the primary emotion (one of the categories)
- "intensity": emotion intensity from 0.0 to 1.0
- "secondary": (optional) secondary emotion if mixed
Result:"""
try:
response = self.generator.generate(prompt)
result = self._parse_json_response(
response,
default={"category": "neutral", "intensity": 0.5},
)
return result
except Exception as e:
logger.warning(f"Emotion scoring failed: {e}")
# 如果配置了情感 embedding 模型,生成情感向量
if self._embedding_generator and self._embedding_generator.is_available():
try:
emotion_vector = self._embedding_generator.embed(text)
return {
"category": "unknown",
"intensity": 0.5,
"vector": emotion_vector,
}
except Exception as e:
logger.warning(f"Emotion embedding failed: {e}")
return {"category": "neutral", "intensity": 0.5, "vector": None}
# ========================================================================
# Multi-Embed Action (D2-4)
# 支持多维向量编码
# 参考: EmotionalRAG
# ========================================================================
def _execute_multi_embed(self, data: dict[str, Any]) -> list[dict[str, Any]]:
"""执行多维向量编码
生成多种类型的向量表示(语义、情感等)
配置参数:
- embeddings: 向量配置列表,每个包含 name, model, field
- output_format: 输出格式 (dict/concat/separate)
"""
embeddings_config = self.config.get("operators.pre_insert.embeddings", [])
if not embeddings_config:
# 默认配置
embeddings_config = [
{"name": "semantic", "model": "default", "field": "content"}
]
dialogs = data.get("dialogs", [])
content = self.dialogue_parser.format(dialogs)
entry = data.copy()
embeddings_result: dict[str, list[float] | None] = {}
for emb_config in embeddings_config:
name = emb_config.get("name", "embedding")
field = emb_config.get("field", "content")
# 获取要编码的文本
if field == "content":
text_to_embed = content
elif field == "entities":
# 从 metadata 获取实体,拼接成字符串
entities = entry.get("metadata", {}).get("entities", [])
text_to_embed = " ".join([e.get("text", "") for e in entities])
elif field == "keywords":
keywords = entry.get("metadata", {}).get("keywords", [])
text_to_embed = " ".join(keywords)
else:
text_to_embed = str(entry.get(field, content))
if text_to_embed and self.embedding_generator.is_available():
try:
embedding = self.embedding_generator.embed(text_to_embed)
embeddings_result[name] = embedding
except Exception as e:
logger.warning(f"Embedding failed for {name}: {e}")
embeddings_result[name] = None
else:
embeddings_result[name] = None
# 根据输出格式处理结果
output_format = self.config.get("operators.pre_insert.output_format", "dict")
if output_format == "dict":
entry["embeddings"] = embeddings_result
elif output_format == "concat":
# 拼接所有向量
all_vectors = [v for v in embeddings_result.values() if v is not None]
if all_vectors:
concat_vector = []
for v in all_vectors:
concat_vector.extend(v)
entry["embedding"] = concat_vector
else:
entry["embedding"] = None
else: # separate
for name, vector in embeddings_result.items():
entry[f"embedding_{name}"] = vector
return [entry]
# ========================================================================
# Validate Action (D2-5)
# 支持: length, language, content, duplicate
# ========================================================================
def _execute_validate(self, data: dict[str, Any]) -> list[dict[str, Any]]:
"""执行输入验证
验证输入数据是否符合规则,不符合则根据 on_fail 策略处理
配置参数:
- rules: 验证规则列表
- on_fail: 失败处理策略 (skip/warn/error/transform)
- transform_action: on_fail=transform 时使用的转换动作
"""
rules = self.config.get("operators.pre_insert.rules", [])
on_fail = self.config.get("operators.pre_insert.on_fail", "skip")
dialogs = data.get("dialogs", [])
text = self.dialogue_parser.format(dialogs)
validation_errors = []
for rule in rules:
rule_type = rule.get("type")
error = self._validate_rule(text, data, rule)
if error:
validation_errors.append(error)
if validation_errors:
if on_fail == "skip":
logger.info(f"Validation failed, skipping: {validation_errors}")
return [] # 返回空列表,跳过此条目
elif on_fail == "warn":
logger.warning(f"Validation warnings: {validation_errors}")
entry = data.copy()
entry["validation_warnings"] = validation_errors
return [entry]
elif on_fail == "error":
raise ValueError(f"Validation failed: {validation_errors}")
elif on_fail == "transform":
# 使用指定的转换动作处理
transform_action = self.config.get(
"operators.pre_insert.transform_action", "summarize"
)
original_transform_type = self.config.get(
"operators.pre_insert.transform_type"
)
# 临时修改配置
self.config._data.setdefault("operators", {}).setdefault(
"pre_insert", {}
)["transform_type"] = transform_action
try:
entries = self._execute_transform(data)
finally:
# 恢复原配置
if original_transform_type:
self.config._data["operators"]["pre_insert"]["transform_type"] = (
original_transform_type
)
return entries
return [data]
def _validate_rule(
self, text: str, data: dict[str, Any], rule: dict[str, Any]
) -> str | None:
"""验证单个规则
Returns:
错误消息,如果验证通过则返回 None
"""
rule_type = rule.get("type")
if rule_type == "length":
min_len = rule.get("min", 0)
max_len = rule.get("max", float("inf"))
text_len = len(text)
if text_len < min_len:
return f"Text too short: {text_len} < {min_len}"
if text_len > max_len:
return f"Text too long: {text_len} > {max_len}"
elif rule_type == "language":
allowed = rule.get("allowed", [])
if allowed:
try:
from langdetect import detect
detected = detect(text)
if detected not in allowed:
return f"Language not allowed: {detected} not in {allowed}"
except ImportError:
logger.warning("langdetect not installed, skipping language check")
except Exception as e:
logger.warning(f"Language detection failed: {e}")
elif rule_type == "content":
blacklist = rule.get("blacklist", [])
for word in blacklist:
if word.lower() in text.lower():
return f"Blacklisted content found: {word}"
elif rule_type == "duplicate":
threshold = rule.get("similarity_threshold", 0.95)
content_hash = hashlib.md5(text.encode()).hexdigest()
# 精确重复检测
if content_hash in self._content_hashes:
return "Duplicate content detected"
self._content_hashes.add(content_hash)
# 如果阈值 < 1,进行相似度检测
if threshold < 1.0 and self._embedding_generator:
# TODO: 实现基于 embedding 的相似度检测
# 这需要维护已有记忆的 embedding 索引
pass
return None
# ========================================================================
# 原有的 tri_embed 实现
# ========================================================================
def _extract_and_embed_triples(self, data: dict[str, Any]) -> list[dict[str, Any]]:
"""提取三元组并进行 Embedding
Args:
data: 对话数据(字典格式),包含 "dialogs" 字段
Returns:
记忆条目列表,每个条目包含 triple, refactor, embedding
"""
dialogs = data.get("dialogs", [])
dialogue = self.dialogue_parser.format(dialogs)
# 使用 LLM 提取三元组
prompt = self.triple_extraction_prompt.replace("{dialogue}", dialogue)
triples_text = self.generator.generate(prompt)
# 解析三元组并重构为自然语言描述
triples, refactor_descriptions = self.triple_parser.parse_and_refactor(triples_text)
# 去重
unique_triples, unique_refactors = self.triple_parser.deduplicate(
triples, refactor_descriptions
)
if not unique_refactors:
return []
# 生成 Embedding
embeddings = self.embedding_generator.embed_batch(unique_refactors)
# 构建记忆条目列表
memory_entries = []
for i, (triple, refactor) in enumerate(zip(unique_triples, unique_refactors)):
entry = {Metadata
Metadata
Assignees
Labels
No labels