22import argparse
33import json
44import networkx as nx
5+ import copy
6+ import asyncio
7+ import ollama
8+ import numpy as np
9+ from typing import Dict , List , Optional , Tuple , Union
10+
511
612from lightrag import LightRAG , QueryParam
713from lightrag .llm .openai import gpt_4o_mini_complete , openai_embed
14+ from lightrag .operate import extract_keywords_only
15+
16+ def cosine_distance (vector1 : np .ndarray , vector2 : np .ndarray ) -> float :
17+ """
18+ Calculate cosine similarity between two vectors.
19+
20+ Args:
21+ vector1: First vector
22+ vector2: Second vector
23+
24+ Returns:
25+ float: Cosine similarity score (1.0 is most similar, 0.0 is least similar)
26+ """
27+ return np .dot (vector1 , vector2 ) / (np .linalg .norm (vector1 ) * np .linalg .norm (vector2 ))
28+
29+ def softmax (x : np .ndarray , temperature : float = 0.2 ) -> np .ndarray :
30+ """
31+ Apply softmax function with temperature control.
32+
33+ Args:
34+ x: Input array of values
35+ temperature: Temperature parameter (lower = sharper distribution)
36+
37+ Returns:
38+ numpy.array: Softmax probabilities
39+ """
40+ if temperature <= 0 :
41+ raise ValueError ("Temperature must be positive." )
842
43+ x = np .array (x )
44+ exp_x = np .exp (x / temperature )
45+ return exp_x / np .sum (exp_x )
946
47+ # def compare_input_with_chunks(
48+ # user_input: str,
49+ # text_chunks: Dict[str, str],
50+ # embedding_model: str = "nomic-embed-text",
51+ # top_k: Optional[int] = None,
52+ # temperature: float = 0.2
53+ # ) -> List[Dict]:
54+ # """
55+ # Compare user input with text chunks using cosine distance.
56+
57+ # Args:
58+ # user_input: User query or input text
59+ # text_chunks: Dictionary of text chunks {chunk_id: content}
60+ # embedding_model: Name of the Ollama embedding model to use
61+ # top_k: Optional number of top chunks to return. If None, returns all chunks.
62+ # temperature: Temperature for softmax probability calculation
63+
64+ # Returns:
65+ # List of dictionaries containing chunk_id, content, similarity score and probability,
66+ # sorted by similarity (highest first)
67+ # """
68+ # # Handle empty input or chunks
69+ # if not user_input or not text_chunks:
70+ # print("Warning: Empty input or text chunks")
71+ # return []
72+
73+ # # Generate embedding for user input
74+ # try:
75+ # input_embedding_response = ollama.embed(
76+ # model=embedding_model,
77+ # input=user_input
78+ # )
79+ # input_embedding = np.array(input_embedding_response['embeddings'][0])
80+ # except Exception as e:
81+ # print(f"Error generating input embedding: {str(e)}")
82+ # return []
83+
84+ # # Calculate similarity for each chunk
85+ # chunk_similarities = []
86+
87+ # for chunk_id, content in text_chunks.items():
88+ # try:
89+ # # Generate embedding for chunk
90+ # chunk_embedding_response = ollama.embed(
91+ # model=embedding_model,
92+ # input=content
93+ # )
94+ # chunk_embedding = np.array(chunk_embedding_response['embeddings'][0])
95+
96+ # # Calculate cosine similarity
97+ # similarity = cosine_distance(input_embedding, chunk_embedding)
98+
99+ # chunk_similarities.append({
100+ # "chunk_id": chunk_id,
101+ # "content": content,
102+ # "similarity": similarity
103+ # })
104+ # except Exception as e:
105+ # print(f"Error processing chunk {chunk_id}: {str(e)}")
106+ # # Include the chunk with zero similarity in case of error
107+ # chunk_similarities.append({
108+ # "chunk_id": chunk_id,
109+ # "content": content,
110+ # "similarity": 0.0
111+ # })
112+
113+ # # Sort by similarity (highest first)
114+ # chunk_similarities.sort(key=lambda x: x["similarity"], reverse=True)
115+
116+ # # Apply softmax to get probability distribution
117+ # if chunk_similarities:
118+ # similarities = [item["similarity"] for item in chunk_similarities]
119+ # probabilities = softmax(similarities, temperature)
120+
121+ # # Add probabilities to results
122+ # for i, prob in enumerate(probabilities):
123+ # chunk_similarities[i]["probability"] = float(prob)
124+
125+ # # Return either top k or all chunks
126+ # if top_k is not None and top_k > 0:
127+ # return chunk_similarities[:min(top_k, len(chunk_similarities))]
128+ # else:
129+ # return chunk_similarities
130+
10131def setup_rag (working_dir ):
11132 """Set up a LightRAG instance"""
12133 if not os .path .exists (working_dir ):
@@ -194,6 +315,104 @@ def generate_chunk_content(chunk_id):
194315 # Default content
195316 return f"Content related to the knowledge graph node or relationship with ID { chunk_id } "
196317
318+ def get_query_related_chunks (rag , query , param = None ):
319+ """
320+ 获取与查询相关的所有文本块
321+
322+ Args:
323+ rag: LightRAG实例
324+ query: 查询文本
325+ param: 查询参数,默认为None时会创建一个新的QueryParam
326+
327+ Returns:
328+ dict: 包含源ID和对应文本块内容的字典 {source_id: chunk_content}
329+ """
330+ # 设置默认查询参数
331+ if param is None :
332+ param = QueryParam (mode = "hybrid" )
333+
334+ # 克隆param来避免修改原始对象
335+ debug_param = copy .deepcopy (param )
336+ # 启用DEBUG模式,让LightRAG返回内部信息
337+ debug_param .debug = True
338+
339+ # 获取查询的关键词
340+ loop = asyncio .get_event_loop ()
341+ hl_keywords , ll_keywords = loop .run_until_complete (
342+ extract_keywords_only (
343+ text = query ,
344+ param = param ,
345+ global_config = rag .__dict__ ,
346+ hashing_kv = rag .llm_response_cache
347+ )
348+ )
349+
350+ print (f"查询关键词: HL={ hl_keywords } , LL={ ll_keywords } " )
351+
352+ # 从知识图谱中收集相关节点的source_ids
353+ related_source_ids = set ()
354+ graph = rag .chunk_entity_relation_graph ._graph
355+
356+ # 根据关键词匹配节点
357+ for node , data in graph .nodes (data = True ):
358+ node_str = str (node ).upper ()
359+ node_desc = str (data .get ("description" , "" )).upper ()
360+
361+ # 检查节点名称或描述是否包含任何关键词
362+ if any (kw .upper () in node_str for kw in hl_keywords + ll_keywords ) or \
363+ any (kw .upper () in node_desc for kw in hl_keywords + ll_keywords ):
364+
365+ if "source_id" in data :
366+ if "<SEP>" in data ["source_id" ]:
367+ for sid in data ["source_id" ].split ("<SEP>" ):
368+ related_source_ids .add (sid .strip ())
369+ else :
370+ related_source_ids .add (data ["source_id" ].strip ())
371+
372+ # 从边中也收集相关的source_ids
373+ for src , tgt , edge_data in graph .edges (data = True ):
374+ src_str = str (src ).upper ()
375+ tgt_str = str (tgt ).upper ()
376+ desc = str (edge_data .get ("description" , "" )).upper ()
377+ keywords = str (edge_data .get ("keywords" , "" )).upper ()
378+
379+ # 检查边的信息是否包含任何关键词
380+ if any (kw .upper () in src_str + tgt_str + desc + keywords for kw in hl_keywords + ll_keywords ):
381+ if "source_id" in edge_data :
382+ if "<SEP>" in edge_data ["source_id" ]:
383+ for sid in edge_data ["source_id" ].split ("<SEP>" ):
384+ related_source_ids .add (sid .strip ())
385+ else :
386+ related_source_ids .add (edge_data ["source_id" ].strip ())
387+
388+ # 从text_chunks中获取文本块
389+ result_chunks = {}
390+
391+ # 优先从客户端存储中获取数据(内存中的数据)
392+ if hasattr (rag .text_chunks , "client_storage" ) and "data" in rag .text_chunks .client_storage :
393+ chunks = rag .text_chunks .client_storage ["data" ]
394+
395+ for chunk_id , chunk_data in chunks .items ():
396+ # 直接匹配chunk_id
397+ if chunk_id in related_source_ids :
398+ result_chunks [chunk_id ] = chunk_data ["content" ]
399+ # 或者匹配chunk的source_id字段
400+ elif "source_id" in chunk_data and chunk_data ["source_id" ] in related_source_ids :
401+ result_chunks [chunk_data ["source_id" ]] = chunk_data ["content" ]
402+ else :
403+ # 如果客户端存储不可用,从数据库中获取
404+ async def get_chunks ():
405+ chunks_data = {}
406+ for source_id in related_source_ids :
407+ chunk = await rag .text_chunks .get_by_id (source_id )
408+ if chunk and "content" in chunk :
409+ chunks_data [source_id ] = chunk ["content" ]
410+ return chunks_data
411+
412+ result_chunks = loop .run_until_complete (get_chunks ())
413+
414+ return result_chunks
415+
197416def main ():
198417 working_dir = "./neuroticism"
199418 graphml_path = "./big_five/neuroticism/graph_chunk_entity_relation.graphml"
@@ -229,6 +448,19 @@ def main():
229448 # print(f"Answer: {result}")
230449 # print("-" * 80)
231450
451+ # 先示范如何获取查询相关的文本块
452+ sample_query = "What does Scrooge think about Christmas?"
453+ print ("\n === Getting Related Text Chunks for Sample Query ===" )
454+ print (f"Sample Query: { sample_query } " )
455+
456+ related_chunks = get_query_related_chunks (rag , sample_query )
457+
458+ print (f"\n Found { len (related_chunks )} related text chunks:" )
459+ for source_id , content in related_chunks .items ():
460+ print (f"\n Source ID: { source_id } " )
461+ print (f"Content: { content [:200 ]} ..." if len (content ) > 200 else f"Content: { content } " )
462+ print ("-" * 80 )
463+
232464 system_prompt = "Please answer the following questions as character Scoorge. "
233465
234466 # Execute hybrid search
@@ -237,8 +469,17 @@ def main():
237469 print (f"\n Question: { question } " )
238470 result = rag .query (system_prompt + question , param = QueryParam (mode = "hybrid" , conversation_history = empty_history ))
239471 print (f"Answer: { result } " )
472+
473+ # 获取并显示此问题相关的文本块
474+ print ("\n --- Related Text Chunks ---" )
475+ query_chunks = get_query_related_chunks (rag , question )
476+ print (f"Found { len (query_chunks )} related chunks" )
477+ for source_id , content in query_chunks .items ():
478+ print (f"Source ID: { source_id } " )
479+ print (f"Content snippet: { content [:100 ]} ..." if len (content ) > 100 else f"Content: { content } " )
480+
240481 print ("-" * 80 )
241482
242483# If this script is run directly, execute the main function
243484if __name__ == "__main__" :
244- main ()
485+ main ()
0 commit comments