diff --git a/docs/openapi.json b/docs/openapi.json index d9ef710b5..01292a912 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -3,7 +3,7 @@ "info": { "title": "MemOS Server REST APIs", "description": "A REST API for managing multiple users with MemOS Server.", - "version": "1.0.1" + "version": "2.0.12" }, "paths": { "/product/search": { diff --git a/pyproject.toml b/pyproject.toml index ba00e62d5..e7fca38ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ ############################################################################## name = "MemoryOS" -version = "2.0.11" +version = "2.0.12" description = "Intelligence Begins with Memory" license = {text = "Apache-2.0"} readme = "README.md" diff --git a/src/memos/__init__.py b/src/memos/__init__.py index 8bc0c7b57..8687b9b9a 100644 --- a/src/memos/__init__.py +++ b/src/memos/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.0.11" +__version__ = "2.0.12" from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index e9ed4f955..b2b9c7d0c 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -8,6 +8,7 @@ from pydantic import validate_call from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies +from memos.api.handlers.cube_scope import resolve_cube_ids from memos.api.product_models import APIADDRequest, APIFeedbackRequest, MemoryResponse from memos.memories.textual.item import ( list_all_fields, @@ -120,10 +121,7 @@ def _resolve_cube_ids(self, add_req: APIADDRequest) -> list[str]: 1) writable_cube_ids (deprecated mem_cube_id is converted to this in model validator) 2) fallback to user_id """ - if add_req.writable_cube_ids: - return list(dict.fromkeys(add_req.writable_cube_ids)) - - return [add_req.user_id] + return resolve_cube_ids(add_req.writable_cube_ids, add_req.user_id) def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView: cube_ids = self._resolve_cube_ids(add_req) diff --git a/src/memos/api/handlers/cube_scope.py b/src/memos/api/handlers/cube_scope.py new file mode 100644 index 000000000..fea8c1f44 --- /dev/null +++ b/src/memos/api/handlers/cube_scope.py @@ -0,0 +1,19 @@ +from __future__ import annotations + + +def resolve_cube_ids( + cube_ids: list[str] | None, + fallback_user_id: str, +) -> list[str]: + """ + Normalize cube ids for API handlers. + + Empty or duplicate entries are removed. If no cube ids are provided, the + request falls back to the caller's user id for backward compatibility. + """ + if cube_ids: + normalized = list(dict.fromkeys(cube_id for cube_id in cube_ids if cube_id)) + if normalized: + return normalized + + return [fallback_user_id] diff --git a/src/memos/api/handlers/feedback_handler.py b/src/memos/api/handlers/feedback_handler.py index 217bca7cd..60620a33f 100644 --- a/src/memos/api/handlers/feedback_handler.py +++ b/src/memos/api/handlers/feedback_handler.py @@ -3,6 +3,7 @@ """ from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies +from memos.api.handlers.cube_scope import resolve_cube_ids from memos.api.product_models import APIFeedbackRequest, MemoryResponse from memos.log import get_logger from memos.multi_mem_cube.composite_cube import CompositeCubeView @@ -55,10 +56,7 @@ def _resolve_cube_ids(self, feedback_req: APIFeedbackRequest) -> list[str]: """ Normalize target cube ids from feedback_req. """ - if feedback_req.writable_cube_ids: - return list(dict.fromkeys(feedback_req.writable_cube_ids)) - - return [feedback_req.user_id] + return resolve_cube_ids(feedback_req.writable_cube_ids, feedback_req.user_id) def _build_cube_view(self, feedback_req: APIFeedbackRequest) -> MemCubeView: cube_ids = self._resolve_cube_ids(feedback_req) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index ba1c50b07..9d70e68e8 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -11,6 +11,7 @@ from typing import Any from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies +from memos.api.handlers.cube_scope import resolve_cube_ids from memos.api.handlers.formatters_handler import rerank_knowledge_mem from memos.api.product_models import APISearchRequest, SearchResponse from memos.log import get_logger @@ -794,10 +795,7 @@ def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]: 1) readable_cube_ids (deprecated mem_cube_id is converted to this in model validator) 2) fallback to user_id """ - if search_req.readable_cube_ids: - return list(dict.fromkeys(search_req.readable_cube_ids)) - - return [search_req.user_id] + return resolve_cube_ids(search_req.readable_cube_ids, search_req.user_id) def _build_cube_view(self, search_req: APISearchRequest, searcher=None) -> MemCubeView: cube_ids = self._resolve_cube_ids(search_req) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 78dcfc797..6a4f80817 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -451,8 +451,8 @@ class APISearchRequest(BaseRequest): mem_cube_id: str | None = Field( None, description=( - "(Deprecated) Single cube ID to search in. " - "Prefer `readable_cube_ids` for multi-cube search." + "(Deprecated) Single cube ID to write feedback into. " + "Prefer `writable_cube_ids` for multi-cube feedback." ), ) @@ -758,6 +758,17 @@ class APIFeedbackRequest(BaseRequest): ), ) + @model_validator(mode="after") + def _convert_deprecated_fields(self) -> "APIFeedbackRequest": + if self.mem_cube_id and not self.writable_cube_ids: + logger.warning( + "APIFeedbackRequest.mem_cube_id is deprecated and will be removed in a future " + "version. Please use `writable_cube_ids` instead." + ) + self.writable_cube_ids = [self.mem_cube_id] + + return self + class APIChatCompleteRequest(BaseRequest): """Request model for chat operations.""" diff --git a/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py b/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py index a346622c5..1540b7219 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py +++ b/src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py @@ -28,11 +28,8 @@ def search( try: if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]: assert isinstance(text_mem_base, TreeTextMemory) - session_id = search_args.get("session_id", "default_session") - target_session_id = session_id - search_priority = ( - {"session_id": target_session_id} if "session_id" in search_args else None - ) + session_id = search_args.get("session_id") + search_priority = {"session_id": session_id} if session_id else None search_filter = search_args.get("filter") search_source = search_args.get("source") plugin = bool(search_source is not None and search_source == "plugin") @@ -45,14 +42,13 @@ def search( "playground_search_goal_parser", False ) - info = search_args.get( - "info", - { - "user_id": user_id, - "session_id": target_session_id, - "chat_history": chat_history, - }, - ) + default_info = { + "user_id": user_id, + "chat_history": chat_history, + } + if session_id: + default_info["session_id"] = session_id + info = search_args.get("info", default_info) results_long_term = mem_cube.text_mem.search( query=query, diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 6a91f436f..f6954e52b 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -22,7 +22,7 @@ ) from memos.memories.textual.item import TextualMemoryItem from memos.multi_mem_cube.views import MemCubeView -from memos.search import search_text_memories +from memos.search import build_search_context, search_text_memories from memos.templates.mem_reader_prompts import PROMPT_MAPPING from memos.types.general_types import ( FINE_STRATEGY, @@ -95,7 +95,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: user_context = UserContext( user_id=search_req.user_id, mem_cube_id=self.cube_id, - session_id=search_req.session_id or "default_session", + session_id=search_req.session_id, ) self.logger.info(f"Search Req is: {search_req}") @@ -219,14 +219,10 @@ def _deep_search( search_req: APISearchRequest, user_context: UserContext, ) -> list: - target_session_id = search_req.session_id or "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - info = { - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - } + search_ctx = build_search_context(search_req) + search_filter = dict(search_ctx.search_filter or {}) + if search_req.session_id: + search_filter["session_id"] = search_req.session_id enhanced_memories = self.searcher.deep_search( query=search_req.query, @@ -235,8 +231,8 @@ def _deep_search( mode=SearchMode.FINE, manual_close_internet=not search_req.internet_search, moscube=search_req.moscube, - search_filter=search_filter, - info=info, + search_filter=search_filter or None, + info=search_ctx.info, ) return self._postformat_memories( enhanced_memories, @@ -281,15 +277,10 @@ def _fine_search( elif FINE_STRATEGY == FineStrategy.AGENTIC_SEARCH: return self._agentic_search(search_req=search_req, user_context=user_context) - target_session_id = search_req.session_id or "default_session" - search_priority = {"session_id": search_req.session_id} if search_req.session_id else None - search_filter = search_req.filter - - info = { - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - } + search_ctx = build_search_context(search_req) + search_priority = search_ctx.search_priority + search_filter = search_ctx.search_filter + info = search_ctx.info # Fine retrieve raw_retrieved_memories = self.searcher.retrieve( diff --git a/src/memos/search/search_service.py b/src/memos/search/search_service.py index fa713a7d1..0ed60efc2 100644 --- a/src/memos/search/search_service.py +++ b/src/memos/search/search_service.py @@ -11,7 +11,7 @@ @dataclass(frozen=True) class SearchContext: - target_session_id: str + target_session_id: str | None search_priority: dict[str, Any] | None search_filter: dict[str, Any] | None info: dict[str, Any] @@ -21,17 +21,19 @@ class SearchContext: def build_search_context( search_req: APISearchRequest, ) -> SearchContext: - target_session_id = search_req.session_id or "default_session" search_priority = {"session_id": search_req.session_id} if search_req.session_id else None + info: dict[str, Any] = { + "user_id": search_req.user_id, + "chat_history": search_req.chat_history, + } + if search_req.session_id: + info["session_id"] = search_req.session_id + return SearchContext( - target_session_id=target_session_id, + target_session_id=search_req.session_id, search_priority=search_priority, search_filter=search_req.filter, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, + info=info, plugin=bool(search_req.source is not None and search_req.source == "plugin"), ) diff --git a/tests/api/test_cube_scope.py b/tests/api/test_cube_scope.py new file mode 100644 index 000000000..a601b4751 --- /dev/null +++ b/tests/api/test_cube_scope.py @@ -0,0 +1,56 @@ +from unittest.mock import Mock + +from memos.api.handlers.add_handler import AddHandler +from memos.api.handlers.base_handler import HandlerDependencies +from memos.api.handlers.feedback_handler import FeedbackHandler +from memos.api.handlers.search_handler import SearchHandler +from memos.api.product_models import APIADDRequest, APIFeedbackRequest, APISearchRequest + + +def _make_dependencies() -> HandlerDependencies: + return HandlerDependencies( + naive_mem_cube=Mock(), + mem_reader=Mock(), + mem_scheduler=Mock(), + searcher=Mock(), + reranker=Mock(), + feedback_server=Mock(), + deepsearch_agent=Mock(), + ) + + +def test_search_handler_prefers_mem_cube_id_over_user_id_fallback(): + handler = SearchHandler(_make_dependencies()) + request = APISearchRequest(query="where is it", user_id="user_a", mem_cube_id="cube_a") + + assert handler._resolve_cube_ids(request) == ["cube_a"] + + +def test_search_handler_deduplicates_readable_cube_ids(): + handler = SearchHandler(_make_dependencies()) + request = APISearchRequest( + query="where is it", + user_id="user_a", + readable_cube_ids=["cube_a", "cube_a", "cube_b"], + ) + + assert handler._resolve_cube_ids(request) == ["cube_a", "cube_b"] + + +def test_add_handler_prefers_mem_cube_id_over_user_id_fallback(): + handler = AddHandler(_make_dependencies()) + request = APIADDRequest(user_id="user_a", mem_cube_id="cube_a", memory_content="remember this") + + assert handler._resolve_cube_ids(request) == ["cube_a"] + + +def test_feedback_handler_prefers_mem_cube_id_over_user_id_fallback(): + handler = FeedbackHandler(_make_dependencies()) + request = APIFeedbackRequest( + user_id="user_a", + mem_cube_id="cube_a", + history=[], + feedback_content="that memory is wrong", + ) + + assert handler._resolve_cube_ids(request) == ["cube_a"] diff --git a/tests/api/test_server_router.py b/tests/api/test_server_router.py index 5906697d9..b8f11c6dc 100644 --- a/tests/api/test_server_router.py +++ b/tests/api/test_server_router.py @@ -5,7 +5,7 @@ input request formats and return properly formatted responses. """ -from unittest.mock import Mock, patch +from unittest.mock import ANY, Mock, patch import pytest @@ -142,6 +142,7 @@ def test_search_valid_input_output(self, mock_handlers, client): assert isinstance(call_args, APISearchRequest) assert call_args.query == "test query" assert call_args.user_id == "test_user" + assert call_args.readable_cube_ids == ["test_cube"] def test_search_invalid_input_missing_query(self, mock_handlers, client): """Test search endpoint with missing required field.""" @@ -386,6 +387,45 @@ def test_get_all_with_search_query(self, mock_handlers, client): # Verify subgraph handler was called mock_handlers["memory"].handle_get_subgraph.assert_called_once() + def test_get_all_uses_first_mem_cube_id_for_handler_scope(self, mock_handlers, client): + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + "mem_cube_ids": ["cube_alpha", "cube_beta"], + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + mock_handlers["memory"].handle_get_all_memories.assert_called_once_with( + user_id="test_user", + mem_cube_id="cube_alpha", + memory_type="text_mem", + naive_mem_cube=ANY, + ) + + def test_get_all_search_query_uses_first_mem_cube_id_for_subgraph_scope( + self, mock_handlers, client + ): + request_data = { + "user_id": "test_user", + "memory_type": "text_mem", + "search_query": "important topic", + "mem_cube_ids": ["cube_alpha", "cube_beta"], + } + + response = client.post("/product/get_all", json=request_data) + + assert response.status_code == 200 + mock_handlers["memory"].handle_get_subgraph.assert_called_once_with( + user_id="test_user", + mem_cube_id="cube_alpha", + query="important topic", + top_k=200, + naive_mem_cube=ANY, + search_type="fulltext", + ) + def test_get_all_invalid_input_missing_user_id(self, mock_handlers, client): """Test get_all endpoint with missing required field.""" request_data = { diff --git a/tests/api/test_server_router_integration.py b/tests/api/test_server_router_integration.py new file mode 100644 index 000000000..35ada1abf --- /dev/null +++ b/tests/api/test_server_router_integration.py @@ -0,0 +1,356 @@ +""" +Integration tests for real /product/add -> /product/search behavior on Neo4j storage. + +These tests intentionally avoid importing the full MemOS server stack at module import time, +because the server boot path has optional external dependencies and heavy side effects. +The server app is imported only inside the integration fixture after: +1. Neo4j availability is confirmed +2. required env vars are patched +3. non-essential external components are stubbed/patched + +Goal: +- reproduce the old regression where `/product/search` could return empty results + when a stored memory had a session id but the search request omitted session id + (or used a different session id) +- keep real graph storage and real API routing in the loop +""" + +from __future__ import annotations + +from contextlib import suppress +import hashlib +import importlib +import math +import os +import sys +import time +import types +import uuid + +from typing import TYPE_CHECKING, Any + +import pytest + +from fastapi.testclient import TestClient + +if TYPE_CHECKING: + from collections.abc import Iterator + + +def _neo4j_integration_configured() -> bool: + try: + import fastapi # noqa: F401 + import neo4j # noqa: F401 + import openai # noqa: F401 + except ImportError: + return False + + return all(os.getenv(k) for k in ("NEO4J_URI", "NEO4J_USER", "NEO4J_PASSWORD")) + + +def _install_ollama_stub() -> None: + """ + `memos.embedders.factory` imports `memos.embedders.ollama` eagerly. + The real `ollama` package is not required for this test, so install a tiny stub + to keep module import deterministic. + """ + if "ollama" in sys.modules: + return + + module = types.ModuleType("ollama") + + class _DummyEmbedResponse: + def __init__(self, embeddings: list[list[float]]): + self.embeddings = embeddings + + class Client: + def __init__(self, *args, **kwargs): + pass + + def list(self) -> dict[str, list[Any]]: + return {"models": []} + + def pull(self, *args, **kwargs) -> None: + return None + + def embed(self, model: str, input: list[str]): + dim = int(os.getenv("EMBEDDING_DIMENSION", "8")) + return _DummyEmbedResponse([[0.0] * dim for _ in input]) + + module.Client = Client + sys.modules["ollama"] = module + + +def _token_hash_embedding(texts: list[str]) -> list[list[float]]: + """ + Deterministic local embedding for integration tests. + + Properties: + - same token -> same dimension contribution + - query containing the same unique token as the stored memory will get a positive cosine score + - fixed dimension so Neo4j vector index config stays consistent + """ + dim = int(os.getenv("EMBEDDING_DIMENSION", "8")) + embeddings: list[list[float]] = [] + + for text in texts: + vector = [0.0] * dim + tokens = [tok for tok in text.lower().replace("\n", " ").split(" ") if tok.strip()] + if not tokens: + tokens = [text.lower().strip() or "__empty__"] + + for token in set(tokens): + digest = hashlib.sha256(token.encode("utf-8")).digest() + bucket = digest[0] % dim + # Positive-only hashed bag-of-words keeps overlap stable and easy to reason about. + vector[bucket] += 1.0 + (digest[1] / 255.0) + + norm = math.sqrt(sum(v * v for v in vector)) or 1.0 + embeddings.append([v / norm for v in vector]) + + return embeddings + + +def _clear_module(module_name: str) -> None: + sys.modules.pop(module_name, None) + + +def _flatten_text_memories(search_payload: dict[str, Any]) -> list[dict[str, Any]]: + buckets = search_payload["data"].get("text_mem", []) + return [memory for bucket in buckets for memory in bucket.get("memories", [])] + + +def _search_until_found( + client: TestClient, + payload: dict[str, Any], + expected_token: str, + timeout_seconds: float = 5.0, +) -> dict[str, Any]: + deadline = time.time() + timeout_seconds + last_response: dict[str, Any] | None = None + + while time.time() < deadline: + response = client.post("/product/search", json=payload) + assert response.status_code == 200, response.text + last_response = response.json() + memories = _flatten_text_memories(last_response) + if any(expected_token in (memory.get("memory") or "") for memory in memories): + return last_response + time.sleep(0.2) + + assert last_response is not None + return last_response + + +@pytest.fixture(scope="module") +def integration_stack(tmp_path_factory) -> Iterator[dict[str, Any]]: + if not _neo4j_integration_configured(): + pytest.skip("Neo4j integration not configured (need NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)") + + monkeypatch = pytest.MonkeyPatch() + static_dir = tmp_path_factory.mktemp("memos-server-static") + + env_updates = { + "FILE_LOCAL_PATH": str(static_dir), + "GRAPH_DB_BACKEND": "neo4j", + "MOS_NEO4J_SHARED_DB": "true", + "EMBEDDING_DIMENSION": "8", + "ENABLE_INTERNET": "false", + "MEM_READER_BACKEND": "simple_struct", + "MOS_RERANKER_BACKEND": "cosine_local", + "MOS_FEEDBACK_RERANKER_BACKEND": "cosine_local", + "MOS_EMBEDDER_BACKEND": "universal_api", + "MOS_EMBEDDER_PROVIDER": "openai", + "MOS_EMBEDDER_API_KEY": "integration-test-key", + "MOS_EMBEDDER_API_BASE": "https://example.invalid/v1", + "MOS_EMBEDDER_MODEL": "integration-test-embedder", + "OPENAI_API_KEY": "integration-test-key", + "OPENAI_API_BASE": "https://example.invalid/v1", + "MEMRADER_API_KEY": "integration-test-key", + "MEMRADER_API_BASE": "https://example.invalid/v1", + "MEMREADER_GENERAL_API_KEY": "integration-test-key", + "MEMREADER_GENERAL_API_BASE": "https://example.invalid/v1", + } + for key, value in env_updates.items(): + monkeypatch.setenv(key, value) + + _install_ollama_stub() + + from memos.memos_tools.singleton import _factory_singleton + + _factory_singleton.clear_cache() + + from memos.embedders.universal_api import UniversalAPIEmbedder + from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( + InternetRetrieverFactory, + ) + + monkeypatch.setattr( + UniversalAPIEmbedder, + "embed", + lambda self, texts: _token_hash_embedding([texts] if isinstance(texts, str) else list(texts)), + raising=True, + ) + monkeypatch.setattr(InternetRetrieverFactory, "from_config", lambda *args, **kwargs: None) + + # Import the server only after env + patches are in place. + _clear_module("memos.api.handlers.component_init") + _clear_module("memos.api.handlers.config_builders") + _clear_module("memos.api.handlers") + _clear_module("memos.api.routers.server_router") + _clear_module("memos.api.server_api") + + server_api = importlib.import_module("memos.api.server_api") + server_router = importlib.import_module("memos.api.routers.server_router") + + client = TestClient(server_api.app) + graph_db = server_router.components["graph_db"] + + try: + yield { + "client": client, + "graph_db": graph_db, + } + finally: + client.close() + with suppress(Exception): + graph_db.driver.close() + _factory_singleton.clear_cache() + monkeypatch.undo() + + +@pytest.fixture +def isolated_cube(integration_stack) -> Iterator[dict[str, Any]]: + cube_id = f"it_cube_{uuid.uuid4().hex[:10]}" + user_id = f"it_user_{uuid.uuid4().hex[:10]}" + graph_db = integration_stack["graph_db"] + + graph_db.clear(user_name=cube_id) + try: + yield { + **integration_stack, + "cube_id": cube_id, + "user_id": user_id, + } + finally: + graph_db.clear(user_name=cube_id) + + +def _add_memory( + client: TestClient, + *, + user_id: str, + cube_id: str, + session_id: str, + unique_token: str, +) -> dict[str, Any]: + add_payload = { + "user_id": user_id, + "mem_cube_id": cube_id, + "session_id": session_id, + "async_mode": "sync", + "mode": "fast", + "messages": [ + { + "role": "user", + "content": ( + f"Integration regression memory for token {unique_token}. " + "This fact must remain searchable across session-scoping variations." + ), + } + ], + } + + response = client.post("/product/add", json=add_payload) + assert response.status_code == 200, response.text + payload = response.json() + assert payload["data"], payload + return payload + + +def _build_search_payload( + *, + user_id: str, + cube_id: str, + unique_token: str, + session_id: str | None, +) -> dict[str, Any]: + payload = { + "user_id": user_id, + "mem_cube_id": cube_id, + "query": unique_token, + "mode": "fast", + "top_k": 5, + "relativity": 0, + "dedup": "no", + "include_preference": False, + "pref_top_k": 0, + "search_tool_memory": False, + "tool_mem_top_k": 0, + "include_skill_memory": False, + "skill_mem_top_k": 0, + } + if session_id is not None: + payload["session_id"] = session_id + return payload + + +class TestServerRouterAddSearchIntegration: + def test_search_without_session_id_finds_memory_added_with_real_session(self, isolated_cube): + client = isolated_cube["client"] + cube_id = isolated_cube["cube_id"] + user_id = isolated_cube["user_id"] + stored_session_id = "session-alpha" + unique_token = f"sessionless-regression-{uuid.uuid4().hex[:8]}" + + _add_memory( + client, + user_id=user_id, + cube_id=cube_id, + session_id=stored_session_id, + unique_token=unique_token, + ) + + search_payload = _build_search_payload( + user_id=user_id, + cube_id=cube_id, + unique_token=unique_token, + session_id=None, + ) + result = _search_until_found(client, search_payload, unique_token) + + memories = _flatten_text_memories(result) + assert memories, result + assert any(unique_token in memory["memory"] for memory in memories), result + assert any( + bucket["cube_id"] == cube_id and bucket.get("memories") + for bucket in result["data"]["text_mem"] + ), result + + def test_search_with_different_session_id_still_returns_memory(self, isolated_cube): + client = isolated_cube["client"] + cube_id = isolated_cube["cube_id"] + user_id = isolated_cube["user_id"] + stored_session_id = "session-alpha" + searched_session_id = "session-beta" + unique_token = f"cross-session-regression-{uuid.uuid4().hex[:8]}" + + _add_memory( + client, + user_id=user_id, + cube_id=cube_id, + session_id=stored_session_id, + unique_token=unique_token, + ) + + search_payload = _build_search_payload( + user_id=user_id, + cube_id=cube_id, + unique_token=unique_token, + session_id=searched_session_id, + ) + result = _search_until_found(client, search_payload, unique_token) + + memories = _flatten_text_memories(result) + assert memories, result + assert any(unique_token in memory["memory"] for memory in memories), result diff --git a/tests/search/test_search_service.py b/tests/search/test_search_service.py new file mode 100644 index 000000000..60eb4f429 --- /dev/null +++ b/tests/search/test_search_service.py @@ -0,0 +1,29 @@ +from memos.api.product_models import APISearchRequest +from memos.search import build_search_context + + +def test_build_search_context_without_session_id_does_not_inject_default_session(): + request = APISearchRequest(query="find memory", user_id="user_a") + + context = build_search_context(request) + + assert context.target_session_id is None + assert context.search_priority is None + assert context.info == { + "user_id": "user_a", + "chat_history": None, + } + + +def test_build_search_context_with_session_id_keeps_soft_priority_and_info(): + request = APISearchRequest(query="find memory", user_id="user_a", session_id="session_42") + + context = build_search_context(request) + + assert context.target_session_id == "session_42" + assert context.search_priority == {"session_id": "session_42"} + assert context.info == { + "user_id": "user_a", + "chat_history": None, + "session_id": "session_42", + }