Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions tests/unit/vertexai/genai/replays/test_structured_memories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

from tests.unit.vertexai.genai.replays import pytest_helper
from vertexai._genai import types


def test_generate_and_retrieve_profile(client):
# TODO: Use prod once available.
client._api_client._http_options.base_url = (
"https://us-central1-autopush-aiplatform.sandbox.googleapis.com"
)
customization_config = {"disable_natural_language_memories": True}
memory_bank_customization_config = types.MemoryBankCustomizationConfig(
**customization_config
)
structured_memory_config = {
"scope_keys": ["user_id"],
"schema_configs": [
{
"id": "user-profile",
"schema": {
"properties": {
"name": {"description": "User's name", "type": "string"}
},
"type": "object",
},
}
],
}
structured_memory_config_obj = types.StructuredMemoryConfig(
**structured_memory_config
)
agent_engine = client.agent_engines.create(
config={
"context_spec": {
"memory_bank_config": {
"customization_configs": [memory_bank_customization_config],
"structured_memory_configs": [structured_memory_config_obj],
},
},
"http_options": {"api_version": "v1beta1"},
},
)
try:
agent_engine = client.agent_engines.get(name=agent_engine.api_resource.name)
memory_bank_config = agent_engine.api_resource.context_spec.memory_bank_config
assert memory_bank_config.customization_configs == [
memory_bank_customization_config
]
assert memory_bank_config.structured_memory_configs == [
structured_memory_config_obj
]

scope = {"user_id": "123"}
client.agent_engines.memories.generate(
name=agent_engine.api_resource.name,
scope=scope,
direct_contents_source={
"events": [{"content": {"parts": [{"text": "My name is Kim."}]}}]
},
)
memories = list(
client.agent_engines.memories.retrieve(
name=agent_engine.api_resource.name,
scope=scope,
config={"memory_types": ["STRUCTURED_PROFILE"]},
)
)
assert len(memories) >= 1
assert memories[0].memory.structured_content is not None

response = client.agent_engines.memories.retrieve_profiles(
name=agent_engine.api_resource.name, scope=scope
)
assert len(response.profiles) == 1

finally:
# Clean up resources.
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
test_method="agent_engines.retrieve_profiles",
)
151 changes: 151 additions & 0 deletions vertexai/_genai/memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ def _RetrieveAgentEngineMemoriesConfig_to_vertex(
[item for item in getv(from_object, ["filter_groups"])],
)

if getv(from_object, ["memory_types"]) is not None:
setv(parent_object, ["memoryTypes"], getv(from_object, ["memory_types"]))

return to_object


Expand Down Expand Up @@ -365,6 +368,20 @@ def _RetrieveAgentEngineMemoriesRequestParameters_to_vertex(
return to_object


def _RetrieveMemoryProfilesRequestParameters_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ["name"]) is not None:
setv(to_object, ["_url", "name"], getv(from_object, ["name"]))

if getv(from_object, ["scope"]) is not None:
setv(to_object, ["scope"], getv(from_object, ["scope"]))

return to_object


def _RollbackAgentEngineMemoryRequestParameters_to_vertex(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -935,6 +952,72 @@ def _retrieve(
self._api_client._verify_response(return_value)
return return_value

def retrieve_profiles(
self,
*,
name: str,
scope: dict[str, str],
config: Optional[types.RetrieveMemoryProfilesConfigOrDict] = None,
) -> types.RetrieveProfilesResponse:
"""
Retrieves memory profiles for an Agent Engine.

Args:
name (str): Required. A fully-qualified resource name or ID such as
"projects/123/locations/us-central1/reasoningEngines/456".
scope (dict[str, str]): Required. The scope of the memories to retrieve.
A memory must have exactly the same scope as the scope provided here
to be retrieved (i.e. same keys and values). Order does not matter,
but it is case-sensitive.

"""

parameter_model = types._RetrieveMemoryProfilesRequestParameters(
name=name,
scope=scope,
config=config,
)

request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError("This method is only supported in the Vertex AI client.")
else:
request_dict = _RetrieveMemoryProfilesRequestParameters_to_vertex(
parameter_model
)
request_url_dict = request_dict.get("_url")
if request_url_dict:
path = "{name}/memories:retrieveProfiles".format_map(request_url_dict)
else:
path = "{name}/memories:retrieveProfiles"

query_params = request_dict.get("_query")
if query_params:
path = f"{path}?{urlencode(query_params)}"
# TODO: remove the hack that pops config.
request_dict.pop("config", None)

http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options

request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)

response = self._api_client.request("post", path, request_dict, http_options)

response_dict = {} if not response.body else json.loads(response.body)

return_value = types.RetrieveProfilesResponse._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)

self._api_client._verify_response(return_value)
return return_value

def _rollback(
self,
*,
Expand Down Expand Up @@ -1909,6 +1992,74 @@ async def _retrieve(
self._api_client._verify_response(return_value)
return return_value

async def retrieve_profiles(
self,
*,
name: str,
scope: dict[str, str],
config: Optional[types.RetrieveMemoryProfilesConfigOrDict] = None,
) -> types.RetrieveProfilesResponse:
"""
Retrieves memory profiles for an Agent Engine.

Args:
name (str): Required. A fully-qualified resource name or ID such as
"projects/123/locations/us-central1/reasoningEngines/456".
scope (dict[str, str]): Required. The scope of the memories to retrieve.
A memory must have exactly the same scope as the scope provided here
to be retrieved (i.e. same keys and values). Order does not matter,
but it is case-sensitive.

"""

parameter_model = types._RetrieveMemoryProfilesRequestParameters(
name=name,
scope=scope,
config=config,
)

request_url_dict: Optional[dict[str, str]]
if not self._api_client.vertexai:
raise ValueError("This method is only supported in the Vertex AI client.")
else:
request_dict = _RetrieveMemoryProfilesRequestParameters_to_vertex(
parameter_model
)
request_url_dict = request_dict.get("_url")
if request_url_dict:
path = "{name}/memories:retrieveProfiles".format_map(request_url_dict)
else:
path = "{name}/memories:retrieveProfiles"

query_params = request_dict.get("_query")
if query_params:
path = f"{path}?{urlencode(query_params)}"
# TODO: remove the hack that pops config.
request_dict.pop("config", None)

http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options

request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)

response = await self._api_client.async_request(
"post", path, request_dict, http_options
)

response_dict = {} if not response.body else json.loads(response.body)

return_value = types.RetrieveProfilesResponse._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)

self._api_client._verify_response(return_value)
return return_value

async def _rollback(
self,
*,
Expand Down
40 changes: 40 additions & 0 deletions vertexai/_genai/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
from .common import _QueryAgentEngineRequestParameters
from .common import _RestoreVersionRequestParameters
from .common import _RetrieveAgentEngineMemoriesRequestParameters
from .common import _RetrieveMemoryProfilesRequestParameters
from .common import _RollbackAgentEngineMemoryRequestParameters
from .common import _RunQueryJobAgentEngineConfig
from .common import _RunQueryJobAgentEngineConfigDict
Expand Down Expand Up @@ -689,12 +690,19 @@
from .common import MemoryMetadataValueDict
from .common import MemoryMetadataValueOrDict
from .common import MemoryOrDict
from .common import MemoryProfile
from .common import MemoryProfileDict
from .common import MemoryProfileOrDict
from .common import MemoryRevision
from .common import MemoryRevisionDict
from .common import MemoryRevisionOrDict
from .common import MemoryStructuredContent
from .common import MemoryStructuredContentDict
from .common import MemoryStructuredContentOrDict
from .common import MemoryTopicId
from .common import MemoryTopicIdDict
from .common import MemoryTopicIdOrDict
from .common import MemoryType
from .common import Message
from .common import MessageDict
from .common import Metadata
Expand Down Expand Up @@ -886,6 +894,12 @@
from .common import RetrieveMemoriesResponseRetrievedMemory
from .common import RetrieveMemoriesResponseRetrievedMemoryDict
from .common import RetrieveMemoriesResponseRetrievedMemoryOrDict
from .common import RetrieveMemoryProfilesConfig
from .common import RetrieveMemoryProfilesConfigDict
from .common import RetrieveMemoryProfilesConfigOrDict
from .common import RetrieveProfilesResponse
from .common import RetrieveProfilesResponseDict
from .common import RetrieveProfilesResponseOrDict
from .common import RollbackAgentEngineMemoryConfig
from .common import RollbackAgentEngineMemoryConfigDict
from .common import RollbackAgentEngineMemoryConfigOrDict
Expand Down Expand Up @@ -1041,6 +1055,12 @@
from .common import SessionOrDict
from .common import State
from .common import Strategy
from .common import StructuredMemoryConfig
from .common import StructuredMemoryConfigDict
from .common import StructuredMemoryConfigOrDict
from .common import StructuredMemoryConfigSchemaConfig
from .common import StructuredMemoryConfigSchemaConfigDict
from .common import StructuredMemoryConfigSchemaConfigOrDict
from .common import SummaryMetric
from .common import SummaryMetricDict
from .common import SummaryMetricOrDict
Expand Down Expand Up @@ -1653,6 +1673,12 @@
"ReasoningEngineContextSpecMemoryBankConfigTtlConfig",
"ReasoningEngineContextSpecMemoryBankConfigTtlConfigDict",
"ReasoningEngineContextSpecMemoryBankConfigTtlConfigOrDict",
"StructuredMemoryConfigSchemaConfig",
"StructuredMemoryConfigSchemaConfigDict",
"StructuredMemoryConfigSchemaConfigOrDict",
"StructuredMemoryConfig",
"StructuredMemoryConfigDict",
"StructuredMemoryConfigOrDict",
"ReasoningEngineContextSpecMemoryBankConfig",
"ReasoningEngineContextSpecMemoryBankConfigDict",
"ReasoningEngineContextSpecMemoryBankConfigOrDict",
Expand Down Expand Up @@ -1743,6 +1769,9 @@
"AgentEngineMemoryConfig",
"AgentEngineMemoryConfigDict",
"AgentEngineMemoryConfigOrDict",
"MemoryStructuredContent",
"MemoryStructuredContentDict",
"MemoryStructuredContentOrDict",
"Memory",
"MemoryDict",
"MemoryOrDict",
Expand Down Expand Up @@ -1812,6 +1841,15 @@
"RetrieveMemoriesResponse",
"RetrieveMemoriesResponseDict",
"RetrieveMemoriesResponseOrDict",
"RetrieveMemoryProfilesConfig",
"RetrieveMemoryProfilesConfigDict",
"RetrieveMemoryProfilesConfigOrDict",
"MemoryProfile",
"MemoryProfileDict",
"MemoryProfileOrDict",
"RetrieveProfilesResponse",
"RetrieveProfilesResponseDict",
"RetrieveProfilesResponseOrDict",
"RollbackAgentEngineMemoryConfig",
"RollbackAgentEngineMemoryConfigDict",
"RollbackAgentEngineMemoryConfigOrDict",
Expand Down Expand Up @@ -2197,6 +2235,7 @@
"Type",
"JobState",
"ManagedTopicEnum",
"MemoryType",
"IdentityType",
"AgentServerMode",
"Operator",
Expand Down Expand Up @@ -2271,6 +2310,7 @@
"_GetAgentEngineMemoryOperationParameters",
"_GetAgentEngineGenerateMemoriesOperationParameters",
"_RetrieveAgentEngineMemoriesRequestParameters",
"_RetrieveMemoryProfilesRequestParameters",
"_RollbackAgentEngineMemoryRequestParameters",
"_UpdateAgentEngineMemoryRequestParameters",
"_PurgeAgentEngineMemoriesRequestParameters",
Expand Down
Loading
Loading