|
10 | 10 | from tools.utils import re_ranking |
11 | 11 | from database.db_init import faiss_database_processing |
12 | 12 | from tools.graph_based_image_retrieval import retrieve_by_hashtags |
| 13 | +from tools.calculate_weighted_exploration import calculate_weighted_exploration |
13 | 14 |
|
14 | 15 | import clip |
15 | 16 |
|
@@ -172,6 +173,22 @@ def cached_results(query_text: str, hiddenHashtags: str, |
172 | 173 | refined_scores, refined_indexes = re_ranking(distances_hnsw, indices_hnsw, |
173 | 174 | graph_scores, graph_indices, |
174 | 175 | k_num=k, boost_amount = 2) |
| 176 | + |
| 177 | + # Decide whether to expand the top-k |
| 178 | + should_expand, k_new = calculate_weighted_exploration(refined_scores, k) |
| 179 | + |
| 180 | + if should_expand: |
| 181 | + # Re-run the query with k_new and return top k results |
| 182 | + logger.info("Expanding the search scope to improve the results...") |
| 183 | + distances_hnsw, indices_hnsw = k_image_search(query_vector, index_hnsw, device, k_nums=k_new) |
| 184 | + graph_scores, graph_indices = retrieve_by_hashtags(sparse_matrix, node_mapping, reverse_node_mapping, G, |
| 185 | + hashtags_list, hashtag_embeddings, hashtag_index, clip, device, model, |
| 186 | + k_num=k_new, max_depth=5, alpha=0.7, similarity_num = 10, |
| 187 | + min_score_threshold=0.01, max_keyframes=10000, max_iterations=10000) |
| 188 | + refined_scores, refined_indexes = re_ranking(distances_hnsw, indices_hnsw, |
| 189 | + graph_scores, graph_indices, |
| 190 | + k_num=k, boost_amount = 2) |
| 191 | + |
175 | 192 | #Filter and Display Results |
176 | 193 | results = display_option_results(display_option, |
177 | 194 | refined_scores, refined_indexes, |
|
0 commit comments