Skip to content

Commit 9ca0b4b

Browse files
committed
add the functionality to retrieve text chunk
1 parent 72551cd commit 9ca0b4b

File tree

7 files changed

+395
-94
lines changed

7 files changed

+395
-94
lines changed

examples/neuroticism/graph_chunk_entity_relation.graphml

Lines changed: 72 additions & 72 deletions
Large diffs are not rendered by default.

examples/neuroticism/kv_store_llm_response_cache.json

Lines changed: 70 additions & 14 deletions
Large diffs are not rendered by default.

examples/neuroticism/vdb_chunks.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

examples/neuroticism/vdb_entities.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

examples/neuroticism/vdb_relationships.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

examples/persona_rag.py

Lines changed: 242 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,132 @@
22
import argparse
33
import json
44
import networkx as nx
5+
import copy
6+
import asyncio
7+
import ollama
8+
import numpy as np
9+
from typing import Dict, List, Optional, Tuple, Union
10+
511

612
from lightrag import LightRAG, QueryParam
713
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
14+
from lightrag.operate import extract_keywords_only
15+
16+
def cosine_distance(vector1: np.ndarray, vector2: np.ndarray) -> float:
17+
"""
18+
Calculate cosine similarity between two vectors.
19+
20+
Args:
21+
vector1: First vector
22+
vector2: Second vector
23+
24+
Returns:
25+
float: Cosine similarity score (1.0 is most similar, 0.0 is least similar)
26+
"""
27+
return np.dot(vector1, vector2) / (np.linalg.norm(vector1) * np.linalg.norm(vector2))
28+
29+
def softmax(x: np.ndarray, temperature: float = 0.2) -> np.ndarray:
30+
"""
31+
Apply softmax function with temperature control.
32+
33+
Args:
34+
x: Input array of values
35+
temperature: Temperature parameter (lower = sharper distribution)
36+
37+
Returns:
38+
numpy.array: Softmax probabilities
39+
"""
40+
if temperature <= 0:
41+
raise ValueError("Temperature must be positive.")
842

43+
x = np.array(x)
44+
exp_x = np.exp(x/temperature)
45+
return exp_x / np.sum(exp_x)
946

47+
# def compare_input_with_chunks(
48+
# user_input: str,
49+
# text_chunks: Dict[str, str],
50+
# embedding_model: str = "nomic-embed-text",
51+
# top_k: Optional[int] = None,
52+
# temperature: float = 0.2
53+
# ) -> List[Dict]:
54+
# """
55+
# Compare user input with text chunks using cosine distance.
56+
57+
# Args:
58+
# user_input: User query or input text
59+
# text_chunks: Dictionary of text chunks {chunk_id: content}
60+
# embedding_model: Name of the Ollama embedding model to use
61+
# top_k: Optional number of top chunks to return. If None, returns all chunks.
62+
# temperature: Temperature for softmax probability calculation
63+
64+
# Returns:
65+
# List of dictionaries containing chunk_id, content, similarity score and probability,
66+
# sorted by similarity (highest first)
67+
# """
68+
# # Handle empty input or chunks
69+
# if not user_input or not text_chunks:
70+
# print("Warning: Empty input or text chunks")
71+
# return []
72+
73+
# # Generate embedding for user input
74+
# try:
75+
# input_embedding_response = ollama.embed(
76+
# model=embedding_model,
77+
# input=user_input
78+
# )
79+
# input_embedding = np.array(input_embedding_response['embeddings'][0])
80+
# except Exception as e:
81+
# print(f"Error generating input embedding: {str(e)}")
82+
# return []
83+
84+
# # Calculate similarity for each chunk
85+
# chunk_similarities = []
86+
87+
# for chunk_id, content in text_chunks.items():
88+
# try:
89+
# # Generate embedding for chunk
90+
# chunk_embedding_response = ollama.embed(
91+
# model=embedding_model,
92+
# input=content
93+
# )
94+
# chunk_embedding = np.array(chunk_embedding_response['embeddings'][0])
95+
96+
# # Calculate cosine similarity
97+
# similarity = cosine_distance(input_embedding, chunk_embedding)
98+
99+
# chunk_similarities.append({
100+
# "chunk_id": chunk_id,
101+
# "content": content,
102+
# "similarity": similarity
103+
# })
104+
# except Exception as e:
105+
# print(f"Error processing chunk {chunk_id}: {str(e)}")
106+
# # Include the chunk with zero similarity in case of error
107+
# chunk_similarities.append({
108+
# "chunk_id": chunk_id,
109+
# "content": content,
110+
# "similarity": 0.0
111+
# })
112+
113+
# # Sort by similarity (highest first)
114+
# chunk_similarities.sort(key=lambda x: x["similarity"], reverse=True)
115+
116+
# # Apply softmax to get probability distribution
117+
# if chunk_similarities:
118+
# similarities = [item["similarity"] for item in chunk_similarities]
119+
# probabilities = softmax(similarities, temperature)
120+
121+
# # Add probabilities to results
122+
# for i, prob in enumerate(probabilities):
123+
# chunk_similarities[i]["probability"] = float(prob)
124+
125+
# # Return either top k or all chunks
126+
# if top_k is not None and top_k > 0:
127+
# return chunk_similarities[:min(top_k, len(chunk_similarities))]
128+
# else:
129+
# return chunk_similarities
130+
10131
def setup_rag(working_dir):
11132
"""Set up a LightRAG instance"""
12133
if not os.path.exists(working_dir):
@@ -194,6 +315,104 @@ def generate_chunk_content(chunk_id):
194315
# Default content
195316
return f"Content related to the knowledge graph node or relationship with ID {chunk_id}"
196317

318+
def get_query_related_chunks(rag, query, param=None):
319+
"""
320+
获取与查询相关的所有文本块
321+
322+
Args:
323+
rag: LightRAG实例
324+
query: 查询文本
325+
param: 查询参数,默认为None时会创建一个新的QueryParam
326+
327+
Returns:
328+
dict: 包含源ID和对应文本块内容的字典 {source_id: chunk_content}
329+
"""
330+
# 设置默认查询参数
331+
if param is None:
332+
param = QueryParam(mode="hybrid")
333+
334+
# 克隆param来避免修改原始对象
335+
debug_param = copy.deepcopy(param)
336+
# 启用DEBUG模式,让LightRAG返回内部信息
337+
debug_param.debug = True
338+
339+
# 获取查询的关键词
340+
loop = asyncio.get_event_loop()
341+
hl_keywords, ll_keywords = loop.run_until_complete(
342+
extract_keywords_only(
343+
text=query,
344+
param=param,
345+
global_config=rag.__dict__,
346+
hashing_kv=rag.llm_response_cache
347+
)
348+
)
349+
350+
print(f"查询关键词: HL={hl_keywords}, LL={ll_keywords}")
351+
352+
# 从知识图谱中收集相关节点的source_ids
353+
related_source_ids = set()
354+
graph = rag.chunk_entity_relation_graph._graph
355+
356+
# 根据关键词匹配节点
357+
for node, data in graph.nodes(data=True):
358+
node_str = str(node).upper()
359+
node_desc = str(data.get("description", "")).upper()
360+
361+
# 检查节点名称或描述是否包含任何关键词
362+
if any(kw.upper() in node_str for kw in hl_keywords + ll_keywords) or \
363+
any(kw.upper() in node_desc for kw in hl_keywords + ll_keywords):
364+
365+
if "source_id" in data:
366+
if "<SEP>" in data["source_id"]:
367+
for sid in data["source_id"].split("<SEP>"):
368+
related_source_ids.add(sid.strip())
369+
else:
370+
related_source_ids.add(data["source_id"].strip())
371+
372+
# 从边中也收集相关的source_ids
373+
for src, tgt, edge_data in graph.edges(data=True):
374+
src_str = str(src).upper()
375+
tgt_str = str(tgt).upper()
376+
desc = str(edge_data.get("description", "")).upper()
377+
keywords = str(edge_data.get("keywords", "")).upper()
378+
379+
# 检查边的信息是否包含任何关键词
380+
if any(kw.upper() in src_str + tgt_str + desc + keywords for kw in hl_keywords + ll_keywords):
381+
if "source_id" in edge_data:
382+
if "<SEP>" in edge_data["source_id"]:
383+
for sid in edge_data["source_id"].split("<SEP>"):
384+
related_source_ids.add(sid.strip())
385+
else:
386+
related_source_ids.add(edge_data["source_id"].strip())
387+
388+
# 从text_chunks中获取文本块
389+
result_chunks = {}
390+
391+
# 优先从客户端存储中获取数据(内存中的数据)
392+
if hasattr(rag.text_chunks, "client_storage") and "data" in rag.text_chunks.client_storage:
393+
chunks = rag.text_chunks.client_storage["data"]
394+
395+
for chunk_id, chunk_data in chunks.items():
396+
# 直接匹配chunk_id
397+
if chunk_id in related_source_ids:
398+
result_chunks[chunk_id] = chunk_data["content"]
399+
# 或者匹配chunk的source_id字段
400+
elif "source_id" in chunk_data and chunk_data["source_id"] in related_source_ids:
401+
result_chunks[chunk_data["source_id"]] = chunk_data["content"]
402+
else:
403+
# 如果客户端存储不可用,从数据库中获取
404+
async def get_chunks():
405+
chunks_data = {}
406+
for source_id in related_source_ids:
407+
chunk = await rag.text_chunks.get_by_id(source_id)
408+
if chunk and "content" in chunk:
409+
chunks_data[source_id] = chunk["content"]
410+
return chunks_data
411+
412+
result_chunks = loop.run_until_complete(get_chunks())
413+
414+
return result_chunks
415+
197416
def main():
198417
working_dir = "./neuroticism"
199418
graphml_path = "./big_five/neuroticism/graph_chunk_entity_relation.graphml"
@@ -229,6 +448,19 @@ def main():
229448
# print(f"Answer: {result}")
230449
# print("-" * 80)
231450

451+
# 先示范如何获取查询相关的文本块
452+
sample_query = "What does Scrooge think about Christmas?"
453+
print("\n=== Getting Related Text Chunks for Sample Query ===")
454+
print(f"Sample Query: {sample_query}")
455+
456+
related_chunks = get_query_related_chunks(rag, sample_query)
457+
458+
print(f"\nFound {len(related_chunks)} related text chunks:")
459+
for source_id, content in related_chunks.items():
460+
print(f"\nSource ID: {source_id}")
461+
print(f"Content: {content[:200]}..." if len(content) > 200 else f"Content: {content}")
462+
print("-" * 80)
463+
232464
system_prompt = "Please answer the following questions as character Scoorge. "
233465

234466
# Execute hybrid search
@@ -237,8 +469,17 @@ def main():
237469
print(f"\nQuestion: {question}")
238470
result = rag.query(system_prompt + question, param=QueryParam(mode="hybrid", conversation_history=empty_history))
239471
print(f"Answer: {result}")
472+
473+
# 获取并显示此问题相关的文本块
474+
print("\n--- Related Text Chunks ---")
475+
query_chunks = get_query_related_chunks(rag, question)
476+
print(f"Found {len(query_chunks)} related chunks")
477+
for source_id, content in query_chunks.items():
478+
print(f"Source ID: {source_id}")
479+
print(f"Content snippet: {content[:100]}..." if len(content) > 100 else f"Content: {content}")
480+
240481
print("-" * 80)
241482

242483
# If this script is run directly, execute the main function
243484
if __name__ == "__main__":
244-
main()
485+
main()

lightrag/lightrag.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -779,8 +779,10 @@ async def ainsert_custom_kg(self, custom_kg: dict):
779779
entity_type = entity_data.get("entity_type", "UNKNOWN")
780780
description = entity_data.get("description", "No description provided")
781781
# source_id = entity_data["source_id"]
782-
source_chunk_id = entity_data.get("source_id", "UNKNOWN")
783-
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
782+
# source_chunk_id = entity_data.get("source_id", "UNKNOWN")
783+
# source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
784+
785+
source_id = entity_data.get("source_id", "UNKNOWN")
784786

785787
# Log if source_id is UNKNOWN
786788
if source_id == "UNKNOWN":
@@ -811,8 +813,10 @@ async def ainsert_custom_kg(self, custom_kg: dict):
811813
keywords = relationship_data["keywords"]
812814
weight = relationship_data.get("weight", 1.0)
813815
# source_id = relationship_data["source_id"]
814-
source_chunk_id = relationship_data.get("source_id", "UNKNOWN")
815-
source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
816+
# source_chunk_id = relationship_data.get("source_id", "UNKNOWN")
817+
# source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
818+
819+
source_id = relationship_data.get("source_id", "UNKNOWN")
816820

817821
# Log if source_id is UNKNOWN
818822
if source_id == "UNKNOWN":

0 commit comments

Comments
 (0)