Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..config import EndpointMethods
from ..utils.middleware import JsonResponseMiddleware
from ..utils.requests import find_match
from ..utils.stac import get_links
from ..utils.stac import ensure_type, get_links

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,7 +62,7 @@ def should_transform_response(self, request: Request, scope: Scope) -> bool:

def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
"""Augment the STAC Item with auth information."""
extensions = data.setdefault("stac_extensions", [])
extensions = ensure_type(data, "stac_extensions", list)
if self.extension_url not in extensions:
extensions.append(self.extension_url)

Expand All @@ -75,7 +75,7 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
# - Item Properties

scheme_loc = data["properties"] if "properties" in data else data
schemes = scheme_loc.setdefault("auth:schemes", {})
schemes = ensure_type(scheme_loc, "auth:schemes", dict)
schemes[self.auth_scheme_name] = {
"type": "openIdConnect",
"openIdConnectUrl": self.oidc_discovery_url,
Expand All @@ -96,6 +96,7 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
default_public=self.default_public,
)
if match.is_private:
link.setdefault("auth:refs", []).append(self.auth_scheme_name)
auth_refs = ensure_type(link, "auth:refs", list)
auth_refs.append(self.auth_scheme_name)

return data
10 changes: 5 additions & 5 deletions src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..config import EndpointMethods
from ..utils.middleware import JsonResponseMiddleware
from ..utils.requests import find_match
from ..utils.stac import ensure_type


@dataclass(frozen=True)
Expand Down Expand Up @@ -57,8 +58,8 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
data["servers"] = [{"url": self.root_path}]

# Add security scheme
components = data.setdefault("components", {})
securitySchemes = components.setdefault("securitySchemes", {})
components = ensure_type(data, "components", dict)
securitySchemes = ensure_type(components, "securitySchemes", dict)
securitySchemes[self.auth_scheme_name] = self.auth_scheme_override or {
"type": "openIdConnect",
"openIdConnectUrl": self.oidc_discovery_url,
Expand All @@ -78,7 +79,6 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
self.default_public,
)
if match.is_private:
config.setdefault("security", []).append(
{self.auth_scheme_name: match.required_scopes}
)
security = ensure_type(config, "security", list)
security.append({self.auth_scheme_name: match.required_scopes})
return data
55 changes: 55 additions & 0 deletions src/stac_auth_proxy/utils/stac.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,61 @@
"""STAC-specific utilities."""

import logging
from collections.abc import Callable
from itertools import chain
from typing import Any, TypeVar

logger = logging.getLogger(__name__)

T = TypeVar("T")


def ensure_type(
data: dict[str, Any],
key: str,
expected_type: type[T],
default_factory: Callable[[], T] | None = None,
) -> T:
"""
Ensure a dictionary value conforms to the expected type.

If the value doesn't exist or is not an instance of the expected type,
it will be replaced with the default value from default_factory.

Args:
data: The dictionary containing the value
key: The key to check
expected_type: The expected type class
default_factory: Optional callable that returns the default value.
If not provided, expected_type will be called with no arguments.

Returns:
The value from the dictionary if it's the correct type, otherwise the default value

Example:
>>> data = {"stac_extensions": None}
>>> extensions = ensure_type(data, "stac_extensions", list)
>>> # extensions is now [] and data["stac_extensions"] is []
>>>
>>> data = {"items": "invalid"}
>>> items = ensure_type(data, "items", list, lambda: ["default"])
>>> # items is now ["default"] with custom factory

"""
value = data.get(key)
if not isinstance(value, expected_type):
if value is not None:
logger.warning(
"Field '%s' expected %s but got %s: %r",
key,
expected_type.__name__,
type(value).__name__,
value,
)
factory = default_factory if default_factory is not None else expected_type
value = factory()
data[key] = value
return value


def get_links(data: dict) -> chain[dict]:
Expand Down
54 changes: 54 additions & 0 deletions tests/test_auth_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,57 @@ def test_transform_json_missing_oidc_metadata(middleware, request_scope):
transformed = middleware.transform_json(catalog, request)
# Should return unchanged when OIDC metadata is missing
assert transformed == catalog


def test_transform_json_with_null_stac_extensions(
middleware, request_scope, oidc_discovery_url
):
"""Test transforming when stac_extensions is None."""
request = Request(request_scope)

catalog = {
"stac_version": "1.0.0",
"id": "test-catalog",
"description": "Test catalog",
"stac_extensions": None,
}

transformed = middleware.transform_json(catalog, request)

assert "stac_extensions" in transformed
assert middleware.extension_url in transformed["stac_extensions"]
assert "auth:schemes" in transformed
assert "test_auth" in transformed["auth:schemes"]


@pytest.mark.parametrize(
"invalid_value",
[
"not-a-list",
42,
{"key": "value"},
3.14,
True,
],
)
def test_transform_json_with_invalid_stac_extensions_types(
middleware, request_scope, oidc_discovery_url, invalid_value
):
"""Test transforming when stac_extensions is an invalid type (string, int, dict, etc)."""
request = Request(request_scope)

catalog = {
"stac_version": "1.0.0",
"id": "test-catalog",
"description": "Test catalog",
"stac_extensions": invalid_value,
}

transformed = middleware.transform_json(catalog, request)

# Should replace invalid value with a proper list
assert "stac_extensions" in transformed
assert isinstance(transformed["stac_extensions"], list)
assert middleware.extension_url in transformed["stac_extensions"]
assert "auth:schemes" in transformed
assert "test_auth" in transformed["auth:schemes"]
93 changes: 93 additions & 0 deletions tests/test_stac_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Tests for STAC utility functions."""

import pytest

from stac_auth_proxy.utils.stac import ensure_type


@pytest.mark.parametrize(
"initial_value,expected_type,default_factory,expected_result",
[
# List type validation
(None, list, list, []),
("not-a-list", list, list, []),
(42, list, list, []),
({"key": "value"}, list, list, []),
(3.14, list, list, []),
(True, list, list, []),
(["existing", "items"], list, list, ["existing", "items"]),
# Dict type validation
(None, dict, dict, {}),
("not-a-dict", dict, dict, {}),
(42, dict, dict, {}),
(["list"], dict, dict, {}),
(3.14, dict, dict, {}),
(True, dict, dict, {}),
({"existing": "value"}, dict, dict, {"existing": "value"}),
],
)
def test_ensure_type(initial_value, expected_type, default_factory, expected_result):
"""Test ensure_type handles various invalid types and preserves valid values."""
data = {"field": initial_value}
result = ensure_type(data, "field", expected_type, default_factory)

assert result == expected_result
assert data["field"] == expected_result
assert isinstance(data["field"], expected_type)


def test_ensure_type_missing_key():
"""Test ensure_type when key doesn't exist in the dictionary."""
data = {}
result = ensure_type(data, "missing_field", list, list)

assert result == []
assert data["missing_field"] == []
assert isinstance(data["missing_field"], list)


def test_ensure_type_with_custom_factory():
"""Test ensure_type with a custom default factory."""
data = {"field": None}
default_value = ["default", "items"]
result = ensure_type(data, "field", list, lambda: default_value.copy())

assert result == ["default", "items"]
assert data["field"] == ["default", "items"]


def test_ensure_type_preserves_valid_value():
"""Test that ensure_type doesn't modify valid values."""
original_list = ["a", "b", "c"]
data = {"field": original_list}

result = ensure_type(data, "field", list, list)

# Should return the same list object, not create a new one
assert result is original_list
assert data["field"] is original_list


def test_ensure_type_without_factory():
"""Test ensure_type uses expected_type as factory when default_factory is not provided."""
# Test with list
data = {"extensions": None}
result = ensure_type(data, "extensions", list)
assert result == []
assert data["extensions"] == []

# Test with dict
data = {"schemes": "invalid"}
result = ensure_type(data, "schemes", dict)
assert result == {}
assert data["schemes"] == {}


def test_ensure_type_factory_precedence():
"""Test that explicit default_factory takes precedence over expected_type."""
data = {"field": None}
# Use a custom factory instead of the default list()
result = ensure_type(data, "field", list, lambda: ["custom", "default"])

assert result == ["custom", "default"]
assert data["field"] == ["custom", "default"]
Loading