Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"info": {
"title": "MemOS Server REST APIs",
"description": "A REST API for managing multiple users with MemOS Server.",
"version": "1.0.1"
"version": "2.0.12"
},
"paths": {
"/product/search": {
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
##############################################################################

name = "MemoryOS"
version = "2.0.11"
version = "2.0.12"
description = "Intelligence Begins with Memory"
license = {text = "Apache-2.0"}
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/memos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.0.11"
__version__ = "2.0.12"

from memos.configs.mem_cube import GeneralMemCubeConfig
from memos.configs.mem_os import MOSConfig
Expand Down
6 changes: 2 additions & 4 deletions src/memos/api/handlers/add_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import validate_call

from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
from memos.api.handlers.cube_scope import resolve_cube_ids
from memos.api.product_models import APIADDRequest, APIFeedbackRequest, MemoryResponse
from memos.memories.textual.item import (
list_all_fields,
Expand Down Expand Up @@ -120,10 +121,7 @@ def _resolve_cube_ids(self, add_req: APIADDRequest) -> list[str]:
1) writable_cube_ids (deprecated mem_cube_id is converted to this in model validator)
2) fallback to user_id
"""
if add_req.writable_cube_ids:
return list(dict.fromkeys(add_req.writable_cube_ids))

return [add_req.user_id]
return resolve_cube_ids(add_req.writable_cube_ids, add_req.user_id)

def _build_cube_view(self, add_req: APIADDRequest) -> MemCubeView:
cube_ids = self._resolve_cube_ids(add_req)
Expand Down
19 changes: 19 additions & 0 deletions src/memos/api/handlers/cube_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations


def resolve_cube_ids(
cube_ids: list[str] | None,
fallback_user_id: str,
) -> list[str]:
"""
Normalize cube ids for API handlers.

Empty or duplicate entries are removed. If no cube ids are provided, the
request falls back to the caller's user id for backward compatibility.
"""
if cube_ids:
normalized = list(dict.fromkeys(cube_id for cube_id in cube_ids if cube_id))
if normalized:
return normalized

return [fallback_user_id]
6 changes: 2 additions & 4 deletions src/memos/api/handlers/feedback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
from memos.api.handlers.cube_scope import resolve_cube_ids
from memos.api.product_models import APIFeedbackRequest, MemoryResponse
from memos.log import get_logger
from memos.multi_mem_cube.composite_cube import CompositeCubeView
Expand Down Expand Up @@ -55,10 +56,7 @@ def _resolve_cube_ids(self, feedback_req: APIFeedbackRequest) -> list[str]:
"""
Normalize target cube ids from feedback_req.
"""
if feedback_req.writable_cube_ids:
return list(dict.fromkeys(feedback_req.writable_cube_ids))

return [feedback_req.user_id]
return resolve_cube_ids(feedback_req.writable_cube_ids, feedback_req.user_id)

def _build_cube_view(self, feedback_req: APIFeedbackRequest) -> MemCubeView:
cube_ids = self._resolve_cube_ids(feedback_req)
Expand Down
6 changes: 2 additions & 4 deletions src/memos/api/handlers/search_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any

from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
from memos.api.handlers.cube_scope import resolve_cube_ids
from memos.api.handlers.formatters_handler import rerank_knowledge_mem
from memos.api.product_models import APISearchRequest, SearchResponse
from memos.log import get_logger
Expand Down Expand Up @@ -794,10 +795,7 @@ def _resolve_cube_ids(self, search_req: APISearchRequest) -> list[str]:
1) readable_cube_ids (deprecated mem_cube_id is converted to this in model validator)
2) fallback to user_id
"""
if search_req.readable_cube_ids:
return list(dict.fromkeys(search_req.readable_cube_ids))

return [search_req.user_id]
return resolve_cube_ids(search_req.readable_cube_ids, search_req.user_id)

def _build_cube_view(self, search_req: APISearchRequest, searcher=None) -> MemCubeView:
cube_ids = self._resolve_cube_ids(search_req)
Expand Down
15 changes: 13 additions & 2 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,8 @@ class APISearchRequest(BaseRequest):
mem_cube_id: str | None = Field(
None,
description=(
"(Deprecated) Single cube ID to search in. "
"Prefer `readable_cube_ids` for multi-cube search."
"(Deprecated) Single cube ID to write feedback into. "
"Prefer `writable_cube_ids` for multi-cube feedback."
),
)

Expand Down Expand Up @@ -758,6 +758,17 @@ class APIFeedbackRequest(BaseRequest):
),
)

@model_validator(mode="after")
def _convert_deprecated_fields(self) -> "APIFeedbackRequest":
if self.mem_cube_id and not self.writable_cube_ids:
logger.warning(
"APIFeedbackRequest.mem_cube_id is deprecated and will be removed in a future "
"version. Please use `writable_cube_ids` instead."
)
self.writable_cube_ids = [self.mem_cube_id]

return self


class APIChatCompleteRequest(BaseRequest):
"""Request model for chat operations."""
Expand Down
22 changes: 9 additions & 13 deletions src/memos/mem_scheduler/memory_manage_modules/search_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,8 @@ def search(
try:
if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]:
assert isinstance(text_mem_base, TreeTextMemory)
session_id = search_args.get("session_id", "default_session")
target_session_id = session_id
search_priority = (
{"session_id": target_session_id} if "session_id" in search_args else None
)
session_id = search_args.get("session_id")
search_priority = {"session_id": session_id} if session_id else None
search_filter = search_args.get("filter")
search_source = search_args.get("source")
plugin = bool(search_source is not None and search_source == "plugin")
Expand All @@ -45,14 +42,13 @@ def search(
"playground_search_goal_parser", False
)

info = search_args.get(
"info",
{
"user_id": user_id,
"session_id": target_session_id,
"chat_history": chat_history,
},
)
default_info = {
"user_id": user_id,
"chat_history": chat_history,
}
if session_id:
default_info["session_id"] = session_id
info = search_args.get("info", default_info)

results_long_term = mem_cube.text_mem.search(
query=query,
Expand Down
33 changes: 12 additions & 21 deletions src/memos/multi_mem_cube/single_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from memos.memories.textual.item import TextualMemoryItem
from memos.multi_mem_cube.views import MemCubeView
from memos.search import search_text_memories
from memos.search import build_search_context, search_text_memories
from memos.templates.mem_reader_prompts import PROMPT_MAPPING
from memos.types.general_types import (
FINE_STRATEGY,
Expand Down Expand Up @@ -95,7 +95,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]:
user_context = UserContext(
user_id=search_req.user_id,
mem_cube_id=self.cube_id,
session_id=search_req.session_id or "default_session",
session_id=search_req.session_id,
)
self.logger.info(f"Search Req is: {search_req}")

Expand Down Expand Up @@ -219,14 +219,10 @@ def _deep_search(
search_req: APISearchRequest,
user_context: UserContext,
) -> list:
target_session_id = search_req.session_id or "default_session"
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None

info = {
"user_id": search_req.user_id,
"session_id": target_session_id,
"chat_history": search_req.chat_history,
}
search_ctx = build_search_context(search_req)
search_filter = dict(search_ctx.search_filter or {})
if search_req.session_id:
search_filter["session_id"] = search_req.session_id

enhanced_memories = self.searcher.deep_search(
query=search_req.query,
Expand All @@ -235,8 +231,8 @@ def _deep_search(
mode=SearchMode.FINE,
manual_close_internet=not search_req.internet_search,
moscube=search_req.moscube,
search_filter=search_filter,
info=info,
search_filter=search_filter or None,
info=search_ctx.info,
)
return self._postformat_memories(
enhanced_memories,
Expand Down Expand Up @@ -281,15 +277,10 @@ def _fine_search(
elif FINE_STRATEGY == FineStrategy.AGENTIC_SEARCH:
return self._agentic_search(search_req=search_req, user_context=user_context)

target_session_id = search_req.session_id or "default_session"
search_priority = {"session_id": search_req.session_id} if search_req.session_id else None
search_filter = search_req.filter

info = {
"user_id": search_req.user_id,
"session_id": target_session_id,
"chat_history": search_req.chat_history,
}
search_ctx = build_search_context(search_req)
search_priority = search_ctx.search_priority
search_filter = search_ctx.search_filter
info = search_ctx.info

# Fine retrieve
raw_retrieved_memories = self.searcher.retrieve(
Expand Down
18 changes: 10 additions & 8 deletions src/memos/search/search_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@dataclass(frozen=True)
class SearchContext:
target_session_id: str
target_session_id: str | None
search_priority: dict[str, Any] | None
search_filter: dict[str, Any] | None
info: dict[str, Any]
Expand All @@ -21,17 +21,19 @@ class SearchContext:
def build_search_context(
search_req: APISearchRequest,
) -> SearchContext:
target_session_id = search_req.session_id or "default_session"
search_priority = {"session_id": search_req.session_id} if search_req.session_id else None
info: dict[str, Any] = {
"user_id": search_req.user_id,
"chat_history": search_req.chat_history,
}
if search_req.session_id:
info["session_id"] = search_req.session_id

return SearchContext(
target_session_id=target_session_id,
target_session_id=search_req.session_id,
search_priority=search_priority,
search_filter=search_req.filter,
info={
"user_id": search_req.user_id,
"session_id": target_session_id,
"chat_history": search_req.chat_history,
},
info=info,
plugin=bool(search_req.source is not None and search_req.source == "plugin"),
)

Expand Down
56 changes: 56 additions & 0 deletions tests/api/test_cube_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from unittest.mock import Mock

from memos.api.handlers.add_handler import AddHandler
from memos.api.handlers.base_handler import HandlerDependencies
from memos.api.handlers.feedback_handler import FeedbackHandler
from memos.api.handlers.search_handler import SearchHandler
from memos.api.product_models import APIADDRequest, APIFeedbackRequest, APISearchRequest


def _make_dependencies() -> HandlerDependencies:
return HandlerDependencies(
naive_mem_cube=Mock(),
mem_reader=Mock(),
mem_scheduler=Mock(),
searcher=Mock(),
reranker=Mock(),
feedback_server=Mock(),
deepsearch_agent=Mock(),
)


def test_search_handler_prefers_mem_cube_id_over_user_id_fallback():
handler = SearchHandler(_make_dependencies())
request = APISearchRequest(query="where is it", user_id="user_a", mem_cube_id="cube_a")

assert handler._resolve_cube_ids(request) == ["cube_a"]


def test_search_handler_deduplicates_readable_cube_ids():
handler = SearchHandler(_make_dependencies())
request = APISearchRequest(
query="where is it",
user_id="user_a",
readable_cube_ids=["cube_a", "cube_a", "cube_b"],
)

assert handler._resolve_cube_ids(request) == ["cube_a", "cube_b"]


def test_add_handler_prefers_mem_cube_id_over_user_id_fallback():
handler = AddHandler(_make_dependencies())
request = APIADDRequest(user_id="user_a", mem_cube_id="cube_a", memory_content="remember this")

assert handler._resolve_cube_ids(request) == ["cube_a"]


def test_feedback_handler_prefers_mem_cube_id_over_user_id_fallback():
handler = FeedbackHandler(_make_dependencies())
request = APIFeedbackRequest(
user_id="user_a",
mem_cube_id="cube_a",
history=[],
feedback_content="that memory is wrong",
)

assert handler._resolve_cube_ids(request) == ["cube_a"]
42 changes: 41 additions & 1 deletion tests/api/test_server_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
input request formats and return properly formatted responses.
"""

from unittest.mock import Mock, patch
from unittest.mock import ANY, Mock, patch

import pytest

Expand Down Expand Up @@ -142,6 +142,7 @@ def test_search_valid_input_output(self, mock_handlers, client):
assert isinstance(call_args, APISearchRequest)
assert call_args.query == "test query"
assert call_args.user_id == "test_user"
assert call_args.readable_cube_ids == ["test_cube"]

def test_search_invalid_input_missing_query(self, mock_handlers, client):
"""Test search endpoint with missing required field."""
Expand Down Expand Up @@ -386,6 +387,45 @@ def test_get_all_with_search_query(self, mock_handlers, client):
# Verify subgraph handler was called
mock_handlers["memory"].handle_get_subgraph.assert_called_once()

def test_get_all_uses_first_mem_cube_id_for_handler_scope(self, mock_handlers, client):
request_data = {
"user_id": "test_user",
"memory_type": "text_mem",
"mem_cube_ids": ["cube_alpha", "cube_beta"],
}

response = client.post("/product/get_all", json=request_data)

assert response.status_code == 200
mock_handlers["memory"].handle_get_all_memories.assert_called_once_with(
user_id="test_user",
mem_cube_id="cube_alpha",
memory_type="text_mem",
naive_mem_cube=ANY,
)

def test_get_all_search_query_uses_first_mem_cube_id_for_subgraph_scope(
self, mock_handlers, client
):
request_data = {
"user_id": "test_user",
"memory_type": "text_mem",
"search_query": "important topic",
"mem_cube_ids": ["cube_alpha", "cube_beta"],
}

response = client.post("/product/get_all", json=request_data)

assert response.status_code == 200
mock_handlers["memory"].handle_get_subgraph.assert_called_once_with(
user_id="test_user",
mem_cube_id="cube_alpha",
query="important topic",
top_k=200,
naive_mem_cube=ANY,
search_type="fulltext",
)

def test_get_all_invalid_input_missing_user_id(self, mock_handlers, client):
"""Test get_all endpoint with missing required field."""
request_data = {
Expand Down
Loading
Loading