|
| 1 | +import pytest |
| 2 | +from unittest.mock import patch, AsyncMock, MagicMock |
| 3 | +from datetime import datetime, timedelta, timezone |
| 4 | +import logging |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +from fastapi import HTTPException, status, Depends, Header |
| 8 | +from fastapi.security import APIKeyHeader |
| 9 | +# --- MODIFIED: Added AsyncSession --- |
| 10 | +from sqlalchemy.ext.asyncio import AsyncSession |
| 11 | +# --- END MODIFIED --- |
| 12 | + |
| 13 | +from jose import JWTError, jwt |
| 14 | + |
| 15 | +# Import functions/classes to test or mock |
| 16 | +from agentvault_registry import security, models, schemas |
| 17 | +from agentvault_registry.crud import developer as developer_crud |
| 18 | +from agentvault_registry.config import settings |
| 19 | +from agentvault_registry.database import get_db # Import for mocking context |
| 20 | + |
| 21 | +# --- Fixtures --- |
| 22 | + |
| 23 | +@pytest.fixture |
| 24 | +def mock_db_session() -> AsyncMock: |
| 25 | + """Provides a mock SQLAlchemy AsyncSession.""" |
| 26 | + # --- MODIFIED: Added spec --- |
| 27 | + return AsyncMock(spec=AsyncSession) |
| 28 | + # --- END MODIFIED --- |
| 29 | + |
| 30 | +@pytest.fixture |
| 31 | +def mock_developer() -> models.Developer: |
| 32 | + """Provides a mock Developer ORM model.""" |
| 33 | + # Ensure all fields needed by dependencies are present |
| 34 | + return models.Developer( |
| 35 | + id=123, |
| 36 | + name="Security Test Dev", |
| 37 | + |
| 38 | + hashed_password=security.hash_password("password"), |
| 39 | + is_verified=True, |
| 40 | + created_at=datetime.now(timezone.utc), |
| 41 | + updated_at=datetime.now(timezone.utc), |
| 42 | + hashed_recovery_key=None, |
| 43 | + email_verification_token=None, |
| 44 | + verification_token_expires=None |
| 45 | + ) |
| 46 | + |
| 47 | +# --- Helper to create tokens --- |
| 48 | +def create_test_token( |
| 49 | + dev_id: int = 123, |
| 50 | + purpose: Optional[str] = None, |
| 51 | + expires_in_minutes: int = settings.ACCESS_TOKEN_EXPIRE_MINUTES, |
| 52 | + secret: str = settings.API_KEY_SECRET, # Use correct default secret |
| 53 | + algorithm: str = security.ALGORITHM # Use correct default algorithm |
| 54 | +) -> str: |
| 55 | + """Helper to create JWT tokens for testing.""" |
| 56 | + expires_delta = timedelta(minutes=expires_in_minutes) |
| 57 | + expire = datetime.now(timezone.utc) + expires_delta |
| 58 | + to_encode = {"sub": str(dev_id), "exp": expire} # Add expiry here |
| 59 | + if purpose: |
| 60 | + to_encode["purpose"] = purpose |
| 61 | + # Use the provided secret and algorithm for encoding |
| 62 | + return jwt.encode(to_encode, secret, algorithm=algorithm) |
| 63 | + |
| 64 | +# --- Tests for verify_access_token_required --- |
| 65 | + |
| 66 | +@pytest.mark.asyncio |
| 67 | +async def test_verify_required_success(): |
| 68 | + """Test successful verification of a standard access token.""" |
| 69 | + test_id = 456 |
| 70 | + token = create_test_token(dev_id=test_id) |
| 71 | + # Call directly, simulating Depends() providing the token |
| 72 | + developer_id = await security.verify_access_token_required(token=token) |
| 73 | + assert developer_id == test_id |
| 74 | + |
| 75 | +@pytest.mark.asyncio |
| 76 | +async def test_verify_required_expired(): |
| 77 | + """Test verification failure with an expired token.""" |
| 78 | + token = create_test_token(expires_in_minutes=-5) # Expired 5 mins ago |
| 79 | + with pytest.raises(HTTPException) as excinfo: |
| 80 | + await security.verify_access_token_required(token=token) |
| 81 | + assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED |
| 82 | + assert "Could not validate credentials" in excinfo.value.detail # Generic message for security |
| 83 | + |
| 84 | +@pytest.mark.asyncio |
| 85 | +async def test_verify_required_invalid_signature(): |
| 86 | + """Test verification failure with wrong secret key.""" |
| 87 | + # Create token with a DIFFERENT secret |
| 88 | + token = create_test_token(secret="--definitely-the-wrong-secret-key--") |
| 89 | + with pytest.raises(HTTPException) as excinfo: |
| 90 | + await security.verify_access_token_required(token=token) |
| 91 | + # Assert it raises 401 because the decode should fail |
| 92 | + assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED |
| 93 | + |
| 94 | +@pytest.mark.asyncio |
| 95 | +async def test_verify_required_missing_sub(): |
| 96 | + """Test verification failure when 'sub' claim is missing.""" |
| 97 | + expire = datetime.now(timezone.utc) + timedelta(minutes=15) |
| 98 | + to_encode = {"exp": expire, "other": "data"} # Missing 'sub' |
| 99 | + token = jwt.encode(to_encode, settings.API_KEY_SECRET, algorithm=security.ALGORITHM) |
| 100 | + with pytest.raises(HTTPException) as excinfo: |
| 101 | + await security.verify_access_token_required(token=token) |
| 102 | + assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED |
| 103 | + |
| 104 | +@pytest.mark.asyncio |
| 105 | +async def test_verify_required_sub_not_int(): |
| 106 | + """Test verification failure when 'sub' claim is not an integer string.""" |
| 107 | + token = create_test_token(dev_id="not-an-int") # type: ignore # Intentional type error for test |
| 108 | + with pytest.raises(HTTPException) as excinfo: |
| 109 | + await security.verify_access_token_required(token=token) |
| 110 | + assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED |
| 111 | + |
| 112 | +@pytest.mark.asyncio |
| 113 | +async def test_verify_required_wrong_purpose(): |
| 114 | + """Test verification failure when a password-set token is used.""" |
| 115 | + token = create_test_token(purpose="password-set") |
| 116 | + with pytest.raises(HTTPException) as excinfo: |
| 117 | + await security.verify_access_token_required(token=token) |
| 118 | + assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED |
| 119 | + |
| 120 | +# --- Tests for verify_access_token_optional --- |
| 121 | + |
| 122 | +@pytest.mark.asyncio |
| 123 | +async def test_verify_optional_success(): |
| 124 | + """Test successful optional verification with valid Bearer token.""" |
| 125 | + test_id = 789 |
| 126 | + token = create_test_token(dev_id=test_id) |
| 127 | + header_value = f"Bearer {token}" |
| 128 | + # Call directly, simulating Header() providing the value |
| 129 | + developer_id = await security.verify_access_token_optional(authorization=header_value) |
| 130 | + assert developer_id == test_id |
| 131 | + |
| 132 | +@pytest.mark.asyncio |
| 133 | +async def test_verify_optional_no_header(): |
| 134 | + """Test optional verification returns None when header is missing.""" |
| 135 | + developer_id = await security.verify_access_token_optional(authorization=None) |
| 136 | + assert developer_id is None |
| 137 | + |
| 138 | +@pytest.mark.asyncio |
| 139 | +async def test_verify_optional_invalid_scheme(): |
| 140 | + """Test optional verification returns None for wrong scheme.""" |
| 141 | + token = create_test_token() |
| 142 | + header_value = f"Basic {token}" # Wrong scheme |
| 143 | + developer_id = await security.verify_access_token_optional(authorization=header_value) |
| 144 | + assert developer_id is None |
| 145 | + |
| 146 | +@pytest.mark.asyncio |
| 147 | +async def test_verify_optional_invalid_format(): |
| 148 | + """Test optional verification returns None for invalid header format.""" |
| 149 | + header_value = "Beareronly" # Missing space |
| 150 | + developer_id = await security.verify_access_token_optional(authorization=header_value) |
| 151 | + assert developer_id is None |
| 152 | + |
| 153 | +@pytest.mark.asyncio |
| 154 | +async def test_verify_optional_expired(): |
| 155 | + """Test optional verification returns None for expired token.""" |
| 156 | + token = create_test_token(expires_in_minutes=-5) |
| 157 | + header_value = f"Bearer {token}" |
| 158 | + developer_id = await security.verify_access_token_optional(authorization=header_value) |
| 159 | + assert developer_id is None |
| 160 | + |
| 161 | +@pytest.mark.asyncio |
| 162 | +async def test_verify_optional_invalid_signature(): |
| 163 | + """Test optional verification returns None for invalid signature.""" |
| 164 | + # Create token with a DIFFERENT secret |
| 165 | + token = create_test_token(secret="--definitely-the-wrong-secret-key--") |
| 166 | + header_value = f"Bearer {token}" |
| 167 | + developer_id = await security.verify_access_token_optional(authorization=header_value) |
| 168 | + # Assert it returns None because the decode should fail |
| 169 | + assert developer_id is None |
| 170 | + |
| 171 | +@pytest.mark.asyncio |
| 172 | +async def test_verify_optional_wrong_purpose(): |
| 173 | + """Test optional verification returns None for password-set token.""" |
| 174 | + token = create_test_token(purpose="password-set") |
| 175 | + header_value = f"Bearer {token}" |
| 176 | + developer_id = await security.verify_access_token_optional(authorization=header_value) |
| 177 | + assert developer_id is None |
| 178 | + |
| 179 | +# --- Tests for verify_temp_password_token --- |
| 180 | + |
| 181 | +@pytest.mark.asyncio |
| 182 | +async def test_verify_temp_token_success(): |
| 183 | + """Test successful verification of a password-set token.""" |
| 184 | + test_id = 111 |
| 185 | + token = create_test_token(dev_id=test_id, purpose="password-set", expires_in_minutes=5) |
| 186 | + developer_id = await security.verify_temp_password_token(token=token) |
| 187 | + assert developer_id == test_id |
| 188 | + |
| 189 | +@pytest.mark.asyncio |
| 190 | +async def test_verify_temp_token_expired(): |
| 191 | + """Test failure for expired password-set token.""" |
| 192 | + token = create_test_token(purpose="password-set", expires_in_minutes=-1) |
| 193 | + with pytest.raises(HTTPException) as excinfo: |
| 194 | + await security.verify_temp_password_token(token=token) |
| 195 | + assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED |
| 196 | + assert "Invalid or expired" in excinfo.value.detail |
| 197 | + |
| 198 | +@pytest.mark.asyncio |
| 199 | +async def test_verify_temp_token_wrong_purpose(): |
| 200 | + """Test failure for token without 'password-set' purpose.""" |
| 201 | + token = create_test_token() # Regular access token |
| 202 | + with pytest.raises(HTTPException) as excinfo: |
| 203 | + await security.verify_temp_password_token(token=token) |
| 204 | + assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED |
| 205 | + |
| 206 | +@pytest.mark.asyncio |
| 207 | +async def test_verify_temp_token_missing_sub(): |
| 208 | + """Test failure for password-set token missing 'sub'.""" |
| 209 | + expire = datetime.now(timezone.utc) + timedelta(minutes=5) |
| 210 | + to_encode = {"exp": expire, "purpose": "password-set"} # Missing 'sub' |
| 211 | + token = jwt.encode(to_encode, settings.API_KEY_SECRET, algorithm=security.ALGORITHM) |
| 212 | + with pytest.raises(HTTPException) as excinfo: |
| 213 | + await security.verify_temp_password_token(token=token) |
| 214 | + assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED |
| 215 | + |
| 216 | +# --- Tests for get_current_developer --- |
| 217 | + |
| 218 | +@pytest.mark.asyncio |
| 219 | +@patch("agentvault_registry.security.verify_access_token_required", new_callable=AsyncMock) |
| 220 | +@patch("agentvault_registry.security.developer_crud.get_developer_by_id", new_callable=AsyncMock) |
| 221 | +async def test_get_current_developer_success( |
| 222 | + mock_crud_get: AsyncMock, |
| 223 | + mock_verify_token: AsyncMock, |
| 224 | + mock_db_session: AsyncMock, # Use the fixture |
| 225 | + mock_developer: models.Developer |
| 226 | +): |
| 227 | + """Test successfully getting the current developer.""" |
| 228 | + mock_verify_token.return_value = mock_developer.id |
| 229 | + mock_crud_get.return_value = mock_developer |
| 230 | + |
| 231 | + # Call the dependency function directly, providing mocks for *its* dependencies |
| 232 | + developer = await security.get_current_developer(db=mock_db_session, developer_id=mock_developer.id) |
| 233 | + |
| 234 | + assert developer is mock_developer |
| 235 | + # verify_access_token_required is mocked at the top level, not called here directly |
| 236 | + mock_crud_get.assert_awaited_once_with(db=mock_db_session, developer_id=mock_developer.id) |
| 237 | + |
| 238 | +@pytest.mark.asyncio |
| 239 | +@patch("agentvault_registry.security.verify_access_token_required", new_callable=AsyncMock) |
| 240 | +@patch("agentvault_registry.security.developer_crud.get_developer_by_id", new_callable=AsyncMock) |
| 241 | +async def test_get_current_developer_not_found_in_db( |
| 242 | + mock_crud_get: AsyncMock, |
| 243 | + mock_verify_token: AsyncMock, |
| 244 | + mock_db_session: AsyncMock # Use the fixture |
| 245 | +): |
| 246 | + """Test failure when developer ID from token is not found in DB.""" |
| 247 | + test_id = 999 |
| 248 | + mock_verify_token.return_value = test_id |
| 249 | + mock_crud_get.return_value = None # Developer not found |
| 250 | + |
| 251 | + with pytest.raises(HTTPException) as excinfo: |
| 252 | + await security.get_current_developer(db=mock_db_session, developer_id=test_id) |
| 253 | + |
| 254 | + assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED |
| 255 | + assert "User associated with token not found" in excinfo.value.detail |
| 256 | + mock_crud_get.assert_awaited_once_with(db=mock_db_session, developer_id=test_id) |
| 257 | + |
| 258 | +# --- Tests for get_current_developer_optional --- |
| 259 | + |
| 260 | +@pytest.mark.asyncio |
| 261 | +@patch("agentvault_registry.security.verify_access_token_optional", new_callable=AsyncMock) |
| 262 | +@patch("agentvault_registry.security.developer_crud.get_developer_by_id", new_callable=AsyncMock) |
| 263 | +async def test_get_current_developer_optional_success( |
| 264 | + mock_crud_get: AsyncMock, |
| 265 | + mock_verify_token_opt: AsyncMock, |
| 266 | + mock_db_session: AsyncMock, # Use the fixture |
| 267 | + mock_developer: models.Developer |
| 268 | +): |
| 269 | + """Test successfully getting optional developer when token is valid.""" |
| 270 | + mock_verify_token_opt.return_value = mock_developer.id |
| 271 | + mock_crud_get.return_value = mock_developer |
| 272 | + |
| 273 | + developer = await security.get_current_developer_optional(db=mock_db_session, developer_id=mock_developer.id) |
| 274 | + |
| 275 | + assert developer is mock_developer |
| 276 | + mock_crud_get.assert_awaited_once_with(db=mock_db_session, developer_id=mock_developer.id) |
| 277 | + |
| 278 | +@pytest.mark.asyncio |
| 279 | +@patch("agentvault_registry.security.verify_access_token_optional", new_callable=AsyncMock) |
| 280 | +@patch("agentvault_registry.security.developer_crud.get_developer_by_id", new_callable=AsyncMock) |
| 281 | +async def test_get_current_developer_optional_no_token( |
| 282 | + mock_crud_get: AsyncMock, |
| 283 | + mock_verify_token_opt: AsyncMock, |
| 284 | + mock_db_session: AsyncMock # Use the fixture |
| 285 | +): |
| 286 | + """Test optional developer returns None when token is missing/invalid.""" |
| 287 | + mock_verify_token_opt.return_value = None # Simulate no valid token |
| 288 | + |
| 289 | + developer = await security.get_current_developer_optional(db=mock_db_session, developer_id=None) |
| 290 | + |
| 291 | + assert developer is None |
| 292 | + mock_crud_get.assert_not_awaited() # DB should not be queried |
| 293 | + |
| 294 | +@pytest.mark.asyncio |
| 295 | +@patch("agentvault_registry.security.verify_access_token_optional", new_callable=AsyncMock) |
| 296 | +@patch("agentvault_registry.security.developer_crud.get_developer_by_id", new_callable=AsyncMock) |
| 297 | +async def test_get_current_developer_optional_not_in_db( |
| 298 | + mock_crud_get: AsyncMock, |
| 299 | + mock_verify_token_opt: AsyncMock, |
| 300 | + mock_db_session: AsyncMock # Use the fixture |
| 301 | +): |
| 302 | + """Test optional developer returns None when token ID not found in DB.""" |
| 303 | + test_id = 998 |
| 304 | + mock_verify_token_opt.return_value = test_id |
| 305 | + mock_crud_get.return_value = None # Simulate DB miss |
| 306 | + |
| 307 | + developer = await security.get_current_developer_optional(db=mock_db_session, developer_id=test_id) |
| 308 | + |
| 309 | + assert developer is None |
| 310 | + mock_crud_get.assert_awaited_once_with(db=mock_db_session, developer_id=test_id) |
| 311 | + |
| 312 | +# --- Tests for verify_programmatic_api_key --- |
| 313 | + |
| 314 | +@pytest.mark.asyncio |
| 315 | +@patch("agentvault_registry.security.developer_crud.get_developer_by_plain_api_key", new_callable=AsyncMock) |
| 316 | +async def test_verify_programmatic_key_success( |
| 317 | + mock_crud_get_key: AsyncMock, |
| 318 | + mock_db_session: AsyncMock, # Use the fixture |
| 319 | + mock_developer: models.Developer |
| 320 | +): |
| 321 | + """Test successful verification of a programmatic API key.""" |
| 322 | + test_key = "avreg_test_key_123" |
| 323 | + mock_crud_get_key.return_value = mock_developer |
| 324 | + |
| 325 | + # Call directly, simulating Depends(api_key_header_scheme) providing the key |
| 326 | + developer = await security.verify_programmatic_api_key(api_key=test_key, db=mock_db_session) |
| 327 | + |
| 328 | + assert developer is mock_developer |
| 329 | + mock_crud_get_key.assert_awaited_once_with(db=mock_db_session, plain_key=test_key) |
| 330 | + |
| 331 | +@pytest.mark.asyncio |
| 332 | +async def test_verify_programmatic_key_missing(): |
| 333 | + """Test failure when X-Api-Key header is missing.""" |
| 334 | + # Call directly with api_key=None |
| 335 | + with pytest.raises(HTTPException) as excinfo: |
| 336 | + await security.verify_programmatic_api_key(api_key=None, db=MagicMock(spec=AsyncSession)) # DB mock needed for signature |
| 337 | + assert excinfo.value.status_code == status.HTTP_403_FORBIDDEN |
| 338 | + assert "X-Api-Key header missing" in excinfo.value.detail |
| 339 | + |
| 340 | +@pytest.mark.asyncio |
| 341 | +@patch("agentvault_registry.security.developer_crud.get_developer_by_plain_api_key", new_callable=AsyncMock) |
| 342 | +async def test_verify_programmatic_key_invalid( |
| 343 | + mock_crud_get_key: AsyncMock, |
| 344 | + mock_db_session: AsyncMock # Use the fixture |
| 345 | +): |
| 346 | + """Test failure when the provided API key is invalid or inactive.""" |
| 347 | + test_key = "avreg_invalid_key_456" |
| 348 | + mock_crud_get_key.return_value = None # Simulate key not found/verified by CRUD |
| 349 | + |
| 350 | + with pytest.raises(HTTPException) as excinfo: |
| 351 | + await security.verify_programmatic_api_key(api_key=test_key, db=mock_db_session) |
| 352 | + |
| 353 | + assert excinfo.value.status_code == status.HTTP_401_UNAUTHORIZED |
| 354 | + assert "Invalid or inactive API Key" in excinfo.value.detail |
| 355 | + mock_crud_get_key.assert_awaited_once_with(db=mock_db_session, plain_key=test_key) |
0 commit comments