Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/packages/purview/agent_framework_purview/_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,8 @@ async def _process_with_scopes(self, pc_request: ProcessContentRequest) -> Proce
cache_key = create_protection_scopes_cache_key(ps_req)
cached_ps_resp = await self._cache.get(cache_key)

if cached_ps_resp is not None:
if isinstance(cached_ps_resp, ProtectionScopesResponse):
ps_resp = cached_ps_resp
if cached_ps_resp is not None and isinstance(cached_ps_resp, ProtectionScopesResponse):
ps_resp = cached_ps_resp
else:
try:
ps_resp = await self._client.get_protection_scopes(ps_req)
Expand Down
19 changes: 19 additions & 0 deletions python/packages/purview/tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,25 @@ class CustomObject:

assert result == obj

async def test_estimate_size_conservative_fallback_when_all_size_methods_fail(self, monkeypatch) -> None:
"""Test that the cache returns a conservative size estimate when all strategies fail."""
cache = InMemoryCacheProvider()

class BadString:
def __str__(self) -> str:
raise RuntimeError("boom")

def raise_getsizeof(_: object) -> int:
raise RuntimeError("no sizeof")

monkeypatch.setattr("agent_framework_purview._cache.sys.getsizeof", raise_getsizeof)

# Arrange/Act
size = cache._estimate_size(BadString())

# Assert
assert size == 1024

async def test_cache_multiple_updates(self) -> None:
"""Test that updating a key multiple times maintains correct size tracking."""
cache = InMemoryCacheProvider(max_size_bytes=1000)
Expand Down
88 changes: 88 additions & 0 deletions python/packages/purview/tests/test_chat_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,39 @@ async def mock_next(ctx: ChatContext) -> None:
with pytest.raises(PurviewPaymentRequiredError):
await middleware.process(context, mock_next)

async def test_chat_middleware_handles_payment_required_post_check(self, mock_credential: AsyncMock) -> None:
"""Test that 402 in post-check is raised when ignore_payment_required=False."""
from agent_framework_purview._exceptions import PurviewPaymentRequiredError

settings = PurviewSettings(app_name="Test App", ignore_payment_required=False)
middleware = PurviewChatPolicyMiddleware(mock_credential, settings)

chat_client = DummyChatClient()
chat_options = MagicMock()
chat_options.model = "test-model"
context = ChatContext(
chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], options=chat_options
)

call_count = 0

async def side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return (False, "user-123")
raise PurviewPaymentRequiredError("Payment required")

with patch.object(middleware._processor, "process_messages", side_effect=side_effect):

async def mock_next(ctx: ChatContext) -> None:
result = MagicMock()
result.messages = [ChatMessage(role=Role.ASSISTANT, text="OK")]
ctx.result = result

with pytest.raises(PurviewPaymentRequiredError):
await middleware.process(context, mock_next)

async def test_chat_middleware_ignores_payment_required_when_configured(self, mock_credential: AsyncMock) -> None:
"""Test that 402 is ignored when ignore_payment_required=True."""
from agent_framework_purview._exceptions import PurviewPaymentRequiredError
Expand Down Expand Up @@ -274,3 +307,58 @@ async def mock_next(ctx: ChatContext) -> None:
await middleware.process(context, mock_next)
# Next should have been called
assert context.result is not None

async def test_chat_middleware_raises_on_pre_check_exception_when_ignore_exceptions_false(
self, mock_credential: AsyncMock
) -> None:
"""Test that exceptions are propagated by default when ignore_exceptions=False."""
settings = PurviewSettings(app_name="Test App", ignore_exceptions=False)
middleware = PurviewChatPolicyMiddleware(mock_credential, settings)

chat_client = DummyChatClient()
chat_options = MagicMock()
chat_options.model = "test-model"
context = ChatContext(
chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], options=chat_options
)

with patch.object(middleware._processor, "process_messages", side_effect=ValueError("boom")):

async def mock_next(_: ChatContext) -> None:
raise AssertionError("next should not be called")

with pytest.raises(ValueError, match="boom"):
await middleware.process(context, mock_next)

async def test_chat_middleware_raises_on_post_check_exception_when_ignore_exceptions_false(
self, mock_credential: AsyncMock
) -> None:
"""Test that post-check exceptions are propagated by default."""
settings = PurviewSettings(app_name="Test App", ignore_exceptions=False)
middleware = PurviewChatPolicyMiddleware(mock_credential, settings)

chat_client = DummyChatClient()
chat_options = MagicMock()
chat_options.model = "test-model"
context = ChatContext(
chat_client=chat_client, messages=[ChatMessage(role=Role.USER, text="Hello")], options=chat_options
)

call_count = 0

async def side_effect(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return (False, "user-123")
raise ValueError("post")

with patch.object(middleware._processor, "process_messages", side_effect=side_effect):

async def mock_next(ctx: ChatContext) -> None:
result = MagicMock()
result.messages = [ChatMessage(role=Role.ASSISTANT, text="OK")]
ctx.result = result

with pytest.raises(ValueError, match="post"):
await middleware.process(context, mock_next)
216 changes: 215 additions & 1 deletion python/packages/purview/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

"""Tests for Purview client."""

from collections.abc import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock, patch

import httpx
Expand All @@ -18,6 +19,8 @@
PurviewServiceError,
)
from agent_framework_purview._models import (
ContentActivitiesRequest,
ContentActivitiesResponse,
PolicyLocation,
ProcessContentRequest,
ProtectionScopesRequest,
Expand Down Expand Up @@ -47,7 +50,9 @@ def settings(self) -> PurviewSettings:
return PurviewSettings(app_name="Test App", tenant_id="test-tenant", default_user_id="test-user")

@pytest.fixture
async def client(self, mock_credential: MagicMock, settings: PurviewSettings) -> PurviewClient:
async def client(
self, mock_credential: MagicMock, settings: PurviewSettings
) -> AsyncGenerator[PurviewClient, None]:
"""Create a PurviewClient with mock credential."""
client = PurviewClient(mock_credential, settings, timeout=10.0)
yield client
Expand Down Expand Up @@ -185,6 +190,215 @@ async def test_get_protection_scopes_success(self, client: PurviewClient) -> Non
assert response.scope_identifier == "scope-123"
assert response.scopes == []

async def test_get_protection_scopes_uses_etag_header_when_present(self, client: PurviewClient) -> None:
"""Test that get_protection_scopes prefers the HTTP ETag header when present."""
from agent_framework_purview._models import ProtectionScopesResponse

location = PolicyLocation(**{"@odata.type": "microsoft.graph.policyLocationApplication", "value": "app-id"})
request = ProtectionScopesRequest(
user_id="user-123", tenant_id="tenant-456", locations=[location], correlation_id="corr-789"
)

response_obj = ProtectionScopesResponse(**{"scopeIdentifier": "scope-from-body", "value": []})

with patch.object(
client,
"_post",
return_value=(response_obj, {"etag": '"etag-from-header"'}),
):
response = await client.get_protection_scopes(request)

assert response.scope_identifier == "etag-from-header"

async def test_post_402_returns_empty_response_when_ignore_payment_required_enabled(
self, mock_credential: MagicMock
) -> None:
"""Test that 402 is suppressed when ignore_payment_required=True."""
from agent_framework_purview._models import ProcessContentResponse

settings = PurviewSettings(app_name="Test App", ignore_payment_required=True)
client = PurviewClient(mock_credential, settings)

request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[])

resp = httpx.Response(402, text="Payment required", request=httpx.Request("POST", "http://test"))

with patch.object(client._client, "post", return_value=resp):
result = await client._post("http://test", request, ProcessContentResponse, token="fake-token")

assert isinstance(result, ProcessContentResponse)
await client.close()

async def test_post_sets_request_and_response_correlation_id(self, client: PurviewClient) -> None:
"""Test that correlation_id is injected into request headers and hydrated from response headers."""
from agent_framework_purview._models import ProcessContentResponse

# correlation_id is optional and should be auto-generated when empty
request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[])
request.correlation_id = "" # force auto-generation branch

captured_headers: dict[str, str] = {}

async def fake_post(url: str, json=None, headers=None):
nonlocal captured_headers
captured_headers = dict(headers or {})
return httpx.Response(
200,
json={"id": "resp-1", "protectionScopeState": "notModified"},
headers={"client-request-id": "corr-from-response"},
request=httpx.Request("POST", url),
)

with patch.object(client._client, "post", side_effect=fake_post):
result_obj, result_headers = await client._post(
"http://test",
request,
ProcessContentResponse,
token="fake-token",
return_response=True,
)

assert "client-request-id" in captured_headers
assert captured_headers["client-request-id"]
assert result_headers["client-request-id"] == "corr-from-response"
assert result_obj.correlation_id == "corr-from-response"

async def test_process_content_402_returns_empty_when_ignored(self, mock_credential: MagicMock) -> None:
"""Test that process_content returns an empty response (non-tuple path) when 402 is ignored."""
from agent_framework_purview._models import ProcessContentResponse

settings = PurviewSettings(app_name="Test App", ignore_payment_required=True)
client = PurviewClient(mock_credential, settings)

req = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[])

mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 402
mock_response.text = "Payment required"

with patch.object(client._client, "post", return_value=mock_response):
response = await client.process_content(req)

assert isinstance(response, ProcessContentResponse)
await client.close()

async def test_post_sets_correlation_id_attribute_on_recording_span(self, client: PurviewClient) -> None:
"""Test that correlation_id is added to the active span when recording is enabled."""
from agent_framework_purview._models import ProcessContentResponse

request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[])
request.correlation_id = "corr-123"

class RecordingSpan:
def __init__(self) -> None:
self.attributes: dict[str, str] = {}

def is_recording(self) -> bool:
return True

def set_attribute(self, key: str, value: str) -> None:
self.attributes[key] = value

span = RecordingSpan()

with (
patch("agent_framework_purview._client.trace.get_current_span", return_value=span),
patch.object(
client._client,
"post",
return_value=httpx.Response(
200,
json={"id": "resp-1", "protectionScopeState": "notModified"},
headers={},
request=httpx.Request("POST", "http://test"),
),
),
):
await client._post("http://test", request, ProcessContentResponse, token="fake-token")

assert span.attributes["correlation_id"] == "corr-123"

async def test_post_uses_constructor_when_response_type_has_no_model_validate(self, client: PurviewClient) -> None:
"""Test that _post falls back to the response type constructor when model_validate is absent."""

class DummyResponse:
def __init__(self, **data):
self.data = data

request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[])
request.correlation_id = "corr-123"

with patch.object(
client._client,
"post",
return_value=httpx.Response(
200,
json={"hello": "world"},
headers={},
request=httpx.Request("POST", "http://test"),
),
):
result = await client._post("http://test", request, DummyResponse, token="fake-token")

assert isinstance(result, DummyResponse)
assert result.data == {"hello": "world"}

async def test_send_content_activities_success(self, client: PurviewClient, content_to_process_factory) -> None:
"""Test send_content_activities success path."""
request = ContentActivitiesRequest(
user_id="user-123",
tenant_id="tenant-456",
content_to_process=content_to_process_factory("hello"),
correlation_id="corr-1",
)

mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.headers = {}
mock_response.json.return_value = {"error": None}

with patch.object(client._client, "post", return_value=mock_response):
resp = await client.send_content_activities(request)

assert isinstance(resp, ContentActivitiesResponse)

async def test_post_handles_invalid_json_response_body(self, client: PurviewClient) -> None:
"""Test that invalid JSON bodies fall back to an empty dict."""
request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[])
request.correlation_id = "corr-123"

mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.headers = {}
mock_response.json.side_effect = ValueError("not json")

with patch.object(client._client, "post", return_value=mock_response):
result = await client._post("http://test", request, ContentActivitiesResponse, token="fake-token")

assert isinstance(result, ContentActivitiesResponse)

async def test_post_deserialization_failure_raises_purview_service_error(self, client: PurviewClient) -> None:
"""Test that response deserialization errors are wrapped as PurviewServiceError."""

class BadResponseType:
@classmethod
def model_validate(cls, value):
raise RuntimeError("boom")

request = ProcessContentRequest(user_id="user-123", tenant_id="tenant-456", content_to_process=[])
request.correlation_id = "corr-123"

mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.headers = {}
mock_response.json.return_value = {"any": "data"}

with (
patch.object(client._client, "post", return_value=mock_response),
pytest.raises(PurviewServiceError, match="Failed to deserialize Purview response"),
):
await client._post("http://test", request, BadResponseType, token="fake-token")

async def test_client_close(self, mock_credential: AsyncMock, settings: PurviewSettings) -> None:
"""Test client properly closes HTTP client."""
client = PurviewClient(mock_credential, settings)
Expand Down
Loading
Loading