diff --git a/.github/workflows/pr-python.yml b/.github/workflows/pr-python.yml new file mode 100644 index 0000000..b699bc2 --- /dev/null +++ b/.github/workflows/pr-python.yml @@ -0,0 +1,31 @@ +name: PR Python Checks +on: + pull_request: + branches: [ main ] + paths: + - 'python/**' + - '.github/workflows/pr-python.yml' + +jobs: + python-quality: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Install deps + working-directory: python + run: | + python -m pip install --upgrade pip + pip install .[all] + - name: Ruff Lint + working-directory: python + run: python -m ruff check src tests + - name: Type check + working-directory: python + run: python -m mypy src/azurepg_entra/ + - name: Tests + if: ${{ always() }} # adjust if you add tests + working-directory: python + run: python -m pytest tests --import-mode=importlib -q \ No newline at end of file diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml deleted file mode 100644 index f0e7e46..0000000 --- a/.github/workflows/pr.yml +++ /dev/null @@ -1,31 +0,0 @@ -name: PR Checks - -on: - pull_request: - branches: - - main - -jobs: - python-quality-checks: - runs-on: ubuntu-latest - - steps: - # Check out the repo - - name: Checkout code - uses: actions/checkout@v4 - - # Set up Python - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.11' - - # Install dependencies - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install .[all] - - # Run mypy (type checker) - - name: Run mypy - run: mypy ./python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py \ No newline at end of file diff --git a/dotnet/dotnet.sln b/dotnet/dotnet.sln index 36674be..9205c04 100644 --- a/dotnet/dotnet.sln +++ b/dotnet/dotnet.sln @@ -22,7 +22,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Postgresql", "Postgresql", EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Npgsql", "Npgsql", "{93934517-16C9-C51A-8F2B-54760F50BDEB}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Azure.Data.Postgresql.Npgsql", "src\Azure\Data\Postgresql\Npgsql\Azure.Data.Postgresql.Npgsql.csproj", "{3E862DB4-B843-4361-94B5-8CF34402B511}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Azure.Data.Postgresql.Npgsql", "src\Azure\Data\PostgreSql\Npgsql\Azure.Data.Postgresql.Npgsql.csproj", "{3E862DB4-B843-4361-94B5-8CF34402B511}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tests", "tests", "{0AB3BF05-4346-4AA6-1389-037BE0695223}" EndProject diff --git a/python/README.md b/python/README.md index 5e1b369..8414de2 100644 --- a/python/README.md +++ b/python/README.md @@ -101,7 +101,7 @@ Choose the driver that best fits your project needs: - **psycopg3**: Modern PostgreSQL driver (recommended for new projects) - **psycopg2**: Legacy PostgreSQL driver (for existing projects) -- **SQLAlchemy**: High-level ORM/Core interface using psycopg3 backend +- **SQLAlchemy**: High-level ORM/Core interface --- @@ -109,85 +109,36 @@ Choose the driver that best fits your project needs: > **Note**: psycopg2 is in maintenance mode. For new projects, consider using psycopg3 instead. -The psycopg2 integration provides both synchronous (psycopg2) and asynchronous (aiopg) connection support with Azure Entra ID authentication. +The psycopg2 integration provides synchronous connection support with Azure Entra ID authentication through connection pooling. ### Installation ```bash pip install "azurepg-entra[psycopg2]" ``` -### Synchronous Connection (psycopg2) +### Connection Pooling (Recommended) ```python -from azurepg_entra.psycopg2 import connect_with_entra -from psycopg2 import pool - -def main(): - # Direct connection - conn = connect_with_entra( - host="your-server.postgres.database.azure.com", - port=5432, - dbname="your_database" - ) - - try: - with conn.cursor() as cur: - cur.execute("SELECT current_user, now()") - user, time = cur.fetchone() - print(f"Connected as: {user} at {time}") - finally: - conn.close() - - # Connection pooling - def entra_connection_factory(*args, **kwargs): - return connect_with_entra( - host="your-server.postgres.database.azure.com", - port=5432, - dbname="your_database" - ) - - connection_pool = pool.ThreadedConnectionPool( - minconn=1, maxconn=5, - connection_factory=entra_connection_factory - ) - - conn = connection_pool.getconn() - try: - with conn.cursor() as cur: - cur.execute("SELECT current_user") - print(f"Pool connection as: {cur.fetchone()[0]}") - finally: - connection_pool.putconn(conn) - connection_pool.closeall() - -if __name__ == "__main__": - main() +from azurepg_entra.psycopg2 import EntraConnection # import library +from psycopg2 import pool # import to use pooling + +with pool.ThreadedConnectionPool( + minconn=1, + maxconn=5, + host="your-server.postgres.database.azure.com", + database="your_database", + connection_factory=EntraConnection +) as connection_pool: ``` -### Asynchronous Connection (aiopg) +### Direct Connection ```python -import asyncio -from azurepg_entra.psycopg2 import connect_with_entra_async - -async def main(): - # Direct async connection - conn = await connect_with_entra_async( - host="your-server.postgres.database.azure.com", - port=5432, - dbname="your_database" - ) - - try: - async with conn.cursor() as cur: - await cur.execute("SELECT current_user, now()") - user, time = await cur.fetchone() - print(f"Async connected as: {user} at {time}") - finally: - conn.close() - -if __name__ == "__main__": - asyncio.run(main()) +from azurepg_entra.psycopg2 import EntraConnection # import library + +with EntraConnection( + "postgresql://your-server.postgres.database.azure.com:5432/your_database" +) as conn ``` --- @@ -204,82 +155,38 @@ pip install "azurepg-entra[psycopg3]" ### Synchronous Connection ```python -from azurepg_entra.psycopg3 import SyncEntraConnection -from psycopg_pool import ConnectionPool - -def main(): - # Direct connection - with SyncEntraConnection.connect( - "postgresql://your-server.postgres.database.azure.com:5432/your_database" - ) as conn: - with conn.cursor() as cur: - cur.execute("SELECT current_user, now()") - user, time = cur.fetchone() - print(f"Connected as: {user} at {time}") - - # Connection pooling (recommended for production) - with ConnectionPool( - conninfo="postgresql://your-server.postgres.database.azure.com:5432/your_database", - connection_class=SyncEntraConnection, - min_size=1, # keep at least 1 connection always open - max_size=5, # allow up to 5 concurrent connections - max_waiting=10, # seconds to wait if pool is full - ) as pool: - with pool.connection() as conn: - with conn.cursor() as cur: - cur.execute("SELECT current_user, now()") - user, time = cur.fetchone() - print(f"Pool connection as: {user} at {time}") - -if __name__ == "__main__": - main() +from azurepg_entra.psycopg3 import EntraConnection # import library +from psycopg_pool import ConnectionPool # import to use pooling + +with ConnectionPool( + conninfo="postgresql://your-server.postgres.database.azure.com:5432/your_database", + connection_class=EntraConnection, + min_size=1, # keep at least 1 connection always open + max_size=5, # allow up to 5 concurrent connections +) as pool ``` ### Asynchronous Connection ```python -import asyncio -import sys -from azurepg_entra.psycopg3 import AsyncEntraConnection -from psycopg_pool import AsyncConnectionPool - -async def main(): - # Direct async connection - async with await AsyncEntraConnection.connect( - "postgresql://your-server.postgres.database.azure.com:5432/your_database" - ) as conn: - async with conn.cursor() as cur: - await cur.execute("SELECT current_user, now()") - user, time = await cur.fetchone() - print(f"Async connected as: {user} at {time}") - - # Async connection pooling (recommended for production) - async with AsyncConnectionPool( - conninfo="postgresql://your-server.postgres.database.azure.com:5432/your_database", - connection_class=AsyncEntraConnection, - min_size=1, # keep at least 1 connection always open - max_size=5, # allow up to 5 concurrent connections - max_waiting=10, # seconds to wait if pool is full - ) as pool: - async with pool.connection() as conn: - async with conn.cursor() as cur: - await cur.execute("SELECT current_user, now()") - user, time = await cur.fetchone() - print(f"Pool connection as: {user} at {time}") - -if __name__ == "__main__": - # Windows compatibility for async operations - if sys.platform == "win32": - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - asyncio.run(main()) +from azurepg_entra.psycopg3 import AsyncEntraConnection # import library +from psycopg_pool import AsyncConnectionPool # import to use pooling + +async with AsyncConnectionPool( + conninfo="postgresql://your-server.postgres.database.azure.com:5432/your_database", + connection_class=AsyncEntraConnection, + min_size=1, # keep at least 1 connection always open + max_size=5, # allow up to 5 concurrent connections +) as pool ``` --- ## SQLAlchemy Integration -SQLAlchemy integration uses psycopg3 as the backend driver with automatic Entra ID authentication. +SQLAlchemy integration uses psycopg3 as the backend driver with automatic Entra ID authentication through event listeners. + +> **For more information**: See SQLAlchemy's documentation on [controlling how parameters are passed to the DBAPI connect function](https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function). ### Installation ```bash @@ -289,73 +196,37 @@ pip install "azurepg-entra[sqlalchemy]" ### Synchronous Engine ```python -from azurepg_entra.sqlalchemy import create_entra_engine -from sqlalchemy import text - -def main(): - # Create synchronous engine with Entra ID authentication - engine = create_entra_engine( - "postgresql+psycopg://your-server.postgres.database.azure.com:5432/your_database" - ) +from sqlalchemy import create_engine +from azurepg_entra.sqlalchemy import enable_entra_authentication # import library + +with create_engine("postgresql+psycopg://your-server.postgres.database.azure.com/your_database") as engine: + # Enable Entra ID authentication + enable_entra_authentication(engine) # Core usage with engine.connect() as conn: - result = conn.execute(text("SELECT current_user, now()")) - user, time = result.fetchone() - print(f"SQLAlchemy connected as: {user} at {time}") - + # ORM usage from sqlalchemy.orm import sessionmaker Session = sessionmaker(bind=engine) - - with Session() as session: - result = session.execute(text("SELECT current_database()")) - db_name = result.scalar() - print(f"Connected to database: {db_name}") - - engine.dispose() - -if __name__ == "__main__": - main() ``` ### Asynchronous Engine ```python -import asyncio -import sys -from azurepg_entra.sqlalchemy import create_async_entra_engine -from sqlalchemy import text -from sqlalchemy.ext.asyncio import async_sessionmaker - -async def main(): - # Create asynchronous engine with Entra ID authentication - engine = await create_async_entra_engine( - "postgresql+psycopg://your-server.postgres.database.azure.com:5432/your_database" - ) +from sqlalchemy.ext.asyncio import create_async_engine +from azurepg_entra.sqlalchemy import enable_entra_authentication_async # import library + +async with create_async_engine("postgresql+psycopg://your-server.postgres.database.azure.com/your_database") as engine: + # Enable Entra ID authentication for async + enable_entra_authentication_async(engine) # Async Core usage async with engine.connect() as conn: - result = await conn.execute(text("SELECT current_user, now()")) - user, time = result.fetchone() - print(f"Async SQLAlchemy connected as: {user} at {time}") # Async ORM usage + from sqlalchemy.ext.asyncio import async_sessionmaker AsyncSession = async_sessionmaker(engine, expire_on_commit=False) - - async with AsyncSession() as session: - result = await session.execute(text("SELECT current_database()")) - db_name = result.scalar() - print(f"Async connected to database: {db_name}") - - await engine.dispose() - -if __name__ == "__main__": - # Windows compatibility for async operations - if sys.platform == "win32": - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - asyncio.run(main()) ``` ## How It Works @@ -380,8 +251,6 @@ The package automatically requests the correct OAuth2 scopes: - **ā° Automatic expiration**: Tokens expire and are refreshed automatically - **šŸ›”ļø SSL enforcement**: All connections require SSL encryption - **šŸ”‘ Principle of least privilege**: Only database-specific scopes are requested -- **šŸ“‹ Audit logging**: Authentication events are logged by Azure Database for PostgreSQL - --- ## Troubleshooting diff --git a/python/pyproject.toml b/python/pyproject.toml index c04baea..ff34530 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -26,19 +26,18 @@ dependencies = [ # psycopg3 support psycopg3 = [ "psycopg[binary]>=3.1.0", - "psycopg-pool>=3.1.0" + "aiohttp>=3.8.0" ] # psycopg2 support psycopg2 = [ - "psycopg2-binary>=2.9.0", - "aiopg>=1.4.0" + "psycopg2-binary>=2.9.0" ] # SQLAlchemy support sqlalchemy = [ "sqlalchemy>=2.0.0", - "psycopg[binary]>=3.1.0" + "aiohttp>=3.8.0" ] # Development dependencies @@ -46,21 +45,25 @@ dev = [ "pytest>=7.0.0", "pytest-asyncio>=0.21.0", "python-dotenv>=1.0.0", - "mypy ~= 1.15" + "mypy ~= 1.15", + "ruff>=0.8.0", + "types-psycopg2>=2.9.0", + "psycopg-pool>=3.1.0" ] # All optional dependencies combined all = [ "psycopg[binary]>=3.1.0", - "psycopg-pool>=3.1.0", + "aiohttp>=3.8.0", "psycopg2-binary>=2.9.0", - "aiopg>=1.4.0", "sqlalchemy>=2.0.0", - "asyncpg>=0.28.0", "pytest>=7.0.0", "pytest-asyncio>=0.21.0", "python-dotenv>=1.0.0", - "mypy ~= 1.15" + "mypy ~= 1.15", + "ruff>=0.8.0", + "types-psycopg2>=2.9.0", + "psycopg-pool>=3.1.0" ] [tool.setuptools] @@ -74,6 +77,19 @@ include = ["azurepg_entra*"] Homepage = "https://github.com/v-anarendra_microsoft/entra-id-integration-for-drivers" Issues = "https://github.com/v-anarendra_microsoft/entra-id-integration-for-drivers/issues" +# Ruff configuration +[tool.ruff] +line-length = 88 +target-version = "py310" + +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "UP", "B", "I", "N"] +ignore = ["N806"] # Allow mixed case variable names + +[tool.ruff.lint.per-file-ignores] +"tests/**/*" = ["N"] +"samples/**/*" = ["T201"] # Allow print statements in samples + # Installation examples: # Basic package only: pip install -e . # With psycopg3 support: pip install -e ".[psycopg3]" diff --git a/python/samples/psycopg2/getting_started/async_pool_utils.py b/python/samples/psycopg2/getting_started/async_pool_utils.py deleted file mode 100644 index ac8661e..0000000 --- a/python/samples/psycopg2/getting_started/async_pool_utils.py +++ /dev/null @@ -1,118 +0,0 @@ -""" -Async connection pool utilities for psycopg2/aiopg with connection factory support. - -This module provides AsyncEntraConnectionPool, a custom async connection pool -that mimics psycopg2.pool.ThreadedConnectionPool's connection_factory pattern -for asynchronous connections. -""" - -import asyncio -from typing import Callable, Any, Awaitable - - -class AsyncEntraConnectionPool: - """ - Custom async connection pool that supports connection_factory pattern. - - Mimics psycopg2.pool.ThreadedConnectionPool API but for async connections. - This is needed because: - 1. psycopg2 pools are sync-only - 2. aiopg.create_pool doesn't support connection_factory parameter - - Usage: - async def my_connection_factory(): - return await connect_with_entra_async(host="...", dbname="...") - - async with AsyncEntraConnectionPool(my_connection_factory, minconn=1, maxconn=5) as pool: - conn = await pool.getconn() - try: - # Use connection - pass - finally: - pool.putconn(conn) - """ - - def __init__(self, connection_factory: Callable[[], Awaitable[Any]], minconn: int = 1, maxconn: int = 5): - """ - Initialize the async connection pool. - - Args: - connection_factory: Async function that creates and returns a new connection - minconn: Minimum number of connections to maintain in the pool - maxconn: Maximum number of connections allowed - """ - self.connection_factory = connection_factory - self.minconn = minconn - self.maxconn = maxconn - self._pool = [] - self._used = set() - self._lock = asyncio.Lock() - self._closed = False - - async def __aenter__(self): - """Context manager entry - pre-populate with minimum connections.""" - for _ in range(self.minconn): - conn = await self.connection_factory() - self._pool.append(conn) - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Context manager exit - close all connections.""" - await self.closeall() - - async def getconn(self): - """ - Get a connection from the pool. - - Returns: - Connection object from the pool or newly created - - Raises: - Exception: If pool is exhausted (all maxconn connections in use) - """ - if self._closed: - raise Exception("Connection pool is closed") - - async with self._lock: - if self._pool: - conn = self._pool.pop() - elif len(self._used) < self.maxconn: - conn = await self.connection_factory() - else: - raise Exception("Connection pool exhausted") - - self._used.add(conn) - return conn - - def putconn(self, conn): - """ - Return a connection to the pool. - - Args: - conn: Connection to return to the pool - """ - if self._closed: - return - - if conn in self._used: - self._used.remove(conn) - if len(self._pool) < self.minconn: - self._pool.append(conn) - else: - # Pool is full, close the connection - conn.close() - - async def closeall(self): - """Close all connections in the pool.""" - self._closed = True - - # Close all pooled connections - for conn in self._pool: - conn.close() - - # Close all connections in use - for conn in list(self._used): - conn.close() - - self._pool.clear() - self._used.clear() \ No newline at end of file diff --git a/python/samples/psycopg2/getting_started/create_db_connection.py b/python/samples/psycopg2/getting_started/create_db_connection.py new file mode 100644 index 0000000..803a4ce --- /dev/null +++ b/python/samples/psycopg2/getting_started/create_db_connection.py @@ -0,0 +1,47 @@ +""" +Sample demonstrating psycopg2 connection with synchronous Entra ID authentication for Azure PostgreSQL. +""" + +import os + +from dotenv import load_dotenv +from psycopg2 import pool +from azurepg_entra.psycopg2 import EntraConnection + +# Load environment variables from .env file +load_dotenv() +SERVER = os.getenv("POSTGRES_SERVER") +DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") + + +def main() -> None: + # We use the EntraConnection class to enable synchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/docs/advanced.html#subclassing-connection + connection_pool = pool.ThreadedConnectionPool( + minconn=1, + maxconn=5, + host=SERVER, + database=DATABASE, + connection_factory=EntraConnection, + ) + + conn = connection_pool.getconn() + try: + with conn.cursor() as cur: + cur.execute("SELECT now()") + result = cur.fetchone() + print(f"Database time: {result[0]}") + + cur.execute("SELECT current_user") + user = cur.fetchone() + print(f"Connected as: {user[0]}") + finally: + connection_pool.putconn(conn) + connection_pool.closeall() + + +if __name__ == "__main__": + main() diff --git a/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py b/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py deleted file mode 100644 index 319371c..0000000 --- a/python/samples/psycopg2/getting_started/create_db_connection_psycopg2.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -Sample demonstrating both synchronous and asynchronous psycopg2 connections -with Azure Entra ID authentication for Azure PostgreSQL. - -This example shows: -1. Synchronous connection using psycopg2 with Entra ID authentication -2. Asynchronous connection using aiopg with Entra ID authentication - -Both examples use the same Azure Entra ID authentication mechanism to connect -to Azure Database for PostgreSQL. -""" - -from dotenv import load_dotenv -import argparse -import asyncio -import sys -import os -from psycopg2 import pool -from azurepg_entra.psycopg2 import connect_with_entra, connect_with_entra_async -from async_pool_utils import AsyncEntraConnectionPool - -# Load environment variables from .env file -load_dotenv() -SERVER = os.getenv("POSTGRES_SERVER") -DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") - -def main_sync(): - """Synchronous connection example using psycopg2 with Entra ID authentication and connection pooling.""" - - try: - # Create a wrapper function that explicitly passes our server parameters - def entra_connection_factory(*args, **kwargs): - # Ignore any arguments passed by ThreadedConnectionPool and use our explicit parameters - return connect_with_entra( - host=SERVER, - port=5432, - dbname=DATABASE - ) - - # Create a connection pool using psycopg2 with Entra ID authentication - connection_pool = pool.ThreadedConnectionPool( - minconn=1, - maxconn=5, - connection_factory=entra_connection_factory - ) - - # Get a connection from the pool - conn = connection_pool.getconn() - - try: - with conn.cursor() as cur: - cur.execute("SELECT now()") - result = cur.fetchone() - print(f"Sync - Database time: {result[0]}") - - # Test current user query - cur.execute("SELECT current_user") - user = cur.fetchone() - print(f"Sync - Connected as: {user[0]}") - finally: - # Return connection to pool - connection_pool.putconn(conn) - connection_pool.closeall() - - except Exception as e: - print(f"Sync - Error connecting to database: {e}") - raise - -async def main_async(): - """Asynchronous connection example with custom async pool using connection_factory pattern.""" - - try: - # Create async connection factory function (mirrors the sync version) - async def entra_async_connection_factory(*args, **kwargs): - # Ignore any arguments and use our explicit parameters - return await connect_with_entra_async( - host=SERVER, - port=5432, - dbname=DATABASE - ) - - # Use our custom async pool with connection factory - async with AsyncEntraConnectionPool(entra_async_connection_factory, minconn=1, maxconn=5) as connection_pool: - # Get a connection from the pool (mirrors sync pattern) - conn = await connection_pool.getconn() - - try: - async with conn.cursor() as cur: - await cur.execute("SELECT now()") - result = await cur.fetchone() - print(f"Async - Database time: {result[0]}") - - # Test current user query - await cur.execute("SELECT current_user") - user = await cur.fetchone() - print(f"Async - Connected as: {user[0]}") - finally: - # Return connection to pool (mirrors sync pattern) - connection_pool.putconn(conn) - - except Exception as e: - print(f"Async - Error connecting to database: {e}") - raise - -async def main(mode: str = "both"): - """Main function that runs sync and/or async examples based on mode. - - Args: - mode: "sync", "async", or "both" to determine which examples to run - """ - if mode in ("sync", "both"): - print("=== Running Synchronous Example ===") - try: - main_sync() - print("āœ… Sync example completed successfully!") - except Exception as e: - print(f"āŒ Sync example failed: {e}") - - if mode in ("async", "both"): - if mode == "both": - print("\n=== Running Asynchronous Example ===") - else: - print("=== Running Asynchronous Example ===") - try: - await main_async() - print("āœ… Async example completed successfully!") - except Exception as e: - print(f"āŒ Async example failed: {e}") - -if __name__ == "__main__": - # Parse command line arguments - parser = argparse.ArgumentParser( - description="Demonstrate psycopg2/aiopg connections with Azure Entra ID authentication" - ) - parser.add_argument( - "--mode", - choices=["sync", "async", "both"], - default="both", - help="Run synchronous, asynchronous, or both examples (default: both)" - ) - args = parser.parse_args() - - # Set Windows event loop policy for compatibility if needed - if sys.platform.startswith('win'): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - asyncio.run(main(args.mode)) \ No newline at end of file diff --git a/python/samples/psycopg2/getting_started/create_everlasting_db_connection.py b/python/samples/psycopg2/getting_started/create_everlasting_db_connection.py new file mode 100644 index 0000000..13c85eb --- /dev/null +++ b/python/samples/psycopg2/getting_started/create_everlasting_db_connection.py @@ -0,0 +1,98 @@ +""" +Sample demonstrating an everlasting psycopg2 connection with Azure Entra ID authentication +for Azure PostgreSQL that runs queries indefinitely to test token refresh capabilities. +""" + +import argparse +import os +import sys +import time +from datetime import datetime +from dotenv import load_dotenv +from psycopg2.pool import ThreadedConnectionPool +from azurepg_entra.psycopg2 import EntraConnection + +# Load environment variables from .env file +load_dotenv() +SERVER = os.getenv("POSTGRES_SERVER") +DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") + + +def run_everlasting_queries(interval_minutes: int = 2) -> None: + """Run database queries indefinitely with psycopg2 and Entra authentication using ThreadedConnectionPool.""" + + print(f"Running queries every {interval_minutes} minutes...") + print("Press Ctrl+C to stop\n") + + # Create connection string + conninfo = f"postgresql://{SERVER}:5432/{DATABASE}" + + # Create connection pool with EntraConnection factory + print("Creating connection pool with Entra ID authentication enabled...") + pool = ThreadedConnectionPool( + minconn=1, maxconn=3, dsn=conninfo, connection_factory=EntraConnection + ) + + execution_count = 0 + + # Get one connection and reuse it throughout the program + conn = pool.getconn() + + try: + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") + print(f"Execution #{execution_count} at {current_time}") + + try: + with conn.cursor() as cur: + # Query 1: Get PostgreSQL version + cur.execute("SELECT version()") + version = cur.fetchone() + print(f"Connected to PostgreSQL: {version[0][:50]}...") + + # Query 2: Get current user (shows the Entra username) + cur.execute("SELECT current_user") + user = cur.fetchone() + print(f"Connected as: {user[0]}") + + # Query 3: Get current timestamp + cur.execute("SELECT now()") + timestamp = cur.fetchone() + print(f"Server time: {timestamp[0]}") + + print("Query execution successful!") + + except Exception as e: + print(f"Database error: {e}") + + print(f"Waiting {interval_minutes} minutes until next execution...\n") + time.sleep(interval_minutes * 60) + finally: + pool.putconn(conn) + pool.closeall() + + +def main() -> None: + """Main function with command line argument parsing.""" + parser = argparse.ArgumentParser( + description="Demonstrate everlasting psycopg2 connection with Azure Entra ID authentication" + ) + parser.add_argument( + "--interval", + type=int, + default=2, + help="Query execution interval in minutes (default: 2)", + ) + args = parser.parse_args() + + # Validate environment variables + if not SERVER: + print("Error: POSTGRES_SERVER environment variable is required") + sys.exit(1) + # Run the everlasting queries + run_everlasting_queries(args.interval) + + +if __name__ == "__main__": + main() diff --git a/python/samples/psycopg3/getting_started/create_db_connection.py b/python/samples/psycopg3/getting_started/create_db_connection.py new file mode 100644 index 0000000..4ca0721 --- /dev/null +++ b/python/samples/psycopg3/getting_started/create_db_connection.py @@ -0,0 +1,118 @@ +""" +Sample demonstrating both synchronous and asynchronous psycopg3 connections +with Azure Entra ID authentication for Azure PostgreSQL. +""" + +import argparse +import asyncio +import os +import sys + +from dotenv import load_dotenv +from psycopg_pool import AsyncConnectionPool, ConnectionPool +from azurepg_entra.psycopg3 import AsyncEntraConnection, EntraConnection + +# Load environment variables from .env file +load_dotenv() +SERVER = os.getenv("POSTGRES_SERVER") +DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") + + +def main_sync() -> None: + """Synchronous connection example using psycopg with Entra ID authentication.""" + + # We use the EntraConnection class to enable synchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect + pool = ConnectionPool( + conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", + min_size=1, + max_size=5, + open=False, + connection_class=EntraConnection, + ) + with pool, pool.connection() as conn, conn.cursor() as cur: + # Query 1 + cur.execute("SELECT now()") + result = cur.fetchone() + print(f"Sync - Database time: {result}") + + # Query 2 + cur.execute("SELECT current_user") + user = cur.fetchone() + print(f"Sync - Connected as: {user[0] if user else 'Unknown'}") + + +async def main_async() -> None: + """Asynchronous connection example using psycopg with Entra ID authentication.""" + + # We use the AsyncEntraConnection class to enable asynchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect + pool = AsyncConnectionPool( + conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", + min_size=1, + max_size=5, + open=False, + connection_class=AsyncEntraConnection, + ) + async with pool, pool.connection() as conn, conn.cursor() as cur: + # Query 1 + await cur.execute("SELECT now()") + result = await cur.fetchone() + print(f"Async - Database time: {result}") + + # Query 2 + await cur.execute("SELECT current_user") + user = await cur.fetchone() + print(f"Async - Connected as: {user[0] if user else 'Unknown'}") + + +async def main(mode: str = "async") -> None: + """Main function that runs sync and/or async examples based on mode. + + Args: + mode: "sync", "async", or "both" to determine which examples to run + """ + if mode in ("sync", "both"): + print("=== Running Synchronous Example ===") + try: + main_sync() + print("Sync example completed successfully!") + except Exception as e: + print(f"Sync example failed: {e}") + + if mode in ("async", "both"): + if mode == "both": + print("\n=== Running Asynchronous Example ===") + else: + print("=== Running Asynchronous Example ===") + try: + await main_async() + print("Async example completed successfully!") + except Exception as e: + print(f"Async example failed: {e}") + + +if __name__ == "__main__": + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Demonstrate psycopg connections with Azure Entra ID authentication" + ) + parser.add_argument( + "--mode", + choices=["sync", "async", "both"], + default="both", + help="Run synchronous, asynchronous, or both examples (default: both)", + ) + args = parser.parse_args() + + # Set Windows event loop policy for compatibility if needed + if sys.platform.startswith("win"): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + asyncio.run(main(args.mode)) diff --git a/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py b/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py deleted file mode 100644 index 9d95eb0..0000000 --- a/python/samples/psycopg3/getting_started/create_db_connection_psycopg3.py +++ /dev/null @@ -1,118 +0,0 @@ -""" -Sample demonstrating both synchronous and asynchronous psycopg connections -with Azure Entra ID authentication for Azure PostgreSQL. - -This example shows: -1. Synchronous connection using SyncEntraConnection and ConnectionPool -2. Asynchronous connection using AsyncEntraConnection and AsyncConnectionPool - -Both examples use the same Azure Entra ID authentication mechanism to connect -to Azure Database for PostgreSQL. -""" - -from psycopg_pool import AsyncConnectionPool, ConnectionPool -from dotenv import load_dotenv -import argparse -import asyncio -import sys -import os -from azurepg_entra.psycopg3 import SyncEntraConnection, AsyncEntraConnection - -# Load environment variables from .env file -load_dotenv() -SERVER = os.getenv("POSTGRES_SERVER") -DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") - -def main_sync(): - """Synchronous connection example using psycopg with Entra ID authentication.""" - - try: - pool = ConnectionPool( - conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", - min_size=1, - max_size=5, - open=False, - connection_class=SyncEntraConnection - ) - pool.open() - with pool, pool.connection() as conn, conn.cursor() as cur: - cur.execute("SELECT now()") - result = cur.fetchone() - print(f"Sync - Database time: {result}") - - # Test current user query - cur.execute("SELECT current_user") - user = cur.fetchone() - print(f"Sync - Connected as: {user[0]}") - except Exception as e: - print(f"Sync - Error connecting to database: {e}") - raise - -async def main_async(): - """Asynchronous connection example using psycopg with Entra ID authentication.""" - - try: - pool = AsyncConnectionPool( - conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", - min_size=1, - max_size=5, - open=False, - connection_class=AsyncEntraConnection - ) - await pool.open() - async with pool, pool.connection() as conn, conn.cursor() as cur: - await cur.execute("SELECT now()") - result = await cur.fetchone() - print(f"Async - Database time: {result}") - - # Test current user query - await cur.execute("SELECT current_user") - user = await cur.fetchone() - print(f"Async - Connected as: {user[0]}") - except Exception as e: - print(f"Async - Error connecting to database: {e}") - raise - -async def main(mode: str = "both"): - """Main function that runs sync and/or async examples based on mode. - - Args: - mode: "sync", "async", or "both" to determine which examples to run - """ - if mode in ("sync", "both"): - print("=== Running Synchronous Example ===") - try: - main_sync() - print("āœ… Sync example completed successfully!") - except Exception as e: - print(f"āŒ Sync example failed: {e}") - - if mode in ("async", "both"): - if mode == "both": - print("\n=== Running Asynchronous Example ===") - else: - print("=== Running Asynchronous Example ===") - try: - await main_async() - print("āœ… Async example completed successfully!") - except Exception as e: - print(f"āŒ Async example failed: {e}") - -if __name__ == "__main__": - # Parse command line arguments - parser = argparse.ArgumentParser( - description="Demonstrate psycopg connections with Azure Entra ID authentication" - ) - parser.add_argument( - "--mode", - choices=["sync", "async", "both"], - default="both", - help="Run synchronous, asynchronous, or both examples (default: both)" - ) - args = parser.parse_args() - - # Set Windows event loop policy for compatibility if needed - if sys.platform.startswith('win'): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - asyncio.run(main(args.mode)) \ No newline at end of file diff --git a/python/samples/psycopg3/getting_started/create_everlasting_db_connection.py b/python/samples/psycopg3/getting_started/create_everlasting_db_connection.py new file mode 100644 index 0000000..1231703 --- /dev/null +++ b/python/samples/psycopg3/getting_started/create_everlasting_db_connection.py @@ -0,0 +1,167 @@ +""" +Sample demonstrating everlasting synchronous and asynchronous psycopg3 connections +with Azure Entra ID authentication for Azure PostgreSQL that run queries indefinitely +to test token refresh capabilities. +""" + +import argparse +import asyncio +import os +import sys +import time + +from datetime import datetime +from dotenv import load_dotenv +from psycopg_pool import AsyncConnectionPool, ConnectionPool +from azurepg_entra.psycopg3 import AsyncEntraConnection, EntraConnection + +# Load environment variables from .env file +load_dotenv() +SERVER = os.getenv("POSTGRES_SERVER") +DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") + + +def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: + """Run synchronous database queries indefinitely with psycopg3 and Entra authentication.""" + + print("=== Running Everlasting Synchronous psycopg3 Connection Example ===") + print(f"Running queries every {interval_minutes} minutes...") + print("Press Ctrl+C to stop\n") + + # We use the EntraConnection class to enable synchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect + pool = ConnectionPool( + conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", + min_size=1, + max_size=3, + open=False, + connection_class=EntraConnection, + ) + + execution_count = 0 + with pool: + # Get connection from pool + conn = pool.getconn() + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") + print(f"Sync Execution #{execution_count} at {current_time}") + with conn.cursor() as cur: + # Query 1: Get PostgreSQL version + cur.execute("SELECT version()") + version = cur.fetchone() + print( + f"Connected to PostgreSQL: {version[0][:50] if version else 'Unknown'}..." + ) + + # Query 2: Get current user + cur.execute("SELECT current_user") + user = cur.fetchone() + print(f"Connected as: {user[0] if user else 'Unknown'}") + + # Query 3: Get current timestamp + cur.execute("SELECT now()") + timestamp = cur.fetchone() + print(f"Server time: {timestamp[0] if timestamp else 'Unknown'}") + + print("Sync query execution successful!") + + print(f"Waiting {interval_minutes} minutes until next execution...\n") + time.sleep(interval_minutes * 60) + + +async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: + """Run asynchronous database queries indefinitely with psycopg3 and Entra authentication.""" + + print("=== Running Everlasting Asynchronous psycopg3 Connection Example ===") + print(f"Running queries every {interval_minutes} minutes...") + print("Press Ctrl+C to stop\n") + + # We use the AsyncEntraConnection class to enable asynchronous Entra-based authentication for database access. + # This class is applied whenever the connection pool creates a new connection, ensuring that Entra + # authentication tokens are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://www.psycopg.org/psycopg3/docs/api/connections.html#psycopg.Connection.connect + pool = AsyncConnectionPool( + conninfo=f"postgresql://{SERVER}:5432/{DATABASE}", + min_size=1, + max_size=3, + open=False, + connection_class=AsyncEntraConnection, + ) + + execution_count = 0 + async with pool: + # Get connection from pool + conn = await pool.getconn() + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") + + print(f"Async Execution #{execution_count} at {current_time}") + async with conn.cursor() as cur: + # Query 1: Get PostgreSQL version + await cur.execute("SELECT version()") + version = await cur.fetchone() + print( + f"Connected to PostgreSQL: {version[0][:50] if version else 'Unknown'}..." + ) + + # Query 2: Get current user + await cur.execute("SELECT current_user") + user = await cur.fetchone() + print(f"Connected as: {user[0] if user else 'Unknown'}") + + # Query 3: Get current timestamp + await cur.execute("SELECT now()") + timestamp = await cur.fetchone() + print(f"Server time: {timestamp[0] if timestamp else 'Unknown'}") + + print("Async query execution successful!") + + print(f"Waiting {interval_minutes} minutes until next execution...\n") + time.sleep(interval_minutes * 60) + + +async def main() -> None: + """Main function with command line argument parsing.""" + parser = argparse.ArgumentParser( + description="Demonstrate everlasting psycopg3 connections with Azure Entra ID authentication" + ) + parser.add_argument( + "--mode", + choices=["sync", "async", "both"], + default="both", + help="Run synchronous, asynchronous, or both examples (default: both)", + ) + parser.add_argument( + "--interval", + type=int, + default=2, + help="Query execution interval in minutes (default: 2)", + ) + args = parser.parse_args() + + # Validate environment variables + if not SERVER: + print("Error: POSTGRES_SERVER environment variable is required") + sys.exit(1) + + if args.mode in ("sync", "both"): + run_everlasting_sync_queries(args.interval) + + if args.mode in ("async", "both"): + if args.mode == "both": + print("\n" + "=" * 60 + "\n") + await run_everlasting_async_queries(args.interval) + + +if __name__ == "__main__": + # Set Windows event loop policy for compatibility if needed + if sys.platform.startswith("win"): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + asyncio.run(main()) diff --git a/python/samples/sqlalchemy/getting_started/create_db_connection.py b/python/samples/sqlalchemy/getting_started/create_db_connection.py new file mode 100644 index 0000000..12d775d --- /dev/null +++ b/python/samples/sqlalchemy/getting_started/create_db_connection.py @@ -0,0 +1,127 @@ +""" +Sample demonstrating both synchronous and asynchronous SQLAlchemy connections +with Azure Entra ID authentication for Azure PostgreSQL. +""" + +import argparse +import asyncio +import os +import sys + +from dotenv import load_dotenv +from sqlalchemy import create_engine, text +from sqlalchemy.ext.asyncio import create_async_engine + +from azurepg_entra.sqlalchemy import ( + enable_entra_authentication, + enable_entra_authentication_async, +) + +# Load environment variables from .env file +load_dotenv() +SERVER = os.getenv("POSTGRES_SERVER") +DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") + + +def main_sync() -> None: + """Synchronous connection example using SQLAlchemy with Entra ID authentication.""" + + # Create a synchronous engine + engine = create_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") + + # We add an event listener to the engine to enable synchronous Entra authentication + # for database access. This event listener is triggered whenever the connection pool + # backing the engine creates a new connection, ensuring that Entra authentication tokens + # are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function + enable_entra_authentication(engine) + + with engine.connect() as conn: + # Query 1 + result = conn.execute(text("SELECT now()")) + row = result.fetchone() + print(f"Sync - Database time: {row[0] if row else 'Unknown'}") + + # Query 2 + result = conn.execute(text("SELECT current_user")) + row = result.fetchone() + print(f"Sync - Connected as: {row[0] if row else 'Unknown'}") + + # Clean up the engine + engine.dispose() + + +async def main_async() -> None: + """Asynchronous connection example using SQLAlchemy with Entra ID authentication.""" + + # Create an asynchronous engine + engine = create_async_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") + + # We add an event listener to the engine to enable asynchronous Entra authentication + # for database access. This event listener is triggered whenever the connection pool + # backing the engine creates a new connection, ensuring that Entra authentication tokens + # are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function + enable_entra_authentication_async(engine) + + async with engine.connect() as conn: + # Query 1 + result = await conn.execute(text("SELECT now()")) + row = result.fetchone() + print(f"Async Core - Database time: {row[0] if row else 'Unknown'}") + + # Query 2 + result = await conn.execute(text("SELECT current_user")) + row = result.fetchone() + print(f"Async Core - Connected as: {row[0] if row else 'Unknown'}") + + # Clean up the engine + await engine.dispose() + + +async def main(mode: str = "async") -> None: + """Main function that runs sync and/or async examples based on mode. + + Args: + mode: "sync", "async", or "both" to determine which examples to run + """ + if mode in ("sync", "both"): + print("=== Running Synchronous SQLAlchemy Example ===") + try: + main_sync() + print("Sync example completed successfully!") + except Exception as e: + print(f"Sync example failed: {e}") + + if mode in ("async", "both"): + if mode == "both": + print("\n=== Running Asynchronous SQLAlchemy Example ===") + else: + print("=== Running Asynchronous SQLAlchemy Example ===") + try: + await main_async() + print("Async example completed successfully!") + except Exception as e: + print(f"Async example failed: {e}") + + +if __name__ == "__main__": + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Demonstrate SQLAlchemy connections with Azure Entra ID authentication" + ) + parser.add_argument( + "--mode", + choices=["sync", "async", "both"], + default="both", + help="Run synchronous, asynchronous, or both examples (default: both)", + ) + args = parser.parse_args() + + # Set Windows event loop policy for compatibility if needed + if sys.platform.startswith("win"): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + asyncio.run(main(args.mode)) diff --git a/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py b/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py deleted file mode 100644 index 918e2d7..0000000 --- a/python/samples/sqlalchemy/getting_started/create_db_connection_sqlalchemy.py +++ /dev/null @@ -1,203 +0,0 @@ -""" -Sample demonstrating both synchronous and asynchronous SQLAlchemy connections -with Azure Entra ID authentication for Azure PostgreSQL. - -This example shows: -1. Synchronous connection using create_engine_with_entra and connection pooling -2. Asynchronous connection using create_async_engine_with_entra and async connection pooling - -Both examples use the same Azure Entra ID authentication mechanism to connect -to Azure Database for PostgreSQL. -""" - -from dotenv import load_dotenv -import argparse -import asyncio -import sys -import os -from sqlalchemy import text -from azurepg_entra.sqlalchemy import create_engine_with_entra, create_async_engine_with_entra - -# Load environment variables from .env file -load_dotenv(dotenv_path=os.path.join(os.path.dirname(__file__), '.env')) -SERVER = os.getenv("POSTGRES_SERVER") -DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") - -def main_sync(): - """Synchronous connection example using SQLAlchemy with Entra ID authentication.""" - - try: - # Create engine with Entra ID authentication and connection pooling - engine = create_engine_with_entra( - f"postgresql+psycopg://{SERVER}:5432/{DATABASE}", - pool_size=5, - max_overflow=10, - pool_pre_ping=True, # Validate connections before use - echo=False # Set to True to see SQL queries - ) - - # Test connection - with engine.connect() as conn: - result = conn.execute(text("SELECT now()")) - db_time = result.fetchone() - print(f"Sync - Database time: {db_time[0]}") - - # Test current user query - result = conn.execute(text("SELECT current_user")) - user = result.fetchone() - print(f"Sync - Connected as: {user[0]}") - - # Test a simple query to verify functionality - result = conn.execute(text("SELECT 'SQLAlchemy Sync Entra Connection Working!' as message")) - message = result.fetchone() - print(f"Sync - Test message: {message[0]}") - - # Clean up the sync engine - engine.dispose() - - except Exception as e: - print(f"Sync - Error connecting to database: {e}") - raise - -async def main_async(): - """Asynchronous connection example using SQLAlchemy with Entra ID authentication.""" - - try: - # Create async engine with Entra ID authentication and connection pooling - engine = create_async_engine_with_entra( - f"postgresql+psycopg://{SERVER}:5432/{DATABASE}", - pool_size=5, - max_overflow=10, - pool_pre_ping=True, # Validate connections before use - echo=False # Set to True to see SQL queries - ) - - # Test async connection - async with engine.connect() as conn: - result = await conn.execute(text("SELECT now()")) - db_time = result.fetchone() - print(f"Async - Database time: {db_time[0]}") - - # Test current user query - result = await conn.execute(text("SELECT current_user")) - user = result.fetchone() - print(f"Async - Connected as: {user[0]}") - - # Test a simple query to verify functionality - result = await conn.execute(text("SELECT 'SQLAlchemy Async Entra Connection Working!' as message")) - message = result.fetchone() - print(f"Async - Test message: {message[0]}") - - # Clean up the async engine - await engine.dispose() - - except Exception as e: - print(f"Async - Error connecting to database: {e}") - raise - -def test_connection_pool_refresh(): - """Test that connection pool handles token refresh properly (sync version).""" - - try: - print("\n=== Testing Connection Pool Token Refresh (Sync) ===") - - engine = create_engine_with_entra( - f"postgresql+psycopg://{SERVER}:5432/{DATABASE}", - pool_size=2, - max_overflow=0, - echo=False - ) - - # Make multiple connections to test token refresh - for i in range(3): - with engine.connect() as conn: - result = conn.execute(text("SELECT current_user, now()")) - user, db_time = result.fetchone() - print(f"Connection {i+1} - User: {user}, Time: {db_time}") - - # Clean up the sync engine - engine.dispose() - print("āœ… Connection pool token refresh test completed successfully!") - - except Exception as e: - print(f"āŒ Connection pool test failed: {e}") - raise - -async def test_async_connection_pool_refresh(): - """Test that async connection pool handles token refresh properly.""" - - try: - print("\n=== Testing Async Connection Pool Token Refresh ===") - - engine = create_async_engine_with_entra( - f"postgresql+psycopg://{SERVER}:5432/{DATABASE}", - pool_size=2, - max_overflow=0, - echo=False - ) - - # Make multiple async connections to test token refresh - for i in range(3): - async with engine.connect() as conn: - result = await conn.execute(text("SELECT current_user, now()")) - user, db_time = result.fetchone() - print(f"Async Connection {i+1} - User: {user}, Time: {db_time}") - - await engine.dispose() - print("āœ… Async connection pool token refresh test completed successfully!") - - except Exception as e: - print(f"āŒ Async connection pool test failed: {e}") - raise - -async def main(mode: str = "both"): - """Main function that runs sync and/or async examples based on mode. - - Args: - mode: "sync", "async", or "both" to determine which examples to run - """ - if mode in ("sync", "both"): - print("=== Running Synchronous Example ===") - try: - main_sync() - print("āœ… Sync example completed successfully!") - - # Test connection pool behavior - test_connection_pool_refresh() - - except Exception as e: - print(f"āŒ Sync example failed: {e}") - - if mode in ("async", "both"): - if mode == "both": - print("\n=== Running Asynchronous Example ===") - else: - print("=== Running Asynchronous Example ===") - try: - await main_async() - print("āœ… Async example completed successfully!") - - # Test async connection pool behavior - await test_async_connection_pool_refresh() - - except Exception as e: - print(f"āŒ Async example failed: {e}") - -if __name__ == "__main__": - # Parse command line arguments - parser = argparse.ArgumentParser( - description="Demonstrate SQLAlchemy connections with Azure Entra ID authentication" - ) - parser.add_argument( - "--mode", - choices=["sync", "async", "both"], - default="both", - help="Run synchronous, asynchronous, or both examples (default: both)" - ) - args = parser.parse_args() - - # Set Windows event loop policy for compatibility if needed - if sys.platform.startswith('win'): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - - asyncio.run(main(args.mode)) \ No newline at end of file diff --git a/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection.py b/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection.py new file mode 100644 index 0000000..2732326 --- /dev/null +++ b/python/samples/sqlalchemy/getting_started/create_everlasting_db_connection.py @@ -0,0 +1,190 @@ +""" +Sample demonstrating everlasting synchronous and asynchronous SQLAlchemy connections +with Azure Entra ID authentication for Azure PostgreSQL that run queries indefinitely +to test token refresh capabilities. +""" + +import argparse +import asyncio +import os +import sys +import time +from datetime import datetime + +from dotenv import load_dotenv +from sqlalchemy import create_engine, text +from sqlalchemy.ext.asyncio import create_async_engine + +from azurepg_entra.sqlalchemy import ( + enable_entra_authentication, + enable_entra_authentication_async, +) + +# Load environment variables from .env file +load_dotenv() +SERVER = os.getenv("POSTGRES_SERVER") +DATABASE = os.getenv("POSTGRES_DATABASE", "postgres") + + +def run_everlasting_sync_queries(interval_minutes: int = 2) -> None: + """Run synchronous database queries indefinitely with SQLAlchemy and Entra authentication.""" + + print("=== Running Everlasting Synchronous SQLAlchemy Connection Example ===") + print(f"Running queries every {interval_minutes} minutes...") + print("Press Ctrl+C to stop\n") + + # Create synchronous engine + engine = create_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") + + # We add an event listener to the engine to enable synchronous Entra authentication + # for database access. This event listener is triggered whenever the connection pool + # backing the engine creates a new connection, ensuring that Entra authentication tokens + # are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function + enable_entra_authentication(engine) + + execution_count = 0 + + # Get one connection and reuse it throughout the program + with engine.connect() as conn: + try: + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") + + print(f"Sync Execution #{execution_count} at {current_time}") + + try: + # Query 1: Get PostgreSQL version + result = conn.execute(text("SELECT version()")) + row = result.fetchone() + version = row[0] if row else "Unknown" + print(f"Connected to PostgreSQL: {version[:50]}...") + + # Query 2: Get current user + result = conn.execute(text("SELECT current_user")) + row = result.fetchone() + user = row[0] if row else "Unknown" + print(f"Connected as: {user}") + + # Query 3: Get current timestamp + result = conn.execute(text("SELECT now()")) + row = result.fetchone() + timestamp = row[0] if row else "Unknown" + print(f"Server time: {timestamp}") + + print("Sync query execution successful!") + + except Exception as e: + print(f"Database error: {e}") + + print(f"Waiting {interval_minutes} minutes until next execution...\n") + time.sleep(interval_minutes * 60) + finally: + engine.dispose() + + +async def run_everlasting_async_queries(interval_minutes: int = 2) -> None: + """Run asynchronous database queries indefinitely with SQLAlchemy and Entra authentication.""" + + print("=== Running Everlasting Asynchronous SQLAlchemy Connection Example ===") + print(f"Running queries every {interval_minutes} minutes...") + print("Press Ctrl+C to stop\n") + + # Create asynchronous engine + engine = create_async_engine(f"postgresql+psycopg://{SERVER}/{DATABASE}") + + # We add an event listener to the engine to enable asynchronous Entra authentication + # for database access. This event listener is triggered whenever the connection pool + # backing the engine creates a new connection, ensuring that Entra authentication tokens + # are properly managed and refreshed so that each connection uses a valid token. + # + # For more details, see: https://docs.sqlalchemy.org/en/20/core/engines.html#controlling-how-parameters-are-passed-to-the-dbapi-connect-function + enable_entra_authentication_async(engine) + + execution_count = 0 + + # Get one connection and reuse it throughout the program + async with engine.connect() as conn: + try: + while True: + execution_count += 1 + current_time = datetime.now().strftime("%H:%M:%S") + + print(f"Async Execution #{execution_count} at {current_time}") + + try: + # Query 1: Get PostgreSQL version + result = await conn.execute(text("SELECT version()")) + row = result.fetchone() + version = row[0] if row else "Unknown" + print(f"Connected to PostgreSQL: {version[:50]}...") + + # Query 2: Get current user + result = await conn.execute(text("SELECT current_user")) + row = result.fetchone() + user = row[0] if row else "Unknown" + print(f"Connected as: {user}") + + # Query 3: Get current timestamp + result = await conn.execute(text("SELECT now()")) + row = result.fetchone() + timestamp = row[0] if row else "Unknown" + print(f"Server time: {timestamp}") + + print("Async query execution successful!") + + except Exception as e: + print(f"Database error: {e}") + + print(f"Waiting {interval_minutes} minutes until next execution...\n") + time.sleep(interval_minutes * 60) + finally: + await engine.dispose() + + +async def main() -> None: + """Main function with command line argument parsing.""" + parser = argparse.ArgumentParser( + description="Demonstrate everlasting SQLAlchemy connections with Azure Entra ID authentication" + ) + parser.add_argument( + "--mode", + choices=["sync", "async", "both"], + default="both", + help="Run synchronous, asynchronous, or both examples (default: both)", + ) + parser.add_argument( + "--interval", + type=int, + default=2, + help="Query execution interval in minutes (default: 2)", + ) + args = parser.parse_args() + + # Validate environment variables + if not SERVER: + print("Error: POSTGRES_SERVER environment variable is required") + sys.exit(1) + + print(f"Target server: {SERVER}") + print(f"Target database: {DATABASE}") + print(f"Query interval: {args.interval} minutes") + print(f"Mode: {args.mode}\n") + + if args.mode in ("sync", "both"): + run_everlasting_sync_queries(args.interval) + + if args.mode in ("async", "both"): + if args.mode == "both": + print("\n" + "=" * 60 + "\n") + await run_everlasting_async_queries(args.interval) + + +if __name__ == "__main__": + # Set Windows event loop policy for compatibility if needed + if sys.platform.startswith("win"): + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + asyncio.run(main()) diff --git a/python/src/azurepg_entra/__init__.py b/python/src/azurepg_entra/__init__.py index 4852663..a6d8259 100644 --- a/python/src/azurepg_entra/__init__.py +++ b/python/src/azurepg_entra/__init__.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. + """ Azure PostgreSQL Entra ID Integration Library -This library provides connection classes for using Azure Entra ID authentication +This library provides connection classes for using Azure Entra ID authentication with Azure Database for PostgreSQL across different PostgreSQL drivers. Available modules (with optional dependencies): @@ -16,4 +17,4 @@ """ __version__ = "0.1.0" -__author__ = "Microsoft Corporation" \ No newline at end of file +__author__ = "Microsoft Corporation" diff --git a/python/src/azurepg_entra/core.py b/python/src/azurepg_entra/core.py index 76d11e0..3eb2ad8 100644 --- a/python/src/azurepg_entra/core.py +++ b/python/src/azurepg_entra/core.py @@ -1,34 +1,45 @@ -import logging -import json +# Copyright (c) Microsoft. All rights reserved. + import base64 +import json from typing import Any, cast + from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential +from azure.core.exceptions import ClientAuthenticationError +from azure.identity import CredentialUnavailableError from azure.identity import DefaultAzureCredential as DefaultAzureCredential from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential -logger = logging.getLogger(__name__) +from azurepg_entra.errors import ( + ScopePermissionError, + TokenDecodeError, + UsernameExtractionError, +) + AZURE_DB_FOR_POSTGRES_SCOPE = "https://ossrdbms-aad.database.windows.net/.default" AZURE_MANAGEMENT_SCOPE = "https://management.azure.com/.default" + def get_entra_token(credential: TokenCredential | None, scope: str) -> str: """Acquires an Entra authentication token for Azure PostgreSQL synchronously. Parameters: - credential (TokenCredential or None): Credential object used to obtain the token. + credential (TokenCredential or None): Credential object used to obtain the token. If None, the default Azure credentials are used. scope (str): The scope for the token request. Returns: str: The acquired authentication token to be used as the database password. """ - logger.info("Acquiring Entra token for postgres password") - credential = credential or DefaultAzureCredential() cred = credential.get_token(scope) return cred.token -async def get_entra_token_async(credential: AsyncTokenCredential | None, scope: str) -> str: + +async def get_entra_token_async( + credential: AsyncTokenCredential | None, scope: str +) -> str: """Asynchronously acquires an Entra authentication token for Azure PostgreSQL. Parameters: @@ -39,13 +50,12 @@ async def get_entra_token_async(credential: AsyncTokenCredential | None, scope: Returns: str: The acquired authentication token to be used as the database password. """ - logger.info("Acquiring Entra token for postgres password") - credential = credential or AsyncDefaultAzureCredential() async with credential: cred = await credential.get_token(scope) return cred.token - + + def decode_jwt(token: str) -> dict[str, Any]: """Decodes a JWT token to extract its payload claims. @@ -53,12 +63,19 @@ def decode_jwt(token: str) -> dict[str, Any]: token (str): The JWT token string in the standard three-part format. Returns: - dict: A dictionary containing the claims extracted from the token payload. + dict[str, Any]: A dictionary containing the claims extracted from the token payload. + + Raises: + TokenValueError: If the token format is invalid or cannot be decoded. """ - payload = token.split(".")[1] - padding = "=" * (4 - len(payload) % 4) - decoded_payload = base64.urlsafe_b64decode(payload + padding) - return cast(dict[str, Any], json.loads(decoded_payload)) + try: + payload = token.split(".")[1] + padding = "=" * (4 - len(payload) % 4) + decoded_payload = base64.urlsafe_b64decode(payload + padding) + return cast(dict[str, Any], json.loads(decoded_payload)) + except Exception as e: + raise TokenDecodeError("Invalid JWT token format") from e + def parse_principal_name(xms_mirid: str) -> str | None: """Parses the principal name from an Azure resource path. @@ -71,42 +88,57 @@ def parse_principal_name(xms_mirid: str) -> str | None: """ if not xms_mirid: return None - + # Parse the xms_mirid claim which looks like # /subscriptions/{subId}/resourcegroups/{resourceGroup}/providers/Microsoft.ManagedIdentity/userAssignedIdentities/{principalName} - last_slash_index = xms_mirid.rfind('/') + last_slash_index = xms_mirid.rfind("/") if last_slash_index == -1: return None beginning = xms_mirid[:last_slash_index] - principal_name = xms_mirid[last_slash_index + 1:] + principal_name = xms_mirid[last_slash_index + 1 :] - if not principal_name or not beginning.lower().endswith("providers/microsoft.managedidentity/userassignedidentities"): + if not principal_name or not beginning.lower().endswith( + "providers/microsoft.managedidentity/userassignedidentities" + ): return None return principal_name + def get_entra_conninfo(credential: TokenCredential | None) -> dict[str, str]: """Synchronously obtains connection information from Entra authentication for Azure PostgreSQL. + This function acquires an access token from Azure Entra ID and extracts the username + from the token claims. It tries multiple claim sources to determine the username. + Parameters: credential (TokenCredential or None): The credential used for token acquisition. - If None, the default Azure credentials are used. + If None, DefaultAzureCredential() is used to automatically discover credentials. Returns: - dict[str, str]: A dictionary with 'user' and 'password' keys containing the username and token. - + dict[str, str]: A dictionary with 'user' and 'password' keys, where: + - 'user': The extracted username from token claims + - 'password': The Entra ID access token for database authentication + Raises: - ValueError: If the username cannot be extracted from the token payload. + TokenDecodeError: If the JWT token cannot be decoded or is malformed. + UsernameExtractionError: If the username cannot be extracted from token claims. + ScopePermissionError: The token could not be acquired from the management scope, possibly due to insufficient permissions. """ credential = credential or DefaultAzureCredential() # Always get the DB-scope token for password db_token = get_entra_token(credential, AZURE_DB_FOR_POSTGRES_SCOPE) - db_claims = decode_jwt(db_token) + try: + db_claims = decode_jwt(db_token) + except TokenDecodeError: + raise xms_mirid = db_claims.get("xms_mirid") username = ( - parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None + parse_principal_name(xms_mirid) + if isinstance(xms_mirid, str) + else None or db_claims.get("upn") or db_claims.get("preferred_username") or db_claims.get("unique_name") @@ -114,61 +146,99 @@ def get_entra_conninfo(credential: TokenCredential | None) -> dict[str, str]: if not username: # Fall back to management scope ONLY to discover username - mgmt_token = get_entra_token(credential, AZURE_MANAGEMENT_SCOPE) - mgmt_claims = decode_jwt(mgmt_token) + try: + mgmt_token = get_entra_token(credential, AZURE_MANAGEMENT_SCOPE) + except (CredentialUnavailableError, ClientAuthenticationError) as e: + raise ScopePermissionError( + "Failed to acquire token from management scope" + ) from e + try: + mgmt_claims = decode_jwt(mgmt_token) + except TokenDecodeError: + raise xms_mirid = mgmt_claims.get("xms_mirid") username = ( - parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None + parse_principal_name(xms_mirid) + if isinstance(xms_mirid, str) + else None or mgmt_claims.get("upn") or mgmt_claims.get("preferred_username") or mgmt_claims.get("unique_name") ) if not username: - raise ValueError( + raise UsernameExtractionError( "Could not determine username from token claims. " "Ensure the identity has the proper Azure AD attributes." ) return {"user": username, "password": db_token} -async def get_entra_conninfo_async(credential: AsyncTokenCredential | None) -> dict[str, str]: + +async def get_entra_conninfo_async( + credential: AsyncTokenCredential | None, +) -> dict[str, str]: """Asynchronously obtains connection information from Entra authentication for Azure PostgreSQL. + This function acquires an access token from Azure Entra ID and extracts the username + from the token claims. It tries multiple claim sources to determine the username. + Parameters: credential (AsyncTokenCredential or None): The async credential used for token acquisition. - If None, the default Azure credentials are used. + If None, AsyncDefaultAzureCredential() is used to automatically discover credentials. Returns: - dict[str, str]: A dictionary with 'user' and 'password' keys containing the username and token. - + dict[str, str]: A dictionary with 'user' and 'password' keys, where: + - 'user': The extracted username from token claims + - 'password': The Entra ID access token for database authentication + Raises: - ValueError: If the username cannot be extracted from the token payload. + TokenDecodeError: If the JWT token cannot be decoded or is malformed. + UsernameExtractionError: If the username cannot be extracted from token claims. + ScopePermissionError: The token could not be acquired from the management scope, possibly due to insufficient permissions. """ credential = credential or AsyncDefaultAzureCredential() db_token = await get_entra_token_async(credential, AZURE_DB_FOR_POSTGRES_SCOPE) - db_claims = decode_jwt(db_token) + try: + db_claims = decode_jwt(db_token) + except TokenDecodeError: + raise xms_mirid = db_claims.get("xms_mirid") username = ( - parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None + parse_principal_name(xms_mirid) + if isinstance(xms_mirid, str) + else None or db_claims.get("upn") or db_claims.get("preferred_username") or db_claims.get("unique_name") ) if not username: - mgmt_token = await get_entra_token_async(credential, AZURE_MANAGEMENT_SCOPE) - mgmt_claims = decode_jwt(mgmt_token) + try: + mgmt_token = await get_entra_token_async(credential, AZURE_MANAGEMENT_SCOPE) + except (CredentialUnavailableError, ClientAuthenticationError) as e: + raise ScopePermissionError( + "Failed to acquire token from management scope" + ) from e + try: + mgmt_claims = decode_jwt(mgmt_token) + except TokenDecodeError: + raise xms_mirid = mgmt_claims.get("xms_mirid") username = ( - parse_principal_name(xms_mirid) if isinstance(xms_mirid, str) else None + parse_principal_name(xms_mirid) + if isinstance(xms_mirid, str) + else None or mgmt_claims.get("upn") or mgmt_claims.get("preferred_username") or mgmt_claims.get("unique_name") ) if not username: - raise ValueError("Could not determine username from token claims.") + raise UsernameExtractionError( + "Could not determine username from token claims. " + "Ensure the identity has the proper Azure AD attributes." + ) - return {"user": username, "password": db_token} \ No newline at end of file + return {"user": username, "password": db_token} diff --git a/python/src/azurepg_entra/errors.py b/python/src/azurepg_entra/errors.py new file mode 100644 index 0000000..3782e64 --- /dev/null +++ b/python/src/azurepg_entra/errors.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft. All rights reserved. + +class AzurePgEntraError(Exception): + """Base class for all custom exceptions in the project.""" + + pass + + +class TokenDecodeError(AzurePgEntraError): + """Raised when a token value is invalid.""" + + pass + + +class UsernameExtractionError(AzurePgEntraError): + """Raised when username cannot be extracted from token.""" + + pass + + +class CredentialValueError(AzurePgEntraError): + """Raised when token credential is invalid.""" + + pass + + +class EntraConnectionValueError(AzurePgEntraError): + """Raised when Entra connection credentials are invalid.""" + + pass + + +class ScopePermissionError(AzurePgEntraError): + """Raised when the provided scope does not have sufficient permissions.""" + + pass diff --git a/python/src/azurepg_entra/psycopg2/__init__.py b/python/src/azurepg_entra/psycopg2/__init__.py index ddcdf20..9b223bf 100644 --- a/python/src/azurepg_entra/psycopg2/__init__.py +++ b/python/src/azurepg_entra/psycopg2/__init__.py @@ -1,57 +1,25 @@ # Copyright (c) Microsoft. All rights reserved. + """ -Psycopg2 + aiopg support for Azure Entra ID authentication with Azure Database for PostgreSQL. +Psycopg2 support for Azure Entra ID authentication with Azure Database for PostgreSQL. -This module provides connection functions that handle Azure Entra ID token acquisition -and authentication for both synchronous (psycopg2) and asynchronous (aiopg) PostgreSQL connections. +This module provides a connection class that handles Azure Entra ID token acquisition +and authentication for synchronous PostgreSQL connections. Requirements: Install with: pip install azurepg-entra[psycopg2] - - This will install: - - psycopg2-binary>=2.8.0 - - aiopg>=1.3.0 (for async support) -Functions: - connect_with_entra: Synchronous connection function with Entra ID authentication (psycopg2) - connect_with_entra_async: Asynchronous connection function with Entra ID authentication (aiopg) - get_entra_conninfo: Synchronous function to get Entra authentication info - get_entra_conninfo_async: Asynchronous function to get Entra authentication info + This will install: + - psycopg2-binary>=2.9.0 -Example usage: - # Synchronous connection - from azurepg_entra.psycopg2 import connect_with_entra - - conn = connect_with_entra( - host="myserver.postgres.database.azure.com", - dbname="mydatabase", - port=5432 - ) - - # Asynchronous connection - from azurepg_entra.psycopg2 import connect_with_entra_async - - conn = await connect_with_entra_async( - host="myserver.postgres.database.azure.com", - dbname="mydatabase", - port=5432 - ) +Classes: + EntraConnection: Synchronous connection class with Entra ID authentication """ -try: - from .psycopg2_entra_id_extension import ( - connect_with_entra, - connect_with_entra_async - ) - - __all__ = [ - "connect_with_entra", - "connect_with_entra_async" - ] - -except ImportError as e: - # Provide a helpful error message if psycopg2/aiopg dependencies are missing - raise ImportError( - "psycopg2 dependencies are not installed. " - "Install them with: pip install azurepg-entra[psycopg2]" - ) from e \ No newline at end of file +from .entra_connection import ( + EntraConnection, +) + +__all__ = [ + "EntraConnection", +] diff --git a/python/src/azurepg_entra/psycopg2/entra_connection.py b/python/src/azurepg_entra/psycopg2/entra_connection.py new file mode 100644 index 0000000..cfb21d4 --- /dev/null +++ b/python/src/azurepg_entra/psycopg2/entra_connection.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft. All rights reserved. + +from typing import Any + +from azure.core.credentials import TokenCredential + +from azurepg_entra.core import get_entra_conninfo +from azurepg_entra.errors import ( + CredentialValueError, + EntraConnectionValueError, +) + +try: + from psycopg2.extensions import connection, make_dsn, parse_dsn +except ImportError as e: + # Provide a helpful error message if psycopg2 dependencies are missing + raise ImportError( + "psycopg2 dependencies are not installed. " + "Install them with: pip install azurepg-entra[psycopg2]" + ) from e + +class EntraConnection(connection): + """Establishes a synchronous PostgreSQL connection using Entra authentication. + + This connection class automatically acquires Azure Entra ID credentials when user + or password are not provided in the DSN or connection parameters. Authentication + errors are printed to console for debugging purposes. + + Parameters: + dsn (str): PostgreSQL connection string (Data Source Name). + **kwargs: Additional keyword arguments including: + - credential (TokenCredential, optional): Azure credential for token acquisition. + If None, DefaultAzureCredential() is used. + - user (str, optional): Database username. If not provided, extracted from Entra token. + - password (str, optional): Database password. If not provided, uses Entra access token. + + Raises: + CredentialValueError: If the provided credential is not a valid TokenCredential. + EntraConnectionValueError: If Entra connection credentials cannot be retrieved + """ + + def __init__(self, dsn: str, **kwargs: Any) -> None: + # Extract current DSN params + dsn_params = parse_dsn(dsn) if dsn else {} + + credential = kwargs.pop("credential", None) + if credential and not isinstance(credential, (TokenCredential)): + raise CredentialValueError( + "credential must be a TokenCredential for sync connections" + ) + + # Check if user and password are already provided + has_user = "user" in dsn_params or "user" in kwargs + has_password = "password" in dsn_params or "password" in kwargs + + # Only get Entra credentials if user or password is missing + if not has_user or not has_password: + try: + entra_creds = get_entra_conninfo(credential) + except (Exception) as e: + raise EntraConnectionValueError( + "Could not retrieve Entra credentials" + ) from e + + # Only update missing credentials + if not has_user and "user" in entra_creds: + dsn_params["user"] = entra_creds["user"] + if not has_password and "password" in entra_creds: + dsn_params["password"] = entra_creds["password"] + + # Update DSN params with any kwargs (kwargs take precedence) + dsn_params.update(kwargs) + + # Create new DSN with updated credentials + new_dsn = make_dsn(**dsn_params) + + # Call parent constructor with updated DSN only + super().__init__(new_dsn) diff --git a/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py b/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py deleted file mode 100644 index 4eeb1a4..0000000 --- a/python/src/azurepg_entra/psycopg2/psycopg2_entra_id_extension.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -""" -Connection classes for using Entra auth with Azure DB for PostgreSQL (psycopg2 + aiopg version). -This module provides both synchronous and asynchronous connection classes that allow you to connect to Azure DB for PostgreSQL -using Entra authentication. Uses psycopg2 for sync connections and aiopg for async connections. - -Sync Example (psycopg2): - from azurepg_entra.psycopg2 import connect_with_entra - - conn = connect_with_entra(host="myserver.postgres.database.azure.com", dbname="mydatabase") - -Async Example (aiopg): - from azurepg_entra.psycopg2 import connect_with_entra_async - - conn = await connect_with_entra_async(host="myserver.postgres.database.azure.com", dbname="mydatabase") - -Note: Async functionality requires aiopg: pip install aiopg -""" - -from typing import Any, Optional, TYPE_CHECKING, cast - -import psycopg2 - -from azurepg_entra.core import get_entra_conninfo, get_entra_conninfo_async - -if TYPE_CHECKING: - import aiopg -else: - try: - import aiopg - except ImportError: - aiopg = None - -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential -from azure.identity import DefaultAzureCredential as DefaultAzureCredential -from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential - -AZURE_DB_FOR_POSTGRES_SCOPE = "https://ossrdbms-aad.database.windows.net/.default" -AZURE_MANAGEMENT_SCOPE = "https://management.azure.com/.default" - -def connect_with_entra(credential: Optional[TokenCredential] = None, **kwargs: Any) -> psycopg2.extensions.connection: - """Creates a synchronous PostgreSQL connection using Entra authentication. - - This function handles Azure Entra ID token acquisition and creates a psycopg2 connection - with the appropriate user and password parameters. - - Parameters: - credential (TokenCredential, optional): The credential used for token acquisition. - If None, the default Azure credentials are used. - **kwargs: Additional connection parameters (host, port, dbname, etc.) - - Returns: - psycopg2.extensions.connection: An open synchronous connection to PostgreSQL. - - Raises: - ValueError: If the provided credential is not a valid TokenCredential. - """ - credential = credential or DefaultAzureCredential() - if credential and not isinstance(credential, TokenCredential): - raise ValueError("credential must be a TokenCredential for synchronous connections") - - # Check if we need to acquire Entra authentication info - if not kwargs.get("user") or not kwargs.get("password"): - entra_conninfo = get_entra_conninfo(credential) - # Always use the token password when Entra authentication is needed - kwargs["password"] = entra_conninfo["password"] - if not kwargs.get("user"): - # If user isn't already set, use the username from the token - kwargs["user"] = entra_conninfo["user"] - - return cast(psycopg2.extensions.connection, psycopg2.connect(**kwargs)) - -async def connect_with_entra_async(credential: Optional[AsyncTokenCredential] = None, **kwargs: Any) -> aiopg.Connection: - """Creates an asynchronous PostgreSQL connection using Entra authentication. - - This function handles Azure Entra ID token acquisition and creates an aiopg connection - with the appropriate user and password parameters. - - Parameters: - credential (AsyncTokenCredential, optional): The async credential used for token acquisition. - If None, the default Azure credentials are used. - **kwargs: Additional connection parameters (host, port, dbname, etc.) - - Returns: - aiopg connection: An open asynchronous connection to PostgreSQL. - - Raises: - ImportError: If aiopg is not installed. - ValueError: If the provided credential is not a valid AsyncTokenCredential. - """ - if aiopg is None: - raise ImportError( - "aiopg is required for async connections. Install with: pip install aiopg" - ) - - credential = credential or AsyncDefaultAzureCredential() - if credential and not isinstance(credential, AsyncTokenCredential): - raise ValueError("credential must be an AsyncTokenCredential for async connections") - - # Check if we need to acquire Entra authentication info - if not kwargs.get("user") or not kwargs.get("password"): - entra_conninfo = await get_entra_conninfo_async(credential) - # Always use the token password when Entra authentication is needed - kwargs["password"] = entra_conninfo["password"] - if not kwargs.get("user"): - # If user isn't already set, use the username from the token - kwargs["user"] = entra_conninfo["user"] - - return await aiopg.connect(**kwargs) \ No newline at end of file diff --git a/python/src/azurepg_entra/psycopg3/__init__.py b/python/src/azurepg_entra/psycopg3/__init__.py index 4eace8d..86f8d8e 100644 --- a/python/src/azurepg_entra/psycopg3/__init__.py +++ b/python/src/azurepg_entra/psycopg3/__init__.py @@ -1,50 +1,24 @@ # Copyright (c) Microsoft. All rights reserved. + """ -Psycopg3 (psycopg) support for Azure Entra ID authentication with Azure Database for PostgreSQL. +Psycopg3 support for Azure Entra ID authentication with Azure Database for PostgreSQL. This module provides connection classes that extend psycopg's Connection and AsyncConnection to automatically handle Azure Entra ID token acquisition and authentication. Requirements: Install with: pip install azurepg-entra[psycopg3] - + This will install: - psycopg[binary]>=3.1.0 - - psycopg-pool>=3.1.0 + - aiohttp>=3.8.0 Classes: - SyncEntraConnection: Synchronous connection class with Entra ID authentication + EntraConnection: Synchronous connection class with Entra ID authentication AsyncEntraConnection: Asynchronous connection class with Entra ID authentication - -Example usage: - from azurepg_entra.psycopg3 import SyncEntraConnection, AsyncEntraConnection - from psycopg_pool import ConnectionPool, AsyncConnectionPool - - # Synchronous usage - pool = ConnectionPool( - conninfo="postgresql://myserver:5432/mydb", - connection_class=SyncEntraConnection - ) - - # Asynchronous usage - async_pool = AsyncConnectionPool( - conninfo="postgresql://myserver:5432/mydb", - connection_class=AsyncEntraConnection - ) """ -try: - from .psycopg3_entra_id_extension import ( - SyncEntraConnection, - AsyncEntraConnection - ) - __all__ = [ - "SyncEntraConnection", - "AsyncEntraConnection" - ] -except ImportError as e: - # Provide a helpful error message if psycopg dependencies are missing - raise ImportError( - "psycopg3 dependencies are not installed. " - "Install them with: pip install azurepg-entra[psycopg3]" - ) from e \ No newline at end of file +from .async_entra_connection import AsyncEntraConnection +from .entra_connection import EntraConnection + +__all__ = ["EntraConnection", "AsyncEntraConnection"] diff --git a/python/src/azurepg_entra/psycopg3/async_entra_connection.py b/python/src/azurepg_entra/psycopg3/async_entra_connection.py new file mode 100644 index 0000000..f694020 --- /dev/null +++ b/python/src/azurepg_entra/psycopg3/async_entra_connection.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft. All rights reserved. + +from typing import Any + +from azure.core.credentials_async import AsyncTokenCredential + +try: + from psycopg import AsyncConnection +except ImportError as e: + raise ImportError( + "psycopg3 dependencies are not installed. " + "Install them with: pip install azurepg-entra[psycopg3]" + ) from e + +from azurepg_entra.core import get_entra_conninfo_async +from azurepg_entra.errors import ( + CredentialValueError, + EntraConnectionValueError, +) + + +class AsyncEntraConnection(AsyncConnection): + """Asynchronous connection class for using Entra authentication with Azure PostgreSQL.""" + + @classmethod + async def connect(cls, *args: Any, **kwargs: Any) -> "AsyncEntraConnection": + """Establishes an asynchronous PostgreSQL connection using Entra authentication. + + This method automatically acquires Azure Entra ID credentials when user or password + are not provided in the connection parameters. Authentication errors are printed to + console for debugging purposes. + + Parameters: + *args: Positional arguments to be forwarded to the parent connection method. + **kwargs: Keyword arguments including: + - credential (AsyncTokenCredential, optional): Async Azure credential for token acquisition. + - user (str, optional): Database username. If not provided, extracted from Entra token. + - password (str, optional): Database password. If not provided, uses Entra access token. + + Returns: + AsyncEntraConnection: An open asynchronous connection to the PostgreSQL database. + + Raises: + CredentialValueError: If the provided credential is not a valid AsyncTokenCredential. + EntraConnectionValueError: If Entra connection credentials are invalid. + """ + credential = kwargs.pop("credential", None) + if credential and not isinstance(credential, (AsyncTokenCredential)): + raise CredentialValueError( + "credential must be an AsyncTokenCredential for async connections" + ) + + # Check if we need to acquire Entra authentication info + if not kwargs.get("user") or not kwargs.get("password"): + try: + entra_conninfo = await get_entra_conninfo_async(credential) + except Exception as e: + raise EntraConnectionValueError( + "Could not retrieve Entra credentials" + ) from e + # Always use the token password when Entra authentication is needed + kwargs["password"] = entra_conninfo["password"] + if not kwargs.get("user"): + # If user isn't already set, use the username from the token + kwargs["user"] = entra_conninfo["user"] + return await super().connect(*args, **kwargs) diff --git a/python/src/azurepg_entra/psycopg3/entra_connection.py b/python/src/azurepg_entra/psycopg3/entra_connection.py new file mode 100644 index 0000000..9197736 --- /dev/null +++ b/python/src/azurepg_entra/psycopg3/entra_connection.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft. All rights reserved. + +from typing import Any + +from azure.core.credentials import TokenCredential + +try: + from psycopg import Connection +except ImportError as e: + raise ImportError( + "psycopg3 dependencies are not installed. " + "Install them with: pip install azurepg-entra[psycopg3]" + ) from e + +from azurepg_entra.core import get_entra_conninfo +from azurepg_entra.errors import ( + CredentialValueError, + EntraConnectionValueError, +) + + +class EntraConnection(Connection): + """Synchronous connection class for using Entra authentication with Azure PostgreSQL.""" + + @classmethod + def connect(cls, *args: Any, **kwargs: Any) -> "EntraConnection": + """Establishes a synchronous PostgreSQL connection using Entra authentication. + + This method automatically acquires Azure Entra ID credentials when user or password + are not provided in the connection parameters. If authentication fails, the original + exception is re-raised to the caller. + + Parameters: + *args: Positional arguments to be forwarded to the parent connection method. + **kwargs: Keyword arguments including: + - credential (TokenCredential, optional): Azure credential for token acquisition. + - user (str, optional): Database username. If not provided, extracted from Entra token. + - password (str, optional): Database password. If not provided, uses Entra access token. + + Returns: + EntraConnection: An open synchronous connection to the PostgreSQL database. + + Raises: + CredentialValueError: If the provided credential is not a valid TokenCredential. + EntraConnectionValueError: If Entra connection credentials cannot be retrieved + """ + credential = kwargs.pop("credential", None) + if credential and not isinstance(credential, (TokenCredential)): + raise CredentialValueError( + "credential must be a TokenCredential for sync connections" + ) + + # Check if we need to acquire Entra authentication info + if not kwargs.get("user") or not kwargs.get("password"): + try: + entra_conninfo = get_entra_conninfo(credential) + except Exception as e: + raise EntraConnectionValueError( + "Could not retrieve Entra credentials" + ) from e + # Always use the token password when Entra authentication is needed + kwargs["password"] = entra_conninfo["password"] + if not kwargs.get("user"): + # If user isn't already set, use the username from the token + kwargs["user"] = entra_conninfo["user"] + return super().connect(*args, **kwargs) diff --git a/python/src/azurepg_entra/psycopg3/psycopg3_entra_id_extension.py b/python/src/azurepg_entra/psycopg3/psycopg3_entra_id_extension.py deleted file mode 100644 index e32c4c6..0000000 --- a/python/src/azurepg_entra/psycopg3/psycopg3_entra_id_extension.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -""" -Connection classes for using Entra auth with Azure DB for PostgreSQL. -This module provides both synchronous and asynchronous connection classes that allow you to connect to Azure DB for PostgreSQL -using Entra authentication with psycopg. It handles token acquisition and connection setup. It is not specific -to this repository and can be used in any project that requires Entra authentication with Azure DB for PostgreSQL. - -For example: - - from azure_pg_entra import AsyncEntraConnection - from psycopg_pool import AsyncConnectionPool - - async with AsyncConnectionPool("", connection_class=AsyncEntraConnection) as pool: - ... - -""" - -from typing import Any -try: - from typing import Self -except ImportError: - from typing_extensions import Self # fallback for older Python - -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential -from azurepg_entra.core import get_entra_conninfo, get_entra_conninfo_async -from psycopg import AsyncConnection, Connection - -class SyncEntraConnection(Connection[tuple[Any, ...]]): - """Synchronous connection class for using Entra authentication with Azure PostgreSQL.""" - - @classmethod - def connect(cls, *args: Any, **kwargs: Any) -> Self: - """Establishes a synchronous PostgreSQL connection using Entra authentication. - - The method checks for provided credentials. If the 'user' or 'password' are not set - in the keyword arguments, it acquires them from Entra via the provided or default credential. - - Parameters: - *args: Positional arguments to be forwarded to the parent connection method. - **kwargs: Keyword arguments including optional 'credential', and optionally 'user' and 'password'. - - Returns: - SyncEntraConnection: An open synchronous connection to the PostgreSQL database. - - Raises: - ValueError: If the provided credential is not a valid TokenCredential. - """ - credential = kwargs.pop("credential", None) - if credential: - if isinstance(credential, AsyncTokenCredential): - raise ValueError("credential must be a TokenCredential for synchronous connections") - if not isinstance(credential, TokenCredential): - raise ValueError("credential must be a TokenCredential for synchronous connections") - - # Check if we need to acquire Entra authentication info - if not kwargs.get("user") or not kwargs.get("password"): - entra_conninfo = get_entra_conninfo(credential) - # Always use the token password when Entra authentication is needed - kwargs["password"] = entra_conninfo["password"] - if not kwargs.get("user"): - # If user isn't already set, use the username from the token - kwargs["user"] = entra_conninfo["user"] - return super().connect(*args, **kwargs) - - -class AsyncEntraConnection(AsyncConnection[tuple[Any, ...]]): - """Asynchronous connection class for using Entra authentication with Azure PostgreSQL.""" - - @classmethod - async def connect(cls, *args: Any, **kwargs: Any) -> Self: - """Establishes an asynchronous PostgreSQL connection using Entra authentication. - - The method checks for provided credentials. If the 'user' or 'password' are not set - in the keyword arguments, it acquires them from Entra via the provided or default credential. - - Parameters: - *args: Positional arguments to be forwarded to the parent connection method. - **kwargs: Keyword arguments including optional 'credential', and optionally 'user' and 'password'. - - Returns: - AsyncEntraConnection: An open asynchronous connection to the PostgreSQL database. - - Raises: - ValueError: If the provided credential is not a valid AsyncTokenCredential. - """ - credential = kwargs.pop("credential", None) - if credential and not isinstance(credential, (AsyncTokenCredential)): - raise ValueError("credential must be an AsyncTokenCredential for async connections") - - # Check if we need to acquire Entra authentication info - if not kwargs.get("user") or not kwargs.get("password"): - entra_conninfo = await get_entra_conninfo_async(credential) - # Always use the token password when Entra authentication is needed - kwargs["password"] = entra_conninfo["password"] - if not kwargs.get("user"): - # If user isn't already set, use the username from the token - kwargs["user"] = entra_conninfo["user"] - return await super().connect(*args, **kwargs) \ No newline at end of file diff --git a/python/src/azurepg_entra/py.typed b/python/src/azurepg_entra/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/python/src/azurepg_entra/sqlalchemy/__init__.py b/python/src/azurepg_entra/sqlalchemy/__init__.py index 7e91176..5767946 100644 --- a/python/src/azurepg_entra/sqlalchemy/__init__.py +++ b/python/src/azurepg_entra/sqlalchemy/__init__.py @@ -1,14 +1,28 @@ # Copyright (c) Microsoft. All rights reserved. + """ SQLAlchemy integration for Azure PostgreSQL with Entra ID authentication. + +This module provides integration between SQLAlchemy and Azure Entra ID +authentication for PostgreSQL connections. It automatically handles token acquisition +and credential injection through SQLAlchemy's event system. + +Requirements: + Install with: pip install azurepg-entra[sqlalchemy] + + This will install: + - sqlalchemy>=2.0.0 + - aiohttp>=3.8.0 + +Functions: + enable_entra_authentication: Enable Entra ID authentication for synchronous SQLAlchemy engines + enable_entra_authentication_async: Enable Entra ID authentication for asynchronous SQLAlchemy engines """ -from .sqlalchemy_entra_id_extension import ( - create_engine_with_entra, - create_async_engine_with_entra, -) +from .async_entra_connection import enable_entra_authentication_async +from .entra_connection import enable_entra_authentication __all__ = [ - "create_engine_with_entra", - "create_async_engine_with_entra", -] \ No newline at end of file + "enable_entra_authentication", + "enable_entra_authentication_async", +] diff --git a/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py b/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py new file mode 100644 index 0000000..fc4848c --- /dev/null +++ b/python/src/azurepg_entra/sqlalchemy/async_entra_connection.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft. All rights reserved. + +from typing import Any + +from azure.core.credentials import TokenCredential + +try: + from sqlalchemy import event + from sqlalchemy.engine import Dialect + from sqlalchemy.ext.asyncio import AsyncEngine +except ImportError as e: + # Provide a helpful error message if SQLAlchemy dependencies are missing + raise ImportError( + "SQLAlchemy dependencies are not installed. " + "Install them with: pip install azurepg-entra[sqlalchemy]" + ) from e + +from azurepg_entra.core import get_entra_conninfo +from azurepg_entra.errors import ( + CredentialValueError, + EntraConnectionValueError, +) + + +def enable_entra_authentication_async(engine: AsyncEngine) -> None: + """ + Enable Azure Entra ID authentication for an async SQLAlchemy engine. + + This function registers an event listener that automatically provides + Entra ID credentials for each database connection if they are not already set. + Event handlers do not support async behavior so the token fetching will still + be synchronous. + + Args: + engine: The async SQLAlchemy Engine to enable Entra authentication for + """ + + @event.listens_for(engine.sync_engine, "do_connect") + def provide_token( + dialect: Dialect, conn_rec: Any, cargs: Any, cparams: dict[str, Any] + ) -> None: + """Event handler that provides Entra credentials for each sync connection. + + Raises: + CredentialValueError: If the provided credential is not a valid TokenCredential. + EntraConnectionValueError: If Entra connection credentials cannot be retrieved + """ + credential = cparams.get("credential", None) + if credential and not isinstance(credential, (TokenCredential)): + raise CredentialValueError( + "credential must be a TokenCredential for async connections" + ) + # Check if credentials are already present + has_user = "user" in cparams + has_password = "password" in cparams + + # Only get Entra credentials if user or password is missing + if not has_user or not has_password: + try: + entra_creds = get_entra_conninfo(credential) + except Exception as e: + raise EntraConnectionValueError( + "Could not retrieve Entra credentials" + ) from e + # Only update missing credentials + if not has_user and "user" in entra_creds: + cparams["user"] = entra_creds["user"] + if not has_password and "password" in entra_creds: + cparams["password"] = entra_creds["password"] diff --git a/python/src/azurepg_entra/sqlalchemy/entra_connection.py b/python/src/azurepg_entra/sqlalchemy/entra_connection.py new file mode 100644 index 0000000..f4cf80c --- /dev/null +++ b/python/src/azurepg_entra/sqlalchemy/entra_connection.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft. All rights reserved. + +from typing import Any + +from azure.core.credentials import TokenCredential + +try: + from sqlalchemy import Engine, event + from sqlalchemy.engine import Dialect +except ImportError as e: + raise ImportError( + "SQLAlchemy dependencies are not installed. " + "Install them with: pip install azurepg-entra[sqlalchemy]" + ) from e + +from azurepg_entra.core import get_entra_conninfo +from azurepg_entra.errors import ( + CredentialValueError, + EntraConnectionValueError, +) + + +def enable_entra_authentication(engine: Engine) -> None: + """ + Enable Azure Entra ID authentication for a SQLAlchemy engine. + + This function registers an event listener that automatically provides + Entra ID credentials for each database connection if they are not already set. + + Args: + engine: The SQLAlchemy Engine to enable Entra authentication for + """ + + @event.listens_for(engine, "do_connect") + def provide_token( + dialect: Dialect, conn_rec: Any, cargs: Any, cparams: dict[str, Any] + ) -> None: + """Event handler that provides Entra credentials for each connection. + + Raises: + CredentialValueError: If the provided credential is not a valid TokenCredential. + EntraConnectionValueError: If Entra connection credentials cannot be retrieved + """ + credential = cparams.get("credential", None) + if credential and not isinstance(credential, (TokenCredential)): + raise CredentialValueError( + "credential must be a TokenCredential for sync connections" + ) + # Check if credentials are already present + has_user = "user" in cparams + has_password = "password" in cparams + + # Only get Entra credentials if user or password is missing + if not has_user or not has_password: + try: + entra_creds = get_entra_conninfo(credential) + except Exception as e: + raise EntraConnectionValueError( + "Could not retrieve Entra credentials" + ) from e + # Only update missing credentials + if not has_user and "user" in entra_creds: + cparams["user"] = entra_creds["user"] + if not has_password and "password" in entra_creds: + cparams["password"] = entra_creds["password"] diff --git a/python/src/azurepg_entra/sqlalchemy/sqlalchemy_entra_id_extension.py b/python/src/azurepg_entra/sqlalchemy/sqlalchemy_entra_id_extension.py deleted file mode 100644 index 030ce0f..0000000 --- a/python/src/azurepg_entra/sqlalchemy/sqlalchemy_entra_id_extension.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -import psycopg -import logging -from typing import Optional, Any, TYPE_CHECKING -from urllib.parse import urlparse, urlunparse -from azurepg_entra.core import get_entra_conninfo, get_entra_conninfo_async -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential -from azure.identity import DefaultAzureCredential as DefaultAzureCredential -from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential -from sqlalchemy.engine.interfaces import DBAPIConnection -from sqlalchemy.ext.asyncio.engine import AsyncEngine - -try: - from sqlalchemy import create_engine, Engine -except ImportError: - raise ImportError("sqlalchemy is required. Install with: pip install sqlalchemy") - -if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import create_async_engine -else: - try: - from sqlalchemy.ext.asyncio import create_async_engine - except ImportError: - create_async_engine = None - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def create_engine_with_entra( - url: str, - credential: Optional[TokenCredential] = None, - **kwargs: Any -) -> Engine: - """Creates a SQLAlchemy Engine using Entra authentication for Azure PostgreSQL. - - This function handles Azure Entra ID token acquisition and creates a SQLAlchemy engine - that automatically refreshes tokens for each new connection. This solves the token - expiration issue by acquiring fresh tokens on each connection attempt. - - Parameters: - url (str): The database URL. Username and password will be replaced with Entra credentials. - credential (TokenCredential, optional): The credential used for token acquisition. - If None, the default Azure credentials are used. - **kwargs: Additional engine creation parameters passed to create_engine() - - Returns: - Engine: A SQLAlchemy engine configured with Entra authentication and automatic token refresh. - - Raises: - ValueError: If the provided credential is not a valid TokenCredential. - - Example: - engine = create_engine_with_entra( - "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - ) - """ - credential = credential or DefaultAzureCredential() - if credential and not isinstance(credential, TokenCredential): - raise ValueError("credential must be a TokenCredential for synchronous engines") - - # Parse the original URL to extract connection parameters - parsed = urlparse(url) - - def connect_with_fresh_token() -> DBAPIConnection | None: - """Custom connection factory that gets a fresh token each time.""" - logger.info("Creating new connection with fresh Entra token") - - # Get fresh Entra authentication info for each connection - entra_conninfo = get_entra_conninfo(credential) - - # Build authenticated URL with fresh token - parsed_copy = parsed._replace( - netloc=f"{entra_conninfo['user']}:{entra_conninfo['password']}@{parsed.hostname}" + - (f":{parsed.port}" if parsed.port else "") - ) - auth_url = urlunparse(parsed_copy) - - # Create a temporary engine with the authenticated URL and get the DBAPI connection - temp_engine = create_engine(auth_url) - raw_conn = temp_engine.raw_connection() - # Return the underlying DBAPI connection, not the SQLAlchemy wrapper - return raw_conn.dbapi_connection - - # Create base URL without credentials for the engine - base_url = f"{parsed.scheme or 'postgresql'}://{parsed.hostname}" - if parsed.port: - base_url += f":{parsed.port}" - if parsed.path: - base_url += parsed.path - if parsed.query: - base_url += f"?{parsed.query}" - - # Create engine with custom connection factory - return create_engine(base_url, creator=connect_with_fresh_token, **kwargs) - -def create_async_engine_with_entra( - url: str, - credential: Optional[AsyncTokenCredential] = None, - **kwargs: Any -) -> AsyncEngine: - """Creates an async SQLAlchemy Engine using Entra authentication for Azure PostgreSQL. - - This function handles Azure Entra ID token acquisition and creates an async SQLAlchemy engine - that automatically refreshes tokens for each new connection. This solves the token - expiration issue by acquiring fresh tokens on each connection attempt. - - Parameters: - url (str): The database URL. Username and password will be replaced with Entra credentials. - credential (AsyncTokenCredential, optional): The async credential used for token acquisition. - If None, the default Azure credentials are used. - **kwargs: Additional engine creation parameters passed to create_async_engine() - - Returns: - AsyncEngine: An async SQLAlchemy engine configured with Entra authentication and automatic token refresh. - - Raises: - ImportError: If sqlalchemy.ext.asyncio is not available. - ValueError: If the provided credential is not a valid AsyncTokenCredential. - - Example: - engine = await create_async_engine_with_entra( - "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - ) - """ - if create_async_engine is None: - raise ImportError( - "sqlalchemy.ext.asyncio is required for async engines. " - "Install with: pip install sqlalchemy[asyncio]" - ) - - credential = credential or AsyncDefaultAzureCredential() - if credential and not isinstance(credential, AsyncTokenCredential): - raise ValueError("credential must be an AsyncTokenCredential for async engines") - - # Parse the original URL to extract connection parameters - parsed = urlparse(url) - - async def async_connect_with_fresh_token() -> psycopg.AsyncConnection: - """Custom async connection factory that gets a fresh token each time.""" - logger.info("Creating new async connection with fresh Entra token") - - # Get fresh Entra authentication info for each connection - entra_conninfo = await get_entra_conninfo_async(credential) - - # For async, we need to return the raw async connection directly - # Import the appropriate async driver and create connection directly - if parsed.scheme == 'postgresql+psycopg' or parsed.scheme == 'postgresql': - return await psycopg.AsyncConnection.connect( - host=parsed.hostname, - port=parsed.port or 5432, - dbname=parsed.path.lstrip('/') if parsed.path else 'postgres', - user=entra_conninfo['user'], - password=entra_conninfo['password'] - ) - else: - raise ValueError(f"Unsupported async URL scheme: {parsed.scheme}. Use postgresql+psycopg or postgresql") - - # Create base URL without credentials for the engine - base_url = f"{parsed.scheme}://{parsed.hostname}" - if parsed.port: - base_url += f":{parsed.port}" - if parsed.path: - base_url += parsed.path - if parsed.query: - base_url += f"?{parsed.query}" - - # Create async engine with custom connection factory - return create_async_engine(base_url, async_creator=async_connect_with_fresh_token, **kwargs) diff --git a/python/tests/azure/data/postgresql/psycopg2/test_entra_id_extension.py b/python/tests/azure/data/postgresql/psycopg2/test_entra_id_extension.py new file mode 100644 index 0000000..4a51038 --- /dev/null +++ b/python/tests/azure/data/postgresql/psycopg2/test_entra_id_extension.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft. All rights reserved. + +import base64 +import json +from unittest.mock import patch + +import pytest +from psycopg2.extensions import make_dsn, parse_dsn + + +def create_test_token(payload): + """Helper to create a test JWT token manually.""" + # Create a simple JWT-like token with header.payload.signature format + header = {"alg": "none", "typ": "JWT"} + header_encoded = ( + base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=") + ) + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) + signature = "" + return f"{header_encoded}.{payload_encoded}.{signature}" + + +class TestEntraConnection: + def test_dsn_processing_adds_entra_credentials(self): + """Test that EntraConnection logic correctly merges Entra credentials into DSN.""" + payload = {"upn": "user@example.com"} + token = create_test_token(payload) + + with patch("azurepg_entra.core.get_entra_conninfo") as mock_get_creds: + mock_get_creds.return_value = { + "user": "user@example.com", + "password": token, + } + + # Test with existing DSN parameters + original_dsn = "host=localhost port=5432 dbname=testdb sslmode=require" + entra_creds = mock_get_creds(None) + + dsn_params = parse_dsn(original_dsn) if original_dsn else {} + dsn_params.update(entra_creds) + new_dsn = make_dsn(**dsn_params) + + mock_get_creds.assert_called_once_with(None) + + # Original params preserved + assert "host=localhost" in new_dsn + assert "port=5432" in new_dsn + assert "dbname=testdb" in new_dsn + assert "sslmode=require" in new_dsn + # Entra creds added + assert "user=user@example.com" in new_dsn + assert f"password={token}" in new_dsn + + +if __name__ == "__main__": + import sys + + exit_code = pytest.main([__file__, "-v", "--tb=short"]) + sys.exit(exit_code) diff --git a/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py b/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py deleted file mode 100644 index bbc6dbf..0000000 --- a/python/tests/azure/data/postgresql/psycopg2/test_psycopg2_entra_id_extension.py +++ /dev/null @@ -1,460 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -""" -Unit Tests for Azure PostgreSQL psycopg Entra ID Extension - -This test suite demonstrates and validates the Azure Entra ID authentication -functionality for PostgreSQL connections using psycopg2. These tests serve as -both validation and examples of how to use the extension. - -Test Categories: -1. JWT Token Decoding - Validates Azure token processing -2. Principal Name Parsing - Tests managed identity resource path parsing -3. Connection Info Generation - Tests core authentication logic -4. Sync/Async Connection Classes - Validates connection establishment - -Key Testing Patterns: -- Every synchronous test has an equivalent asynchronous test -- Comprehensive mocking to avoid external dependencies -- Edge case validation for robust error handling -- Clear test naming that describes expected behavior - -Usage: - # Run all tests - pytest test_psycopg_entra_id_extension.py - - # Run specific test class - pytest test_psycopg2_entra_id_extension.py::TestConnectWithEntra - - # Run with verbose output - pytest -v test_psycopg_entra_id_extension.py - -Dependencies: - pip install pytest pytest-asyncio - -For more information about Azure Entra ID authentication: -https://docs.microsoft.com/en-us/azure/postgresql/concepts-aad-authentication -""" - -import base64 -import json -import pytest -from unittest.mock import AsyncMock, Mock, patch -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential - -from azurepg_entra.psycopg2 import ( - connect_with_entra, - connect_with_entra_async, - decode_jwt, - get_entra_conninfo, - get_entra_conninfo_async, - parse_principal_name -) - -# Test Configuration -# These tests use mocking to avoid requiring actual Azure credentials or database connections. -# For integration testing with real Azure resources, see the samples/ directory. - - -def create_test_token(payload): - """Helper to create a test JWT token.""" - encoded_payload = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') - return f"header.{encoded_payload}.signature" - - -class TestDecodeJwt: - """ - Tests for JWT token decoding functionality. - - These tests validate that Azure Entra ID tokens are properly decoded - to extract user information. The extension supports various token formats - and claim structures that Azure may provide. - """ - - def test_decode_jwt_valid_token(self): - """Test decoding a valid JWT token with UPN (User Principal Name) claim.""" - # UPN is the most common claim for user identity in Azure AD tokens - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - def test_decode_jwt_with_padding(self): - """Test decoding JWT token that requires base64 padding.""" - payload = {"preferred_username": "testuser", "exp": 9999999999} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - def test_decode_jwt_minimal_payload(self): - """Test decoding JWT with minimal payload.""" - payload = {"unique_name": "user123"} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - -class TestParsePrincipalName: - """ - Tests for Azure resource path parsing. - - When using managed identities, Azure provides resource paths that need - to be parsed to extract the identity name for database authentication. - These tests ensure robust parsing of various path formats. - """ - - def test_parse_principal_name_valid_user_assigned(self): - """Test parsing a valid user-assigned identity resource path.""" - resource_path = "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-identity" - result = parse_principal_name(resource_path) - assert result == "my-identity" - - def test_parse_principal_name_empty_string(self): - """Test parsing an empty string returns None.""" - assert parse_principal_name("") is None - - def test_parse_principal_name_none(self): - """Test parsing None returns None.""" - assert parse_principal_name(None) is None - - def test_parse_principal_name_no_slash(self): - """Test parsing a string without slashes returns None.""" - assert parse_principal_name("no-slashes-here") is None - - def test_parse_principal_name_invalid_path(self): - """Test parsing an invalid resource path returns None.""" - result = parse_principal_name("/subscriptions/12345/resourcegroups/mygroup/providers/SomeOther/resource/my-identity") - assert result is None - - def test_parse_principal_name_missing_identity_name(self): - """Test parsing a path without identity name returns None.""" - result = parse_principal_name("/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/") - assert result is None - - def test_parse_principal_name_case_insensitive(self): - """Test parsing with different case variations.""" - resource_path = "/subscriptions/12345/resourcegroups/mygroup/providers/MICROSOFT.MANAGEDIDENTITY/USERASSIGNEDIDENTITIES/my-identity" - result = parse_principal_name(resource_path) - assert result == "my-identity" - - -class TestGetEntraConninfo: - """ - Tests for the core authentication logic (sync and async versions). - - These functions handle the complete flow of: - 1. Requesting Azure tokens with appropriate scopes - 2. Decoding tokens to extract user information - 3. Handling fallback scenarios for managed identities - 4. Returning connection parameters for psycopg - - Both synchronous and asynchronous patterns are tested to ensure - consistent behavior across different usage scenarios. - """ - - # Sync tests - def test_get_entra_conninfo_with_credential(self): - """Test getting connection info with sync credential and upn claim.""" - mock_credential = Mock(spec=TokenCredential) - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token', return_value=token): - result = get_entra_conninfo(mock_credential) - assert result == {"user": "user@example.com", "password": token} - - def test_get_entra_conninfo_no_username_claims(self): - """Test error when no username claims are present.""" - mock_credential = Mock(spec=TokenCredential) - payload = {"sub": "subject123", "iat": 1234567890} # No username claims - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): - get_entra_conninfo(mock_credential) - - def test_get_entra_conninfo_username_priority(self): - """Test that upn takes priority over other username claims.""" - mock_credential = Mock(spec=TokenCredential) - # Azure tokens may contain multiple username claims - test priority order - payload = { - "upn": "upn@example.com", # Highest priority - "preferred_username": "preferred@example.com", # Second priority - "unique_name": "unique@example.com" # Fallback option - } - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token', return_value=token): - result = get_entra_conninfo(mock_credential) - assert result["user"] == "upn@example.com" # Should use highest priority claim - - def test_get_entra_conninfo_fallback_to_management_scope(self): - """Test fallback to management scope when DB scope token has no username.""" - db_payload = {"sub": "subject123", "iat": 1234567890} - db_token = create_test_token(db_payload) - - mgmt_payload = { - "xms_mirid": "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fallback-identity" - } - mgmt_token = create_test_token(mgmt_payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token') as mock_get_token: - mock_get_token.side_effect = [db_token, mgmt_token] - - result = get_entra_conninfo(None) - assert result["user"] == "fallback-identity" - assert result["password"] == db_token - assert mock_get_token.call_count == 2 - - # Async tests - mirror of sync tests - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_with_credential(self): - """Test getting connection info with async credential and upn claim.""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token_async', return_value=token): - result = await get_entra_conninfo_async(mock_credential) - assert result == {"user": "user@example.com", "password": token} - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_no_username_claims(self): - """Test error when no username claims are present (async).""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = {"sub": "subject123", "iat": 1234567890} # No username claims - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token_async', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): - await get_entra_conninfo_async(mock_credential) - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_username_priority(self): - """Test that upn takes priority over other username claims (async).""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = { - "preferred_username": "preferred@example.com", - "unique_name": "unique@example.com", - "upn": "upn@example.com" - } - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token_async', return_value=token): - result = await get_entra_conninfo_async(mock_credential) - assert result["user"] == "upn@example.com" - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_fallback_to_management_scope(self): - """Test fallback to management scope when DB scope token has no username (async).""" - db_payload = {"sub": "subject123", "iat": 1234567890} - db_token = create_test_token(db_payload) - - mgmt_payload = { - "xms_mirid": "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fallback-identity" - } - mgmt_token = create_test_token(mgmt_payload) - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_token_async') as mock_get_token: - mock_get_token.side_effect = [db_token, mgmt_token] - - result = await get_entra_conninfo_async(None) - assert result["user"] == "fallback-identity" - assert result["password"] == db_token - assert mock_get_token.call_count == 2 - - -class TestConnectWithEntra: - """ - Tests for the connect_with_entra function. - - This function handles Azure Entra ID authentication for psycopg2 connections. - Tests validate: - - Proper credential handling and validation - - Fallback to standard authentication when credentials exist - - Integration with the underlying psycopg2.connect logic - """ - - def test_connect_with_user_and_password(self): - """Test connection when user and password are already provided (passthrough).""" - kwargs = { - "host": "localhost", - "port": 5432, - "user": "existing_user", - "password": "existing_password", - "dbname": "testdb" - } - - with patch('psycopg2.connect') as mock_connect: - mock_connection = Mock() - mock_connect.return_value = mock_connection - - result = connect_with_entra(**kwargs) - - mock_connect.assert_called_once() - call_args = mock_connect.call_args[1] - assert call_args["user"] == "existing_user" - assert call_args["password"] == "existing_password" - assert result == mock_connection - - def test_connect_without_user_password_with_credential(self): - """Test connection using Entra authentication with provided credential.""" - mock_credential = Mock(spec=TokenCredential) - kwargs = { - "host": "localhost", - "port": 5432, - "dbname": "testdb" - } - - expected_conninfo = {"user": "test@example.com", "password": "token123"} - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_conninfo', return_value=expected_conninfo) as mock_get_conninfo: - with patch('psycopg2.connect') as mock_connect: - mock_connection = Mock() - mock_connect.return_value = mock_connection - - result = connect_with_entra(credential=mock_credential, **kwargs) - - mock_get_conninfo.assert_called_once_with(mock_credential) - mock_connect.assert_called_once() - call_args = mock_connect.call_args[1] - assert call_args["user"] == "test@example.com" - assert call_args["password"] == "token123" - assert result == mock_connection - - def test_connect_invalid_credential_type(self): - """Test connection with invalid credential type raises error.""" - invalid_credential = "not_a_credential" - kwargs = {"host": "localhost"} - - with pytest.raises(ValueError, match="credential must be a TokenCredential for synchronous connections"): - connect_with_entra(credential=invalid_credential, **kwargs) - - -class TestConnectWithEntraAsync: - """ - Tests for the connect_with_entra_async function. - - This function handles Azure Entra ID authentication for aiopg connections. - Tests mirror the synchronous tests to ensure consistent behavior between - sync/async usage patterns. - """ - - @pytest.mark.asyncio - async def test_connect_with_user_and_password(self): - """Test connection when user and password are already provided (passthrough).""" - kwargs = { - "host": "localhost", - "port": 5432, - "user": "existing_user", - "password": "existing_password", - "dbname": "testdb" - } - - with patch('aiopg.connect', new_callable=AsyncMock) as mock_connect: - mock_connection = Mock() - mock_connect.return_value = mock_connection - - result = await connect_with_entra_async(**kwargs) - - mock_connect.assert_called_once() - call_args = mock_connect.call_args[1] - assert call_args["user"] == "existing_user" - assert call_args["password"] == "existing_password" - assert result == mock_connection - - @pytest.mark.asyncio - async def test_connect_without_user_password_with_credential(self): - """Test connection using Entra authentication with provided credential.""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - kwargs = { - "host": "localhost", - "port": 5432, - "dbname": "testdb" - } - - expected_conninfo = {"user": "test@example.com", "password": "token123"} - - with patch('azurepg_entra.psycopg2.psycopg2_entra_id_extension.get_entra_conninfo_async', return_value=expected_conninfo) as mock_get_conninfo: - with patch('aiopg.connect', new_callable=AsyncMock) as mock_connect: - mock_connection = Mock() - mock_connect.return_value = mock_connection - - result = await connect_with_entra_async(credential=mock_credential, **kwargs) - - mock_get_conninfo.assert_called_once_with(mock_credential) - mock_connect.assert_called_once() - call_args = mock_connect.call_args[1] - assert call_args["user"] == "test@example.com" - assert call_args["password"] == "token123" - assert result == mock_connection - - @pytest.mark.asyncio - async def test_connect_invalid_credential_type(self): - """Test connection with invalid credential type raises error.""" - invalid_credential = "not_a_credential" - kwargs = {"host": "localhost"} - - with pytest.raises(ValueError, match="credential must be an AsyncTokenCredential for async connections"): - await connect_with_entra_async(credential=invalid_credential, **kwargs) - - -# Example usage and test runner -if __name__ == "__main__": - """ - Direct execution example for development and validation. - - This runs all tests with verbose output, which is helpful for: - - Understanding test coverage - - Debugging test failures - - Learning expected behavior patterns - - For CI/CD or automated testing, use pytest directly: - pytest test_psycopg_entra_id_extension.py -v --tb=short - """ - import sys - - # Run with verbose output and short traceback format - exit_code = pytest.main([__file__, "-v", "--tb=short"]) - - # Provide helpful guidance based on results - if exit_code == 0: - print("\nāœ… All tests passed! The Azure Entra ID extension is working correctly.") - print("šŸ’” Next steps: Try running the samples in samples/psycopg/getting_started/") - else: - print(f"\nāŒ Some tests failed (exit code: {exit_code})") - print("šŸ’” Check the test output above for details on any failures.") - print("šŸ’” Ensure all dependencies are installed: pip install pytest pytest-asyncio") - - sys.exit(exit_code) - - -""" -Quick Start Guide for Understanding These Tests: - -1. Basic JWT Testing (TestDecodeJwt): - - Shows how Azure tokens are decoded to get user info - - Demonstrates different token claim formats - -2. Resource Path Parsing (TestParsePrincipalName): - - Tests managed identity resource path handling - - Shows how identity names are extracted for database auth - -3. Core Authentication Logic (TestGetEntraConninfo): - - Tests the main authentication flow - - Shows both sync and async patterns - - Demonstrates fallback mechanisms for edge cases - -4. Connection Functions (TestConnectWithEntra/TestConnectWithEntraAsync): - - Tests the customer-facing connection classes - - Shows how existing psycopg code can be easily adapted - - Validates credential handling and error cases - -Key Patterns to Notice: -- Every sync test has an async equivalent for consistency -- Mocking isolates tests from external dependencies -- Clear error messages help with debugging -- Tests serve as usage examples for developers - -For integration testing with real Azure resources, see the samples directory. -""" \ No newline at end of file diff --git a/python/tests/azure/data/postgresql/psycopg3/test_entra_id_extension.py b/python/tests/azure/data/postgresql/psycopg3/test_entra_id_extension.py new file mode 100644 index 0000000..c93214f --- /dev/null +++ b/python/tests/azure/data/postgresql/psycopg3/test_entra_id_extension.py @@ -0,0 +1,126 @@ +# Copyright (c) Microsoft. All rights reserved. + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + +from azurepg_entra.errors import CredentialValueError +from azurepg_entra.psycopg3 import ( + AsyncEntraConnection, + EntraConnection, +) + + +class TestSyncConnection: + def test_connect_with_existing_credentials(self): + """Test that existing user/password credentials are used without fetching Entra credentials.""" + kwargs = { + "host": "localhost", + "user": "existing_user", + "password": "existing_password", + } + + with patch("psycopg.Connection.connect") as mock_connect: + mock_connection = Mock() + mock_connect.return_value = mock_connection + + result = EntraConnection.connect(**kwargs) + + assert result == mock_connection + call_args = mock_connect.call_args[1] + assert call_args["user"] == "existing_user" + assert call_args["password"] == "existing_password" + + def test_connect_with_entra_credential(self): + """Test that Entra credentials are fetched and used when no user/password provided.""" + mock_credential = Mock(spec=TokenCredential) + kwargs = {"host": "localhost", "credential": mock_credential} + + with patch( + "azurepg_entra.psycopg3.entra_connection.get_entra_conninfo", + return_value={"user": "test@example.com", "password": "token123"}, + ): + with patch("psycopg.Connection.connect") as mock_connect: + mock_connection = Mock() + mock_connect.return_value = mock_connection + + result = EntraConnection.connect(**kwargs) + + assert result == mock_connection + call_args = mock_connect.call_args[1] + assert call_args["user"] == "test@example.com" + assert call_args["password"] == "token123" + + def test_connect_invalid_credential_type_throws(self): + """Test that invalid credential type raises CredentialValueError.""" + with pytest.raises( + CredentialValueError, + match="credential must be a TokenCredential for sync connections", + ): + EntraConnection.connect(host="localhost", credential="invalid") + + +class TestAsyncConnection: + @pytest.mark.asyncio + async def test_connect_with_existing_credentials(self): + """Test that existing user/password credentials are used without fetching Entra credentials (async).""" + kwargs = { + "host": "localhost", + "user": "existing_user", + "password": "existing_password", + } + + with patch( + "psycopg.AsyncConnection.connect", new_callable=AsyncMock + ) as mock_connect: + mock_connection = Mock() + mock_connect.return_value = mock_connection + + result = await AsyncEntraConnection.connect(**kwargs) + + assert result == mock_connection + call_args = mock_connect.call_args[1] + assert call_args["user"] == "existing_user" + assert call_args["password"] == "existing_password" + + @pytest.mark.asyncio + async def test_connect_with_entra_credential(self): + """Test that Entra credentials are fetched and used when no user/password provided (async).""" + mock_credential = AsyncMock(spec=AsyncTokenCredential) + kwargs = {"host": "localhost", "credential": mock_credential} + + with patch( + "azurepg_entra.psycopg3.async_entra_connection.get_entra_conninfo_async", + new_callable=AsyncMock, + return_value={"user": "test@example.com", "password": "token123"}, + ): + with patch( + "psycopg.AsyncConnection.connect", new_callable=AsyncMock + ) as mock_connect: + mock_connection = Mock() + mock_connect.return_value = mock_connection + + result = await AsyncEntraConnection.connect(**kwargs) + + assert result == mock_connection + call_args = mock_connect.call_args[1] + assert call_args["user"] == "test@example.com" + assert call_args["password"] == "token123" + + @pytest.mark.asyncio + async def test_connect_invalid_credential_type_throws(self): + """Test that invalid credential type raises CredentialValueError (async).""" + with pytest.raises( + CredentialValueError, + match="credential must be an AsyncTokenCredential for async connections", + ): + await AsyncEntraConnection.connect(host="localhost", credential="invalid") + + +if __name__ == "__main__": + import sys + + exit_code = pytest.main([__file__, "-v", "--tb=short"]) + sys.exit(exit_code) diff --git a/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py b/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py deleted file mode 100644 index 837141c..0000000 --- a/python/tests/azure/data/postgresql/psycopg3/test_psycopg3_entra_id_extension.py +++ /dev/null @@ -1,467 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -""" -Unit Tests for Azure PostgreSQL psycopg Entra ID Extension - -This test suite demonstrates and validates the Azure Entra ID authentication -functionality for PostgreSQL connections using psycopg3. These tests serve as -both validation and examples of how to use the extension. - -Test Categories: -1. JWT Token Decoding - Validates Azure token processing -2. Principal Name Parsing - Tests managed identity resource path parsing -3. Connection Info Generation - Tests core authentication logic -4. Sync/Async Connection Classes - Validates connection establishment - -Key Testing Patterns: -- Every synchronous test has an equivalent asynchronous test -- Comprehensive mocking to avoid external dependencies -- Edge case validation for robust error handling -- Clear test naming that describes expected behavior - -Usage: - # Run all tests - pytest test_psycopg_entra_id_extension.py - - # Run specific test class - pytest test_psycopg_entra_id_extension.py::TestSyncEntraConnection - - # Run with verbose output - pytest -v test_psycopg_entra_id_extension.py - -Dependencies: - pip install pytest pytest-asyncio - -For more information about Azure Entra ID authentication: -https://docs.microsoft.com/en-us/azure/postgresql/concepts-aad-authentication -""" - -import base64 -import json -import pytest -from unittest.mock import AsyncMock, Mock, patch -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential - -from azurepg_entra.psycopg3 import ( - AsyncEntraConnection, - SyncEntraConnection, - get_entra_conninfo, - get_entra_conninfo_async, - decode_jwt, - parse_principal_name, -) - -# Test Configuration -# These tests use mocking to avoid requiring actual Azure credentials or database connections. -# For integration testing with real Azure resources, see the samples/ directory. - - -def create_test_token(payload): - """Helper to create a test JWT token.""" - encoded_payload = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') - return f"header.{encoded_payload}.signature" - - -class TestDecodeJwt: - """ - Tests for JWT token decoding functionality. - - These tests validate that Azure Entra ID tokens are properly decoded - to extract user information. The extension supports various token formats - and claim structures that Azure may provide. - """ - - def test_decode_jwt_valid_token(self): - """Test decoding a valid JWT token with UPN (User Principal Name) claim.""" - # UPN is the most common claim for user identity in Azure AD tokens - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - def test_decode_jwt_with_padding(self): - """Test decoding JWT token that requires base64 padding.""" - payload = {"preferred_username": "testuser", "exp": 9999999999} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - def test_decode_jwt_minimal_payload(self): - """Test decoding JWT with minimal payload.""" - payload = {"unique_name": "user123"} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - -class TestParsePrincipalName: - """ - Tests for Azure resource path parsing. - - When using managed identities, Azure provides resource paths that need - to be parsed to extract the identity name for database authentication. - These tests ensure robust parsing of various path formats. - """ - - def test_parse_principal_name_valid_user_assigned(self): - """Test parsing a valid user-assigned identity resource path.""" - resource_path = "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-identity" - result = parse_principal_name(resource_path) - assert result == "my-identity" - - def test_parse_principal_name_empty_string(self): - """Test parsing an empty string returns None.""" - assert parse_principal_name("") is None - - def test_parse_principal_name_none(self): - """Test parsing None returns None.""" - assert parse_principal_name(None) is None - - def test_parse_principal_name_no_slash(self): - """Test parsing a string without slashes returns None.""" - assert parse_principal_name("no-slashes-here") is None - - def test_parse_principal_name_invalid_path(self): - """Test parsing an invalid resource path returns None.""" - result = parse_principal_name("/subscriptions/12345/resourcegroups/mygroup/providers/SomeOther/resource/my-identity") - assert result is None - - def test_parse_principal_name_missing_identity_name(self): - """Test parsing a path without identity name returns None.""" - result = parse_principal_name("/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/") - assert result is None - - def test_parse_principal_name_case_insensitive(self): - """Test parsing with different case variations.""" - resource_path = "/subscriptions/12345/resourcegroups/mygroup/providers/MICROSOFT.MANAGEDIDENTITY/USERASSIGNEDIDENTITIES/my-identity" - result = parse_principal_name(resource_path) - assert result == "my-identity" - - -class TestGetEntraConninfo: - """ - Tests for the core authentication logic (sync and async versions). - - These functions handle the complete flow of: - 1. Requesting Azure tokens with appropriate scopes - 2. Decoding tokens to extract user information - 3. Handling fallback scenarios for managed identities - 4. Returning connection parameters for psycopg - - Both synchronous and asynchronous patterns are tested to ensure - consistent behavior across different usage scenarios. - """ - - # Sync tests - def test_get_entra_conninfo_with_credential(self): - """Test getting connection info with sync credential and upn claim.""" - mock_credential = Mock(spec=TokenCredential) - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token', return_value=token): - result = get_entra_conninfo(mock_credential) - assert result == {"user": "user@example.com", "password": token} - - def test_get_entra_conninfo_no_username_claims(self): - """Test error when no username claims are present.""" - mock_credential = Mock(spec=TokenCredential) - payload = {"sub": "subject123", "iat": 1234567890} # No username claims - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): - get_entra_conninfo(mock_credential) - - def test_get_entra_conninfo_username_priority(self): - """Test that upn takes priority over other username claims.""" - mock_credential = Mock(spec=TokenCredential) - # Azure tokens may contain multiple username claims - test priority order - payload = { - "upn": "upn@example.com", # Highest priority - "preferred_username": "preferred@example.com", # Second priority - "unique_name": "unique@example.com" # Fallback option - } - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token', return_value=token): - result = get_entra_conninfo(mock_credential) - assert result["user"] == "upn@example.com" # Should use highest priority claim - - def test_get_entra_conninfo_fallback_to_management_scope(self): - """Test fallback to management scope when DB scope token has no username.""" - db_payload = {"sub": "subject123", "iat": 1234567890} - db_token = create_test_token(db_payload) - - mgmt_payload = { - "xms_mirid": "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fallback-identity" - } - mgmt_token = create_test_token(mgmt_payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token') as mock_get_token: - mock_get_token.side_effect = [db_token, mgmt_token] - - result = get_entra_conninfo(None) - assert result["user"] == "fallback-identity" - assert result["password"] == db_token - assert mock_get_token.call_count == 2 - - # Async tests - mirror of sync tests - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_with_credential(self): - """Test getting connection info with async credential and upn claim.""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token_async', return_value=token): - result = await get_entra_conninfo_async(mock_credential) - assert result == {"user": "user@example.com", "password": token} - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_no_username_claims(self): - """Test error when no username claims are present (async).""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = {"sub": "subject123", "iat": 1234567890} # No username claims - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token_async', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): - await get_entra_conninfo_async(mock_credential) - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_username_priority(self): - """Test that upn takes priority over other username claims (async).""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = { - "upn": "upn@example.com", - "preferred_username": "preferred@example.com", - "unique_name": "unique@example.com" - } - token = create_test_token(payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token_async', return_value=token): - result = await get_entra_conninfo_async(mock_credential) - assert result["user"] == "upn@example.com" - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_fallback_to_management_scope(self): - """Test fallback to management scope when DB scope token has no username (async).""" - db_payload = {"sub": "subject123", "iat": 1234567890} - db_token = create_test_token(db_payload) - - mgmt_payload = { - "xms_mirid": "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fallback-identity" - } - mgmt_token = create_test_token(mgmt_payload) - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_token_async') as mock_get_token: - mock_get_token.side_effect = [db_token, mgmt_token] - - result = await get_entra_conninfo_async(None) - assert result["user"] == "fallback-identity" - assert result["password"] == db_token - assert mock_get_token.call_count == 2 - - -class TestSyncEntraConnection: - """ - Tests for the SyncEntraConnection class. - - This class extends psycopg's Connection to automatically handle - Azure Entra ID authentication. Tests validate: - - Proper credential handling and validation - - Fallback to standard authentication when credentials exist - - Integration with the underlying psycopg connection logic - """ - - def test_connect_with_user_and_password(self): - """Test connection when user and password are already provided (passthrough).""" - kwargs = { - "host": "localhost", - "port": 5432, - "user": "existing_user", - "password": "existing_password", - "dbname": "testdb" - } - - with patch('psycopg.Connection.connect') as mock_super_connect: - mock_connection = Mock() - mock_super_connect.return_value = mock_connection - - result = SyncEntraConnection.connect(**kwargs) - - mock_super_connect.assert_called_once() - call_args = mock_super_connect.call_args[1] - assert call_args["user"] == "existing_user" - assert call_args["password"] == "existing_password" - assert "credential" not in call_args - assert result == mock_connection - - def test_connect_without_user_password_with_credential(self): - """Test connection using Entra authentication with provided credential.""" - mock_credential = Mock(spec=TokenCredential) - kwargs = { - "host": "localhost", - "port": 5432, - "dbname": "testdb", - "credential": mock_credential - } - - expected_conninfo = {"user": "test@example.com", "password": "token123"} - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_conninfo', return_value=expected_conninfo) as mock_get_conninfo: - with patch('psycopg.Connection.connect') as mock_super_connect: - mock_connection = Mock() - mock_super_connect.return_value = mock_connection - - result = SyncEntraConnection.connect(**kwargs) - - mock_get_conninfo.assert_called_once_with(mock_credential) - mock_super_connect.assert_called_once() - call_args = mock_super_connect.call_args[1] - assert call_args["user"] == "test@example.com" - assert call_args["password"] == "token123" - assert "credential" not in call_args - assert result == mock_connection - - def test_connect_invalid_credential_type(self): - """Test connection with invalid credential type raises error.""" - invalid_credential = "not_a_credential" - kwargs = {"host": "localhost", "credential": invalid_credential} - - with pytest.raises(ValueError, match="credential must be a TokenCredential for synchronous connections"): - SyncEntraConnection.connect(**kwargs) - - -class TestAsyncEntraConnection: - """ - Tests for the AsyncEntraConnection class. - - This class extends psycopg's AsyncConnection to automatically handle - Azure Entra ID authentication in async contexts. Tests mirror the - synchronous tests to ensure consistent behavior between sync/async - usage patterns. - """ - - @pytest.mark.asyncio - async def test_connect_with_user_and_password(self): - """Test connection when user and password are already provided (passthrough).""" - kwargs = { - "host": "localhost", - "port": 5432, - "user": "existing_user", - "password": "existing_password", - "dbname": "testdb" - } - - with patch('psycopg.AsyncConnection.connect', new_callable=AsyncMock) as mock_super_connect: - mock_connection = Mock() - mock_super_connect.return_value = mock_connection - - result = await AsyncEntraConnection.connect(**kwargs) - - mock_super_connect.assert_called_once() - call_args = mock_super_connect.call_args[1] - assert call_args["user"] == "existing_user" - assert call_args["password"] == "existing_password" - assert "credential" not in call_args - assert result == mock_connection - - @pytest.mark.asyncio - async def test_connect_without_user_password_with_credential(self): - """Test connection using Entra authentication with provided credential.""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - kwargs = { - "host": "localhost", - "port": 5432, - "dbname": "testdb", - "credential": mock_credential - } - - expected_conninfo = {"user": "test@example.com", "password": "token123"} - - with patch('azurepg_entra.psycopg3.psycopg3_entra_id_extension.get_entra_conninfo_async', return_value=expected_conninfo) as mock_get_conninfo: - with patch('psycopg.AsyncConnection.connect', new_callable=AsyncMock) as mock_super_connect: - mock_connection = Mock() - mock_super_connect.return_value = mock_connection - - result = await AsyncEntraConnection.connect(**kwargs) - - mock_get_conninfo.assert_called_once_with(mock_credential) - mock_super_connect.assert_called_once() - call_args = mock_super_connect.call_args[1] - assert call_args["user"] == "test@example.com" - assert call_args["password"] == "token123" - assert "credential" not in call_args - assert result == mock_connection - - @pytest.mark.asyncio - async def test_connect_invalid_credential_type(self): - """Test connection with invalid credential type raises error.""" - invalid_credential = "not_a_credential" - kwargs = {"host": "localhost", "credential": invalid_credential} - - with pytest.raises(ValueError, match="credential must be an AsyncTokenCredential for async connections"): - await AsyncEntraConnection.connect(**kwargs) - - -# Example usage and test runner -if __name__ == "__main__": - """ - Direct execution example for development and validation. - - This runs all tests with verbose output, which is helpful for: - - Understanding test coverage - - Debugging test failures - - Learning expected behavior patterns - - For CI/CD or automated testing, use pytest directly: - pytest test_psycopg_entra_id_extension.py -v --tb=short - """ - import sys - - # Run with verbose output and short traceback format - exit_code = pytest.main([__file__, "-v", "--tb=short"]) - - # Provide helpful guidance based on results - if exit_code == 0: - print("\nāœ… All tests passed! The Azure Entra ID extension is working correctly.") - print("šŸ’” Next steps: Try running the samples in samples/psycopg/getting_started/") - else: - print(f"\nāŒ Some tests failed (exit code: {exit_code})") - print("šŸ’” Check the test output above for details on any failures.") - print("šŸ’” Ensure all dependencies are installed: pip install pytest pytest-asyncio") - - sys.exit(exit_code) - - -""" -Quick Start Guide for Understanding These Tests: - -1. Basic JWT Testing (TestDecodeJwt): - - Shows how Azure tokens are decoded to get user info - - Demonstrates different token claim formats - -2. Resource Path Parsing (TestParsePrincipalName): - - Tests managed identity resource path handling - - Shows how identity names are extracted for database auth - -3. Core Authentication Logic (TestGetEntraConninfo): - - Tests the main authentication flow - - Shows both sync and async patterns - - Demonstrates fallback mechanisms for edge cases - -4. Connection Classes (TestSyncEntraConnection/TestAsyncEntraConnection): - - Tests the customer-facing connection classes - - Shows how existing psycopg code can be easily adapted - - Validates credential handling and error cases - -Key Patterns to Notice: -- Every sync test has an async equivalent for consistency -- Mocking isolates tests from external dependencies -- Clear error messages help with debugging -- Tests serve as usage examples for developers - -For integration testing with real Azure resources, see the samples directory. -""" \ No newline at end of file diff --git a/python/tests/azure/data/postgresql/sqlalchemy/test_entra_id_extension.py b/python/tests/azure/data/postgresql/sqlalchemy/test_entra_id_extension.py new file mode 100644 index 0000000..f078704 --- /dev/null +++ b/python/tests/azure/data/postgresql/sqlalchemy/test_entra_id_extension.py @@ -0,0 +1,188 @@ +# Copyright (c) Microsoft. All rights reserved. + +from unittest.mock import Mock, patch + +import pytest + + +class TestEnableEntraAuthentication: + def test_sync_authentication_function_registration(self): + """Test that enable_entra_authentication registers event listener successfully.""" + mock_engine = Mock() + + with patch("sqlalchemy.event.listens_for") as mock_event_listener: + from azurepg_entra.sqlalchemy import enable_entra_authentication + + enable_entra_authentication(mock_engine) + + # Verify event listener was registered with correct parameters + mock_event_listener.assert_called_once_with(mock_engine, "do_connect") + + def test_provide_token_method(self): + """Test the provide_token event handler method directly.""" + mock_engine = Mock() + + # Capture the event handler function + captured_handler = None + + def capture_handler(engine, event_name): + def decorator(func): + nonlocal captured_handler + captured_handler = func + return func + + return decorator + + with patch("sqlalchemy.event.listens_for", side_effect=capture_handler): + with patch( + "azurepg_entra.sqlalchemy.entra_connection.get_entra_conninfo" + ) as mock_get_creds: + mock_get_creds.return_value = { + "user": "test@example.com", + "password": "test_token", + } + + from azurepg_entra.sqlalchemy import enable_entra_authentication + + enable_entra_authentication(mock_engine) + + # Test the captured handler directly + mock_cparams = {} + captured_handler(None, None, None, mock_cparams) + + # Verify credentials were added + mock_get_creds.assert_called_once_with(None) + assert mock_cparams["user"] == "test@example.com" + assert mock_cparams["password"] == "test_token" + + def test_provide_token_skips_existing_credentials(self): + """Test that provide_token skips when credentials already exist.""" + mock_engine = Mock() + + # Capture the event handler function + captured_handler = None + + def capture_handler(engine, event_name): + def decorator(func): + nonlocal captured_handler + captured_handler = func + return func + + return decorator + + with patch("sqlalchemy.event.listens_for", side_effect=capture_handler): + with patch( + "azurepg_entra.sqlalchemy.entra_connection.get_entra_conninfo" + ) as mock_get_creds: + from azurepg_entra.sqlalchemy import enable_entra_authentication + + enable_entra_authentication(mock_engine) + + # Test with existing credentials + mock_cparams = { + "user": "existing@example.com", + "password": "existing_password", + } + captured_handler(None, None, None, mock_cparams) + + # Verify get_entra_conninfo was not called + mock_get_creds.assert_not_called() + assert mock_cparams["user"] == "existing@example.com" + assert mock_cparams["password"] == "existing_password" + + def test_async_authentication_function_registration(self): + """Test that enable_entra_authentication_async registers event listener successfully.""" + mock_async_engine = Mock() + mock_sync_engine = Mock() + mock_async_engine.sync_engine = mock_sync_engine + + with patch("sqlalchemy.event.listens_for") as mock_event_listener: + from azurepg_entra.sqlalchemy import enable_entra_authentication_async + + enable_entra_authentication_async(mock_async_engine) + + # Verify event listener was registered on sync_engine + mock_event_listener.assert_called_once_with(mock_sync_engine, "do_connect") + + def test_provide_token_async_method(self): + """Test the provide_token_async event handler method directly.""" + mock_async_engine = Mock() + mock_sync_engine = Mock() + mock_async_engine.sync_engine = mock_sync_engine + + # Capture the event handler function + captured_handler = None + + def capture_handler(engine, event_name): + def decorator(func): + nonlocal captured_handler + captured_handler = func + return func + + return decorator + + with patch("sqlalchemy.event.listens_for", side_effect=capture_handler): + with patch( + "azurepg_entra.sqlalchemy.async_entra_connection.get_entra_conninfo" + ) as mock_get_creds: + mock_get_creds.return_value = { + "user": "test@example.com", + "password": "test_token", + } + + from azurepg_entra.sqlalchemy import enable_entra_authentication_async + + enable_entra_authentication_async(mock_async_engine) + + # Test the captured handler directly + mock_cparams = {} + captured_handler(None, None, None, mock_cparams) + + # Verify credentials were added + mock_get_creds.assert_called_once_with(None) + assert mock_cparams["user"] == "test@example.com" + assert mock_cparams["password"] == "test_token" + + def test_provide_token_async_skips_existing_credentials(self): + """Test that provide_token_async skips when credentials already exist.""" + mock_async_engine = Mock() + mock_sync_engine = Mock() + mock_async_engine.sync_engine = mock_sync_engine + + # Capture the event handler function + captured_handler = None + + def capture_handler(engine, event_name): + def decorator(func): + nonlocal captured_handler + captured_handler = func + return func + + return decorator + + with patch("sqlalchemy.event.listens_for", side_effect=capture_handler): + with patch( + "azurepg_entra.sqlalchemy.async_entra_connection.get_entra_conninfo" + ) as mock_get_creds: + from azurepg_entra.sqlalchemy import enable_entra_authentication_async + + enable_entra_authentication_async(mock_async_engine) + + # Test with existing credentials + mock_cparams = { + "user": "existing@example.com", + "password": "existing_password", + } + captured_handler(None, None, None, mock_cparams) + + # Verify get_entra_conninfo was not called (credentials already exist) + mock_get_creds.assert_not_called() + assert mock_cparams["user"] == "existing@example.com" + assert mock_cparams["password"] == "existing_password" + + +if __name__ == "__main__": + import sys + + exit_code = pytest.main([__file__, "-v", "--tb=short"]) + sys.exit(exit_code) diff --git a/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py b/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py deleted file mode 100644 index 16964a8..0000000 --- a/python/tests/azure/data/postgresql/sqlalchemy/test_sqlalchemy_entra_id_extension.py +++ /dev/null @@ -1,688 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -""" -Unit Tests for Azure PostgreSQL SQLAlchemy Entra ID Extension - -This test suite demonstrates and validates the Azure Entra ID authentication -functionality for PostgreSQL connections using SQLAlchemy. These tests serve as -both validation and examples of how to use the extension. - -Test Categories: -1. JWT Token Decoding - Validates Azure token processing -2. Principal Name Parsing - Tests managed identity resource path parsing -3. Connection Info Generation - Tests core authentication logic -4. Engine Creation Functions - Validates sync/async engine creation with Entra auth -5. Connection Factory Behavior - Tests custom connection factories for token refresh - -Key Testing Patterns: -- Every synchronous test has an equivalent asynchronous test -- Comprehensive mocking to avoid external dependencies -- Edge case validation for robust error handling -- Clear test naming that describes expected behavior - -Usage: - # Run all tests - pytest test_sqlalchemy_entra_id_extension.py - - # Run specific test class - pytest test_sqlalchemy_entra_id_extension.py::TestCreateEngineWithEntra - - # Run with verbose output - pytest -v test_sqlalchemy_entra_id_extension.py - -Dependencies: - pip install pytest pytest-asyncio sqlalchemy - -For more information about Azure Entra ID authentication: -https://docs.microsoft.com/en-us/azure/postgresql/concepts-aad-authentication -""" - -import base64 -import json -import pytest -from unittest.mock import AsyncMock, Mock, patch, MagicMock -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential - -from azurepg_entra.sqlalchemy import ( - create_engine_with_entra, - create_async_engine_with_entra, - get_entra_conninfo, - get_entra_conninfo_async, - decode_jwt, - parse_principal_name, -) - -# Test Configuration -# These tests use mocking to avoid requiring actual Azure credentials or database connections. -# For integration testing with real Azure resources, see the samples/ directory. - - -def create_test_token(payload): - """Helper to create a test JWT token.""" - encoded_payload = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip('=') - return f"header.{encoded_payload}.signature" - - -class TestDecodeJwt: - """ - Tests for JWT token decoding functionality. - - These tests validate that Azure Entra ID tokens are properly decoded - to extract user information. The extension supports various token formats - and claim structures that Azure may provide. - """ - - def test_decode_jwt_valid_token(self): - """Test decoding a valid JWT token with UPN (User Principal Name) claim.""" - # UPN is the most common claim for user identity in Azure AD tokens - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - def test_decode_jwt_with_padding(self): - """Test decoding JWT token that requires base64 padding.""" - payload = {"preferred_username": "testuser", "exp": 9999999999} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - def test_decode_jwt_minimal_payload(self): - """Test decoding JWT with minimal payload.""" - payload = {"unique_name": "user123"} - token = create_test_token(payload) - result = decode_jwt(token) - assert result == payload - - -class TestParsePrincipalName: - """ - Tests for Azure resource path parsing. - - When using managed identities, Azure provides resource paths that need - to be parsed to extract the identity name for database authentication. - These tests ensure robust parsing of various path formats. - """ - - def test_parse_principal_name_valid_user_assigned(self): - """Test parsing a valid user-assigned identity resource path.""" - resource_path = "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-identity" - result = parse_principal_name(resource_path) - assert result == "my-identity" - - def test_parse_principal_name_empty_string(self): - """Test parsing an empty string returns None.""" - assert parse_principal_name("") is None - - def test_parse_principal_name_none(self): - """Test parsing None returns None.""" - assert parse_principal_name(None) is None - - def test_parse_principal_name_no_slash(self): - """Test parsing a string without slashes returns None.""" - assert parse_principal_name("no-slashes-here") is None - - def test_parse_principal_name_invalid_path(self): - """Test parsing an invalid resource path returns None.""" - result = parse_principal_name("/subscriptions/12345/resourcegroups/mygroup/providers/SomeOther/resource/my-identity") - assert result is None - - def test_parse_principal_name_missing_identity_name(self): - """Test parsing a path without identity name returns None.""" - result = parse_principal_name("/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/") - assert result is None - - def test_parse_principal_name_case_insensitive(self): - """Test parsing with different case variations.""" - resource_path = "/subscriptions/12345/resourcegroups/mygroup/providers/MICROSOFT.MANAGEDIDENTITY/USERASSIGNEDIDENTITIES/my-identity" - result = parse_principal_name(resource_path) - assert result == "my-identity" - - -class TestGetEntraConninfo: - """ - Tests for the core authentication logic (sync and async versions). - - These functions handle the complete flow of: - 1. Requesting Azure tokens with appropriate scopes - 2. Decoding tokens to extract user information - 3. Handling fallback scenarios for managed identities - 4. Returning connection parameters for SQLAlchemy - - Both synchronous and asynchronous patterns are tested to ensure - consistent behavior across different usage scenarios. - """ - - # Sync tests - def test_get_entra_conninfo_with_credential(self): - """Test getting connection info with sync credential and upn claim.""" - mock_credential = Mock(spec=TokenCredential) - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token', return_value=token): - result = get_entra_conninfo(mock_credential) - assert result == {"user": "user@example.com", "password": token} - - def test_get_entra_conninfo_no_username_claims(self): - """Test error when no username claims are present.""" - mock_credential = Mock(spec=TokenCredential) - payload = {"sub": "subject123", "iat": 1234567890} # No username claims - token = create_test_token(payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): - get_entra_conninfo(mock_credential) - - def test_get_entra_conninfo_username_priority(self): - """Test that upn takes priority over other username claims.""" - mock_credential = Mock(spec=TokenCredential) - # Azure tokens may contain multiple username claims - test priority order - payload = { - "upn": "upn@example.com", # Highest priority - "preferred_username": "preferred@example.com", # Second priority - "unique_name": "unique@example.com" # Fallback option - } - token = create_test_token(payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token', return_value=token): - result = get_entra_conninfo(mock_credential) - assert result["user"] == "upn@example.com" # Should use highest priority claim - - def test_get_entra_conninfo_fallback_to_management_scope(self): - """Test fallback to management scope when DB scope token has no username.""" - db_payload = {"sub": "subject123", "iat": 1234567890} - db_token = create_test_token(db_payload) - - mgmt_payload = { - "xms_mirid": "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fallback-identity" - } - mgmt_token = create_test_token(mgmt_payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token') as mock_get_token: - mock_get_token.side_effect = [db_token, mgmt_token] - - result = get_entra_conninfo(None) - assert result["user"] == "fallback-identity" - assert result["password"] == db_token - assert mock_get_token.call_count == 2 - - # Async tests - mirror of sync tests - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_with_credential(self): - """Test getting connection info with async credential and upn claim.""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = {"upn": "user@example.com", "iat": 1234567890} - token = create_test_token(payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token_async', return_value=token): - result = await get_entra_conninfo_async(mock_credential) - assert result == {"user": "user@example.com", "password": token} - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_no_username_claims(self): - """Test error when no username claims are present (async).""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = {"sub": "subject123", "iat": 1234567890} # No username claims - token = create_test_token(payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token_async', return_value=token): - with pytest.raises(ValueError, match="Could not determine username from token claims"): - await get_entra_conninfo_async(mock_credential) - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_username_priority(self): - """Test that upn takes priority over other username claims (async).""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - payload = { - "upn": "upn@example.com", - "preferred_username": "preferred@example.com", - "unique_name": "unique@example.com" - } - token = create_test_token(payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token_async', return_value=token): - result = await get_entra_conninfo_async(mock_credential) - assert result["user"] == "upn@example.com" - - @pytest.mark.asyncio - async def test_get_entra_conninfo_async_fallback_to_management_scope(self): - """Test fallback to management scope when DB scope token has no username (async).""" - db_payload = {"sub": "subject123", "iat": 1234567890} - db_token = create_test_token(db_payload) - - mgmt_payload = { - "xms_mirid": "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/fallback-identity" - } - mgmt_token = create_test_token(mgmt_payload) - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_token_async') as mock_get_token: - mock_get_token.side_effect = [db_token, mgmt_token] - - result = await get_entra_conninfo_async(None) - assert result["user"] == "fallback-identity" - assert result["password"] == db_token - assert mock_get_token.call_count == 2 - - -class TestCreateEngineWithEntra: - """ - Tests for the create_engine_with_entra function. - - This function creates a synchronous SQLAlchemy engine with Entra authentication. - Tests validate: - - Proper engine creation with custom connection factory - - URL parsing and reconstruction - - Credential handling and validation - - Integration with SQLAlchemy's engine creation process - """ - - def test_create_engine_basic_url(self): - """Test engine creation with basic PostgreSQL URL.""" - mock_credential = Mock(spec=TokenCredential) - url = "postgresql://myserver.postgres.database.azure.com/mydatabase" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - mock_engine = Mock() - mock_create_engine.return_value = mock_engine - - result = create_engine_with_entra(url, credential=mock_credential) - - # Verify create_engine was called with base URL and creator function - mock_create_engine.assert_called_once() - call_args = mock_create_engine.call_args - - # Check that base URL doesn't contain credentials - assert "myserver.postgres.database.azure.com" in call_args[0][0] - assert "creator" in call_args[1] - assert callable(call_args[1]["creator"]) - assert result == mock_engine - - def test_create_engine_with_psycopg_scheme(self): - """Test engine creation with psycopg+ scheme.""" - url = "postgresql+psycopg://myserver.postgres.database.azure.com:5432/mydatabase" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - mock_engine = Mock() - mock_create_engine.return_value = mock_engine - - result = create_engine_with_entra(url) - - mock_create_engine.assert_called_once() - call_args = mock_create_engine.call_args - - # Verify scheme is preserved in base URL - assert call_args[0][0].startswith("postgresql+psycopg://") - assert result == mock_engine - - def test_create_engine_with_query_parameters(self): - """Test engine creation preserves query parameters.""" - url = "postgresql://myserver.postgres.database.azure.com/mydatabase?sslmode=require&connect_timeout=30" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - mock_engine = Mock() - mock_create_engine.return_value = mock_engine - - result = create_engine_with_entra(url) - - mock_create_engine.assert_called_once() - call_args = mock_create_engine.call_args - - # Verify query parameters are preserved - assert "sslmode=require" in call_args[0][0] - assert "connect_timeout=30" in call_args[0][0] - assert result == mock_engine - - def test_create_engine_invalid_credential_type(self): - """Test engine creation with invalid credential type raises error.""" - invalid_credential = "not_a_credential" - url = "postgresql://myserver.postgres.database.azure.com/mydatabase" - - with pytest.raises(ValueError, match="credential must be a TokenCredential for synchronous engines"): - create_engine_with_entra(url, credential=invalid_credential) - - def test_create_engine_connection_factory_behavior(self): - """Test that the custom connection factory works correctly.""" - mock_credential = Mock(spec=TokenCredential) - url = "postgresql://myserver.postgres.database.azure.com/mydatabase" - expected_conninfo = {"user": "test@example.com", "password": "token123"} - - # Mock the connection factory components - mock_temp_engine = Mock() - mock_raw_conn = Mock() - mock_dbapi_conn = Mock() - mock_raw_conn.dbapi_connection = mock_dbapi_conn - mock_temp_engine.raw_connection.return_value = mock_raw_conn - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo', return_value=expected_conninfo): - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - # Mock the main create_engine call - mock_main_engine = Mock() - # Mock the temp engine creation inside connection factory - mock_create_engine.side_effect = [mock_main_engine, mock_temp_engine] - - result = create_engine_with_entra(url, credential=mock_credential) - assert result == mock_main_engine - - # Verify main engine creation - assert mock_create_engine.call_count >= 1 - main_call_args = mock_create_engine.call_args_list[0] - - # Extract and test the connection factory function - creator_func = main_call_args[1]["creator"] - assert callable(creator_func) - - # Test the connection factory function (still within the outer patch context) - conn_result = creator_func() - assert conn_result == mock_dbapi_conn - - def test_create_engine_with_kwargs(self): - """Test that additional kwargs are passed through to create_engine.""" - url = "postgresql://myserver.postgres.database.azure.com/mydatabase" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - mock_engine = Mock() - mock_create_engine.return_value = mock_engine - - result = create_engine_with_entra( - url, - pool_size=10, - max_overflow=20, - echo=True - ) - - mock_create_engine.assert_called_once() - call_args = mock_create_engine.call_args[1] - - # Verify additional kwargs are passed through - assert call_args["pool_size"] == 10 - assert call_args["max_overflow"] == 20 - assert call_args["echo"] == True - assert result == mock_engine - - -class TestCreateAsyncEngineWithEntra: - """ - Tests for the create_async_engine_with_entra function. - - This function creates an asynchronous SQLAlchemy engine with Entra authentication. - Tests validate: - - Proper async engine creation with custom connection factory - - URL parsing and psycopg3 async connection handling - - Credential handling for async operations - - Integration with SQLAlchemy's async engine creation process - """ - - def test_create_async_engine_import_error(self): - """Test error when sqlalchemy.ext.asyncio is not available.""" - url = "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - - # Mock the absence of create_async_engine - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_async_engine', None): - with pytest.raises(ImportError, match="sqlalchemy.ext.asyncio is required for async engines"): - create_async_engine_with_entra(url) - - def test_create_async_engine_basic_url(self): - """Test async engine creation with basic PostgreSQL URL.""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - url = "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_async_engine') as mock_create_async_engine: - mock_engine = Mock() - mock_create_async_engine.return_value = mock_engine - - result = create_async_engine_with_entra(url, credential=mock_credential) - - # Verify create_async_engine was called with base URL and async_creator function - mock_create_async_engine.assert_called_once() - call_args = mock_create_async_engine.call_args - - # Check that base URL doesn't contain credentials - assert "myserver.postgres.database.azure.com" in call_args[0][0] - assert "async_creator" in call_args[1] - assert callable(call_args[1]["async_creator"]) - assert result == mock_engine - - def test_create_async_engine_with_port_and_query(self): - """Test async engine creation preserves port and query parameters.""" - url = "postgresql+psycopg://myserver.postgres.database.azure.com:5432/mydatabase?sslmode=require" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_async_engine') as mock_create_async_engine: - mock_engine = Mock() - mock_create_async_engine.return_value = mock_engine - - result = create_async_engine_with_entra(url) - - mock_create_async_engine.assert_called_once() - call_args = mock_create_async_engine.call_args - - # Verify port and query parameters are preserved - assert ":5432" in call_args[0][0] - assert "sslmode=require" in call_args[0][0] - assert result == mock_engine - - def test_create_async_engine_invalid_credential_type(self): - """Test async engine creation with invalid credential type raises error.""" - invalid_credential = "not_an_async_credential" - url = "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - - with pytest.raises(ValueError, match="credential must be an AsyncTokenCredential for async engines"): - create_async_engine_with_entra(url, credential=invalid_credential) - - @pytest.mark.asyncio - async def test_create_async_engine_connection_factory_behavior(self): - """Test that the custom async connection factory works correctly.""" - mock_credential = AsyncMock(spec=AsyncTokenCredential) - url = "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - expected_conninfo = {"user": "test@example.com", "password": "token123"} - - # Mock psycopg AsyncConnection - mock_async_conn = Mock() - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.get_entra_conninfo_async', return_value=expected_conninfo): - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_async_engine') as mock_create_async_engine: - with patch('psycopg.AsyncConnection.connect', new_callable=AsyncMock, return_value=mock_async_conn) as mock_psycopg_connect: - mock_main_engine = Mock() - mock_create_async_engine.return_value = mock_main_engine - - result = create_async_engine_with_entra(url, credential=mock_credential) - - # Verify main engine creation - mock_create_async_engine.assert_called_once() - main_call_args = mock_create_async_engine.call_args - - # Extract and test the async connection factory function - async_creator_func = main_call_args[1]["async_creator"] - assert callable(async_creator_func) - - # Test the async connection factory function - conn_result = await async_creator_func() - - # Verify psycopg.AsyncConnection.connect was called with correct parameters - mock_psycopg_connect.assert_called_once() - connect_call_args = mock_psycopg_connect.call_args[1] - assert connect_call_args["host"] == "myserver.postgres.database.azure.com" - assert connect_call_args["port"] == 5432 - assert connect_call_args["dbname"] == "mydatabase" - assert connect_call_args["user"] == "test@example.com" - assert connect_call_args["password"] == "token123" - - assert conn_result == mock_async_conn - assert result == mock_main_engine - - def test_create_async_engine_unsupported_scheme(self): - """Test async engine creation with unsupported URL scheme.""" - url = "postgresql+asyncpg://myserver.postgres.database.azure.com/mydatabase" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_async_engine') as mock_create_async_engine: - mock_engine = Mock() - mock_create_async_engine.return_value = mock_engine - - # Create the engine (this should succeed) - result = create_async_engine_with_entra(url) - - # Extract the async_creator function - call_args = mock_create_async_engine.call_args - async_creator_func = call_args[1]["async_creator"] - - # Test that calling the async_creator with unsupported scheme raises error - with pytest.raises(ValueError, match="Unsupported async URL scheme: postgresql\\+asyncpg"): - # We need to create an async context to test this - import asyncio - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(async_creator_func()) - finally: - loop.close() - - def test_create_async_engine_with_kwargs(self): - """Test that additional kwargs are passed through to create_async_engine.""" - url = "postgresql+psycopg://myserver.postgres.database.azure.com/mydatabase" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_async_engine') as mock_create_async_engine: - mock_engine = Mock() - mock_create_async_engine.return_value = mock_engine - - result = create_async_engine_with_entra( - url, - pool_size=15, - max_overflow=25, - echo=True - ) - - mock_create_async_engine.assert_called_once() - call_args = mock_create_async_engine.call_args[1] - - # Verify additional kwargs are passed through - assert call_args["pool_size"] == 15 - assert call_args["max_overflow"] == 25 - assert call_args["echo"] == True - assert result == mock_engine - - -class TestUrlParsing: - """ - Tests for URL parsing and reconstruction logic. - - These tests validate that the URL parsing and reconstruction logic - works correctly for various URL formats and edge cases. - """ - - def test_url_parsing_with_different_schemes(self): - """Test URL parsing works with different PostgreSQL schemes.""" - test_urls = [ - "postgresql://server.com/db", - "postgresql+psycopg://server.com/db", - "postgresql+psycopg2://server.com/db" - ] - - for url in test_urls: - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - mock_engine = Mock() - mock_create_engine.return_value = mock_engine - - create_engine_with_entra(url) - - mock_create_engine.assert_called_once() - call_args = mock_create_engine.call_args - # Verify the scheme is preserved in the base URL - assert call_args[0][0].startswith(url.split('://')[0] + '://') - - def test_url_parsing_with_complex_parameters(self): - """Test URL parsing with complex query parameters and paths.""" - url = "postgresql://server.com:5432/complex_db_name?sslmode=require&application_name=test%20app&connect_timeout=30" - - with patch('azurepg_entra.sqlalchemy.sqlalchemy_entra_id_extension.create_engine') as mock_create_engine: - mock_engine = Mock() - mock_create_engine.return_value = mock_engine - - create_engine_with_entra(url) - - mock_create_engine.assert_called_once() - call_args = mock_create_engine.call_args - base_url = call_args[0][0] - - # Verify all components are preserved - assert "server.com:5432" in base_url - assert "/complex_db_name" in base_url - assert "sslmode=require" in base_url - assert "application_name=test%20app" in base_url - assert "connect_timeout=30" in base_url - - -# Example usage and test runner -if __name__ == "__main__": - """ - Direct execution example for development and validation. - - This runs all tests with verbose output, which is helpful for: - - Understanding test coverage - - Debugging test failures - - Learning expected behavior patterns - - For CI/CD or automated testing, use pytest directly: - pytest test_sqlalchemy_entra_id_extension.py -v --tb=short - """ - import sys - - # Run with verbose output and short traceback format - exit_code = pytest.main([__file__, "-v", "--tb=short"]) - - # Provide helpful guidance based on results - if exit_code == 0: - print("\nāœ… All tests passed! The SQLAlchemy Azure Entra ID extension is working correctly.") - print("šŸ’” Next steps: Try running the samples in samples/sqlalchemy/getting_started/") - else: - print(f"\nāŒ Some tests failed (exit code: {exit_code})") - print("šŸ’” Check the test output above for details on any failures.") - print("šŸ’” Ensure all dependencies are installed: pip install pytest pytest-asyncio sqlalchemy") - - sys.exit(exit_code) - - -""" -Quick Start Guide for Understanding These Tests: - -1. Basic JWT Testing (TestDecodeJwt): - - Shows how Azure tokens are decoded to get user info - - Demonstrates different token claim formats - -2. Resource Path Parsing (TestParsePrincipalName): - - Tests managed identity resource path handling - - Shows how identity names are extracted for database auth - -3. Core Authentication Logic (TestGetEntraConninfo): - - Tests the main authentication flow - - Shows both sync and async patterns - - Demonstrates fallback mechanisms for edge cases - -4. Engine Creation Functions (TestCreateEngineWithEntra/TestCreateAsyncEngineWithEntra): - - Tests the main SQLAlchemy engine creation functions - - Shows how custom connection factories work - - Validates URL parsing and reconstruction - - Tests both sync and async engine patterns - -5. URL Parsing (TestUrlParsing): - - Tests URL parsing and reconstruction logic - - Shows how different schemes and parameters are handled - - Validates complex URL scenarios - -Key Patterns to Notice: -- Every sync test has an async equivalent for consistency -- Mocking isolates tests from external dependencies -- Custom connection factories are thoroughly tested -- URL parsing handles various PostgreSQL driver schemes -- Clear error messages help with debugging -- Tests serve as usage examples for developers - -For integration testing with real Azure resources, see the samples directory. - -SQLAlchemy-Specific Features Tested: -- Custom connection factory (creator parameter) -- Custom async connection factory (async_creator parameter) -- Engine configuration parameter pass-through -- URL scheme handling for different PostgreSQL drivers -- Integration with SQLAlchemy's connection pooling -- Proper DBAPI connection object handling -""" \ No newline at end of file diff --git a/python/tests/azure/data/postgresql/test_core_functionality.py b/python/tests/azure/data/postgresql/test_core_functionality.py new file mode 100644 index 0000000..5f46818 --- /dev/null +++ b/python/tests/azure/data/postgresql/test_core_functionality.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft. All rights reserved. + +import base64 +import json +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + +from azurepg_entra.core import ( + decode_jwt, + get_entra_conninfo, + get_entra_conninfo_async, + parse_principal_name, +) +from azurepg_entra.errors import TokenDecodeError, UsernameExtractionError + + +def create_test_token(payload): + """Helper to create a test JWT token manually.""" + # Create a simple JWT-like token with header.payload.signature format + header = {"alg": "none", "typ": "JWT"} + header_encoded = ( + base64.urlsafe_b64encode(json.dumps(header).encode()).decode().rstrip("=") + ) + payload_encoded = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).decode().rstrip("=") + ) + signature = "" + return f"{header_encoded}.{payload_encoded}.{signature}" + + +class TestJwtParsing: + def test_decode_jwt_with_upn(self): + payload = {"upn": "user@example.com"} + token = create_test_token(payload) + result = decode_jwt(token) + assert result == payload + + def test_decode_jwt_with_preferred_username(self): + payload = {"preferred_username": "testuser@example.com"} + token = create_test_token(payload) + result = decode_jwt(token) + assert result == payload + + def test_decode_jwt_invalid_format_raises_exception(self): + with pytest.raises(TokenDecodeError, match="Invalid JWT token format"): + decode_jwt("invalid.token") + + def test_parse_principal_name_valid_path(self): + path = "/subscriptions/12345/resourcegroups/mygroup/providers/Microsoft.ManagedIdentity/userAssignedIdentities/my-identity" + result = parse_principal_name(path) + assert result == "my-identity" + + def test_parse_principal_name_invalid_path_returns_none(self): + assert parse_principal_name("") is None + assert parse_principal_name(None) is None + assert parse_principal_name("/invalid/path") is None + + +class TestEntraAuthentication: + def test_get_entra_conninfo_with_upn(self): + mock_credential = Mock(spec=TokenCredential) + payload = {"upn": "user@example.com"} + token = create_test_token(payload) + + with patch("azurepg_entra.core.get_entra_token", return_value=token): + result = get_entra_conninfo(mock_credential) + assert result == {"user": "user@example.com", "password": token} + + def test_get_entra_conninfo_no_username_throws(self): + mock_credential = Mock(spec=TokenCredential) + payload = {"sub": "subject123"} + token = create_test_token(payload) + + # Mock both the DB token and the management token to have no username claims + with patch("azurepg_entra.core.get_entra_token", return_value=token): + with pytest.raises( + UsernameExtractionError, + match="Could not determine username from token claims", + ): + get_entra_conninfo(mock_credential) + + @pytest.mark.asyncio + async def test_get_entra_conninfo_async_with_upn(self): + mock_credential = AsyncMock(spec=AsyncTokenCredential) + payload = {"upn": "user@example.com"} + token = create_test_token(payload) + + with patch("azurepg_entra.core.get_entra_token_async", return_value=token): + result = await get_entra_conninfo_async(mock_credential) + assert result == {"user": "user@example.com", "password": token} + + @pytest.mark.asyncio + async def test_get_entra_conninfo_async_no_username_throws(self): + mock_credential = AsyncMock(spec=AsyncTokenCredential) + payload = {"sub": "subject123"} + token = create_test_token(payload) + + # Mock both the DB token and the management token to have no username claims + with patch("azurepg_entra.core.get_entra_token_async", return_value=token): + with pytest.raises( + UsernameExtractionError, + match="Could not determine username from token claims", + ): + await get_entra_conninfo_async(mock_credential) + + +if __name__ == "__main__": + import sys + + exit_code = pytest.main([__file__, "-v", "--tb=short"]) + sys.exit(exit_code) diff --git a/scripts/run-python-checks.ps1 b/scripts/run-python-checks.ps1 new file mode 100644 index 0000000..603b33a --- /dev/null +++ b/scripts/run-python-checks.ps1 @@ -0,0 +1,116 @@ +param( + [switch]$Verbose, + [switch]$Help, + [switch]$RecreateVenv +) + +<# +.SYNOPSIS + Run Python quality checks (lint, type, tests) locally. + +.DESCRIPTION + Mirrors the CI steps defined in pr-python.yml: install deps (.[all]), Ruff lint, mypy, and pytest. + +.PARAMETER Verbose + Show full tool output instead of suppressing it. + +.PARAMETER RecreateVenv + Delete and recreate the .venv before installing dependencies. + +.EXAMPLE + ./run-python-checks.ps1 + +.EXAMPLE + ./run-python-checks.ps1 -Verbose + +.EXAMPLE + ./run-python-checks.ps1 -RecreateVenv +#> + +if ($Help) { + Get-Help -Detailed -ErrorAction SilentlyContinue + Write-Host "Usage: ./run-python-checks.ps1 [-Verbose] [-RecreateVenv]" -ForegroundColor Cyan + exit 0 +} + +function Write-CheckResult { + param([string]$Name, [bool]$Success, [string]$Message = "") + if ($Success) { + Write-Host "PASS $Name" -ForegroundColor Green + if ($Message) { Write-Host " $Message" -ForegroundColor Gray } + } else { + Write-Host "FAIL $Name" -ForegroundColor Red + if ($Message) { Write-Host " $Message" -ForegroundColor Gray } + $script:OverallSuccess = $false + } +} + +$OverallSuccess = $true +$pythonRoot = Join-Path (Get-Location) "python" +if (-not (Test-Path $pythonRoot)) { + Write-Host "python/ directory not found" -ForegroundColor Red + exit 1 +} + +$venvPath = Join-Path $pythonRoot ".venv" +$venvPython = Join-Path $venvPath "Scripts/python.exe" + +Push-Location $pythonRoot +try { + # Use explicit parentheses so PowerShell doesn't mis-bind -and as a parameter to Test-Path on some shells + if ((Test-Path $venvPath) -and $RecreateVenv) { + Write-Host "Recreating virtual environment..." -ForegroundColor Blue + Remove-Item -Recurse -Force $venvPath + } + if (-not (Test-Path $venvPath)) { + Write-Host "Creating virtual environment" -ForegroundColor Blue + python -m venv $venvPath + if ($LASTEXITCODE -ne 0) { Write-CheckResult "Prepare venv" $false "python -m venv failed"; exit 1 } + Write-CheckResult "Prepare venv" $true "Created new venv" + } else { + Write-CheckResult "Prepare venv" $true "Reused existing venv" + } + + if (-not (Test-Path $venvPython)) { Write-CheckResult "Venv Python present" $false "Not found"; exit 1 } + $pythonVersion = & $venvPython --version 2>&1 + Write-Host "Using $pythonVersion" + + & $venvPython -m pip install --upgrade pip | Out-Null + Write-CheckResult "pip upgrade" ($LASTEXITCODE -eq 0) + + Write-Host "Installing project deps (.[all])" -ForegroundColor Blue + & $venvPython -m pip install .[all] | Out-Null + $depsOk = $LASTEXITCODE -eq 0 + if (-not $depsOk) { Write-CheckResult "Install deps" $false; exit 1 } else { Write-CheckResult "Install deps" $true } + + # Ruff + Write-Host "Running Ruff lint" -ForegroundColor Blue + if ($Verbose) { & $venvPython -m ruff check src tests } else { & $venvPython -m ruff check src tests *> $null } + Write-CheckResult "ruff lint" ($LASTEXITCODE -eq 0) + + # mypy + Write-Host "Running mypy type check" -ForegroundColor Blue + if ($Verbose) { & $venvPython -m mypy src/azurepg_entra/ } else { & $venvPython -m mypy src/azurepg_entra/ *> $null } + Write-CheckResult "mypy" ($LASTEXITCODE -eq 0) + + if (Test-Path "tests") { + Write-Host "Running pytest" -ForegroundColor Blue + + # Use importlib mode to avoid import collisions from files with same basename + # This allows pytest to handle multiple test_entra_id_extension.py files + if ($Verbose) { + & $venvPython -m pytest tests --import-mode=importlib -v + } else { + & $venvPython -m pytest tests --import-mode=importlib -q *> $null + } + $allTestsPass = ($LASTEXITCODE -eq 0) + Write-CheckResult "pytest" $allTestsPass + } else { + Write-Host "WARN No tests directory present" -ForegroundColor Yellow + } +} +finally { + Pop-Location +} + +if (-not $OverallSuccess) { exit 1 }